Repository: snowflakedb/gosnowflake Branch: master Commit: a0b59e44724f Files: 394 Total size: 2.4 MB Directory structure: gitextract_osgz45y5/ ├── .cursor/ │ └── rules/ │ ├── overall-guidelines.mdc │ └── testing.mdc ├── .github/ │ ├── CODEOWNERS │ ├── ISSUE_TEMPLATE/ │ │ ├── BUG_REPORT.md │ │ └── FEATURE_REQUEST.md │ ├── ISSUE_TEMPLATE.md │ ├── PULL_REQUEST_TEMPLATE.md │ ├── repo_meta.yaml │ ├── secret_scanning.yml │ └── workflows/ │ ├── build-test.yml │ ├── changelog.yml │ ├── cla_bot.yml │ ├── jira_close.yml │ ├── jira_comment.yml │ ├── jira_issue.yml │ ├── parameters/ │ │ └── public/ │ │ ├── rsa_key_golang_aws.p8.gpg │ │ ├── rsa_key_golang_azure.p8.gpg │ │ └── rsa_key_golang_gcp.p8.gpg │ ├── parameters_aws_auth_tests.json.gpg │ ├── parameters_aws_golang.json.gpg │ ├── parameters_azure_golang.json.gpg │ ├── parameters_gcp_golang.json.gpg │ ├── rsa-2048-private-key.p8.gpg │ ├── rsa_keys/ │ │ ├── rsa_key.p8.gpg │ │ └── rsa_key_invalid.p8.gpg │ └── semgrep.yml ├── .gitignore ├── .golangci.yml ├── .pre-commit-config.yaml ├── .windsurf/ │ └── rules/ │ └── go.md ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Jenkinsfile ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── aaa_test.go ├── arrow_chunk.go ├── arrow_stream.go ├── arrow_test.go ├── arrowbatches/ │ ├── batches.go │ ├── batches_test.go │ ├── context.go │ ├── converter.go │ ├── converter_test.go │ └── schema.go ├── assert_test.go ├── async.go ├── async_test.go ├── auth.go ├── auth_generic_test_methods_test.go ├── auth_oauth.go ├── auth_oauth_test.go ├── auth_test.go ├── auth_wif.go ├── auth_wif_test.go ├── auth_with_external_browser_test.go ├── auth_with_keypair_test.go ├── auth_with_mfa_test.go ├── auth_with_oauth_okta_authorization_code_test.go ├── auth_with_oauth_okta_client_credentials_test.go ├── auth_with_oauth_snowflake_authorization_code_test.go ├── auth_with_oauth_snowflake_authorization_code_wildcards_test.go ├── auth_with_oauth_test.go ├── auth_with_okta_test.go ├── auth_with_pat_test.go ├── authexternalbrowser.go ├── authexternalbrowser_test.go ├── authokta.go ├── authokta_test.go ├── azure_storage_client.go ├── azure_storage_client_test.go ├── bind_uploader.go ├── bindings_test.go ├── chunk.go ├── chunk_downloader.go ├── chunk_downloader_test.go ├── chunk_test.go ├── ci/ │ ├── _init.sh │ ├── build.bat │ ├── build.sh │ ├── container/ │ │ ├── test_authentication.sh │ │ └── test_component.sh │ ├── docker/ │ │ └── rockylinux9/ │ │ └── Dockerfile │ ├── gofix.sh │ ├── image/ │ │ ├── Dockerfile │ │ ├── build.sh │ │ ├── scripts/ │ │ │ └── entrypoint.sh │ │ └── update.sh │ ├── scripts/ │ │ ├── .gitignore │ │ ├── README.md │ │ ├── ca.crt │ │ ├── ca.der │ │ ├── ca.key │ │ ├── ca.srl │ │ ├── execute_tests.sh │ │ ├── hang_webserver.py │ │ ├── login_internal_docker.sh │ │ ├── run_wiremock.sh │ │ ├── setup_connection_parameters.sh │ │ ├── setup_gpg.sh │ │ ├── wiremock-ecdsa-pub.key │ │ ├── wiremock-ecdsa.crt │ │ ├── wiremock-ecdsa.csr │ │ ├── wiremock-ecdsa.key │ │ ├── wiremock-ecdsa.p12 │ │ ├── wiremock.crt │ │ ├── wiremock.csr │ │ ├── wiremock.key │ │ ├── wiremock.p12 │ │ └── wiremock.v3.ext │ ├── test.bat │ ├── test.sh │ ├── test_authentication.sh │ ├── test_revocation.sh │ ├── test_rockylinux9.sh │ ├── test_rockylinux9_docker.sh │ ├── test_wif.sh │ └── wif/ │ └── parameters/ │ ├── parameters_wif.json.gpg │ ├── rsa_wif_aws_azure.gpg │ └── rsa_wif_gcp.gpg ├── client.go ├── client_configuration.go ├── client_configuration_test.go ├── client_test.go ├── cmd/ │ ├── arrow/ │ │ ├── .gitignore │ │ ├── Makefile │ │ └── transform_batches_to_rows/ │ │ ├── Makefile │ │ └── transform_batches_to_rows.go │ ├── logger/ │ │ ├── Makefile │ │ └── logger.go │ ├── mfa/ │ │ ├── Makefile │ │ └── mfa.go │ ├── programmatic_access_token/ │ │ ├── .gitignore │ │ ├── Makefile │ │ └── pat.go │ ├── tomlfileconnection/ │ │ ├── .gitignore │ │ └── Makefile │ └── variant/ │ ├── Makefile │ └── insertvariantobject.go ├── codecov.yml ├── connection.go ├── connection_configuration_test.go ├── connection_test.go ├── connection_util.go ├── connectivity_diagnosis.go ├── connectivity_diagnosis_test.go ├── connector.go ├── connector_test.go ├── converter.go ├── converter_test.go ├── crl.go ├── crl_test.go ├── ctx_test.go ├── datatype.go ├── datatype_test.go ├── datetime.go ├── datetime_test.go ├── doc.go ├── driver.go ├── driver_ocsp_test.go ├── driver_test.go ├── dsn.go ├── easy_logging.go ├── easy_logging_test.go ├── encrypt_util.go ├── encrypt_util_test.go ├── errors.go ├── errors_test.go ├── file_compression_type.go ├── file_transfer_agent.go ├── file_transfer_agent_test.go ├── file_util.go ├── file_util_test.go ├── function_wrapper_test.go ├── function_wrappers.go ├── gcs_storage_client.go ├── gcs_storage_client_test.go ├── go.mod ├── go.sum ├── gosnowflake.mak ├── heartbeat.go ├── heartbeat_test.go ├── htap.go ├── htap_test.go ├── internal/ │ ├── arrow/ │ │ └── arrow.go │ ├── compilation/ │ │ ├── cgo_disabled.go │ │ ├── cgo_enabled.go │ │ ├── linking_mode.go │ │ ├── minicore_disabled.go │ │ └── minicore_enabled.go │ ├── config/ │ │ ├── assert_test.go │ │ ├── auth_type.go │ │ ├── config.go │ │ ├── config_bool.go │ │ ├── connection_configuration.go │ │ ├── connection_configuration_test.go │ │ ├── crl_mode.go │ │ ├── dsn.go │ │ ├── dsn_test.go │ │ ├── ocsp_mode.go │ │ ├── priv_key.go │ │ ├── tls_config.go │ │ ├── tls_config_test.go │ │ └── token_accessor.go │ ├── errors/ │ │ └── errors.go │ ├── logger/ │ │ ├── accessor.go │ │ ├── accessor_test.go │ │ ├── context.go │ │ ├── easy_logging_support.go │ │ ├── interfaces.go │ │ ├── level_filtering.go │ │ ├── optional_interfaces.go │ │ ├── proxy.go │ │ ├── secret_detector.go │ │ ├── secret_detector_test.go │ │ ├── secret_masking.go │ │ ├── secret_masking_test.go │ │ ├── slog_handler.go │ │ ├── slog_logger.go │ │ └── source_location_test.go │ ├── os/ │ │ ├── libc_info.go │ │ ├── libc_info_linux.go │ │ ├── libc_info_notlinux.go │ │ ├── libc_info_test.go │ │ ├── os_details.go │ │ ├── os_details_linux.go │ │ ├── os_details_notlinux.go │ │ ├── os_details_test.go │ │ └── test_data/ │ │ └── sample_os_release │ ├── query/ │ │ ├── response_types.go │ │ └── transform.go │ └── types/ │ └── types.go ├── local_storage_client.go ├── local_storage_client_test.go ├── location.go ├── location_test.go ├── locker.go ├── log.go ├── log_client_test.go ├── log_test.go ├── minicore.go ├── minicore_disabled_test.go ├── minicore_posix.go ├── minicore_provider_darwin_amd64.go ├── minicore_provider_darwin_arm64.go ├── minicore_provider_linux_amd64.go ├── minicore_provider_linux_arm64.go ├── minicore_provider_windows_amd64.go ├── minicore_provider_windows_arm64.go ├── minicore_test.go ├── minicore_windows.go ├── monitoring.go ├── multistatement.go ├── multistatement_test.go ├── ocsp.go ├── ocsp_test.go ├── old_driver_test.go ├── os_specific_posix.go ├── os_specific_windows.go ├── parameters.json.local ├── parameters.json.tmpl ├── permissions_test.go ├── platform_detection.go ├── platform_detection_test.go ├── prepared_statement_test.go ├── priv_key_test.go ├── put_get_test.go ├── put_get_user_stage_test.go ├── put_get_with_aws_test.go ├── query.go ├── restful.go ├── restful_test.go ├── result.go ├── retry.go ├── retry_test.go ├── rows.go ├── rows_test.go ├── s3_storage_client.go ├── s3_storage_client_test.go ├── secret_detector.go ├── secret_detector_test.go ├── secure_storage_manager.go ├── secure_storage_manager_linux.go ├── secure_storage_manager_notlinux.go ├── secure_storage_manager_test.go ├── sflog/ │ ├── interface.go │ ├── levels.go │ └── slog.go ├── sqlstate.go ├── statement.go ├── statement_test.go ├── storage_client.go ├── storage_client_test.go ├── storage_file_util_test.go ├── structured_type.go ├── structured_type_arrow_batches_test.go ├── structured_type_read_test.go ├── structured_type_write_test.go ├── telemetry.go ├── telemetry_test.go ├── test_data/ │ ├── .gitignore │ ├── connections.toml │ ├── multistatements.sql │ ├── multistatements_drop.sql │ ├── orders_100.csv │ ├── orders_101.csv │ ├── put_get_1.txt │ ├── snowflake/ │ │ └── session/ │ │ └── token │ ├── userdata1.parquet │ ├── userdata1_orc │ └── wiremock/ │ └── mappings/ │ ├── auth/ │ │ ├── external_browser/ │ │ │ ├── parallel_login_first_fails_then_successful_flow.json │ │ │ ├── parallel_login_successful_flow.json │ │ │ └── successful_flow.json │ │ ├── mfa/ │ │ │ ├── parallel_login_first_fails_then_successful_flow.json │ │ │ └── parallel_login_successful_flow.json │ │ ├── oauth2/ │ │ │ ├── authorization_code/ │ │ │ │ ├── error_from_idp.json │ │ │ │ ├── invalid_code.json │ │ │ │ ├── successful_flow.json │ │ │ │ ├── successful_flow_with_offline_access.json │ │ │ │ └── successful_flow_with_single_use_refresh_token.json │ │ │ ├── client_credentials/ │ │ │ │ ├── invalid_client.json │ │ │ │ └── successful_flow.json │ │ │ ├── login_request.json │ │ │ ├── login_request_with_expired_access_token.json │ │ │ └── refresh_token/ │ │ │ ├── invalid_refresh_token.json │ │ │ ├── successful_flow.json │ │ │ └── successful_flow_without_new_refresh_token.json │ │ ├── password/ │ │ │ ├── invalid_host.json │ │ │ ├── invalid_password.json │ │ │ ├── invalid_user.json │ │ │ ├── successful_flow.json │ │ │ └── successful_flow_with_telemetry.json │ │ ├── pat/ │ │ │ ├── invalid_token.json │ │ │ ├── reading_fresh_token.json │ │ │ └── successful_flow.json │ │ └── wif/ │ │ ├── azure/ │ │ │ ├── http_error.json │ │ │ ├── missing_issuer_claim.json │ │ │ ├── missing_sub_claim.json │ │ │ ├── non_json_response.json │ │ │ ├── successful_flow_azure_functions.json │ │ │ ├── successful_flow_azure_functions_custom_entra_resource.json │ │ │ ├── successful_flow_azure_functions_no_client_id.json │ │ │ ├── successful_flow_azure_functions_v2_issuer.json │ │ │ ├── successful_flow_basic.json │ │ │ ├── successful_flow_v2_issuer.json │ │ │ └── unparsable_token.json │ │ └── gcp/ │ │ ├── http_error.json │ │ ├── missing_issuer_claim.json │ │ ├── missing_sub_claim.json │ │ ├── successful_flow.json │ │ ├── successful_impersionation_flow.json │ │ └── unparsable_token.json │ ├── close_session.json │ ├── hang.json │ ├── minicore/ │ │ └── auth/ │ │ ├── disabled_flow.json │ │ ├── successful_flow.json │ │ └── successful_flow_linux.json │ ├── ocsp/ │ │ ├── auth_failure.json │ │ ├── malformed.json │ │ └── unauthorized.json │ ├── platform_detection/ │ │ ├── aws_ec2_instance_success.json │ │ ├── aws_identity_success.json │ │ ├── azure_managed_identity_success.json │ │ ├── azure_vm_success.json │ │ ├── gce_identity_success.json │ │ ├── gce_vm_success.json │ │ └── timeout_response.json │ ├── query/ │ │ ├── long_running_query.json │ │ ├── query_by_id_timeout.json │ │ ├── query_execution.json │ │ ├── query_monitoring.json │ │ ├── query_monitoring_error.json │ │ ├── query_monitoring_malformed.json │ │ └── query_monitoring_running.json │ ├── retry/ │ │ └── redirection_retry_workflow.json │ ├── select1.json │ └── telemetry/ │ ├── custom_telemetry.json │ └── telemetry.json ├── test_utils_test.go ├── tls_config.go ├── tls_config_test.go ├── transaction.go ├── transaction_test.go ├── transport.go ├── transport_test.go ├── url_util.go ├── util.go ├── util_test.go ├── uuid.go ├── value_awaiter.go ├── version.go └── wiremock_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .cursor/rules/overall-guidelines.mdc ================================================ --- alwaysApply: true --- # Cursor Rules for Go Snowflake Driver ## General Development Standards ### Code Quality - Follow Go formatting standards (use `gofmt`) - Use meaningful variable and function names - Include error handling for all operations that can fail - Write comprehensive documentation for public APIs ### Project Structure - Place test files in the same package as the code being tested - Use `test_data/` directory for test fixtures and sample data - Group related functionality in logical packages ### Testing - Test files should be named `*_test.go` - **For test-specific rules, see `testing.mdc`** - Write both positive and negative test cases - Use table-driven tests for testing multiple scenarios ### Code Review Guidelines - Ensure code follows Go best practices - Verify comprehensive test coverage - Check that error messages are descriptive and helpful for debugging - Validate that public APIs are properly documented ================================================ FILE: .cursor/rules/testing.mdc ================================================ --- alwaysApply: true --- # Cursor Rules for Go Test Files This file automatically applies when working on `*_test.go` files. ## Testing Standards ### Assertion Helper Usage - **ALWAYS** Attempt to use assertion helpers from `assert_test.go` instead of direct `t.Fatal`, `t.Fatalf`, `t.Error`, or `t.Errorf` calls. Where it makes sense, add new assertion helpers. - **NEVER** write manual if-then-fatal patterns in test functions when a suitable assertion helper exists. #### Common Assertion Patterns: **Error Checking:** ```go // ❌ WRONG if err != nil { t.Fatalf("Unexpected error: %v", err) } // ✅ CORRECT assertNilF(t, err, "Unexpected error") ``` **Nil Checking:** ```go // ❌ WRONG if obj == nil { t.Fatal("Expected non-nil object") } // ✅ CORRECT assertNotNilF(t, obj, "Expected non-nil object") ``` **Equality Checking:** ```go // ❌ WRONG if actual != expected { t.Fatalf("Expected %v, got %v", expected, actual) } // ✅ CORRECT assertEqualF(t, actual, expected, "Values should match") ``` **Error Message Validation:** ```go // ❌ WRONG if err.Error() != expectedMsg { t.Fatalf("Expected error: %s, got: %s", expectedMsg, err.Error()) } // ✅ CORRECT assertEqualF(t, err.Error(), expectedMsg, "Error message should match") ``` **Boolean Assertions:** ```go // ❌ WRONG if !condition { t.Fatal("Condition should be true") } // ✅ CORRECT assertTrueF(t, condition, "Condition should be true") ``` #### Helper Function Reference: Always examine `assertion_helpers.go` for the latest set of helpers. Consider these existing examples below. - `assertNilF/E(t, value, description)` - Assert value is nil - `assertNotNilF/E(t, value, description)` - Assert value is not nil - `assertEqualF/E(t, actual, expected, description)` - Assert equality - `assertNotEqualF/E(t, actual, expected, description)` - Assert inequality - `assertTrueF/E(t, value, description)` - Assert boolean is true - `assertFalseF/E(t, value, description)` - Assert boolean is false - `assertStringContainsF/E(t, str, substring, description)` - Assert string contains substring - `assertErrIsF/E(t, actual, expected, description)` - Assert error matches expected error #### When to Use F vs E: - Use `F` suffix (Fatal) for critical failures that should stop the test immediately as well as for preconditions - Use `E` suffix (Error) for non-critical failures that allow the test to continue ## Code Review Guidelines: - Flag any direct use of `t.Fatal*` or `t.Error*` in new code - Ensure all test functions use appropriate assertion helpers - Verify that error messages are descriptive and helpful for debugging - Check that tests are comprehensive and cover edge cases# Cursor Rules for Go Test Files ================================================ FILE: .github/CODEOWNERS ================================================ * @snowflakedb/Client /transport.go @snowflakedb/pki-oversight @snowflakedb/Client /crl.go @snowflakedb/pki-oversight @snowflakedb/Client /ocsp.go @snowflakedb/pki-oversight @snowflakedb/Client # GitHub Advanced Security Secret Scanning config /.github/secret_scanning.yml @snowflakedb/prodsec-security-manager-write ================================================ FILE: .github/ISSUE_TEMPLATE/BUG_REPORT.md ================================================ --- name: Bug Report 🐞 about: Something isn't working as expected? Here is the right place to report. labels: bug --- :exclamation: If you need **urgent assistance** then [file a case with Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge). Otherwise continue here. Please answer these questions before submitting your issue. In order to accurately debug the issue this information is required. Thanks! 1. What version of GO driver are you using? 2. What operating system and processor architecture are you using? 3. What version of GO are you using? run `go version` in your console 4.Server version:* E.g. 1.90.1 You may get the server version by running a query: ``` SELECT CURRENT_VERSION(); ``` 5. What did you do? If possible, provide a recipe for reproducing the error. A complete runnable program is good. 6. What did you expect to see? What should have happened and what happened instead? 7. Can you set logging to DEBUG and collect the logs? https://community.snowflake.com/s/article/How-to-generate-log-file-on-Snowflake-connectors Before sharing any information, please be sure to review the log and remove any sensitive information. ================================================ FILE: .github/ISSUE_TEMPLATE/FEATURE_REQUEST.md ================================================ --- name: Feature Request 💡 about: Suggest a new idea for the project. labels: feature --- ## What is the current behavior? ## What is the desired behavior? ## How would this improve `gosnowflake`? ## References, Other Background ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ ### Issue description Tell us what should happen and what happens instead ### Example code ```go If possible, please enter some example code here to reproduce the issue. ``` ### Error log ``` If you have an error log, please paste it here. ``` Add ``glog` option to your application to collect log files. ### Configuration *Driver version (or git SHA):* *Go version:* run `go version` in your console *Server version:* E.g. 1.90.1 You may get the server version by running a query: ``` SELECT CURRENT_VERSION(); ``` *Client OS:* E.g. Debian 8.1 (Jessie), Windows 10 ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ ### Description SNOW-XXX Please explain the changes you made here. ### Checklist - [ ] Added proper logging (if possible) - [ ] Created tests which fail without the change (if possible) - [ ] Extended the README / documentation, if necessary ================================================ FILE: .github/repo_meta.yaml ================================================ point_of_contact: @snowflakedb/client production: true code_owners_file_present: false jira_area: Developer Platform ================================================ FILE: .github/secret_scanning.yml ================================================ paths-ignore: - "**/test_data/**" ================================================ FILE: .github/workflows/build-test.yml ================================================ name: Build and Test permissions: contents: read on: push: branches: - master tags: - v* pull_request: schedule: - cron: '7 3 * * *' workflow_dispatch: inputs: goTestParams: default: description: 'Parameters passed to go test' sequentialTests: type: boolean default: false description: 'Run tests sequentially (no buffering, slower)' concurrency: # older builds for the same pull request numer or branch should be cancelled group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: lint: runs-on: ubuntu-latest name: Check linter steps: - uses: actions/checkout@v4 - name: Setup go uses: actions/setup-go@v5 with: go-version: '1.26' - name: golangci-lint uses: golangci/golangci-lint-action@v7 with: version: v2.11 - name: Format, Lint shell: bash run: ./ci/build.sh - name: Run go fix across all platforms and tags shell: bash run: ./ci/gofix.sh build-test-linux: runs-on: ubuntu-latest strategy: fail-fast: false matrix: cloud: [ 'AWS', 'AZURE', 'GCP' ] go: [ '1.24', '1.25', '1.26' ] name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Ubuntu steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: ${{ matrix.cloud }} GORACE: history_size=7 GO_TEST_PARAMS: ${{ inputs.goTestParams }} SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ./ci/test.sh - name: Upload test results to Codecov if: ${{!cancelled()}} uses: codecov/test-results-action@v1 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} build-test-linux-no-home: runs-on: ubuntu-latest name: Ubuntu - no HOME steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: '1.25' - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: AWS GORACE: history_size=7 GO_TEST_PARAMS: ${{ inputs.goTestParams }} SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 HOME_EMPTY: "yes" run: ./ci/test.sh build-test-mac: runs-on: macos-latest strategy: fail-fast: false matrix: cloud: [ 'AWS', 'AZURE', 'GCP' ] go: [ '1.24', '1.25', '1.26' ] name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Mac steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: ${{ matrix.cloud }} GO_TEST_PARAMS: ${{ inputs.goTestParams }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ./ci/test.sh - name: Upload test results to Codecov if: ${{!cancelled()}} uses: codecov/test-results-action@v1 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} build-test-mac-no-home: runs-on: macos-latest name: Mac - no HOME steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: '1.25' - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: AWS GO_TEST_PARAMS: ${{ inputs.goTestParams }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 HOME_EMPTY: "yes" run: ./ci/test.sh build-test-windows: runs-on: windows-latest strategy: fail-fast: false matrix: cloud: [ 'AWS', 'AZURE', 'GCP' ] go: [ '1.24', '1.25', '1.26' ] name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Windows steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - uses: actions/setup-python@v5 with: python-version: '3.x' architecture: 'x64' - name: Test shell: cmd env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: ${{ matrix.cloud }} GO_TEST_PARAMS: ${{ inputs.goTestParams }} SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ci\\test.bat - name: Upload test results to Codecov if: ${{!cancelled()}} uses: codecov/test-results-action@v1 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} fipsOnly: runs-on: ubuntu-latest strategy: fail-fast: false name: FIPS only mode steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: ${{ matrix.cloud }} GORACE: history_size=7 GO_TEST_PARAMS: ${{ inputs.goTestParams }} TEST_GODEBUG: fips140=only SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ./ci/test.sh - name: Upload test results to Codecov if: ${{!cancelled()}} uses: codecov/test-results-action@v1 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} build-test-linux-minicore-disabled: runs-on: ubuntu-latest name: Ubuntu - minicore disabled steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: '1.25' - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: AWS GORACE: history_size=7 GO_TEST_PARAMS: ${{ inputs.goTestParams }} -tags=minicore_disabled WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ./ci/test.sh build-test-mac-minicore-disabled: runs-on: macos-latest name: Mac - minicore disabled steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: '1.25' - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: AWS GO_TEST_PARAMS: ${{ inputs.goTestParams }} -tags=minicore_disabled WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ./ci/test.sh build-test-windows-minicore-disabled: runs-on: windows-latest name: Windows - minicore disabled steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: '1.25' - uses: actions/setup-python@v5 with: python-version: '3.x' architecture: 'x64' - name: Test shell: cmd env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: AWS GO_TEST_PARAMS: ${{ inputs.goTestParams }} -tags=minicore_disabled WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ci\\test.bat ecc: runs-on: ubuntu-latest strategy: fail-fast: false name: Elliptic curves check steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: '1.25' - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: AWS GORACE: history_size=7 GO_TEST_PARAMS: ${{ inputs.goTestParams }} -run TestQueryViaHttps WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 WIREMOCK_ENABLE_ECDSA: true run: ./ci/test.sh build-test-rockylinux9: runs-on: ubuntu-latest strategy: fail-fast: false matrix: cloud_go: - cloud: 'AWS' go: '1.24.2' - cloud: 'AZURE' go: '1.25.0' - cloud: 'GCP' go: '1.26.0' name: ${{ matrix.cloud_go.cloud }} Go ${{ matrix.cloud_go.go }} on Rocky Linux 9 steps: - uses: actions/checkout@v4 - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: ${{ matrix.cloud_go.cloud }} GORACE: history_size=7 GO_TEST_PARAMS: ${{ inputs.goTestParams }} SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ./ci/test_rockylinux9_docker.sh ${{ matrix.cloud_go.go }} build-test-ubuntu-arm: runs-on: ubuntu-24.04-arm strategy: fail-fast: false matrix: cloud_go: - cloud: 'AWS' go: '1.24' - cloud: 'AZURE' go: '1.25' - cloud: 'GCP' go: '1.26' name: ${{ matrix.cloud_go.cloud }} Go ${{ matrix.cloud_go.go }} on Ubuntu ARM steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 17 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: ${{ matrix.cloud_go.go }} - name: Test shell: bash env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: ${{ matrix.cloud_go.cloud }} GORACE: history_size=7 GO_TEST_PARAMS: ${{ inputs.goTestParams }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ./ci/test.sh - name: Upload test results to Codecov if: ${{!cancelled()}} uses: codecov/test-results-action@v1 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} build-test-windows-arm: runs-on: windows-11-arm strategy: fail-fast: false matrix: cloud_go: - cloud: 'AWS' go: '1.24' - cloud: 'AZURE' go: '1.25' - cloud: 'GCP' go: '1.26' name: ${{ matrix.cloud_go.cloud }} Go ${{ matrix.cloud_go.cloud }} on Windows ARM steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 # for wiremock with: java-version: 21 distribution: 'temurin' - name: Setup go uses: actions/setup-go@v5 with: go-version: ${{ matrix.cloud_go.go }} - uses: actions/setup-python@v5 with: python-version: '3.x' architecture: 'x64' - name: Test shell: cmd env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }} CLOUD_PROVIDER: ${{ matrix.cloud_go.cloud }} GO_TEST_PARAMS: ${{ inputs.goTestParams }} WIREMOCK_PORT: 14335 WIREMOCK_HTTPS_PORT: 13567 run: ci\\test.bat - name: Upload test results to Codecov if: ${{!cancelled()}} uses: codecov/test-results-action@v1 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }} ================================================ FILE: .github/workflows/changelog.yml ================================================ name: Changelog Check on: pull_request: types: [opened, synchronize, labeled, unlabeled] jobs: check_change_log: runs-on: ubuntu-latest if: ${{!contains(github.event.pull_request.labels.*.name, 'NO-CHANGELOG-UPDATES')}} steps: - name: Checkout uses: actions/checkout@v3 with: fetch-depth: 0 - name: Ensure CHANGELOG.md is updated run: git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -wq "CHANGELOG.md" ================================================ FILE: .github/workflows/cla_bot.yml ================================================ name: "CLA Assistant" on: issue_comment: types: [created] pull_request_target: types: [opened,closed,synchronize] jobs: CLAAssistant: runs-on: ubuntu-latest permissions: actions: write contents: write pull-requests: write statuses: write steps: - name: "CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' uses: contributor-assistant/github-action/@master env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PERSONAL_ACCESS_TOKEN : ${{ secrets.CLA_BOT_TOKEN }} with: path-to-signatures: 'signatures/version1.json' path-to-document: 'https://github.com/snowflakedb/CLA/blob/main/README.md' branch: 'main' allowlist: 'dependabot[bot],github-actions,Jenkins User,_jenkins,sfc-gh-snyk-sca-sa,snyk-bot' remote-organization-name: 'snowflake-eng' remote-repository-name: 'cla-db' ================================================ FILE: .github/workflows/jira_close.yml ================================================ name: Jira closure on: issues: types: [closed, deleted] jobs: close-issue: runs-on: ubuntu-latest steps: - name: Extract issue from title id: extract env: TITLE: "${{ github.event.issue.title }}" run: | jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://') echo ::set-output name=jira::$jira - name: Close Jira Issue if: startsWith(steps.extract.outputs.jira, 'SNOW-') env: ISSUE_KEY: ${{ steps.extract.outputs.jira }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} run: | JIRA_API_URL="${JIRA_BASE_URL}/rest/api/2/issue/${ISSUE_KEY}/transitions" curl -X POST \ --url "$JIRA_API_URL" \ --user "${JIRA_USER_EMAIL}:${JIRA_API_TOKEN}" \ --header "Content-Type: application/json" \ --data "{ \"update\": { \"comment\": [ { \"add\": { \"body\": \"Closed on GitHub\" } } ] }, \"fields\": { \"customfield_12860\": { \"id\": \"11506\" }, \"customfield_10800\": { \"id\": \"-1\" }, \"customfield_12500\": { \"id\": \"11302\" }, \"customfield_12400\": { \"id\": \"-1\" }, \"resolution\": { \"name\": \"Done\" } }, \"transition\": { \"id\": \"71\" } }" ================================================ FILE: .github/workflows/jira_comment.yml ================================================ name: Jira comment on: issue_comment: types: [created] jobs: comment-issue: runs-on: ubuntu-latest steps: - name: Jira login uses: atlassian/gajira-login@master env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} - name: Extract issue from title id: extract env: TITLE: "${{ github.event.issue.title }}" run: | jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://') echo ::set-output name=jira::$jira - name: Comment on issue uses: atlassian/gajira-comment@master if: startsWith(steps.extract.outputs.jira, 'SNOW-') && github.event.comment.user.login != 'codecov[bot]' with: issue: "${{ steps.extract.outputs.jira }}" comment: "${{ github.event.comment.user.login }} commented:\n\n${{ github.event.comment.body }}\n\n${{ github.event.comment.html_url }}" ================================================ FILE: .github/workflows/jira_issue.yml ================================================ name: Jira creation on: issues: types: [opened] issue_comment: types: [created] jobs: create-issue: runs-on: ubuntu-latest permissions: issues: write if: ((github.event_name == 'issue_comment' && github.event.comment.body == 'recreate jira' && github.event.comment.user.login == 'sfc-gh-mkeller') || (github.event_name == 'issues' && github.event.pull_request.user.login != 'whitesource-for-github-com[bot]')) steps: - name: Create JIRA Ticket id: create env: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} ISSUE_TITLE: ${{ github.event.issue.title }} ISSUE_BODY: ${{ github.event.issue.body }} ISSUE_URL: ${{ github.event.issue.html_url }} run: | # debug #set -x TMP_BODY=$(mktemp) trap "rm -f $TMP_BODY" EXIT # Escape special characters in title and body TITLE=$(echo "${ISSUE_TITLE//`/\\`}" | sed 's/"/\\"/g' | sed "s/'/\\\'/g") echo "${ISSUE_BODY//`/\\`}" | sed 's/"/\\"/g' | sed "s/'/\\\'/g" > $TMP_BODY echo -e "\n\n_Created from GitHub Action_ for $ISSUE_URL" >> $TMP_BODY BODY=$(cat "$TMP_BODY") PAYLOAD=$(jq -n \ --arg issuetitle "$TITLE" \ --arg issuebody "$BODY" \ '{ fields: { project: { key: "SNOW" }, issuetype: { name: "Bug" }, summary: $issuetitle, description: $issuebody, customfield_11401: { id: "14723" }, assignee: { id: "712020:e527ae71-55cc-4e02-9217-1ca4ca8028a2" }, components: [{ id: "19286" }], labels: ["oss"], priority: { id: "10001" } } }') # Create JIRA issue using REST API RESPONSE=$(curl -s -X POST \ -H "Content-Type: application/json" \ -H "Accept: application/json" \ -u "$JIRA_USER_EMAIL:$JIRA_API_TOKEN" \ "$JIRA_BASE_URL/rest/api/2/issue" \ -d "$PAYLOAD") # Extract JIRA issue key from response JIRA_KEY=$(echo "$RESPONSE" | jq -r '.key') if [ "$JIRA_KEY" = "null" ] || [ -z "$JIRA_KEY" ]; then echo "Failed to create JIRA issue" echo "Response: $RESPONSE" echo "Request payload: $PAYLOAD" exit 1 fi echo "Created JIRA issue: $JIRA_KEY" echo "jira_key=$JIRA_KEY" >> $GITHUB_OUTPUT - name: Update GitHub Issue env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} REPOSITORY: ${{ github.repository }} ISSUE_NUMBER: ${{ github.event.issue.number }} JIRA_KEY: ${{ steps.create.outputs.jira_key }} ISSUE_TITLE: ${{ github.event.issue.title }} run: | TITLE=$(echo "${ISSUE_TITLE//`/\\`}" | sed 's/"/\\"/g' | sed "s/'/\\\'/g") PAYLOAD=$(jq -n \ --arg issuetitle "$TITLE" \ --arg jirakey "$JIRA_KEY" \ '{ title: ($jirakey + ": " + $issuetitle) }') # Update Github issue title with jira id curl -s \ -X PATCH \ -H "Authorization: Bearer $GITHUB_TOKEN" \ -H "Accept: application/vnd.github+json" \ -H "X-GitHub-Api-Version: 2022-11-28" \ "https://api.github.com/repos/$REPOSITORY/issues/$ISSUE_NUMBER" \ -d "$PAYLOAD" if [ "$?" != 0 ]; then echo "Failed to update GH issue. Payload was:" echo "$PAYLOAD" exit 1 fi ================================================ FILE: .github/workflows/semgrep.yml ================================================ name: Run semgrep checks on: pull_request: branches: [main, master] permissions: contents: read jobs: run-semgrep-reusable-workflow: uses: snowflakedb/reusable-workflows/.github/workflows/semgrep-v2.yml@main secrets: token: ${{ secrets.SEMGREP_APP_TOKEN }} ================================================ FILE: .gitignore ================================================ *.DS_Store .idea/ .vscode/ parameters*.json parameters*.bat *.p8 coverage.txt fuzz-*/ /select1 /selectmany /verifycert wss-golang-agent.config wss-unified-agent.jar whitesource/ *.swp cp.out __debug_bin* test-output.txt test-report.junit.xml # exclude vendor vendor # SSH private key for WIF tests ci/wif/parameters/rsa_wif_aws_azure ci/wif/parameters/rsa_wif_gcp ================================================ FILE: .golangci.yml ================================================ version: "2" run: tests: true linters: exclusions: rules: - path: "_test.go" linters: - errcheck - path: "cmd/" linters: - errcheck - path: "_test.go" linters: - staticcheck text: "implement StmtQueryContext" - path: "_test.go" linters: - staticcheck text: "implement StmtExecContext" - linters: - staticcheck text: "QF1001" - linters: - staticcheck text: "SA1019: .+\\.(LoginTimeout|RequestTimeout|JWTExpireTimeout|ClientTimeout|JWTClientTimeout|ExternalBrowserTimeout|CloudStorageTimeout|Tracing) is deprecated" ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: git@github.com:snowflakedb/casec_precommit.git # SSH # - repo: https://github.com/snowflakedb/casec_precommit.git # HTTPS rev: v1.5 hooks: - id: snapps-secret-scanner ================================================ FILE: .windsurf/rules/go.md ================================================ --- trigger: glob description: globs: **/*.go --- # Go files rules ## General 1. Unless it's necessary or told otherwise, try reusing existing files, both for implementation and tests. 2. If possible, try running relevant tests. ## Tests 1. Create a test file with the name same as prod code file by default. 2. For assertions use our test helpers defined in assert_test.go. ## Logging 1. Add reasonable logging - don't repeat logs, but add them when it's meaningful. 2. Always consider log levels. ================================================ FILE: CHANGELOG.md ================================================ # Changelog ## Upcoming release Bug fixes: - Fixed empty `Account` when connecting with programmatic `Config` and `database/sql.Connector` by deriving `Account` from the first DNS label of `Host` in `FillMissingConfigParameters` when `Host` matches the Snowflake hostname pattern (snowflakedb/gosnowflake#1772). ## 2.0.1 Bug fixes: - Fixed default `CrlDownloadMaxSize` to be 20MB instead of 200MB, as the previous value was set too high and could cause out-of-memory issues (snowflakedb/gosnowflake#1735). - Replaced global `paramsMutex` with per-connection `syncParams` to encapsulate parameter synchronization and avoid cross-connection contention (snowflakedb/gosnoflake#1747). - `Config.Params` map is not modified anymore, to avoid changing parameter values across connections of the same connection pool (snowflakedb/gosnowflake#1747). - Set `BlobContentMD5` on Azure uploads so that multi-part uploads have the blob content-MD5 property populated (snowflakedb/gosnowflake#1757). - Fixed 403 errors from Google/GCP/GCS PUT queries on versioned stages (snowflakedb/gosnowflake#1760). - Fixed not updating query context cache for failed queries (snowflakedb/gosnowflake#1763). Internal changes: - Moved configuration to a dedicated internal package (snowflakedb/gosnowflake#1720). - Modernized Go syntax idioms throughout the codebase. - Added libc family, version and dynamic linking marker to client environment telemetry (snowflakedb/gosnowflake#1750). - Bumped a few libraries to fix vulnerabilities (snowflakedb/gosnowflake#1751, snowflakedb/gosnowflake#1756). - Depointerised query context cache in `snowflakeConn` (snowflakedb/gosnowflake#1763). ## 2.0.0 Breaking changes: - Removed `RaisePutGetError` from `SnowflakeFileTransferOptions` - current behaviour is aligned to always raise errors for PUT/GET operations (snowflakedb/gosnowflake#1690). - Removed `GetFileToStream` from `SnowflakeFileTransferOptions` - using `WithFileGetStream` automatically enables file streaming for GETs (snowflakedb/gosnowflake#1690). - Renamed `WithFileStream` to `WithFilePutStream` for consistency (snowflakedb/gosnowflake#1690). - `Array` function now returns error for unsupported types (snowflakedb/gosnowflake#1693). - `WithMultiStatement` does not return error anymore (snowflakedb/gosnowflake#1693). - `WithOriginalTimestamp` is removed, use `WithArrowBatchesTimestampOption(UseOriginalTimestamp)` instead (snowflakedb/gosnowflake#1693). - `WithMapValuesNullable` and `WithArrayValuesNullable` combined into one option `WithEmbeddedValuesNullable` (snowflakedb/gosnowflake#1693). - Hid streaming chunk downloader. It will be removed completely in the future (snowflakedb/gosnowflake#1696). - Maximum number of chunk download goroutines is now configured with `CLIENT_PREFETCH_THREADS` session parameter (snowflakedb/gosnowflake#1696) and default to 4. - Fixed typo in `GOSNOWFLAKE_SKIP_REGISTRATION` env variable (snowflakedb/gosnowflake#1696). - Removed `ClientIP` field from `Config` struct. This field was never used and is not needed for any functionality (snowflakedb/gosnowflake#1692). - Unexported MfaToken and IdToken (snowflakedb/gosnowflake#1692). - Removed `InsecureMode` field from `Config` struct. Use `DisableOCSPChecks` instead (snowflakedb/gosnowflake#1692). - Renamed `KeepSessionAlive` field in `Config` struct to `ServerSessionKeepAlive` to adjust with the remaining drivers (snowflakedb/gosnowflake#1692). - Removed `DisableTelemetry` field from `Config` struct. Use `CLIENT_TELEMETRY_ENABLED` session parameter instead (snowflakedb/gosnowflake#1692). - Removed stream chunk downloader. Use a regular, default downloader instead. (snowflakedb/gosnowflake#1702). - Removed `SnowflakeTransport`. Use `Config.Transporter` or simply register your own TLS config with `RegisterTLSConfig` if you just need a custom root certificates set (snowflakedb/gosnowflake#1703). - Arrow batches changes (snowflakedb/gosnowflake#1706): - Arrow batches have been extracted to a separate package. It should significantly drop the compilation size for those who don't need arrow batches (~34MB -> ~18MB). - Removed `GetArrowBatches` from `SnowflakeRows` and `SnowflakeResult`. Use `arrowbatches.GetArrowBatches(rows.(SnowflakeRows))` instead. - Migrated functions: - `sf.WithArrowBatchesTimestampOption` -> `arrowbatches.WithTimstampOption` - `sf.WithArrowBatchesUtf8Validation` -> `arrowbatches.WithUtf8Validation` - `sf.ArrowSnowflakeTimestampToTime` -> `arrowbatches.ArrowSnowflakeTimestampToTime` - Logging changes (snowflakedb/gosnowflake#1710): - Removed Logrus logger and migrated to slog. - Simplified `SFLogger` interface. - Added `SFSlogLogger` interface for setting custom slog handler. Bug fixes: - The query `context.Context` is now propagated to cloud storage operations for PUT and GET queries, allowing for better cancellation handling (snowflakedb/gosnowflake#1690). New features: - Added support for Go 1.26, dropped support for Go 1.23 (snowflakedb/gosnowflake#1707). - Added support for FIPS-only mode (snowflakedb/gosnowflake#1496). Bug fixes: - Added panic recovery block for stage file uploads and downloads operation (snowflakedb/gosnowflake#1687). - Fixed WIF metadata request from Azure container, manifested with HTTP 400 error (snowflakedb/gosnowflake#1701). - Fixed SAML authentication port validation bypass in `isPrefixEqual` where the second URL's port was never checked (snowflakedb/gosnowflake#1712). - Fixed a race condition in OCSP cache clearer (snowflakedb/gosnowflake#1704). - The query `context.Context` is now propagated to cloud storage operations for PUT and GET queries, allowing for better cancellation handling (snowflakedb/gosnowflake#1690). - Fixed `tokenFilePath` DSN parameter triggering false validation error claiming both `token` and `tokenFilePath` were specified when only `tokenFilePath` was provided in the DSN string (snowflakedb/gosnowflake#1715). - Fixed minicore crash (SIGFPE) on fully statically linked Linux binaries by detecting static linking via ELF PT_INTERP inspection and skipping `dlopen` gracefully (snowflakedb/gosnowflake#1721). Internal changes: - Moved configuration to a dedicated internal package (snowflakedb/gosnowflake#1720). ## 1.19.0 New features: - Added ability to disable minicore loading at compile time (snowflakedb/gosnowflake#1679). - Exposed `tokenFilePath` in `Config` (snowflakedb/gosnowflake#1666). - `tokenFilePath` is now read for every new connection (snowflakedb/gosnowflake#1666). - Added support for identity impersonation when using workload identity federation (snowflakedb/gosnowflake#1652, snowflakedb/gosnowflake#1660). Bug fixes: - Fixed getting file from an unencrypted stage (snowflakedb/gosnowflake#1672). - Fixed minicore file name gathering in client environment (snowflakedb/gosnowflake#1661). - Fixed file descriptor leaks in cloud storage calls (snowflakedb/gosnowflake#1682) - Fixed path escaping for GCS urls (snowflakedb/gosnowflake#1678). Internal changes: - Improved Linux telemetry gathering (snowflakedb/gosnowflake#1677). - Improved some logs returned from cloud storage clients (snowflakedb/gosnowflake#1665). ## 1.18.1 Bug fixes: - Handle HTTP307 & 308 in drivers to achieve better resiliency to backend errors (snowflakedb/gosnowflake#1616). - Create temp directory only if needed during file transfer (snowflakedb/gosnowflake#1647) - Fix unnecessary user expansion for file paths (snowflakedb/gosnowflake#1646). Internal changes: - Remove spammy "telemetry disabled" log messages (snowflakedb/gosnowflake#1638). - Introduced shared library ([source code](https://github.com/snowflakedb/universal-driver/tree/main/sf_mini_core)) for extended telemetry to identify and prepare testing platform for native rust extensions (snowflakedb/gosnowflake#1629) ## 1.18.0 New features: - Added validation of CRL `NextUpdate` for freshly downloaded CRLs (snowflakedb/gosnowflake#1617) - Exposed function to send arbitrary telemetry data (snowflakedb/gosnowflake#1627) - Added logging of query text and parameters (snowflakedb/gosnowflake#1625) Bug fixes: - Fixed a data race error in tests caused by platform_detection init() function (snowflakedb/gosnowflake#1618) - Make secrets detector initialization thread safe and more maintainable (snowflakedb/gosnowflake#1621) Internal changes: - Added ISA to login request telemetry (snowflakedb/gosnowflake#1620) ## 1.17.1 - Fix unsafe reflection of nil pointer on DECFLOAT func in bind uploader (snowflakedb/gosnowflake#1604). - Added temporary download files cleanup (snowflakedb/gosnowflake#1577) - Marked fields as deprecated (snowflakedb/gosnowflake#1556) - Exposed `QueryStatus` from `SnowflakeResult` and `SnowflakeRows` in `GetStatus()` function (snowflakedb/gosnowflake#1556) - Split timeout settings into separate groups based on target service types (snowflakedb/gosnowflake#1531) - Added small clarification in oauth.go example on token escaping (snowflakedb/gosnowflake#1574) - Ensured proper permissions for CRL cache directory (snowflakedb/gosnowflake#1588) - Added `CrlDownloadMaxSize` to limit the size of CRL downloads (snowflakedb/gosnowflake#1588) - Added platform telemetry to login requests. Can be disabled with `SNOWFLAKE_DISABLE_PLATFORM_DETECTION` environment variable (snowflakedb/gosnowflake#1601) - Bypassed proxy settings for WIF metadata requests (snowflakedb/gosnowflake#1593) - Fixed a bug where GCP PUT/GET operations would fail when the connection context was cancelled (snowflakedb/gosnowflake#1584) - Fixed nil pointer dereference while calling long-running queries (snowflakedb/gosnowflake#1592) (snowflakedb/gosnowflake#1596) - Moved keyring-based secure storage manager into separate file to avoid the need to initialize keyring on Linux. (snowflakedb/gosnowflake#1595) - Enabling official support for RHEL9 by testing and enabling CI/CD checks for Rocky Linux in CICD, (snowflakedb/gosnowflake#1597) - Improve logging (snowflakedb/gosnowflake#1570) ## 1.17.0 - Added ability to configure OCSP per connection (snowflakedb/gosnowflake#1528) - Added `DECFLOAT` support, see details in `doc.go` (snowflakedb/gosnowflake#1504, snowflakedb/gosnowflake#1506) - Added support for Go 1.25, dropped support for Go 1.22 (snowflakedb/gosnowflake#1544) - Added proxy options to connection parameters (snowflakedb/gosnowflake#1511) - Added `client_session_keep_alive_heartbeat_frequency` connection param (snowflakedb/gosnowflake#1576) - Added support for multi-part downloads for S3, Azure and GCP (snowflakedb/gosnowflake#1549) - Added `singleAuthenticationPrompt` to control whether only one authentication should be performed at the same time for authentications that need human interactions (like MFA or OAuth authorization code). Default is true. (snowflakedb/gosnowflake#1561) - Fixed missing `DisableTelemetry` option in connection parameters (snowflakedb/gosnowflake#1520) - Fixed multistatements in large result sets (snowflakedb/gosnowflake#1539, snowflakedb/gosnowflake#1543, snowflakedb/gosnowflake#1547) - Fixed unnecessary retries when context is cancelled (snowflakedb/gosnowflake#1540) - Fixed regression in TOML connection file (snowflakedb/gosnowflake#1530) ## Prior Releases Release notes available at https://docs.snowflake.com/en/release-notes/clients-drivers/golang ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing Guidelines ## Reporting Issues Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/snowflakedb/gosnowflake/issues?state=open) or was [recently closed](https://github.com/snowflakedb/gosnowflake/issues?direction=desc&page=1&sort=updated&state=closed). ## Contributing Code By contributing to this project, you share your code under the Apache License 2, as specified in the LICENSE file. ### Code Review Everyone is invited to review and comment on pull requests. If it looks fine to you, comment with "LGTM" (Looks good to me). If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes. Before merging the Pull Request, at least one Snowflake team member must have commented with "LGTM". ================================================ FILE: Jenkinsfile ================================================ @Library('pipeline-utils') import com.snowflake.DevEnvUtils import groovy.json.JsonOutput timestamps { node('high-memory-node') { stage('checkout') { scmInfo = checkout scm println("${scmInfo}") env.GIT_BRANCH = scmInfo.GIT_BRANCH env.GIT_COMMIT = scmInfo.GIT_COMMIT } params = [ string(name: 'svn_revision', value: 'temptest-deployed'), string(name: 'branch', value: 'main'), string(name: 'client_git_commit', value: scmInfo.GIT_COMMIT), string(name: 'client_git_branch', value: scmInfo.GIT_BRANCH), string(name: 'TARGET_DOCKER_TEST_IMAGE', value: 'go-chainguard-go1_24'), string(name: 'parent_job', value: env.JOB_NAME), string(name: 'parent_build_number', value: env.BUILD_NUMBER) ] stage('Authenticate Artifactory') { script { new DevEnvUtils().withSfCli { sh "sf artifact oci auth" } } } parallel( 'Test': { stage('Test') { build job: 'RT-LanguageGo-PC', parameters: params } }, 'Test Authentication': { stage('Test Authentication') { withCredentials([ string(credentialsId: 'sfctest0-parameters-secret', variable: 'PARAMETERS_SECRET') ]) { sh '''\ |#!/bin/bash -e |$WORKSPACE/ci/test_authentication.sh '''.stripMargin() } } }, 'Test WIF Auth': { stage('Test WIF Auth') { withCredentials([ string(credentialsId: 'sfctest0-parameters-secret', variable: 'PARAMETERS_SECRET'), ]) { sh '''\ |#!/bin/bash -e |$WORKSPACE/ci/test_wif.sh '''.stripMargin() } } }, 'Test Revocation Validation': { stage('Test Revocation Validation') { withCredentials([ usernamePassword(credentialsId: 'jenkins-snowflakedb-github-app', usernameVariable: 'GITHUB_USER', passwordVariable: 'GITHUB_TOKEN') ]) { try { sh '''\ |#!/bin/bash -e |chmod +x $WORKSPACE/ci/test_revocation.sh |$WORKSPACE/ci/test_revocation.sh '''.stripMargin() } finally { archiveArtifacts artifacts: 'revocation-results.json,revocation-report.html', allowEmptyArchive: true publishHTML(target: [ allowMissing: true, alwaysLinkToLastBuild: true, keepAll: true, reportDir: '.', reportFiles: 'revocation-report.html', reportName: 'Revocation Validation Report' ]) } } } } ) } } pipeline { agent { label 'high-memory-node' } options { timestamps() } environment { COMMIT_SHA_LONG = sh(returnStdout: true, script: "echo \$(git rev-parse " + "HEAD)").trim() // environment variables for semgrep_agent (for findings / analytics page) // remove .git at the end // remove SCM URL + .git at the end BASELINE_BRANCH = "${env.CHANGE_TARGET}" } stages { stage('Checkout') { steps { checkout scm } } } } def wgetUpdateGithub(String state, String folder, String targetUrl, String seconds) { def ghURL = "https://api.github.com/repos/snowflakedb/gosnowflake/statuses/$COMMIT_SHA_LONG" def data = JsonOutput.toJson([state: "${state}", context: "jenkins/${folder}",target_url: "${targetUrl}"]) sh "wget ${ghURL} --spider -q --header='Authorization: token $GIT_PASSWORD' --post-data='${data}'" } ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ NAME:=gosnowflake VERSION:=$(shell git describe --tags --abbrev=0) REVISION:=$(shell git rev-parse --short HEAD) COVFLAGS:= ## Run fmt, lint and test all: fmt lint cov include gosnowflake.mak ## Run tests test_setup: test_teardown python3 ci/scripts/hang_webserver.py 12345 & test_teardown: pkill -9 hang_webserver || true test: deps test_setup ./ci/scripts/execute_tests.sh ## Run Coverage tests cov: make test COVFLAGS="-coverprofile=coverage.txt -covermode=atomic" ## Lint lint: clint ## Format source codes fmt: cfmt @for c in $$(ls cmd); do \ (cd cmd/$$c; make fmt); \ done ## Install sample programs install: for c in $$(ls cmd); do \ (cd cmd/$$c; GOBIN=$$GOPATH/bin go install $$c.go); \ done ## Build fuzz tests fuzz-build: for c in $$(ls | grep -E "fuzz-*"); do \ (cd $$c; make fuzz-build); \ done ## Run fuzz-dsn fuzz-dsn: (cd fuzz-dsn; go-fuzz -bin=./dsn-fuzz.zip -workdir=.) .PHONY: setup deps update test lint help fuzz-dsn ================================================ FILE: README.md ================================================ ## Migrating to v2 **Version 2.0.0 of the Go Snowflake Driver was released on March 3rd, 2026.** This major version includes breaking changes that require code updates when migrating from v1.x. ### Key Changes and Migration Steps #### 1. Update Import Paths Update your `go.mod` to use v2: ```sh go get -u github.com/snowflakedb/gosnowflake/v2 ``` Update imports in your code: ```go // Old (v1) import "github.com/snowflakedb/gosnowflake" // New (v2) import "github.com/snowflakedb/gosnowflake/v2" ``` #### 2. Arrow Batches Moved to Separate Package The public Arrow batches API now lives in `github.com/snowflakedb/gosnowflake/v2/arrowbatches`. Importing that sub-package pulls in the additional Arrow compute dependency only for applications that use Arrow batches directly. **Migration:** ```go import ( "context" "database/sql/driver" sf "github.com/snowflakedb/gosnowflake/v2" "github.com/snowflakedb/gosnowflake/v2/arrowbatches" ) ctx := arrowbatches.WithArrowBatches(context.Background()) var rows driver.Rows err := conn.Raw(func(x any) error { rows, err = x.(driver.QueryerContext).QueryContext(ctx, query, nil) return err }) if err != nil { // handle error } batches, err := arrowbatches.GetArrowBatches(rows.(sf.SnowflakeRows)) if err != nil { // handle error } ``` **Optional helper mapping:** - `sf.WithArrowBatchesTimestampOption` → `arrowbatches.WithTimestampOption` - `sf.WithArrowBatchesUtf8Validation` → `arrowbatches.WithUtf8Validation` - `sf.ArrowSnowflakeTimestampToTime` → `arrowbatches.ArrowSnowflakeTimestampToTime` - `sf.WithOriginalTimestamp` → `arrowbatches.WithTimestampOption(ctx, arrowbatches.UseOriginalTimestamp)` #### 3. Configuration Struct Changes **Renamed fields:** ```go // Old (v1) config := &gosnowflake.Config{ KeepSessionAlive: true, InsecureMode: true, DisableTelemetry: true, } // New (v2) config := &gosnowflake.Config{ ServerSessionKeepAlive: true, // Renamed for consistency with other drivers DisableOCSPChecks: true, // Replaces InsecureMode // DisableTelemetry removed - use CLIENT_TELEMETRY_ENABLED session parameter } ``` **Removed fields:** - `ClientIP` - No longer used - `MfaToken` and `IdToken` - Now unexported - `DisableTelemetry` - Use `CLIENT_TELEMETRY_ENABLED` session parameter instead #### 4. Logger Changes The built-in logger is now based on Go's standard `log/slog`: ```go logger := gosnowflake.GetLogger() _ = logger.SetLogLevel("debug") ``` For custom logging, continue implementing `SFLogger`. If you want to customize the built-in slog handler, type-assert `GetLogger()` to `SFSlogLogger` and call `SetHandler`. #### 5. File Transfer Changes **Configuration options:** ```go // Old (v1) options := &gosnowflake.SnowflakeFileTransferOptions{ RaisePutGetError: true, GetFileToStream: true, } ctx = gosnowflake.WithFileStream(ctx, stream) // New (v2) // RaisePutGetError removed - errors always raised // GetFileToStream removed - use WithFileGetStream instead ctx = gosnowflake.WithFilePutStream(ctx, stream) // Renamed from WithFileStream ctx = gosnowflake.WithFileGetStream(ctx, stream) // For GET operations ``` #### 6. Context and Function Changes ```go // Old (v1) ctx, err := gosnowflake.WithMultiStatement(ctx, 0) if err != nil { // handle error } // New (v2) ctx = gosnowflake.WithMultiStatement(ctx, 0) // No error returned ``` ```go // Old (v1) values := gosnowflake.Array(data) // New (v2) values, err := gosnowflake.Array(data) // Now returns error for unsupported types if err != nil { // handle error } ``` #### 7. Nullable Options Combined ```go // Old (v1) ctx = gosnowflake.WithMapValuesNullable(ctx) ctx = gosnowflake.WithArrayValuesNullable(ctx) // New (v2) ctx = gosnowflake.WithEmbeddedValuesNullable(ctx) // Handles both maps and arrays ``` #### 8. Session Parameter Changes **Chunk download workers:** ```go // Old (v1) gosnowflake.MaxChunkDownloadWorkers = 10 // Global variable // New (v2) // Configure via CLIENT_PREFETCH_THREADS session parameter. // NOTE: The default is 4. db.Exec("ALTER SESSION SET CLIENT_PREFETCH_THREADS = 10") ``` #### 9. Transport Configuration ```go import "crypto/tls" // Old (v1) gosnowflake.SnowflakeTransport = yourTransport // New (v2) config := &gosnowflake.Config{ Transporter: yourCustomTransport, } // Or, if you only need custom TLS settings/certificates: tlsConfig := &tls.Config{ // ... } _ = gosnowflake.RegisterTLSConfig("custom", tlsConfig) config.TLSConfigName = "custom" ``` #### 10. Environment Variable Fix If you use the skip registration environment variable: ```sh # Old (v1) GOSNOWFLAKE_SKIP_REGISTERATION=true # Note the typo # New (v2) GOSNOWFLAKE_SKIP_REGISTRATION=true # Typo fixed ``` ### Additional Resources - Full list of changes: See [CHANGELOG.md](./CHANGELOG.md) - Questions or issues: [GitHub Issues](https://github.com/snowflakedb/gosnowflake/issues) ## Support For official support and urgent, production-impacting issues, please [contact Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge). # Go Snowflake Driver Coverage This topic provides instructions for installing, running, and modifying the Go Snowflake Driver. The driver supports Go's [database/sql](https://golang.org/pkg/database/sql/) package. # Prerequisites The following software packages are required to use the Go Snowflake Driver. ## Go The latest driver requires the [Go language](https://golang.org/) 1.24 or higher. The supported operating systems are 64-bits Linux, Mac OS, and Windows, but you may run the driver on other platforms if the Go language works correctly on those platforms. # Installation If you don't have a project initialized, set it up. ```sh go mod init example.com/snowflake ``` Get Gosnowflake source code, if not installed. ```sh go get -u github.com/snowflakedb/gosnowflake/v2 ``` # Docs For detailed documentation and basic usage examples, please see the documentation at [godoc.org](https://godoc.org/github.com/snowflakedb/gosnowflake/v2). ## Notes This driver currently does not support GCP regional endpoints. Please ensure that any workloads using through this driver do not require support for regional endpoints on GCP. If you have questions about this, please contact Snowflake Support. The driver uses Rust library called sf_mini_core, you can find its source code [here](https://github.com/snowflakedb/universal-driver/tree/main/sf_mini_core) # Sample Programs Snowflake provides a set of sample programs to test with. Set the environment variable ``$GOPATH`` to the top directory of your workspace, e.g., ``~/go`` and make certain to include ``$GOPATH/bin`` in the environment variable ``$PATH``. Run the ``make`` command to build all sample programs. ```sh make install ``` In the following example, the program ``select1.go`` is built and installed in ``$GOPATH/bin`` and can be run from the command line: ```sh SNOWFLAKE_TEST_ACCOUNT= \ SNOWFLAKE_TEST_USER= \ SNOWFLAKE_TEST_PASSWORD= \ select1 Congrats! You have successfully run SELECT 1 with Snowflake DB! ``` # Development The developer notes are hosted with the source code on [GitHub](https://github.com/snowflakedb/gosnowflake/v2). ## Testing Code Set the Snowflake connection info in ``parameters.json``: ```json { "testconnection": { "SNOWFLAKE_TEST_USER": "", "SNOWFLAKE_TEST_PASSWORD": "", "SNOWFLAKE_TEST_ACCOUNT": "", "SNOWFLAKE_TEST_WAREHOUSE": "", "SNOWFLAKE_TEST_DATABASE": "", "SNOWFLAKE_TEST_SCHEMA": "", "SNOWFLAKE_TEST_ROLE": "", "SNOWFLAKE_TEST_DEBUG": "false" } } ``` Install [jq](https://stedolan.github.io/jq) so that the parameters can get parsed correctly, and run ``make test`` in your Go development environment: ```sh make test ``` ### Setting debug mode during tests This is for debugging Large SQL statements (greater than 300 characters). If you want to enable debug mode, set `SNOWFLAKE_TEST_DEBUG` to `true` in `parameters.json`, or export it in your shell instance. ## customizing Logging Tags If you would like to ensure that certain tags are always present in the logs, `RegisterClientLogContextHook` can be used in your init function. See example below. ```go import "github.com/snowflakedb/gosnowflake/v2" func init() { // each time the logger is used, the logs will contain a REQUEST_ID field with requestID the value extracted // from the context gosnowflake.RegisterClientLogContextHook("REQUEST_ID", func(ctx context.Context) interface{} { return requestIdFromContext(ctx) }) } ``` ## Setting Log Level If you want to change the log level, `SetLogLevel` can be used in your init function like this: ```go import "github.com/snowflakedb/gosnowflake/v2" func init() { // The following line changes the log level to debug _ = gosnowflake.GetLogger().SetLogLevel("debug") } ``` The following is a list of options you can pass in to set the level from least to most verbose: - `"OFF"` - `"fatal"` - `"error"` - `"warn"` - `"info"` - `"debug"` - `"trace"` ## Capturing Code Coverage Configure your testing environment as described above and run ``make cov``. The coverage percentage will be printed on the console when the testing completes. ```sh make cov ``` For more detailed analysis, results are printed to ``coverage.txt`` in the project directory. To read the coverage report, run: ```sh go tool cover -html=coverage.txt ``` ## Submitting Pull Requests You may use your preferred editor to edit the driver code. Make certain to run ``make fmt lint`` before submitting any pull request to Snowflake. This command formats your source code according to the standard Go style and detects any coding style issues. ================================================ FILE: SECURITY.md ================================================ # Security Policy Please refer to the Snowflake [HackerOne program](https://hackerone.com/snowflake?type=team) for our security policies and for reporting any security vulnerabilities. For other security related questions and concerns, please contact the Snowflake security team at security@snowflake.com ================================================ FILE: aaa_test.go ================================================ package gosnowflake import ( "testing" ) func TestShowServerVersion(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery("SELECT CURRENT_VERSION()") defer func() { assertNilF(t, rows.Close()) }() var version string rows.Next() assertNilF(t, rows.Scan(&version)) println(version) }) } ================================================ FILE: arrow_chunk.go ================================================ package gosnowflake import ( "bytes" "context" "encoding/base64" "github.com/snowflakedb/gosnowflake/v2/internal/query" "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" ) type arrowResultChunk struct { reader *ipc.Reader rowCount int loc *time.Location allocator memory.Allocator } func (arc *arrowResultChunk) decodeArrowChunk(ctx context.Context, rowType []query.ExecResponseRowType, highPrec bool, params *syncParams) ([]chunkRowType, error) { defer arc.reader.Release() logger.Debug("Arrow Decoder") var chunkRows []chunkRowType for arc.reader.Next() { record := arc.reader.Record() start := len(chunkRows) numRows := int(record.NumRows()) logger.Debugf("rows in current record: %v", numRows) columns := record.Columns() chunkRows = append(chunkRows, make([]chunkRowType, numRows)...) for i := start; i < start+numRows; i++ { chunkRows[i].ArrowRow = make([]snowflakeValue, len(columns)) } for colIdx, col := range columns { values := make([]snowflakeValue, numRows) if err := arrowToValues(ctx, values, rowType[colIdx], col, arc.loc, highPrec, params); err != nil { return nil, err } for i := range values { chunkRows[start+i].ArrowRow[colIdx] = values[i] } } arc.rowCount += numRows } logger.Debugf("The number of chunk rows: %v", len(chunkRows)) return chunkRows, arc.reader.Err() } // decodeArrowBatchRaw reads raw (untransformed) arrow records from the IPC reader. // The records are not transformed with arrow-compute; the arrowbatches sub-package // handles transformation when the user calls ArrowBatch.Fetch(). func (arc *arrowResultChunk) decodeArrowBatchRaw() (*[]arrow.Record, error) { var records []arrow.Record defer arc.reader.Release() for arc.reader.Next() { record := arc.reader.Record() record.Retain() records = append(records, record) } return &records, arc.reader.Err() } // Build arrow chunk based on RowSet of base64 func buildFirstArrowChunk(rowsetBase64 string, loc *time.Location, alloc memory.Allocator) (arrowResultChunk, error) { rowSetBytes, err := base64.StdEncoding.DecodeString(rowsetBase64) if err != nil { return arrowResultChunk{}, err } rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes), ipc.WithAllocator(alloc)) if err != nil { return arrowResultChunk{}, err } return arrowResultChunk{rr, 0, loc, alloc}, nil } ================================================ FILE: arrow_stream.go ================================================ package gosnowflake import ( "bufio" "bytes" "compress/gzip" "context" "encoding/base64" "fmt" "io" "maps" "net/http" "strconv" "time" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/snowflakedb/gosnowflake/v2/internal/query" ) // ArrowStreamLoader is a convenience interface for downloading // Snowflake results via multiple Arrow Record Batch streams. // // Some queries from Snowflake do not return Arrow data regardless // of the settings, such as "SHOW WAREHOUSES". In these cases, // you'll find TotalRows() > 0 but GetBatches returns no batches // and no errors. In this case, the data is accessible via JSONData // with the actual types matching up to the metadata in RowTypes. type ArrowStreamLoader interface { GetBatches() ([]ArrowStreamBatch, error) NextResultSet(ctx context.Context) error TotalRows() int64 RowTypes() []query.ExecResponseRowType Location() *time.Location JSONData() [][]*string } // ArrowStreamBatch is a type describing a potentially yet-to-be-downloaded // Arrow IPC stream. Call GetStream to download and retrieve an io.Reader // that can be used with ipc.NewReader to get record batch results. type ArrowStreamBatch struct { idx int numrows int64 scd *snowflakeArrowStreamChunkDownloader Loc *time.Location rr io.ReadCloser } // NumRows returns the total number of rows that the metadata stated should // be in this stream of record batches. func (asb *ArrowStreamBatch) NumRows() int64 { return asb.numrows } // GetStream returns a stream of bytes consisting of an Arrow IPC Record // batch stream. Close should be called on the returned stream when done // to ensure no leaked memory. func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) { if asb.rr == nil { if err := asb.downloadChunkStreamHelper(ctx); err != nil { return nil, err } } return asb.rr, nil } // streamWrapReader wraps an io.Reader so that Close closes the underlying body. type streamWrapReader struct { io.Reader wrapped io.ReadCloser } func (w *streamWrapReader) Close() error { if cl, ok := w.Reader.(io.ReadCloser); ok { if err := cl.Close(); err != nil { return err } } return w.wrapped.Close() } func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error { headers := make(map[string]string) if len(asb.scd.ChunkHeader) > 0 { maps.Copy(headers, asb.scd.ChunkHeader) } else { headers[headerSseCAlgorithm] = headerSseCAes headers[headerSseCKey] = asb.scd.Qrmk } resp, err := asb.scd.FuncGet(ctx, asb.scd.sc, asb.scd.ChunkMetas[asb.idx].URL, headers, asb.scd.sc.rest.RequestTimeout) if err != nil { return err } if resp.StatusCode != http.StatusOK { defer func() { _ = resp.Body.Close() }() b, err := io.ReadAll(resp.Body) if err != nil { return err } _ = b return &SnowflakeError{ Number: ErrFailedToGetChunk, SQLState: SQLStateConnectionFailure, Message: fmt.Sprintf("failed to get chunk. idx: %v", asb.idx), MessageArgs: []any{asb.idx}, } } defer func() { if asb.rr == nil { _ = resp.Body.Close() } }() bufStream := bufio.NewReader(resp.Body) gzipMagic, err := bufStream.Peek(2) if err != nil { return err } if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { bufStream0, err := gzip.NewReader(bufStream) if err != nil { return err } asb.rr = &streamWrapReader{Reader: bufStream0, wrapped: resp.Body} } else { asb.rr = &streamWrapReader{Reader: bufStream, wrapped: resp.Body} } return nil } type snowflakeArrowStreamChunkDownloader struct { sc *snowflakeConn ChunkMetas []query.ExecResponseChunk Total int64 Qrmk string ChunkHeader map[string]string FuncGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) RowSet rowSetType resultIDs []string } func (scd *snowflakeArrowStreamChunkDownloader) Location() *time.Location { if scd.sc != nil { return getCurrentLocation(&scd.sc.syncParams) } return nil } func (scd *snowflakeArrowStreamChunkDownloader) TotalRows() int64 { return scd.Total } func (scd *snowflakeArrowStreamChunkDownloader) RowTypes() []query.ExecResponseRowType { return scd.RowSet.RowType } func (scd *snowflakeArrowStreamChunkDownloader) JSONData() [][]*string { return scd.RowSet.JSON } func (scd *snowflakeArrowStreamChunkDownloader) maybeFirstBatch() ([]byte, error) { if scd.RowSet.RowSetBase64 == "" { return nil, nil } rowSetBytes, err := base64.StdEncoding.DecodeString(scd.RowSet.RowSetBase64) if err != nil { logger.Warnf("skipping first batch as it is not a valid base64 response. %v", err) return nil, err } rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes)) if err != nil { logger.Warnf("skipping first batch as it is not a valid IPC stream. %v", err) return nil, err } rr.Release() return rowSetBytes, nil } func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamBatch, err error) { chunkMetaLen := len(scd.ChunkMetas) loc := scd.Location() out = make([]ArrowStreamBatch, chunkMetaLen, chunkMetaLen+1) toFill := out rowSetBytes, err := scd.maybeFirstBatch() if err != nil { return nil, err } if len(rowSetBytes) > 0 { out = out[:chunkMetaLen+1] out[0] = ArrowStreamBatch{ scd: scd, Loc: loc, rr: io.NopCloser(bytes.NewReader(rowSetBytes)), } toFill = out[1:] } var totalCounted int64 for i := range toFill { toFill[i] = ArrowStreamBatch{ idx: i, numrows: int64(scd.ChunkMetas[i].RowCount), Loc: loc, scd: scd, } totalCounted += int64(scd.ChunkMetas[i].RowCount) } if len(rowSetBytes) > 0 { out[0].numrows = scd.Total - totalCounted } return } func (scd *snowflakeArrowStreamChunkDownloader) NextResultSet(ctx context.Context) error { if !scd.hasNextResultSet() { return io.EOF } resultID := scd.resultIDs[0] scd.resultIDs = scd.resultIDs[1:] resultPath := fmt.Sprintf(urlQueriesResultFmt, resultID) resp, err := scd.sc.getQueryResultResp(ctx, resultPath) if err != nil { return err } if !resp.Success { code, err := strconv.Atoi(resp.Code) if err != nil { logger.WithContext(ctx).Errorf("error while parsing code: %v", err) } return exceptionTelemetry(&SnowflakeError{ Number: code, SQLState: resp.Data.SQLState, Message: resp.Message, QueryID: resp.Data.QueryID, }, scd.sc) } scd.ChunkMetas = resp.Data.Chunks scd.Total = resp.Data.Total scd.Qrmk = resp.Data.Qrmk scd.ChunkHeader = resp.Data.ChunkHeaders scd.RowSet = rowSetType{ RowType: resp.Data.RowType, JSON: resp.Data.RowSet, RowSetBase64: resp.Data.RowSetBase64, } return nil } func (scd *snowflakeArrowStreamChunkDownloader) hasNextResultSet() bool { return len(scd.resultIDs) > 0 } ================================================ FILE: arrow_test.go ================================================ package gosnowflake import ( "bytes" "context" "fmt" "math/big" "reflect" "strings" "testing" "time" "github.com/apache/arrow-go/v18/arrow/memory" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "database/sql/driver" ) func TestArrowBatchDataProvider(t *testing.T) { runDBTest(t, func(dbt *DBTest) { ctx := ia.EnableArrowBatches(context.Background()) query := "select '0.1':: DECIMAL(38, 19) as c" var rows driver.Rows var err error err = dbt.conn.Raw(func(x any) error { queryer, implementsQueryContext := x.(driver.QueryerContext) assertTrueF(t, implementsQueryContext, "snowflake connection driver does not implement queryerContext") rows, err = queryer.QueryContext(ctx, query, nil) return err }) assertNilF(t, err, "error running select query") sfRows, isSfRows := rows.(SnowflakeRows) assertTrueF(t, isSfRows, "rows should be snowflakeRows") provider, isProvider := sfRows.(ia.BatchDataProvider) assertTrueF(t, isProvider, "rows should implement BatchDataProvider") info, err := provider.GetArrowBatches() assertNilF(t, err, "error getting arrow batch data") assertNotEqualF(t, len(info.Batches), 0, "should have at least one batch") // Verify raw records are available for the first batch batch := info.Batches[0] assertNotNilF(t, batch.Records, "first batch should have pre-decoded records") records := *batch.Records assertNotEqualF(t, len(records), 0, "should have at least one record") // Verify column 0 has data (raw decimal value) strVal := records[0].Column(0).ValueStr(0) assertTrueF(t, len(strVal) > 0, fmt.Sprintf("column should have a value, got: %s", strVal)) }) } func TestArrowBigInt(t *testing.T) { runDBTest(t, func(dbt *DBTest) { testcases := []struct { num string prec int sc int }{ {"10000000000000000000000000000000000000", 38, 0}, {"-10000000000000000000000000000000000000", 38, 0}, {"12345678901234567890123456789012345678", 38, 0}, // #pragma: allowlist secret {"-12345678901234567890123456789012345678", 38, 0}, {"99999999999999999999999999999999999999", 38, 0}, {"-99999999999999999999999999999999999999", 38, 0}, } for _, tc := range testcases { rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) if !rows.Next() { dbt.Error("failed to query") } defer rows.Close() var v *big.Int if err := rows.Scan(&v); err != nil { dbt.Errorf("failed to scan. %#v", err) } b, ok := new(big.Int).SetString(tc.num, 10) if !ok { dbt.Errorf("failed to convert %v big.Int.", tc.num) } if v.Cmp(b) != 0 { dbt.Errorf("big.Int value mismatch: expected %v, got %v", b, v) } } }) } func TestArrowBigFloat(t *testing.T) { runDBTest(t, func(dbt *DBTest) { testcases := []struct { num string prec int sc int }{ {"1.23", 30, 2}, {"1.0000000000000000000000000000000000000", 38, 37}, {"-1.0000000000000000000000000000000000000", 38, 37}, {"1.2345678901234567890123456789012345678", 38, 37}, {"-1.2345678901234567890123456789012345678", 38, 37}, {"9.9999999999999999999999999999999999999", 38, 37}, {"-9.9999999999999999999999999999999999999", 38, 37}, } for _, tc := range testcases { rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) if !rows.Next() { dbt.Error("failed to query") } defer rows.Close() var v *big.Float if err := rows.Scan(&v); err != nil { dbt.Errorf("failed to scan. %#v", err) } prec := v.Prec() b, ok := new(big.Float).SetPrec(prec).SetString(tc.num) if !ok { dbt.Errorf("failed to convert %v to big.Float.", tc.num) } if v.Cmp(b) != 0 { dbt.Errorf("big.Float value mismatch: expected %v, got %v", b, v) } } }) } func TestArrowIntPrecision(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec(forceJSON) intTestcases := []struct { num string prec int sc int }{ {"10000000000000000000000000000000000000", 38, 0}, {"-10000000000000000000000000000000000000", 38, 0}, {"12345678901234567890123456789012345678", 38, 0}, // pragma: allowlist secret {"-12345678901234567890123456789012345678", 38, 0}, {"99999999999999999999999999999999999999", 38, 0}, {"-99999999999999999999999999999999999999", 38, 0}, } t.Run("arrow_disabled_scan_int64", func(t *testing.T) { for _, tc := range intTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v int64 if err := rows.Scan(&v); err == nil { t.Error("should fail to scan") } } }) t.Run("arrow_disabled_scan_string", func(t *testing.T) { for _, tc := range intTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v string if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } if v != tc.num { t.Errorf("string value mismatch: expected %v, got %v", tc.num, v) } } }) dbt.mustExec(forceARROW) t.Run("arrow_enabled_scan_big_int", func(t *testing.T) { for _, tc := range intTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v string if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } if !strings.EqualFold(v, tc.num) { t.Errorf("int value mismatch: expected %v, got %v", tc.num, v) } } }) t.Run("arrow_high_precision_enabled_scan_big_int", func(t *testing.T) { for _, tc := range intTestcases { rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v *big.Int if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } b, ok := new(big.Int).SetString(tc.num, 10) if !ok { t.Errorf("failed to convert %v big.Int.", tc.num) } if v.Cmp(b) != 0 { t.Errorf("big.Int value mismatch: expected %v, got %v", b, v) } } }) }) } // TestArrowFloatPrecision tests the different variable types allowed in the // rows.Scan() method. Note that for lower precision types we do not attempt // to check the value as precision could be lost. func TestArrowFloatPrecision(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec(forceJSON) fltTestcases := []struct { num string prec int sc int }{ {"1.23", 30, 2}, {"1.0000000000000000000000000000000000000", 38, 37}, {"-1.0000000000000000000000000000000000000", 38, 37}, {"1.2345678901234567890123456789012345678", 38, 37}, {"-1.2345678901234567890123456789012345678", 38, 37}, {"9.9999999999999999999999999999999999999", 38, 37}, {"-9.9999999999999999999999999999999999999", 38, 37}, } t.Run("arrow_disabled_scan_float64", func(t *testing.T) { for _, tc := range fltTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v float64 if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_disabled_scan_float32", func(t *testing.T) { for _, tc := range fltTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v float32 if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_disabled_scan_string", func(t *testing.T) { for _, tc := range fltTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v string if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } if !strings.EqualFold(v, tc.num) { t.Errorf("int value mismatch: expected %v, got %v", tc.num, v) } } }) dbt.mustExec(forceARROW) t.Run("arrow_enabled_scan_float64", func(t *testing.T) { for _, tc := range fltTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v float64 if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_enabled_scan_float32", func(t *testing.T) { for _, tc := range fltTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v float32 if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_enabled_scan_string", func(t *testing.T) { for _, tc := range fltTestcases { rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v string if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } if v != tc.num { t.Errorf("string value mismatch: expected %v, got %v", tc.num, v) } } }) t.Run("arrow_high_precision_enabled_scan_big_float", func(t *testing.T) { for _, tc := range fltTestcases { rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) defer rows.Close() if !rows.Next() { t.Error("failed to query") } var v *big.Float if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan. %#v", err) } prec := v.Prec() b, ok := new(big.Float).SetPrec(prec).SetString(tc.num) if !ok { t.Errorf("failed to convert %v to big.Float.", tc.num) } if v.Cmp(b) != 0 { t.Errorf("big.Float value mismatch: expected %v, got %v", b, v) } } }) }) } func TestArrowTimePrecision(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE t (col5 TIME(5), col6 TIME(6), col7 TIME(7), col8 TIME(8));") defer dbt.mustExec("DROP TABLE IF EXISTS t") dbt.mustExec("INSERT INTO t VALUES ('23:59:59.99999', '23:59:59.999999', '23:59:59.9999999', '23:59:59.99999999');") rows := dbt.mustQuery("select * from t") defer rows.Close() var c5, c6, c7, c8 time.Time for rows.Next() { if err := rows.Scan(&c5, &c6, &c7, &c8); err != nil { t.Errorf("values were not scanned: %v", err) } } nano := 999999990 expected := time.Time{}.Add(23*time.Hour + 59*time.Minute + 59*time.Second + 99*time.Millisecond) if c8.Unix() != expected.Unix() || c8.Nanosecond() != nano { t.Errorf("the value did not match. expected: %v, got: %v", expected, c8) } if c7.Unix() != expected.Unix() || c7.Nanosecond() != nano-(nano%1e2) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c7) } if c6.Unix() != expected.Unix() || c6.Nanosecond() != nano-(nano%1e3) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c6) } if c5.Unix() != expected.Unix() || c5.Nanosecond() != nano-(nano%1e4) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c5) } dbt.mustExec(`CREATE TABLE t_ntz ( col1 TIMESTAMP_NTZ(1), col2 TIMESTAMP_NTZ(2), col3 TIMESTAMP_NTZ(3), col4 TIMESTAMP_NTZ(4), col5 TIMESTAMP_NTZ(5), col6 TIMESTAMP_NTZ(6), col7 TIMESTAMP_NTZ(7), col8 TIMESTAMP_NTZ(8) );`) defer dbt.mustExec("DROP TABLE IF EXISTS t_ntz") dbt.mustExec(`INSERT INTO t_ntz VALUES ( '9999-12-31T23:59:59.9', '9999-12-31T23:59:59.99', '9999-12-31T23:59:59.999', '9999-12-31T23:59:59.9999', '9999-12-31T23:59:59.99999', '9999-12-31T23:59:59.999999', '9999-12-31T23:59:59.9999999', '9999-12-31T23:59:59.99999999' );`) rows2 := dbt.mustQuery("select * from t_ntz") defer rows2.Close() var c1, c2, c3, c4 time.Time for rows2.Next() { if err := rows2.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8); err != nil { t.Errorf("values were not scanned: %v", err) } } expected = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) if c8.Unix() != expected.Unix() || c8.Nanosecond() != nano { t.Errorf("the value did not match. expected: %v, got: %v", expected, c8) } if c7.Unix() != expected.Unix() || c7.Nanosecond() != nano-(nano%1e2) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c7) } if c6.Unix() != expected.Unix() || c6.Nanosecond() != nano-(nano%1e3) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c6) } if c5.Unix() != expected.Unix() || c5.Nanosecond() != nano-(nano%1e4) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c5) } if c4.Unix() != expected.Unix() || c4.Nanosecond() != nano-(nano%1e5) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c4) } if c3.Unix() != expected.Unix() || c3.Nanosecond() != nano-(nano%1e6) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c3) } if c2.Unix() != expected.Unix() || c2.Nanosecond() != nano-(nano%1e7) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c2) } if c1.Unix() != expected.Unix() || c1.Nanosecond() != nano-(nano%1e8) { t.Errorf("the value did not match. expected: %v, got: %v", expected, c1) } }) } func TestArrowVariousTypes(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext( WithHigherPrecision(context.Background()), selectVariousTypes) defer rows.Close() if !rows.Next() { dbt.Error("failed to query") } cc, err := rows.Columns() if err != nil { dbt.Errorf("columns: %v", cc) } ct, err := rows.ColumnTypes() if err != nil { dbt.Errorf("column types: %v", ct) } var v1 *big.Float var v2, v2a int var v3 string var v4 float64 var v5 []byte var v6 bool if err = rows.Scan(&v1, &v2, &v2a, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } if v1.Cmp(big.NewFloat(1.0)) != 0 { dbt.Errorf("failed to scan. %#v", *v1) } if ct[0].Name() != "C1" || ct[1].Name() != "C2" || ct[2].Name() != "C2A" || ct[3].Name() != "C3" || ct[4].Name() != "C4" || ct[5].Name() != "C5" || ct[6].Name() != "C6" { dbt.Errorf("failed to get column names: %#v", ct) } if ct[0].ScanType() != reflect.TypeFor[*big.Float]() { dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeFor[float64](), ct[0].ScanType()) } if ct[1].ScanType() != reflect.TypeFor[int64]() { dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeFor[int64](), ct[1].ScanType()) } if ct[2].ScanType() != reflect.TypeFor[*big.Int]() { dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeFor[*big.Int](), ct[2].ScanType()) } var pr, sc int64 var cLen int64 pr, sc = dbt.mustDecimalSize(ct[0]) if pr != 30 || sc != 2 { dbt.Errorf("failed to get precision and scale. %#v", ct[0]) } dbt.mustFailLength(ct[0]) if canNull := dbt.mustNullable(ct[0]); canNull { dbt.Errorf("failed to get nullable. %#v", ct[0]) } if cLen != 0 { dbt.Errorf("failed to get length. %#v", ct[0]) } if v2 != 2 { dbt.Errorf("failed to scan. %#v", v2) } pr, sc = dbt.mustDecimalSize(ct[1]) if pr != 18 || sc != 0 { dbt.Errorf("failed to get precision and scale. %#v", ct[1]) } dbt.mustFailLength(ct[1]) if canNull := dbt.mustNullable(ct[1]); canNull { dbt.Errorf("failed to get nullable. %#v", ct[1]) } if v2a != 22 { dbt.Errorf("failed to scan. %#v", v2a) } dbt.mustFailLength(ct[2]) if canNull := dbt.mustNullable(ct[2]); canNull { dbt.Errorf("failed to get nullable. %#v", ct[2]) } if v3 != "t3" { dbt.Errorf("failed to scan. %#v", v3) } dbt.mustFailDecimalSize(ct[3]) if cLen = dbt.mustLength(ct[3]); cLen != 2 { dbt.Errorf("failed to get length. %#v", ct[3]) } if canNull := dbt.mustNullable(ct[3]); canNull { dbt.Errorf("failed to get nullable. %#v", ct[3]) } if v4 != 4.2 { dbt.Errorf("failed to scan. %#v", v4) } dbt.mustFailDecimalSize(ct[4]) dbt.mustFailLength(ct[4]) if canNull := dbt.mustNullable(ct[4]); canNull { dbt.Errorf("failed to get nullable. %#v", ct[4]) } if !bytes.Equal(v5, []byte{0xab, 0xcd}) { dbt.Errorf("failed to scan. %#v", v5) } dbt.mustFailDecimalSize(ct[5]) if cLen = dbt.mustLength(ct[5]); cLen != 8388608 { // BINARY dbt.Errorf("failed to get length. %#v", ct[5]) } if canNull := dbt.mustNullable(ct[5]); canNull { dbt.Errorf("failed to get nullable. %#v", ct[5]) } if !v6 { dbt.Errorf("failed to scan. %#v", v6) } dbt.mustFailDecimalSize(ct[6]) dbt.mustFailLength(ct[6]) }) } func TestArrowMemoryCleanedUp(t *testing.T) { mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer mem.AssertSize(t, 0) runDBTest(t, func(dbt *DBTest) { ctx := WithArrowAllocator( context.Background(), mem, ) rows := dbt.mustQueryContext(ctx, "select 1 UNION select 2 ORDER BY 1") defer rows.Close() var v int assertTrueF(t, rows.Next()) assertNilF(t, rows.Scan(&v)) assertEqualE(t, v, 1) assertTrueF(t, rows.Next()) assertNilF(t, rows.Scan(&v)) assertEqualE(t, v, 2) assertFalseE(t, rows.Next()) }) } ================================================ FILE: arrowbatches/batches.go ================================================ package arrowbatches import ( "cmp" "context" "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" "time" sf "github.com/snowflakedb/gosnowflake/v2" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/memory" ) // ArrowBatch represents a chunk of data retrievable in arrow.Record format. type ArrowBatch struct { raw ia.BatchRaw rowTypes []query.ExecResponseRowType allocator memory.Allocator ctx context.Context } // WithContext sets the context for subsequent Fetch calls on this batch. func (rb *ArrowBatch) WithContext(ctx context.Context) *ArrowBatch { rb.ctx = ctx return rb } // Fetch returns an array of arrow.Record representing this batch's data. // Records are transformed from Snowflake's internal format to standard Arrow types. func (rb *ArrowBatch) Fetch() (*[]arrow.Record, error) { var rawRecords *[]arrow.Record ctx := cmp.Or(rb.ctx, context.Background()) if rb.raw.Records != nil { rawRecords = rb.raw.Records } else if rb.raw.Download != nil { recs, rowCount, err := rb.raw.Download(ctx) if err != nil { return nil, err } rawRecords = recs rb.raw.Records = recs rb.raw.RowCount = rowCount } if rawRecords == nil || len(*rawRecords) == 0 { empty := make([]arrow.Record, 0) return &empty, nil } var transformed []arrow.Record for i, rec := range *rawRecords { newRec, err := arrowToRecord(ctx, rec, rb.allocator, rb.rowTypes, rb.raw.Location) if err != nil { for _, t := range transformed { t.Release() } for _, r := range (*rawRecords)[i:] { r.Release() } rb.raw.Records = nil return nil, err } transformed = append(transformed, newRec) rec.Release() } rb.raw.Records = nil rb.raw.RowCount = countArrowBatchRows(&transformed) return &transformed, nil } // GetRowCount returns the number of rows in this batch. func (rb *ArrowBatch) GetRowCount() int { return rb.raw.RowCount } // GetLocation returns the timezone location for this batch. func (rb *ArrowBatch) GetLocation() *time.Location { return rb.raw.Location } // GetRowTypes returns the column metadata for this batch. func (rb *ArrowBatch) GetRowTypes() []query.ExecResponseRowType { return rb.rowTypes } // ArrowSnowflakeTimestampToTime converts an original Snowflake timestamp to time.Time. func (rb *ArrowBatch) ArrowSnowflakeTimestampToTime(rec arrow.Record, colIdx int, recIdx int) *time.Time { scale := int(rb.rowTypes[colIdx].Scale) dbType := rb.rowTypes[colIdx].Type return ArrowSnowflakeTimestampToTime(rec.Column(colIdx), types.GetSnowflakeType(dbType), scale, recIdx, rb.raw.Location) } // GetArrowBatches retrieves arrow batches from SnowflakeRows. // The rows must have been queried with arrowbatches.WithArrowBatches(ctx). func GetArrowBatches(rows sf.SnowflakeRows) ([]*ArrowBatch, error) { provider, ok := rows.(ia.BatchDataProvider) if !ok { return nil, &sf.SnowflakeError{ Number: sf.ErrNotImplemented, Message: "rows do not support arrow batch data", } } info, err := provider.GetArrowBatches() if err != nil { return nil, err } batches := make([]*ArrowBatch, len(info.Batches)) for i, raw := range info.Batches { batches[i] = &ArrowBatch{ raw: raw, rowTypes: info.RowTypes, allocator: info.Allocator, ctx: info.Ctx, } } return batches, nil } func countArrowBatchRows(recs *[]arrow.Record) (cnt int) { for _, r := range *recs { cnt += int(r.NumRows()) } return } // GetAllocator returns the memory allocator for this batch. func (rb *ArrowBatch) GetAllocator() memory.Allocator { return rb.allocator } ================================================ FILE: arrowbatches/batches_test.go ================================================ package arrowbatches import ( "context" "crypto/rsa" "crypto/x509" "database/sql" "database/sql/driver" "encoding/pem" "errors" "fmt" "math" "os" "path/filepath" "strconv" "strings" "sync" "testing" "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" sf "github.com/snowflakedb/gosnowflake/v2" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" ) // testConn holds a reusable database connection for running multiple queries. type testConn struct { db *sql.DB conn *sql.Conn } // repoRoot walks up from the current working directory to find the directory // containing go.mod, which is the repository root. func repoRoot(t *testing.T) string { t.Helper() dir, err := os.Getwd() if err != nil { t.Fatalf("failed to get working directory: %v", err) } for { if _, err = os.Stat(filepath.Join(dir, "go.mod")); err == nil { return dir } if !os.IsNotExist(err) { t.Fatalf("failed to stat go.mod in %q: %v", dir, err) } parent := filepath.Dir(dir) if parent == dir { t.Fatal("could not find repository root (no go.mod found)") } dir = parent } } // readPrivateKey reads an RSA private key from a PEM file. If the path is // relative it is resolved against the repository root so that tests in // sub-packages work with repo-root-relative paths. func readPrivateKey(t *testing.T, path string) *rsa.PrivateKey { t.Helper() if !filepath.IsAbs(path) { path = filepath.Join(repoRoot(t), path) } data, err := os.ReadFile(path) if err != nil { t.Fatalf("failed to read private key file %q: %v", path, err) } block, _ := pem.Decode(data) if block == nil { t.Fatalf("failed to decode PEM block from %q", path) } key, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { t.Fatalf("failed to parse private key from %q: %v", path, err) } rsaKey, ok := key.(*rsa.PrivateKey) if !ok { t.Fatalf("private key in %q is not RSA (got %T)", path, key) } return rsaKey } func testConfig(t *testing.T) *sf.Config { t.Helper() configParams := []*sf.ConfigParam{ {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, {Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true}, {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, {Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false}, {Name: "Warehouse", EnvName: "SNOWFLAKE_TEST_WAREHOUSE", FailOnMissing: false}, } isJWT := os.Getenv("SNOWFLAKE_TEST_AUTHENTICATOR") == "SNOWFLAKE_JWT" if !isJWT { configParams = append(configParams, &sf.ConfigParam{Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true}, ) } cfg, err := sf.GetConfigFromEnv(configParams) if err != nil { t.Fatalf("failed to get config from environment: %v", err) } if isJWT { privKeyPath := os.Getenv("SNOWFLAKE_TEST_PRIVATE_KEY") if privKeyPath == "" { t.Fatal("SNOWFLAKE_TEST_PRIVATE_KEY must be set for JWT authentication") } cfg.PrivateKey = readPrivateKey(t, privKeyPath) cfg.Authenticator = sf.AuthTypeJwt } tz := "UTC" if cfg.Params == nil { cfg.Params = make(map[string]*string) } cfg.Params["timezone"] = &tz return cfg } func openTestConn(ctx context.Context, t *testing.T) *testConn { t.Helper() cfg := testConfig(t) dsn, err := sf.DSN(cfg) if err != nil { t.Fatalf("failed to create DSN: %v", err) } db, err := sql.Open("snowflake", dsn) if err != nil { t.Fatalf("failed to open db: %v", err) } conn, err := db.Conn(ctx) if err != nil { db.Close() t.Fatalf("failed to get connection: %v", err) } return &testConn{db: db, conn: conn} } func (tc *testConn) close() { tc.conn.Close() tc.db.Close() } // queryRows executes a query on the existing connection and returns // SnowflakeRows plus a function to close just the rows. func (tc *testConn) queryRows(ctx context.Context, t *testing.T, query string) (sf.SnowflakeRows, func()) { t.Helper() var rows driver.Rows var err error err = tc.conn.Raw(func(x any) error { queryer, ok := x.(driver.QueryerContext) if !ok { return fmt.Errorf("connection does not implement QueryerContext") } rows, err = queryer.QueryContext(ctx, query, nil) return err }) if err != nil { t.Fatalf("failed to execute query: %v", err) } sfRows, ok := rows.(sf.SnowflakeRows) if !ok { rows.Close() t.Fatalf("rows do not implement SnowflakeRows") } return sfRows, func() { rows.Close() } } // queryRawRows is a convenience wrapper that opens a new connection, // runs a single query, and returns SnowflakeRows with a full cleanup. func queryRawRows(ctx context.Context, t *testing.T, query string) (sf.SnowflakeRows, func()) { t.Helper() tc := openTestConn(ctx, t) sfRows, closeRows := tc.queryRows(ctx, t, query) return sfRows, func() { closeRows() tc.close() } } func TestGetArrowBatches(t *testing.T) { ctx := WithArrowBatches(context.Background()) sfRows, cleanup := queryRawRows(ctx, t, "SELECT 1 AS num, 'hello' AS str") defer cleanup() batches, err := GetArrowBatches(sfRows) if err != nil { t.Fatalf("GetArrowBatches failed: %v", err) } if len(batches) == 0 { t.Fatal("expected at least one batch") } records, err := batches[0].Fetch() if err != nil { t.Fatalf("Fetch failed: %v", err) } if records == nil || len(*records) == 0 { t.Fatal("expected at least one record") } rec := (*records)[0] defer rec.Release() if rec.NumCols() != 2 { t.Fatalf("expected 2 columns, got %d", rec.NumCols()) } if rec.NumRows() != 1 { t.Fatalf("expected 1 row, got %d", rec.NumRows()) } } func TestGetArrowBatchesHighPrecision(t *testing.T) { ctx := sf.WithHigherPrecision(WithArrowBatches(context.Background())) sfRows, cleanup := queryRawRows(ctx, t, "SELECT '0.1'::DECIMAL(38, 19) AS c") defer cleanup() batches, err := GetArrowBatches(sfRows) if err != nil { t.Fatalf("GetArrowBatches failed: %v", err) } if len(batches) == 0 { t.Fatal("expected at least one batch") } records, err := batches[0].Fetch() if err != nil { t.Fatalf("Fetch failed: %v", err) } if records == nil || len(*records) == 0 { t.Fatal("expected at least one record") } rec := (*records)[0] defer rec.Release() strVal := rec.Column(0).ValueStr(0) expected := "1000000000000000000" if strVal != expected { t.Fatalf("expected %q, got %q", expected, strVal) } } func TestGetArrowBatchesLargeResultSet(t *testing.T) { numrows := 3000 pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(WithArrowBatches(context.Background()), pool) query := fmt.Sprintf("SELECT SEQ8(), RANDSTR(1000, RANDOM()) FROM TABLE(GENERATOR(ROWCOUNT=>%v))", numrows) sfRows, cleanup := queryRawRows(ctx, t, query) defer cleanup() batches, err := GetArrowBatches(sfRows) if err != nil { t.Fatalf("GetArrowBatches failed: %v", err) } if len(batches) == 0 { t.Fatal("expected at least one batch") } maxWorkers := 10 type count struct { mu sync.Mutex val int } cnt := &count{} var wg sync.WaitGroup work := make(chan int, len(batches)) for range maxWorkers { wg.Add(1) go func() { defer wg.Done() for i := range work { recs, fetchErr := batches[i].Fetch() if fetchErr != nil { t.Errorf("Fetch failed for batch %d: %v", i, fetchErr) return } for _, r := range *recs { cnt.mu.Lock() cnt.val += int(r.NumRows()) cnt.mu.Unlock() r.Release() } } }() } for i := range batches { work <- i } close(work) wg.Wait() if cnt.val != numrows { t.Fatalf("row count mismatch: expected %d, got %d", numrows, cnt.val) } } func TestGetArrowBatchesWithTimestampOption(t *testing.T) { ctx := WithTimestampOption( WithArrowBatches(context.Background()), UseOriginalTimestamp, ) sfRows, cleanup := queryRawRows(ctx, t, "SELECT TO_TIMESTAMP_NTZ('2024-01-15 13:45:30.123456789') AS ts") defer cleanup() batches, err := GetArrowBatches(sfRows) if err != nil { t.Fatalf("GetArrowBatches failed: %v", err) } if len(batches) == 0 { t.Fatal("expected at least one batch") } records, err := batches[0].Fetch() if err != nil { t.Fatalf("Fetch failed: %v", err) } if records == nil || len(*records) == 0 { t.Fatal("expected at least one record") } rec := (*records)[0] defer rec.Release() if rec.NumRows() != 1 { t.Fatalf("expected 1 row, got %d", rec.NumRows()) } if rec.NumCols() != 1 { t.Fatalf("expected 1 column, got %d", rec.NumCols()) } } func TestGetArrowBatchesJSONResponseError(t *testing.T) { ctx := WithArrowBatches(context.Background()) cfg := testConfig(t) dsn, err := sf.DSN(cfg) if err != nil { t.Fatalf("failed to create DSN: %v", err) } db, err := sql.Open("snowflake", dsn) if err != nil { t.Fatalf("failed to open db: %v", err) } defer db.Close() conn, err := db.Conn(ctx) if err != nil { t.Fatalf("failed to get connection: %v", err) } defer conn.Close() _, err = conn.ExecContext(ctx, "ALTER SESSION SET GO_QUERY_RESULT_FORMAT = json") if err != nil { t.Fatalf("failed to set JSON format: %v", err) } var rows driver.Rows err = conn.Raw(func(x any) error { queryer, ok := x.(driver.QueryerContext) if !ok { return fmt.Errorf("connection does not implement QueryerContext") } rows, err = queryer.QueryContext(ctx, "SELECT 'hello'", nil) return err }) if err != nil { t.Fatalf("failed to execute query: %v", err) } defer rows.Close() sfRows, ok := rows.(sf.SnowflakeRows) if !ok { t.Fatal("rows do not implement SnowflakeRows") } _, err = GetArrowBatches(sfRows) if err == nil { t.Fatal("expected error when using arrow batches with JSON response") } var se *sf.SnowflakeError if !errors.As(err, &se) { t.Fatalf("expected SnowflakeError, got %T: %v", err, err) } if se.Number != sf.ErrNonArrowResponseInArrowBatches { t.Fatalf("expected error code %d, got %d", sf.ErrNonArrowResponseInArrowBatches, se.Number) } } // TestTimestampConversionDistantDates tests all 10 timestamp scales (0-9) // because each scale exercises a mathematically distinct code path in // extractEpoch/extractFraction (converter.go). Past bugs have been // scale-specific: SNOW-526255 (time scale for arrow) and SNOW-2091309 // (precision loss at scale 0). Do not reduce the scale range. func TestTimestampConversionDistantDates(t *testing.T) { timestamps := [2]string{ "9999-12-12 23:59:59.999999999", "0001-01-01 00:00:00.000000000", } tsTypes := [3]string{"TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ"} precisions := []struct { name string option ia.TimestampOption unit arrow.TimeUnit expectError bool }{ {"second", UseSecondTimestamp, arrow.Second, false}, {"millisecond", UseMillisecondTimestamp, arrow.Millisecond, false}, {"microsecond", UseMicrosecondTimestamp, arrow.Microsecond, false}, {"nanosecond", UseNanosecondTimestamp, arrow.Nanosecond, true}, } for _, prec := range precisions { t.Run(prec.name, func(t *testing.T) { t.Parallel() pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator( WithTimestampOption(WithArrowBatches(context.Background()), prec.option), pool, ) tc := openTestConn(ctx, t) defer tc.close() for _, tsStr := range timestamps { for _, tp := range tsTypes { for scale := 0; scale <= 9; scale++ { t.Run(tp+"("+strconv.Itoa(scale)+")_"+tsStr, func(t *testing.T) { query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale) sfRows, closeRows := tc.queryRows(ctx, t, query) defer closeRows() batches, err := GetArrowBatches(sfRows) if err != nil { t.Fatalf("GetArrowBatches failed: %v", err) } if len(batches) == 0 { t.Fatal("expected at least one batch") } records, err := batches[0].Fetch() if prec.expectError { expectedError := "Cannot convert timestamp" if err == nil { t.Fatalf("no error, expected: %v", expectedError) } if !strings.Contains(err.Error(), expectedError) { t.Fatalf("improper error, expected: %v, got: %v", expectedError, err.Error()) } return } if err != nil { t.Fatalf("Fetch failed: %v", err) } if records == nil || len(*records) == 0 { t.Fatal("expected at least one record") } rec := (*records)[0] defer rec.Release() actual := rec.Column(0).(*array.Timestamp).TimestampValues()[0] actualYear := actual.ToTime(prec.unit).Year() ts, err := time.Parse("2006-01-02 15:04:05", tsStr) if err != nil { t.Fatalf("failed to parse time: %v", err) } exp := ts.Truncate(time.Duration(math.Pow10(9 - scale))) if actualYear != exp.Year() { t.Fatalf("unexpected year, expected: %v, got: %v", exp.Year(), actualYear) } }) } } } }) } } // TestTimestampConversionWithOriginalTimestamp tests all 10 timestamp scales // (0-9) because each scale exercises a mathematically distinct code path in // extractEpoch/extractFraction. See TestTimestampConversionDistantDates for // rationale on why the full scale range is required. func TestTimestampConversionWithOriginalTimestamp(t *testing.T) { timestamps := [3]string{ "2000-10-10 10:10:10.123456789", "9999-12-12 23:59:59.999999999", "0001-01-01 00:00:00.000000000", } tsTypes := [3]string{"TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ"} pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator( WithTimestampOption(WithArrowBatches(context.Background()), UseOriginalTimestamp), pool, ) tc := openTestConn(ctx, t) defer tc.close() for _, tsStr := range timestamps { ts, err := time.Parse("2006-01-02 15:04:05", tsStr) if err != nil { t.Fatalf("failed to parse time: %v", err) } for _, tp := range tsTypes { t.Run(tp+"_"+tsStr, func(t *testing.T) { // Batch all 10 scales into a single multi-column query to reduce round trips. var cols []string for scale := 0; scale <= 9; scale++ { cols = append(cols, fmt.Sprintf("'%s'::%s(%v)", tsStr, tp, scale)) } query := "SELECT " + strings.Join(cols, ", ") sfRows, closeRows := tc.queryRows(ctx, t, query) defer closeRows() batches, err := GetArrowBatches(sfRows) if err != nil { t.Fatalf("GetArrowBatches failed: %v", err) } if len(batches) != 1 { t.Fatalf("expected 1 batch, got %d", len(batches)) } records, err := batches[0].Fetch() if err != nil { t.Fatalf("Fetch failed: %v", err) } if records == nil || len(*records) == 0 { t.Fatal("expected at least one record") } for scale := 0; scale <= 9; scale++ { exp := ts.Truncate(time.Duration(math.Pow10(9 - scale))) for _, r := range *records { defer r.Release() act := batches[0].ArrowSnowflakeTimestampToTime(r, scale, 0) if act == nil { t.Fatalf("scale %d: unexpected nil, expected: %v", scale, exp) } else if !exp.Equal(*act) { t.Fatalf("scale %d: unexpected result, expected: %v, got: %v", scale, exp, *act) } } } }) } } } ================================================ FILE: arrowbatches/context.go ================================================ package arrowbatches import ( "context" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" ) // Timestamp option constants. const ( UseNanosecondTimestamp = ia.UseNanosecondTimestamp UseMicrosecondTimestamp = ia.UseMicrosecondTimestamp UseMillisecondTimestamp = ia.UseMillisecondTimestamp UseSecondTimestamp = ia.UseSecondTimestamp UseOriginalTimestamp = ia.UseOriginalTimestamp ) // WithArrowBatches returns a context that enables arrow batch mode for queries. func WithArrowBatches(ctx context.Context) context.Context { return ia.EnableArrowBatches(ctx) } // WithTimestampOption returns a context that sets the timestamp conversion option // for arrow batches. func WithTimestampOption(ctx context.Context, option ia.TimestampOption) context.Context { return ia.WithTimestampOption(ctx, option) } // WithUtf8Validation returns a context that enables UTF-8 validation for // string columns in arrow batches. func WithUtf8Validation(ctx context.Context) context.Context { return ia.EnableUtf8Validation(ctx) } ================================================ FILE: arrowbatches/converter.go ================================================ package arrowbatches import ( "context" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" "math" "math/big" "strings" "time" "unicode/utf8" sf "github.com/snowflakedb/gosnowflake/v2" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/compute" "github.com/apache/arrow-go/v18/arrow/memory" ) // arrowToRecord transforms a raw arrow.Record from Snowflake into a record // with standard Arrow types (e.g., converting struct-based timestamps to // arrow.Timestamp, decimal128 to int64/float64, etc.) func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocator, rowType []query.ExecResponseRowType, loc *time.Location) (arrow.Record, error) { timestampOption := ia.GetTimestampOption(ctx) higherPrecision := ia.HigherPrecisionEnabled(ctx) s, err := recordToSchema(record.Schema(), rowType, loc, timestampOption, higherPrecision) if err != nil { return nil, err } var cols []arrow.Array numRows := record.NumRows() ctxAlloc := compute.WithAllocator(ctx, pool) for i, col := range record.Columns() { fieldMetadata := rowType[i].ToFieldMetadata() newCol, err := arrowToRecordSingleColumn(ctxAlloc, s.Field(i), col, fieldMetadata, higherPrecision, timestampOption, pool, loc, numRows) if err != nil { return nil, err } cols = append(cols, newCol) defer newCol.Release() } newRecord := array.NewRecord(s, cols, numRows) return newRecord, nil } func arrowToRecordSingleColumn(ctx context.Context, field arrow.Field, col arrow.Array, fieldMetadata query.FieldMetadata, higherPrecisionEnabled bool, timestampOption ia.TimestampOption, pool memory.Allocator, loc *time.Location, numRows int64) (arrow.Array, error) { var err error newCol := col snowflakeType := types.GetSnowflakeType(fieldMetadata.Type) switch snowflakeType { case types.FixedType: if higherPrecisionEnabled { col.Retain() } else if col.DataType().ID() == arrow.DECIMAL || col.DataType().ID() == arrow.DECIMAL256 { var toType arrow.DataType if fieldMetadata.Scale == 0 { toType = arrow.PrimitiveTypes.Int64 } else { toType = arrow.PrimitiveTypes.Float64 } newCol, err = compute.CastArray(ctx, col, compute.UnsafeCastOptions(toType)) if err != nil { return nil, err } } else if fieldMetadata.Scale != 0 && col.DataType().ID() != arrow.INT64 { result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true}, &compute.ArrayDatum{Value: newCol.Data()}, compute.NewDatum(math.Pow10(int(fieldMetadata.Scale)))) if err != nil { return nil, err } defer result.Release() newCol = result.(*compute.ArrayDatum).MakeArray() } else if fieldMetadata.Scale != 0 && col.DataType().ID() == arrow.INT64 { values := col.(*array.Int64).Int64Values() floatValues := make([]float64, len(values)) for i, val := range values { floatValues[i], _ = intToBigFloat(val, int64(fieldMetadata.Scale)).Float64() } builder := array.NewFloat64Builder(pool) builder.AppendValues(floatValues, nil) newCol = builder.NewArray() builder.Release() } else { col.Retain() } case types.TimeType: newCol, err = compute.CastArray(ctx, col, compute.SafeCastOptions(arrow.FixedWidthTypes.Time64ns)) if err != nil { return nil, err } case types.TimestampNtzType, types.TimestampLtzType, types.TimestampTzType: if timestampOption == ia.UseOriginalTimestamp { col.Retain() } else { var unit arrow.TimeUnit switch timestampOption { case ia.UseMicrosecondTimestamp: unit = arrow.Microsecond case ia.UseMillisecondTimestamp: unit = arrow.Millisecond case ia.UseSecondTimestamp: unit = arrow.Second case ia.UseNanosecondTimestamp: unit = arrow.Nanosecond } var tb *array.TimestampBuilder if snowflakeType == types.TimestampLtzType { tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit, TimeZone: loc.String()}) } else { tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit}) } defer tb.Release() for i := 0; i < int(numRows); i++ { ts := ArrowSnowflakeTimestampToTime(col, snowflakeType, int(fieldMetadata.Scale), i, loc) if ts != nil { var ar arrow.Timestamp switch timestampOption { case ia.UseMicrosecondTimestamp: ar = arrow.Timestamp(ts.UnixMicro()) case ia.UseMillisecondTimestamp: ar = arrow.Timestamp(ts.UnixMilli()) case ia.UseSecondTimestamp: ar = arrow.Timestamp(ts.Unix()) case ia.UseNanosecondTimestamp: ar = arrow.Timestamp(ts.UnixNano()) if ts.UTC().Year() != ar.ToTime(arrow.Nanosecond).Year() { return nil, &sf.SnowflakeError{ Number: sf.ErrTooHighTimestampPrecision, SQLState: sf.SQLStateInvalidDataTimeFormat, Message: fmt.Sprintf("Cannot convert timestamp %v in column %v to Arrow.Timestamp data type due to too high precision. Please use context with WithOriginalTimestamp.", ts.UTC(), fieldMetadata.Name), } } } tb.Append(ar) } else { tb.AppendNull() } } newCol = tb.NewArray() } case types.TextType: if stringCol, ok := col.(*array.String); ok { newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows) } case types.ObjectType: if structCol, ok := col.(*array.Struct); ok { var internalCols []arrow.Array for i := 0; i < structCol.NumField(); i++ { internalCol := structCol.Field(i) newInternalCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.StructType).Field(i), internalCol, fieldMetadata.Fields[i], higherPrecisionEnabled, timestampOption, pool, loc, numRows) if err != nil { return nil, err } internalCols = append(internalCols, newInternalCol) defer newInternalCol.Release() } var fieldNames []string for _, f := range field.Type.(*arrow.StructType).Fields() { fieldNames = append(fieldNames, f.Name) } nullBitmap := memory.NewBufferBytes(structCol.NullBitmapBytes()) numberOfNulls := structCol.NullN() return array.NewStructArrayWithNulls(internalCols, fieldNames, nullBitmap, numberOfNulls, 0) } else if stringCol, ok := col.(*array.String); ok { newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows) } case types.ArrayType: if listCol, ok := col.(*array.List); ok { newCol, err = arrowToRecordSingleColumn(ctx, field.Type.(*arrow.ListType).ElemField(), listCol.ListValues(), fieldMetadata.Fields[0], higherPrecisionEnabled, timestampOption, pool, loc, numRows) if err != nil { return nil, err } defer newCol.Release() newData := array.NewData(arrow.ListOf(newCol.DataType()), listCol.Len(), listCol.Data().Buffers(), []arrow.ArrayData{newCol.Data()}, listCol.NullN(), 0) defer newData.Release() return array.NewListData(newData), nil } else if stringCol, ok := col.(*array.String); ok { newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows) } case types.MapType: if mapCol, ok := col.(*array.Map); ok { keyCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.MapType).KeyField(), mapCol.Keys(), fieldMetadata.Fields[0], higherPrecisionEnabled, timestampOption, pool, loc, numRows) if err != nil { return nil, err } defer keyCol.Release() valueCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.MapType).ItemField(), mapCol.Items(), fieldMetadata.Fields[1], higherPrecisionEnabled, timestampOption, pool, loc, numRows) if err != nil { return nil, err } defer valueCol.Release() structArr, err := array.NewStructArray([]arrow.Array{keyCol, valueCol}, []string{"k", "v"}) if err != nil { return nil, err } defer structArr.Release() newData := array.NewData(arrow.MapOf(keyCol.DataType(), valueCol.DataType()), mapCol.Len(), mapCol.Data().Buffers(), []arrow.ArrayData{structArr.Data()}, mapCol.NullN(), 0) defer newData.Release() return array.NewMapData(newData), nil } else if stringCol, ok := col.(*array.String); ok { newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows) } default: col.Retain() } return newCol, nil } func arrowStringRecordToColumn( ctx context.Context, stringCol *array.String, mem memory.Allocator, numRows int64, ) arrow.Array { if ia.Utf8ValidationEnabled(ctx) && stringCol.DataType().ID() == arrow.STRING { tb := array.NewStringBuilder(mem) defer tb.Release() for i := 0; i < int(numRows); i++ { if stringCol.IsValid(i) { stringValue := stringCol.Value(i) if !utf8.ValidString(stringValue) { stringValue = strings.ToValidUTF8(stringValue, "�") } tb.Append(stringValue) } else { tb.AppendNull() } } arr := tb.NewArray() return arr } stringCol.Retain() return stringCol } func intToBigFloat(val int64, scale int64) *big.Float { f := new(big.Float).SetInt64(val) s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) return new(big.Float).Quo(f, s) } // ArrowSnowflakeTimestampToTime converts original timestamp returned by Snowflake to time.Time. func ArrowSnowflakeTimestampToTime( column arrow.Array, sfType types.SnowflakeType, scale int, recIdx int, loc *time.Location) *time.Time { if column.IsNull(recIdx) { return nil } var ret time.Time switch sfType { case types.TimestampNtzType: if column.DataType().ID() == arrow.STRUCT { structData := column.(*array.Struct) epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).UTC() } else { intData := column.(*array.Int64) value := intData.Value(recIdx) epoch := extractEpoch(value, scale) fraction := extractFraction(value, scale) ret = time.Unix(epoch, fraction).UTC() } case types.TimestampLtzType: if column.DataType().ID() == arrow.STRUCT { structData := column.(*array.Struct) epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(loc) } else { intData := column.(*array.Int64) value := intData.Value(recIdx) epoch := extractEpoch(value, scale) fraction := extractFraction(value, scale) ret = time.Unix(epoch, fraction).In(loc) } case types.TimestampTzType: structData := column.(*array.Struct) if structData.NumField() == 2 { value := structData.Field(0).(*array.Int64).Int64Values() timezone := structData.Field(1).(*array.Int32).Int32Values() epoch := extractEpoch(value[recIdx], scale) fraction := extractFraction(value[recIdx], scale) locTz := sf.Location(int(timezone[recIdx]) - 1440) ret = time.Unix(epoch, fraction).In(locTz) } else { epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() timezone := structData.Field(2).(*array.Int32).Int32Values() locTz := sf.Location(int(timezone[recIdx]) - 1440) ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(locTz) } } return &ret } func extractEpoch(value int64, scale int) int64 { return value / int64(math.Pow10(scale)) } func extractFraction(value int64, scale int) int64 { return (value % int64(math.Pow10(scale))) * int64(math.Pow10(9-scale)) } ================================================ FILE: arrowbatches/converter_test.go ================================================ package arrowbatches import ( "context" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" "math/big" "strings" "testing" "time" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/memory" ) var decimalShift = new(big.Int).Exp(big.NewInt(2), big.NewInt(64), nil) func stringIntToDecimal(src string) (decimal128.Num, bool) { b, ok := new(big.Int).SetString(src, 10) if !ok { return decimal128.Num{}, ok } var high, low big.Int high.QuoRem(b, decimalShift, &low) return decimal128.New(high.Int64(), low.Uint64()), true } func decimalToBigInt(num decimal128.Num) *big.Int { high := new(big.Int).SetInt64(num.HighBits()) low := new(big.Int).SetUint64(num.LowBits()) return new(big.Int).Add(new(big.Int).Mul(high, decimalShift), low) } func TestArrowToRecord(t *testing.T) { pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer pool.AssertSize(t, 0) var valids []bool localTime := time.Date(2019, 1, 1, 1, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600)) localTimeFarIntoFuture := time.Date(9000, 2, 6, 14, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600)) epochField := arrow.Field{Name: "epoch", Type: &arrow.Int64Type{}} timezoneField := arrow.Field{Name: "timezone", Type: &arrow.Int32Type{}} fractionField := arrow.Field{Name: "fraction", Type: &arrow.Int32Type{}} timestampTzStructWithoutFraction := arrow.StructOf(epochField, timezoneField) timestampTzStructWithFraction := arrow.StructOf(epochField, fractionField, timezoneField) timestampNtzStruct := arrow.StructOf(epochField, fractionField) timestampLtzStruct := arrow.StructOf(epochField, fractionField) type testObj struct { field1 int field2 string } for _, tc := range []struct { logical string physical string sc *arrow.Schema rowType query.ExecResponseRowType values any expected any error string arrowBatchesTimestampOption ia.TimestampOption enableArrowBatchesUtf8Validation bool withHigherPrecision bool nrows int builder array.Builder append func(b array.Builder, vs any) compare func(src any, expected any, rec arrow.Record) int }{ { logical: "fixed", physical: "number", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), values: []int64{1, 2}, nrows: 2, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, }, { logical: "fixed", physical: "int64", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}}}, nil), values: []string{"10000000000000000000000000000000000000", "-12345678901234567890123456789012345678"}, nrows: 2, builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 0}), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringIntToDecimal(s) if !ok { t.Fatalf("failed to convert to Int64") } b.(*array.Decimal128Builder).Append(num) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]string) for i, dec := range convertedRec.Column(0).(*array.Int64).Int64Values() { num, ok := stringIntToDecimal(srcvs[i]) if !ok { return i } srcDec := decimalToBigInt(num).Int64() if srcDec != dec { return i } } return -1 }, }, { logical: "fixed", physical: "number(38,0)", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}}}, nil), values: []string{"10000000000000000000000000000000000000", "-12345678901234567890123456789012345678"}, withHigherPrecision: true, nrows: 2, builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 0}), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringIntToDecimal(s) if !ok { t.Fatalf("failed to convert to Int64") } b.(*array.Decimal128Builder).Append(num) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]string) for i, dec := range convertedRec.Column(0).(*array.Decimal128).Values() { srcDec, ok := stringIntToDecimal(srcvs[i]) if !ok { return i } if srcDec != dec { return i } } return -1 }, }, { logical: "fixed", physical: "float64", rowType: query.ExecResponseRowType{Scale: 37}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 37}}}, nil), values: []string{"1.2345678901234567890123456789012345678", "-9.999999999999999"}, nrows: 2, builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 37}), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, err := decimal128.FromString(s, 38, 37) if err != nil { t.Fatalf("failed to convert to decimal: %s", err) } b.(*array.Decimal128Builder).Append(num) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]string) for i, dec := range convertedRec.Column(0).(*array.Float64).Float64Values() { num, err := decimal128.FromString(srcvs[i], 38, 37) if err != nil { return i } srcDec := num.ToFloat64(37) if srcDec != dec { return i } } return -1 }, }, { logical: "fixed", physical: "number(38,37)", rowType: query.ExecResponseRowType{Scale: 37}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 37}}}, nil), values: []string{"1.2345678901234567890123456789012345678", "-9.999999999999999"}, withHigherPrecision: true, nrows: 2, builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 37}), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, err := decimal128.FromString(s, 38, 37) if err != nil { t.Fatalf("failed to convert to decimal: %s", err) } b.(*array.Decimal128Builder).Append(num) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]string) for i, dec := range convertedRec.Column(0).(*array.Decimal128).Values() { srcDec, err := decimal128.FromString(srcvs[i], 38, 37) if err != nil { return i } if srcDec != dec { return i } } return -1 }, }, { logical: "fixed", physical: "int8", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int8Type{}}}, nil), values: []int8{1, 2}, nrows: 2, builder: array.NewInt8Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int8Builder).AppendValues(vs.([]int8), valids) }, }, { logical: "fixed", physical: "int16", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int16Type{}}}, nil), values: []int16{1, 2}, nrows: 2, builder: array.NewInt16Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int16Builder).AppendValues(vs.([]int16), valids) }, }, { logical: "fixed", physical: "int32", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int32Type{}}}, nil), values: []int32{1, 2}, nrows: 2, builder: array.NewInt32Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int32Builder).AppendValues(vs.([]int32), valids) }, }, { logical: "fixed", physical: "int64", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), values: []int64{1, 2}, nrows: 2, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, }, { logical: "fixed", physical: "float8", rowType: query.ExecResponseRowType{Scale: 1}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int8Type{}}}, nil), values: []int8{10, 16}, nrows: 2, builder: array.NewInt8Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int8Builder).AppendValues(vs.([]int8), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]int8) for i, f := range convertedRec.Column(0).(*array.Float64).Float64Values() { rawFloat, _ := intToBigFloat(int64(srcvs[i]), 1).Float64() if rawFloat != f { return i } } return -1 }, }, { logical: "fixed", physical: "int8", rowType: query.ExecResponseRowType{Scale: 1}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int8Type{}}}, nil), values: []int8{10, 16}, withHigherPrecision: true, nrows: 2, builder: array.NewInt8Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int8Builder).AppendValues(vs.([]int8), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]int8) for i, f := range convertedRec.Column(0).(*array.Int8).Int8Values() { if srcvs[i] != f { return i } } return -1 }, }, { logical: "fixed", physical: "float16", rowType: query.ExecResponseRowType{Scale: 1}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int16Type{}}}, nil), values: []int16{20, 26}, nrows: 2, builder: array.NewInt16Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int16Builder).AppendValues(vs.([]int16), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]int16) for i, f := range convertedRec.Column(0).(*array.Float64).Float64Values() { rawFloat, _ := intToBigFloat(int64(srcvs[i]), 1).Float64() if rawFloat != f { return i } } return -1 }, }, { logical: "fixed", physical: "int16", rowType: query.ExecResponseRowType{Scale: 1}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int16Type{}}}, nil), values: []int16{20, 26}, withHigherPrecision: true, nrows: 2, builder: array.NewInt16Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int16Builder).AppendValues(vs.([]int16), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]int16) for i, f := range convertedRec.Column(0).(*array.Int16).Int16Values() { if srcvs[i] != f { return i } } return -1 }, }, { logical: "fixed", physical: "float32", rowType: query.ExecResponseRowType{Scale: 2}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int32Type{}}}, nil), values: []int32{200, 265}, nrows: 2, builder: array.NewInt32Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int32Builder).AppendValues(vs.([]int32), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]int32) for i, f := range convertedRec.Column(0).(*array.Float64).Float64Values() { rawFloat, _ := intToBigFloat(int64(srcvs[i]), 2).Float64() if rawFloat != f { return i } } return -1 }, }, { logical: "fixed", physical: "int32", rowType: query.ExecResponseRowType{Scale: 2}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int32Type{}}}, nil), values: []int32{200, 265}, withHigherPrecision: true, nrows: 2, builder: array.NewInt32Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int32Builder).AppendValues(vs.([]int32), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]int32) for i, f := range convertedRec.Column(0).(*array.Int32).Int32Values() { if srcvs[i] != f { return i } } return -1 }, }, { logical: "fixed", physical: "float64", rowType: query.ExecResponseRowType{Scale: 5}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), values: []int64{12345, 234567}, nrows: 2, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]int64) for i, f := range convertedRec.Column(0).(*array.Float64).Float64Values() { rawFloat, _ := intToBigFloat(srcvs[i], 5).Float64() if rawFloat != f { return i } } return -1 }, }, { logical: "fixed", physical: "int64", rowType: query.ExecResponseRowType{Scale: 5}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), values: []int64{12345, 234567}, withHigherPrecision: true, nrows: 2, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]int64) for i, f := range convertedRec.Column(0).(*array.Int64).Int64Values() { if srcvs[i] != f { return i } } return -1 }, }, { logical: "boolean", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.BooleanType{}}}, nil), values: []bool{true, false}, nrows: 2, builder: array.NewBooleanBuilder(pool), append: func(b array.Builder, vs any) { b.(*array.BooleanBuilder).AppendValues(vs.([]bool), valids) }, }, { logical: "real", physical: "float", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Float64Type{}}}, nil), values: []float64{1, 2}, nrows: 2, builder: array.NewFloat64Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Float64Builder).AppendValues(vs.([]float64), valids) }, }, { logical: "text", physical: "string", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}}}, nil), values: []string{"foo", "bar"}, nrows: 2, builder: array.NewStringBuilder(pool), append: func(b array.Builder, vs any) { b.(*array.StringBuilder).AppendValues(vs.([]string), valids) }, }, { logical: "text", physical: "string with invalid utf8", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}}}, nil), rowType: query.ExecResponseRowType{Type: "TEXT"}, values: []string{"\xFF", "bar", "baz\xFF\xFF"}, expected: []string{"�", "bar", "baz��"}, enableArrowBatchesUtf8Validation: true, nrows: 2, builder: array.NewStringBuilder(pool), append: func(b array.Builder, vs any) { b.(*array.StringBuilder).AppendValues(vs.([]string), valids) }, compare: func(src any, expected any, convertedRec arrow.Record) int { arr := convertedRec.Column(0).(*array.String) for i := 0; i < arr.Len(); i++ { if expected.([]string)[i] != arr.Value(i) { return i } } return -1 }, }, { logical: "binary", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.BinaryType{}}}, nil), values: [][]byte{[]byte("foo"), []byte("bar")}, nrows: 2, builder: array.NewBinaryBuilder(pool, arrow.BinaryTypes.Binary), append: func(b array.Builder, vs any) { b.(*array.BinaryBuilder).AppendValues(vs.([][]byte), valids) }, }, { logical: "date", sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Date32Type{}}}, nil), values: []time.Time{time.Now(), localTime}, nrows: 2, builder: array.NewDate32Builder(pool), append: func(b array.Builder, vs any) { for _, d := range vs.([]time.Time) { b.(*array.Date32Builder).Append(arrow.Date32(d.Unix())) } }, }, { logical: "time", sc: arrow.NewSchema([]arrow.Field{{Type: arrow.FixedWidthTypes.Time64ns}}, nil), values: []time.Time{time.Now(), time.Now()}, nrows: 2, builder: array.NewTime64Builder(pool, arrow.FixedWidthTypes.Time64ns.(*arrow.Time64Type)), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Time64Builder).Append(arrow.Time64(t.UnixNano())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) arr := convertedRec.Column(0).(*array.Time64) for i := 0; i < arr.Len(); i++ { if srcvs[i].UnixNano() != int64(arr.Value(i)) { return i } } return -1 }, }, { logical: "timestamp_ntz", physical: "int64", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, nrows: 2, rowType: query.ExecResponseRowType{Scale: 3}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixMilli()) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { return i } } return -1 }, }, { logical: "timestamp_ntz", physical: "struct", values: []time.Time{time.Now(), localTime}, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { return i } } return -1 }, }, { logical: "timestamp_ntz", physical: "struct", values: []time.Time{time.Now().Truncate(time.Microsecond), localTime.Truncate(time.Microsecond)}, arrowBatchesTimestampOption: ia.UseMicrosecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Microsecond)) { return i } } return -1 }, }, { logical: "timestamp_ntz", physical: "struct", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, arrowBatchesTimestampOption: ia.UseMillisecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Millisecond)) { return i } } return -1 }, }, { logical: "timestamp_ntz", physical: "struct", values: []time.Time{time.Now().Truncate(time.Second), localTime.Truncate(time.Second)}, arrowBatchesTimestampOption: ia.UseSecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Second)) { return i } } return -1 }, }, { logical: "timestamp_ntz", physical: "error", values: []time.Time{localTimeFarIntoFuture}, error: "Cannot convert timestamp", nrows: 1, rowType: query.ExecResponseRowType{Scale: 3}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixMilli()) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { return 0 }, }, { logical: "timestamp_ntz", physical: "int64 with original timestamp", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)}, arrowBatchesTimestampOption: ia.UseOriginalTimestamp, nrows: 3, rowType: query.ExecResponseRowType{Scale: 3}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixMilli()) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i := 0; i < convertedRec.Column(0).Len(); i++ { ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_ntz"), 3, i, nil) if !srcvs[i].Equal(*ts) { return i } } return -1 }, }, { logical: "timestamp_ntz", physical: "struct with original timestamp", values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture}, arrowBatchesTimestampOption: ia.UseOriginalTimestamp, nrows: 3, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i := 0; i < convertedRec.Column(0).Len(); i++ { ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_ntz"), 9, i, nil) if !srcvs[i].Equal(*ts) { return i } } return -1 }, }, { logical: "timestamp_ltz", physical: "int64", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, nrows: 2, rowType: query.ExecResponseRowType{Scale: 3}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixMilli()) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { return i } } return -1 }, }, { logical: "timestamp_ltz", physical: "struct", values: []time.Time{time.Now(), localTime}, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { return i } } return -1 }, }, { logical: "timestamp_ltz", physical: "struct", values: []time.Time{time.Now().Truncate(time.Microsecond), localTime.Truncate(time.Microsecond)}, arrowBatchesTimestampOption: ia.UseMicrosecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Microsecond)) { return i } } return -1 }, }, { logical: "timestamp_ltz", physical: "struct", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, arrowBatchesTimestampOption: ia.UseMillisecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Millisecond)) { return i } } return -1 }, }, { logical: "timestamp_ltz", physical: "struct", values: []time.Time{time.Now().Truncate(time.Second), localTime.Truncate(time.Second)}, arrowBatchesTimestampOption: ia.UseSecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampNtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Second)) { return i } } return -1 }, }, { logical: "timestamp_ltz", physical: "error", values: []time.Time{localTimeFarIntoFuture}, error: "Cannot convert timestamp", nrows: 1, rowType: query.ExecResponseRowType{Scale: 3}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixMilli()) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { return 0 }, }, { logical: "timestamp_ltz", physical: "int64 with original timestamp", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)}, arrowBatchesTimestampOption: ia.UseOriginalTimestamp, nrows: 3, rowType: query.ExecResponseRowType{Scale: 3}, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil), builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixMilli()) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i := 0; i < convertedRec.Column(0).Len(); i++ { ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_ltz"), 3, i, localTime.Location()) if !srcvs[i].Equal(*ts) { return i } } return -1 }, }, { logical: "timestamp_ltz", physical: "struct with original timestamp", values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture}, arrowBatchesTimestampOption: ia.UseOriginalTimestamp, nrows: 3, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampLtzStruct}}, nil), builder: array.NewStructBuilder(pool, timestampLtzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i := 0; i < convertedRec.Column(0).Len(); i++ { ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_ltz"), 9, i, localTime.Location()) if !srcvs[i].Equal(*ts) { return i } } return -1 }, }, { logical: "timestamp_tz", physical: "struct2", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, nrows: 2, rowType: query.ExecResponseRowType{Scale: 3}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithoutFraction}}, nil), builder: array.NewStructBuilder(pool, timestampTzStructWithoutFraction), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.UnixMilli()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(0)) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { return i } } return -1 }, }, { logical: "timestamp_tz", physical: "struct3", values: []time.Time{time.Now(), localTime}, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil), builder: array.NewStructBuilder(pool, timestampTzStructWithFraction), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0)) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) { return i } } return -1 }, }, { logical: "timestamp_tz", physical: "struct3", values: []time.Time{time.Now().Truncate(time.Microsecond), localTime.Truncate(time.Microsecond)}, arrowBatchesTimestampOption: ia.UseMicrosecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil), builder: array.NewStructBuilder(pool, timestampTzStructWithFraction), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0)) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Microsecond)) { return i } } return -1 }, }, { logical: "timestamp_tz", physical: "struct3", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)}, arrowBatchesTimestampOption: ia.UseMillisecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil), builder: array.NewStructBuilder(pool, timestampTzStructWithFraction), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0)) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Millisecond)) { return i } } return -1 }, }, { logical: "timestamp_tz", physical: "struct3", values: []time.Time{time.Now().Truncate(time.Second), localTime.Truncate(time.Second)}, arrowBatchesTimestampOption: ia.UseSecondTimestamp, nrows: 2, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil), builder: array.NewStructBuilder(pool, timestampTzStructWithFraction), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0)) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() { if !srcvs[i].Equal(t.ToTime(arrow.Second)) { return i } } return -1 }, }, { logical: "timestamp_tz", physical: "struct2 with original timestamp", values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)}, arrowBatchesTimestampOption: ia.UseOriginalTimestamp, nrows: 3, rowType: query.ExecResponseRowType{Scale: 3}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithoutFraction}}, nil), builder: array.NewStructBuilder(pool, timestampTzStructWithoutFraction), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.UnixMilli()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(0)) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i := 0; i < convertedRec.Column(0).Len(); i++ { ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_tz"), 3, i, nil) if !srcvs[i].Equal(*ts) { return i } } return -1 }, }, { logical: "timestamp_tz", physical: "struct3 with original timestamp", values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture}, arrowBatchesTimestampOption: ia.UseOriginalTimestamp, nrows: 3, rowType: query.ExecResponseRowType{Scale: 9}, sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil), builder: array.NewStructBuilder(pool, timestampTzStructWithFraction), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond())) sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0)) } }, compare: func(src any, expected any, convertedRec arrow.Record) int { srcvs := src.([]time.Time) for i := 0; i < convertedRec.Column(0).Len(); i++ { ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_tz"), 9, i, nil) if !srcvs[i].Equal(*ts) { return i } } return -1 }, }, { logical: "array", values: [][]string{{"foo", "bar"}, {"baz", "quz", "quux"}}, nrows: 2, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}}}, nil), builder: array.NewStringBuilder(pool), append: func(b array.Builder, vs any) { for _, a := range vs.([][]string) { b.(*array.StringBuilder).Append(fmt.Sprint(a)) } }, }, { logical: "object", values: []testObj{{0, "foo"}, {1, "bar"}}, nrows: 2, sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}}}, nil), builder: array.NewStringBuilder(pool), append: func(b array.Builder, vs any) { for _, o := range vs.([]testObj) { b.(*array.StringBuilder).Append(fmt.Sprint(o)) } }, }, } { testName := tc.logical if tc.physical != "" { testName += " " + tc.physical } t.Run(testName, func(t *testing.T) { scope := memory.NewCheckedAllocatorScope(pool) defer scope.CheckSize(t) b := tc.builder defer b.Release() tc.append(b, tc.values) arr := b.NewArray() defer arr.Release() rawRec := array.NewRecord(tc.sc, []arrow.Array{arr}, int64(tc.nrows)) defer rawRec.Release() meta := tc.rowType meta.Type = tc.logical ctx := context.Background() switch tc.arrowBatchesTimestampOption { case ia.UseOriginalTimestamp: ctx = ia.WithTimestampOption(ctx, ia.UseOriginalTimestamp) case ia.UseSecondTimestamp: ctx = ia.WithTimestampOption(ctx, ia.UseSecondTimestamp) case ia.UseMillisecondTimestamp: ctx = ia.WithTimestampOption(ctx, ia.UseMillisecondTimestamp) case ia.UseMicrosecondTimestamp: ctx = ia.WithTimestampOption(ctx, ia.UseMicrosecondTimestamp) default: ctx = ia.WithTimestampOption(ctx, ia.UseNanosecondTimestamp) } if tc.enableArrowBatchesUtf8Validation { ctx = ia.EnableUtf8Validation(ctx) } if tc.withHigherPrecision { ctx = ia.WithHigherPrecision(ctx) } transformedRec, err := arrowToRecord(ctx, rawRec, pool, []query.ExecResponseRowType{meta}, localTime.Location()) if err != nil { if tc.error == "" || !strings.Contains(err.Error(), tc.error) { t.Fatalf("error: %s", err) } } else { defer transformedRec.Release() if tc.error != "" { t.Fatalf("expected error: %s", tc.error) } if tc.compare != nil { idx := tc.compare(tc.values, tc.expected, transformedRec) if idx != -1 { t.Fatalf("error: column array value mismatch at index %v", idx) } } else { for i, c := range transformedRec.Columns() { rawCol := rawRec.Column(i) if rawCol != c { t.Fatalf("error: expected column %s, got column %s", rawCol, c) } } } } }) } } ================================================ FILE: arrowbatches/schema.go ================================================ package arrowbatches import ( "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" "time" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "github.com/apache/arrow-go/v18/arrow" ) func recordToSchema(sc *arrow.Schema, rowType []query.ExecResponseRowType, loc *time.Location, timestampOption ia.TimestampOption, withHigherPrecision bool) (*arrow.Schema, error) { fields := recordToSchemaRecursive(sc.Fields(), rowType, loc, timestampOption, withHigherPrecision) meta := sc.Metadata() return arrow.NewSchema(fields, &meta), nil } func recordToSchemaRecursive(inFields []arrow.Field, rowType []query.ExecResponseRowType, loc *time.Location, timestampOption ia.TimestampOption, withHigherPrecision bool) []arrow.Field { var outFields []arrow.Field for i, f := range inFields { fieldMetadata := rowType[i].ToFieldMetadata() converted, t := recordToSchemaSingleField(fieldMetadata, f, withHigherPrecision, timestampOption, loc) newField := f if converted { newField = arrow.Field{ Name: f.Name, Type: t, Nullable: f.Nullable, Metadata: f.Metadata, } } outFields = append(outFields, newField) } return outFields } func recordToSchemaSingleField(fieldMetadata query.FieldMetadata, f arrow.Field, withHigherPrecision bool, timestampOption ia.TimestampOption, loc *time.Location) (bool, arrow.DataType) { t := f.Type converted := true switch types.GetSnowflakeType(fieldMetadata.Type) { case types.FixedType: switch f.Type.ID() { case arrow.DECIMAL: if withHigherPrecision { converted = false } else if fieldMetadata.Scale == 0 { t = &arrow.Int64Type{} } else { t = &arrow.Float64Type{} } default: if withHigherPrecision { converted = false } else if fieldMetadata.Scale != 0 { t = &arrow.Float64Type{} } else { converted = false } } case types.TimeType: t = &arrow.Time64Type{Unit: arrow.Nanosecond} case types.TimestampNtzType, types.TimestampTzType: switch timestampOption { case ia.UseOriginalTimestamp: converted = false case ia.UseMicrosecondTimestamp: t = &arrow.TimestampType{Unit: arrow.Microsecond} case ia.UseMillisecondTimestamp: t = &arrow.TimestampType{Unit: arrow.Millisecond} case ia.UseSecondTimestamp: t = &arrow.TimestampType{Unit: arrow.Second} default: t = &arrow.TimestampType{Unit: arrow.Nanosecond} } case types.TimestampLtzType: switch timestampOption { case ia.UseOriginalTimestamp: converted = false case ia.UseMicrosecondTimestamp: t = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: loc.String()} case ia.UseMillisecondTimestamp: t = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: loc.String()} case ia.UseSecondTimestamp: t = &arrow.TimestampType{Unit: arrow.Second, TimeZone: loc.String()} default: t = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()} } case types.ObjectType: converted = false if f.Type.ID() == arrow.STRUCT { var internalFields []arrow.Field for idx, internalField := range f.Type.(*arrow.StructType).Fields() { internalConverted, convertedDataType := recordToSchemaSingleField(fieldMetadata.Fields[idx], internalField, withHigherPrecision, timestampOption, loc) converted = converted || internalConverted if internalConverted { newInternalField := arrow.Field{ Name: internalField.Name, Type: convertedDataType, Metadata: internalField.Metadata, Nullable: internalField.Nullable, } internalFields = append(internalFields, newInternalField) } else { internalFields = append(internalFields, internalField) } } t = arrow.StructOf(internalFields...) } case types.ArrayType: if _, ok := f.Type.(*arrow.ListType); ok { converted, dataType := recordToSchemaSingleField(fieldMetadata.Fields[0], f.Type.(*arrow.ListType).ElemField(), withHigherPrecision, timestampOption, loc) if converted { t = arrow.ListOf(dataType) } } else { t = f.Type } case types.MapType: convertedKey, keyDataType := recordToSchemaSingleField(fieldMetadata.Fields[0], f.Type.(*arrow.MapType).KeyField(), withHigherPrecision, timestampOption, loc) convertedValue, valueDataType := recordToSchemaSingleField(fieldMetadata.Fields[1], f.Type.(*arrow.MapType).ItemField(), withHigherPrecision, timestampOption, loc) converted = convertedKey || convertedValue if converted { t = arrow.MapOf(keyDataType, valueDataType) } default: converted = false } return converted, t } ================================================ FILE: assert_test.go ================================================ package gosnowflake import ( "bytes" "errors" "fmt" "math" "reflect" "regexp" "slices" "strings" "testing" "time" ) func assertNilE(t *testing.T, actual any, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateNil(actual, descriptions...)) } func assertNilF(t *testing.T, actual any, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateNil(actual, descriptions...)) } func assertNotNilE(t *testing.T, actual any, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateNotNil(actual, descriptions...)) } func assertNotNilF(t *testing.T, actual any, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateNotNil(actual, descriptions...)) } func assertErrIsF(t *testing.T, actual, expected error, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateErrIs(actual, expected, descriptions...)) } func assertErrIsE(t *testing.T, actual, expected error, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateErrIs(actual, expected, descriptions...)) } func assertErrorsAsF(t *testing.T, err error, target any, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateErrorsAs(err, target, descriptions...)) } func assertEqualE(t *testing.T, actual any, expected any, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) } func assertEqualF(t *testing.T, actual any, expected any, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) } func assertEqualIgnoringWhitespaceE(t *testing.T, actual string, expected string, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEqualIgnoringWhitespace(actual, expected, descriptions...)) } func assertEqualEpsilonE(t *testing.T, actual, expected, epsilon float64, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEqualEpsilon(actual, expected, epsilon, descriptions...)) } func assertDeepEqualE(t *testing.T, actual any, expected any, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateDeepEqual(actual, expected, descriptions...)) } func assertNotEqualF(t *testing.T, actual any, expected any, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateNotEqual(actual, expected, descriptions...)) } func assertNotEqualE(t *testing.T, actual any, expected any, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateNotEqual(actual, expected, descriptions...)) } func assertBytesEqualE(t *testing.T, actual []byte, expected []byte, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateBytesEqual(actual, expected, descriptions...)) } func assertTrueF(t *testing.T, actual bool, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateEqual(actual, true, descriptions...)) } func assertTrueE(t *testing.T, actual bool, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEqual(actual, true, descriptions...)) } func assertFalseF(t *testing.T, actual bool, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateEqual(actual, false, descriptions...)) } func assertFalseE(t *testing.T, actual bool, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEqual(actual, false, descriptions...)) } func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) } func assertStringContainsF(t *testing.T, actual string, expectedToContain string, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) } func assertEmptyStringE(t *testing.T, actual string, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEmptyString(actual, descriptions...)) } func assertHasPrefixF(t *testing.T, actual string, expectedPrefix string, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateHasPrefix(actual, expectedPrefix, descriptions...)) } func assertHasPrefixE(t *testing.T, actual string, expectedPrefix string, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateHasPrefix(actual, expectedPrefix, descriptions...)) } func assertBetweenE(t *testing.T, value float64, min float64, max float64, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateValueBetween(value, min, max, descriptions...)) } func assertBetweenInclusiveE(t *testing.T, value float64, min float64, max float64, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateValueBetweenInclusive(value, min, max, descriptions...)) } func assertEmptyE[T any](t *testing.T, actual []T, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEmpty(actual, descriptions...)) } func fatalOnNonEmpty(t *testing.T, errMsg string) { if errMsg != "" { t.Helper() t.Fatal(formatErrorMessage(errMsg)) } } func errorOnNonEmpty(t *testing.T, errMsg string) { if errMsg != "" { t.Helper() t.Error(formatErrorMessage(errMsg)) } } func formatErrorMessage(errMsg string) string { return fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339Nano), maskSecrets(errMsg)) } func validateNil(actual any, descriptions ...string) string { if isNil(actual) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to be nil but was not. %s", maskSecrets(fmt.Sprintf("%v", actual)), desc) } func validateNotNil(actual any, descriptions ...string) string { if !isNil(actual) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected to be not nil but was not. %s", desc) } func validateErrIs(actual, expected error, descriptions ...string) string { if errors.Is(actual, expected) { return "" } desc := joinDescriptions(descriptions...) actualStr := "nil" expectedStr := "nil" if actual != nil { actualStr = maskSecrets(actual.Error()) } if expected != nil { expectedStr = maskSecrets(expected.Error()) } return fmt.Sprintf("expected %v to be %v. %s", actualStr, expectedStr, desc) } func validateErrorsAs(err error, target any, descriptions ...string) string { if errors.As(err, target) { return "" } desc := joinDescriptions(descriptions...) errStr := "nil" if err != nil { errStr = maskSecrets(err.Error()) } targetType := reflect.TypeOf(target) return fmt.Sprintf("expected error %v to be assignable to %v but was not. %s", errStr, targetType, desc) } func validateEqual(actual any, expected any, descriptions ...string) string { if expected == actual { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s", maskSecrets(fmt.Sprintf("%v", actual)), maskSecrets(fmt.Sprintf("%v", expected)), desc) } func removeWhitespaces(s string) string { pattern, err := regexp.Compile(`\s+`) if err != nil { panic(err) } return pattern.ReplaceAllString(s, "") } func validateEqualIgnoringWhitespace(actual string, expected string, descriptions ...string) string { if removeWhitespaces(expected) == removeWhitespaces(actual) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s", maskSecrets(actual), maskSecrets(expected), desc) } func validateEqualEpsilon(actual, expected, epsilon float64, descriptions ...string) string { if math.Abs(actual-expected) < epsilon { return "" } return fmt.Sprintf("expected \"%f\" to be equal to \"%f\" within epsilon \"%f\" but was not. %s", actual, expected, epsilon, joinDescriptions(descriptions...)) } func validateDeepEqual(actual any, expected any, descriptions ...string) string { if reflect.DeepEqual(actual, expected) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s", maskSecrets(fmt.Sprintf("%v", actual)), maskSecrets(fmt.Sprintf("%v", expected)), desc) } func validateNotEqual(actual any, expected any, descriptions ...string) string { if expected != actual { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" not to be equal to \"%s\" but they were the same. %s", maskSecrets(fmt.Sprintf("%v", actual)), maskSecrets(fmt.Sprintf("%v", expected)), desc) } func validateBytesEqual(actual []byte, expected []byte, descriptions ...string) string { if bytes.Equal(actual, expected) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s", maskSecrets(string(actual)), maskSecrets(string(expected)), desc) } func validateStringContains(actual string, expectedToContain string, descriptions ...string) string { if strings.Contains(actual, expectedToContain) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to contain \"%s\" but did not. %s", maskSecrets(actual), maskSecrets(expectedToContain), desc) } func validateEmptyString(actual string, descriptions ...string) string { if actual == "" { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to be empty, but was not. %s", maskSecrets(actual), desc) } func validateHasPrefix(actual string, expectedPrefix string, descriptions ...string) string { if strings.HasPrefix(actual, expectedPrefix) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to start with \"%s\" but did not. %s", maskSecrets(actual), maskSecrets(expectedPrefix), desc) } func validateValueBetween(value float64, min float64, max float64, descriptions ...string) string { if value > min && value < max { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" should be between \"%s\" and \"%s\" but did not. %s", fmt.Sprintf("%f", value), fmt.Sprintf("%f", min), fmt.Sprintf("%f", max), desc) } func validateValueBetweenInclusive(value float64, min float64, max float64, descriptions ...string) string { if value >= min && value <= max { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" should be between \"%s\" and \"%s\" inclusively but did not. %s", fmt.Sprintf("%f", value), fmt.Sprintf("%f", min), fmt.Sprintf("%f", max), desc) } func validateEmpty[T any](value []T, descriptions ...string) string { if len(value) == 0 { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%v\" to be empty. %s", maskSecrets(fmt.Sprintf("%v", value)), desc) } func joinDescriptions(descriptions ...string) string { return strings.Join(descriptions, " ") } func isNil(value any) bool { if value == nil { return true } val := reflect.ValueOf(value) return slices.Contains([]reflect.Kind{reflect.Pointer, reflect.Slice, reflect.Map, reflect.Interface, reflect.Func}, val.Kind()) && val.IsNil() } ================================================ FILE: async.go ================================================ package gosnowflake import ( "context" "fmt" "net/url" "strconv" "time" ) func (sr *snowflakeRestful) processAsync( ctx context.Context, respd *execResponse, headers map[string]string, timeout time.Duration, cfg *Config) (*execResponse, error) { // placeholder object to return to user while retrieving results rows := new(snowflakeRows) res := new(snowflakeResult) switch resType := getResultType(ctx); resType { case execResultType: res.queryID = respd.Data.QueryID res.status = QueryStatusInProgress res.errChannel = make(chan error) respd.Data.AsyncResult = res case queryResultType: rows.queryID = respd.Data.QueryID rows.status = QueryStatusInProgress rows.errChannel = make(chan error) rows.ctx = ctx respd.Data.AsyncRows = rows default: return respd, nil } // spawn goroutine to retrieve asynchronous results go GoroutineWrapper( ctx, func() { err := sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg) if err != nil { logger.WithContext(ctx).Errorf("error while calling getAsync. %v", err) } }, ) return respd, nil } func (sr *snowflakeRestful) getAsync( ctx context.Context, headers map[string]string, URL *url.URL, timeout time.Duration, res *snowflakeResult, rows *snowflakeRows, cfg *Config) error { resType := getResultType(ctx) var errChannel chan error sfError := &SnowflakeError{ Number: ErrAsync, } if resType == execResultType { errChannel = res.errChannel sfError.QueryID = res.queryID } else { errChannel = rows.errChannel sfError.QueryID = rows.queryID } defer close(errChannel) token, _, _ := sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) respd, err := getQueryResultWithRetriesForAsyncMode(ctx, sr, URL, headers, timeout) if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) sfError.Message = err.Error() errChannel <- sfError return err } sc := &snowflakeConn{rest: sr, cfg: cfg, currentTimeProvider: defaultTimeProvider} if respd.Success { if resType == execResultType { res.insertID = -1 if isDml(respd.Data.StatementTypeID) { res.affectedRows, err = updateRows(respd.Data) if err != nil { return err } } else if isMultiStmt(&respd.Data) { r, err := sc.handleMultiExec(ctx, respd.Data) if err != nil { res.errChannel <- err return err } res.affectedRows, err = r.RowsAffected() if err != nil { res.errChannel <- err return err } } res.queryID = respd.Data.QueryID res.errChannel <- nil // mark exec status complete } else { rows.sc = sc rows.queryID = respd.Data.QueryID if isMultiStmt(&respd.Data) { if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil { rows.errChannel <- err return err } } else { rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data)) } if err = rows.ChunkDownloader.start(); err != nil { rows.errChannel <- err return err } rows.errChannel <- nil // mark query status complete } } else { var code int if respd.Code != "" { code, err = strconv.Atoi(respd.Code) if err != nil { code = -1 } } else { code = -1 } errChannel <- &SnowflakeError{ Number: code, SQLState: respd.Data.SQLState, Message: respd.Message, QueryID: respd.Data.QueryID, } } return nil } func getQueryResultWithRetriesForAsyncMode( ctx context.Context, sr *snowflakeRestful, URL *url.URL, headers map[string]string, timeout time.Duration) (respd *execResponse, err error) { retry := 0 retryPattern := []int32{1, 1, 2, 3, 4, 8, 10} retryPatternIndex := 0 retryCountForSessionRenewal := 0 for { logger.WithContext(ctx).Debugf("Retry count for get query result request in async mode: %v", retry) respd, err = getExecResponse(ctx, sr, URL, headers, timeout) if err != nil { return respd, err } if respd.Code == sessionExpiredCode { // Update the session token in the header and retry token, _, _ := sr.TokenAccessor.GetTokens() if token != "" && headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, token) { headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) logger.WithContext(ctx).Debug("Session token has been updated.") retry++ continue } // Renew the session token if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { logger.WithContext(ctx).Errorf("failed to renew session token. err: %v", err) return respd, err } retryCountForSessionRenewal++ // If this is the first response, go back to retry the query // since it failed due to session expiration logger.WithContext(ctx).Debugf("retry count for session renewal: %v", retryCountForSessionRenewal) if retryCountForSessionRenewal < 2 { retry++ continue } else { logger.WithContext(ctx).Errorf("failed to get query result with the renewed session token. err: %v", err) return respd, err } } else if respd.Code != queryInProgressAsyncCode { // If the query takes longer than 45 seconds to complete the results are not returned. // If the query is still in progress after 45 seconds, retry the request to the /results endpoint. // For all other scenarios continue processing results response break } else { // Sleep before retrying get result request. Exponential backoff up to 5 seconds. // Once 5 second backoff is reached it will keep retrying with this sleeptime. sleepTime := time.Millisecond * time.Duration(500*retryPattern[retryPatternIndex]) logger.WithContext(ctx).Debugf("Query execution still in progress. Response code: %v, message: %v Sleep for %v ms", respd.Code, respd.Message, sleepTime) time.Sleep(sleepTime) retry++ if retryPatternIndex < len(retryPattern)-1 { retryPatternIndex++ } } } if len(respd.Data.RowType) > 0 { logger.Infof("[Server Response Validation]: RowType: %s, QueryResultFormat: %s", respd.Data.RowType[0].Name, respd.Data.QueryResultFormat) } return respd, nil } ================================================ FILE: async_test.go ================================================ package gosnowflake import ( "context" "database/sql" "fmt" "testing" ) func TestAsyncMode(t *testing.T) { ctx := WithAsyncMode(context.Background()) numrows := 100000 cnt := 0 var idx int var v string runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) defer rows.Close() // Next() will block and wait until results are available for rows.Next() { if err := rows.Scan(&idx, &v); err != nil { t.Fatal(err) } cnt++ } logger.Infof("NextResultSet: %v", rows.NextResultSet()) if cnt != numrows { t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) } dbt.mustExec("create or replace table test_async_exec (value boolean)") res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)") count, err := res.RowsAffected() if err != nil { t.Fatalf("res.RowsAffected() returned error: %v", err) } if count != 1 { t.Fatalf("expected 1 affected row, got %d", count) } }) } func TestAsyncModePing(t *testing.T) { ctx := WithAsyncMode(context.Background()) runDBTest(t, func(dbt *DBTest) { defer func() { if r := recover(); r != nil { t.Fatalf("panic during ping: %v", r) } }() err := dbt.conn.PingContext(ctx) if err != nil { t.Fatal(err) } }) } func TestAsyncModeMultiStatement(t *testing.T) { withMultiStmtCtx := WithMultiStatement(context.Background(), 6) ctx := WithAsyncMode(withMultiStmtCtx) multiStmtQuery := "begin;\n" + "delete from test_multi_statement_async;\n" + "insert into test_multi_statement_async values (1, 'a'), (2, 'b');\n" + "select 1;\n" + "select 2;\n" + "rollback;" runDBTest(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_async") dbt.mustExec(`create or replace table test_multi_statement_async( c1 number, c2 string) as select 10, 'z'`) defer dbt.mustExec("drop table if exists test_multi_statement_async") res := dbt.mustExecContext(ctx, multiStmtQuery) count, err := res.RowsAffected() if err != nil { t.Fatalf("res.RowsAffected() returned error: %v", err) } if count != 3 { t.Fatalf("expected 3 affected rows, got %d", count) } }) } func TestAsyncModeCancel(t *testing.T) { withCancelCtx, cancel := context.WithCancel(context.Background()) ctx := WithAsyncMode(withCancelCtx) numrows := 100000 runDBTest(t, func(dbt *DBTest) { dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) cancel() }) } func TestAsyncQueryFail(t *testing.T) { ctx := WithAsyncMode(context.Background()) runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, "selectt 1") defer rows.Close() if rows.Next() { t.Fatal("should have no rows available") } else { if err := rows.Err(); err == nil { t.Fatal("should return a syntax error") } } }) } // TestMultipleAsyncQueries validates that shorter async queries return before // longer ones. The TIMELIMIT values (30 and 10) must have sufficient separation // to avoid flaky ordering. Do not reduce these values significantly. func TestMultipleAsyncQueries(t *testing.T) { ctx := WithAsyncMode(context.Background()) s1 := "foo" s2 := "bar" ch1 := make(chan string) ch2 := make(chan string) db := openDB(t) runDBTest(t, func(dbt *DBTest) { rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30)) if err != nil { t.Fatalf("can't read rows1: %v", err) } defer rows1.Close() rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10)) if err != nil { t.Fatalf("can't read rows2: %v", err) } defer rows2.Close() go retrieveRows(rows1, ch1) go retrieveRows(rows2, ch2) select { case res := <-ch1: t.Fatalf("value %v should not have been called earlier.", res) case res := <-ch2: if res != s2 { t.Fatalf("query failed. expected: %v, got: %v", s2, res) } } }) } func retrieveRows(rows *sql.Rows, ch chan string) { var s string for rows.Next() { if err := rows.Scan(&s); err != nil { ch <- err.Error() close(ch) return } } ch <- s close(ch) } // TestLongRunningAsyncQuery validates the retry logic for async queries that // exceed Snowflake's 45-second threshold. After 45 seconds, the /results // endpoint returns "query in progress" (code 333334) and the driver must retry. // The 50-second wait MUST exceed 45 seconds to exercise this code path. func TestLongRunningAsyncQuery(t *testing.T) { runDBTest(t, func(dbt *DBTest) { ctx := WithMultiStatement(context.Background(), 0) query := "CALL SYSTEM$WAIT(50, 'SECONDS');use snowflake_sample_data" rows := dbt.mustQueryContext(WithAsyncMode(ctx), query) defer rows.Close() var v string i := 0 for { for rows.Next() { err := rows.Scan(&v) if err != nil { t.Fatalf("failed to get result. err: %v", err) } if v == "" { t.Fatal("should have returned a result") } results := []string{"waited 50 seconds", "Statement executed successfully."} if v != results[i] { t.Fatalf("unexpected result returned. expected: %v, but got: %v", results[i], v) } i++ } if !rows.NextResultSet() { break } } }) } func TestLongRunningAsyncQueryFetchResultByID(t *testing.T) { runDBTest(t, func(dbt *DBTest) { queryIDChan := make(chan string, 1) ctx := WithAsyncMode(context.Background()) ctx = WithQueryIDChan(ctx, queryIDChan) // Run a long running query asynchronously go dbt.mustExecContext(ctx, "CALL SYSTEM$WAIT(50, 'SECONDS')") // Get the query ID without waiting for the query to finish queryID := <-queryIDChan assertNotNilF(t, queryID, "expected a nonempty query ID") ctx = WithFetchResultByID(ctx, queryID) rows := dbt.mustQueryContext(ctx, "") defer rows.Close() var v string assertTrueF(t, rows.Next()) err := rows.Scan(&v) assertNilF(t, err, fmt.Sprintf("failed to get result. err: %v", err)) assertNotNilF(t, v, "should have returned a result") expected := "waited 50 seconds" if v != expected { t.Fatalf("unexpected result returned. expected: %v, but got: %v", expected, v) } assertFalseF(t, rows.NextResultSet()) }) } ================================================ FILE: auth.go ================================================ package gosnowflake import ( "context" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "runtime" "slices" "strconv" "strings" "time" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" "github.com/golang-jwt/jwt/v5" "github.com/snowflakedb/gosnowflake/v2/internal/compilation" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" internalos "github.com/snowflakedb/gosnowflake/v2/internal/os" ) const ( clientType = "Go" ) const ( clientStoreTemporaryCredential = "CLIENT_STORE_TEMPORARY_CREDENTIAL" clientRequestMfaToken = "CLIENT_REQUEST_MFA_TOKEN" idTokenAuthenticator = "ID_TOKEN" ) // AuthType indicates the type of authentication in Snowflake type AuthType = sfconfig.AuthType const ( // AuthTypeSnowflake is the general username password authentication AuthTypeSnowflake = sfconfig.AuthTypeSnowflake // AuthTypeOAuth is the OAuth authentication AuthTypeOAuth = sfconfig.AuthTypeOAuth // AuthTypeExternalBrowser is to use a browser to access an Fed and perform SSO authentication AuthTypeExternalBrowser = sfconfig.AuthTypeExternalBrowser // AuthTypeOkta is to use a native okta URL to perform SSO authentication on Okta AuthTypeOkta = sfconfig.AuthTypeOkta // AuthTypeJwt is to use Jwt to perform authentication AuthTypeJwt = sfconfig.AuthTypeJwt // AuthTypeTokenAccessor is to use the provided token accessor and bypass authentication AuthTypeTokenAccessor = sfconfig.AuthTypeTokenAccessor // AuthTypeUsernamePasswordMFA is to use username and password with mfa AuthTypeUsernamePasswordMFA = sfconfig.AuthTypeUsernamePasswordMFA // AuthTypePat is to use programmatic access token AuthTypePat = sfconfig.AuthTypePat // AuthTypeOAuthAuthorizationCode is to use browser-based OAuth2 flow AuthTypeOAuthAuthorizationCode = sfconfig.AuthTypeOAuthAuthorizationCode // AuthTypeOAuthClientCredentials is to use non-interactive OAuth2 flow AuthTypeOAuthClientCredentials = sfconfig.AuthTypeOAuthClientCredentials // AuthTypeWorkloadIdentityFederation is to use CSP identity for authentication AuthTypeWorkloadIdentityFederation = sfconfig.AuthTypeWorkloadIdentityFederation ) func isOauthNativeFlow(authType AuthType) bool { return authType == AuthTypeOAuthAuthorizationCode || authType == AuthTypeOAuthClientCredentials } var refreshOAuthTokenErrorCodes = []string{ strconv.Itoa(ErrMissingAccessATokenButRefreshTokenPresent), invalidOAuthAccessTokenCode, expiredOAuthAccessTokenCode, } // userAgent shows up in User-Agent HTTP header var userAgent = fmt.Sprintf("%v/%v (%v-%v) %v/%v", clientType, SnowflakeGoDriverVersion, runtime.GOOS, runtime.GOARCH, runtime.Compiler, runtime.Version()) type authRequestClientEnvironment struct { Application string `json:"APPLICATION"` ApplicationPath string `json:"APPLICATION_PATH"` Os string `json:"OS"` OsVersion string `json:"OS_VERSION"` OsDetails map[string]string `json:"OS_DETAILS,omitempty"` Isa string `json:"ISA,omitempty"` OCSPMode string `json:"OCSP_MODE"` GoVersion string `json:"GO_VERSION"` OAuthType string `json:"OAUTH_TYPE,omitempty"` CertRevocationCheckMode string `json:"CERT_REVOCATION_CHECK_MODE,omitempty"` Platform []string `json:"PLATFORM,omitempty"` CoreVersion string `json:"CORE_VERSION,omitempty"` CoreLoadError string `json:"CORE_LOAD_ERROR,omitempty"` CoreFileName string `json:"CORE_FILE_NAME,omitempty"` CgoEnabled bool `json:"CGO_ENABLED,omitempty"` LinkingMode string `json:"LINKING_MODE,omitempty"` LibcFamily string `json:"LIBC_FAMILY,omitempty"` LibcVersion string `json:"LIBC_VERSION,omitempty"` } type authRequestData struct { ClientAppID string `json:"CLIENT_APP_ID"` ClientAppVersion string `json:"CLIENT_APP_VERSION"` SvnRevision string `json:"SVN_REVISION"` AccountName string `json:"ACCOUNT_NAME"` LoginName string `json:"LOGIN_NAME,omitempty"` Password string `json:"PASSWORD,omitempty"` RawSAMLResponse string `json:"RAW_SAML_RESPONSE,omitempty"` ExtAuthnDuoMethod string `json:"EXT_AUTHN_DUO_METHOD,omitempty"` Passcode string `json:"PASSCODE,omitempty"` Authenticator string `json:"AUTHENTICATOR,omitempty"` SessionParameters map[string]any `json:"SESSION_PARAMETERS,omitempty"` ClientEnvironment authRequestClientEnvironment `json:"CLIENT_ENVIRONMENT"` BrowserModeRedirectPort string `json:"BROWSER_MODE_REDIRECT_PORT,omitempty"` ProofKey string `json:"PROOF_KEY,omitempty"` Token string `json:"TOKEN,omitempty"` Provider string `json:"PROVIDER,omitempty"` } type authRequest struct { Data authRequestData `json:"data"` } type nameValueParameter struct { Name string `json:"name"` Value any `json:"value"` } type authResponseSessionInfo struct { DatabaseName string `json:"databaseName"` SchemaName string `json:"schemaName"` WarehouseName string `json:"warehouseName"` RoleName string `json:"roleName"` } type authResponseMain struct { Token string `json:"token,omitempty"` Validity time.Duration `json:"validityInSeconds,omitempty"` MasterToken string `json:"masterToken,omitempty"` MasterValidity time.Duration `json:"masterValidityInSeconds"` MfaToken string `json:"mfaToken,omitempty"` MfaTokenValidity time.Duration `json:"mfaTokenValidityInSeconds"` IDToken string `json:"idToken,omitempty"` IDTokenValidity time.Duration `json:"idTokenValidityInSeconds"` DisplayUserName string `json:"displayUserName"` ServerVersion string `json:"serverVersion"` FirstLogin bool `json:"firstLogin"` RemMeToken string `json:"remMeToken"` RemMeValidity time.Duration `json:"remMeValidityInSeconds"` HealthCheckInterval time.Duration `json:"healthCheckInterval"` NewClientForUpgrade string `json:"newClientForUpgrade"` SessionID int64 `json:"sessionId"` Parameters []nameValueParameter `json:"parameters"` SessionInfo authResponseSessionInfo `json:"sessionInfo"` TokenURL string `json:"tokenUrl,omitempty"` SSOURL string `json:"ssoUrl,omitempty"` ProofKey string `json:"proofKey,omitempty"` } type authResponse struct { Data authResponseMain `json:"data"` Message string `json:"message"` Code string `json:"code"` Success bool `json:"success"` } func postAuth( ctx context.Context, sr *snowflakeRestful, client *http.Client, params *url.Values, headers map[string]string, bodyCreator bodyCreatorType, timeout time.Duration) ( data *authResponse, err error) { params.Set(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) params.Set(requestGUIDKey, NewUUID().String()) fullURL := sr.getFullURL(loginRequestPath, params) logger.WithContext(ctx).Infof("full URL: %v", fullURL) resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, sr.MaxRetryCount) if err != nil { return nil, err } defer func() { if closeErr := resp.Body.Close(); closeErr != nil { logger.WithContext(ctx).Errorf("failed to close HTTP response body for %v. err: %v", fullURL, closeErr) } }() if resp.StatusCode == http.StatusOK { var respd authResponse err = json.NewDecoder(resp.Body).Decode(&respd) if err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } return &respd, nil } switch resp.StatusCode { case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: // service availability or connectivity issue. Most likely server side issue. return nil, &SnowflakeError{ Number: ErrCodeServiceUnavailable, SQLState: SQLStateConnectionWasNotEstablished, Message: sferrors.ErrMsgServiceUnavailable, MessageArgs: []any{resp.StatusCode, fullURL}, } case http.StatusUnauthorized, http.StatusForbidden: // failed to connect to db. account name may be wrong return nil, &SnowflakeError{ Number: ErrCodeFailedToConnect, SQLState: SQLStateConnectionRejected, Message: sferrors.ErrMsgFailedToConnect, MessageArgs: []any{resp.StatusCode, fullURL}, } } b, err := io.ReadAll(resp.Body) if err != nil { logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) return nil, err } logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) logger.WithContext(ctx).Infof("Header: %v", resp.Header) return nil, &SnowflakeError{ Number: ErrFailedToAuth, SQLState: SQLStateConnectionRejected, Message: sferrors.ErrMsgFailedToAuth, MessageArgs: []any{resp.StatusCode, fullURL}, } } // Generates a map of headers needed to authenticate // with Snowflake. func getHeaders() map[string]string { headers := make(map[string]string) headers[httpHeaderContentType] = headerContentTypeApplicationJSON headers[httpHeaderAccept] = headerAcceptTypeApplicationSnowflake headers[httpClientAppID] = clientType headers[httpClientAppVersion] = SnowflakeGoDriverVersion headers[httpHeaderUserAgent] = userAgent return headers } // Used to authenticate the user with Snowflake. func authenticate( ctx context.Context, sc *snowflakeConn, samlResponse []byte, proofKey []byte, ) (resp *authResponseMain, err error) { if sc.cfg.Authenticator == AuthTypeTokenAccessor { logger.WithContext(ctx).Info("Bypass authentication using existing token from token accessor") sessionInfo := authResponseSessionInfo{ DatabaseName: sc.cfg.Database, SchemaName: sc.cfg.Schema, WarehouseName: sc.cfg.Warehouse, RoleName: sc.cfg.Role, } token, masterToken, sessionID := sc.cfg.TokenAccessor.GetTokens() return &authResponseMain{ Token: token, MasterToken: masterToken, SessionID: sessionID, SessionInfo: sessionInfo, }, nil } headers := getHeaders() // Get the current application path applicationPath, err := os.Executable() if err != nil { logger.WithContext(ctx).Warnf("Failed to get executable path: %v", err) applicationPath = "unknown" } oauthType := "" switch sc.cfg.Authenticator { case AuthTypeOAuthAuthorizationCode: oauthType = "OAUTH_AUTHORIZATION_CODE" case AuthTypeOAuthClientCredentials: oauthType = "OAUTH_CLIENT_CREDENTIALS" } clientEnvironment := newAuthRequestClientEnvironment() clientEnvironment.Application = sc.cfg.Application clientEnvironment.ApplicationPath = applicationPath clientEnvironment.OAuthType = oauthType clientEnvironment.CertRevocationCheckMode = sc.cfg.CertRevocationCheckMode.String() clientEnvironment.Platform = getDetectedPlatforms() sessionParameters := make(map[string]any) for k, v := range sc.syncParams.All() { // upper casing to normalize keys sessionParameters[strings.ToUpper(k)] = v } sessionParameters[sessionClientValidateDefaultParameters] = sc.cfg.ValidateDefaultParameters != ConfigBoolFalse if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue { sessionParameters[clientRequestMfaToken] = true } if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { sessionParameters[clientStoreTemporaryCredential] = true } bodyCreator := func() ([]byte, error) { return createRequestBody(sc, sessionParameters, clientEnvironment, proofKey, samlResponse) } params := &url.Values{} if sc.cfg.Database != "" { params.Add("databaseName", sc.cfg.Database) } if sc.cfg.Schema != "" { params.Add("schemaName", sc.cfg.Schema) } if sc.cfg.Warehouse != "" { params.Add("warehouse", sc.cfg.Warehouse) } if sc.cfg.Role != "" { params.Add("roleName", sc.cfg.Role) } logger.WithContext(ctx).Infof("Information for Auth: Host: %v, User: %v, Authenticator: %v, Params: %v, Protocol: %v, Port: %v, LoginTimeout: %v", sc.rest.Host, sc.cfg.User, sc.cfg.Authenticator.String(), params, sc.rest.Protocol, sc.rest.Port, sc.rest.LoginTimeout) respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, sc.rest.getClientFor(sc.cfg.Authenticator), params, headers, bodyCreator, sc.rest.LoginTimeout) if err != nil { return nil, err } if !respd.Success { logger.WithContext(ctx).Error("Authentication FAILED") sc.rest.TokenAccessor.SetTokens("", "", -1) if sessionParameters[clientRequestMfaToken] == true { credentialsStorage.deleteCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) } if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser { credentialsStorage.deleteCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) } if sessionParameters[clientStoreTemporaryCredential] == true && isOauthNativeFlow(sc.cfg.Authenticator) { credentialsStorage.deleteCredential(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) } code, err := strconv.Atoi(respd.Code) if err != nil { return nil, err } return nil, exceptionTelemetry(&SnowflakeError{ Number: code, SQLState: SQLStateConnectionRejected, Message: respd.Message, }, sc) } logger.WithContext(ctx).Info("Authentication SUCCESS") sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID) if sessionParameters[clientRequestMfaToken] == true { token := respd.Data.MfaToken credentialsStorage.setCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token) } if sessionParameters[clientStoreTemporaryCredential] == true { token := respd.Data.IDToken credentialsStorage.setCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token) } return &respd.Data, nil } func newAuthRequestClientEnvironment() authRequestClientEnvironment { var coreVersion string var coreLoadError string // Try to get minicore version, but don't block if it's not loaded yet if !compilation.MinicoreEnabled { logger.Trace("minicore disabled at compile time") coreLoadError = "Minicore is disabled at compile time (built with -tags minicore_disabled)" } else if strings.EqualFold(os.Getenv(disableMinicoreEnv), "true") { logger.Trace("minicore loading disabled") coreLoadError = "Minicore is disabled with SF_DISABLE_MINICORE env variable" } else if mc := getMiniCore(); mc != nil { var err error coreVersion, err = mc.FullVersion() if err != nil { logger.Debugf("Minicore loading failed. %v", err) var mcErr *miniCoreError if errors.As(err, &mcErr) { coreLoadError = fmt.Sprintf("Failed to load binary: %v", mcErr.errorType) } else { coreLoadError = "Failed to load binary: unknown" } } } else { // Minicore not loaded yet - this is expected during startup coreVersion = "" coreLoadError = "Minicore is still loading" logger.Debugf("Minicore not yet loaded for client environment telemetry") } libcInfo := internalos.GetLibcInfo() linkingMode, err := compilation.CheckDynamicLinking() if err != nil { logger.Debugf("cannot determine if app is dynamically linked: %v", err) } return authRequestClientEnvironment{ Os: runtime.GOOS, OsVersion: osVersion, OsDetails: internalos.GetOsDetails(), Isa: runtime.GOARCH, GoVersion: runtime.Version(), CoreVersion: coreVersion, CoreFileName: getMiniCoreFileName(), CoreLoadError: coreLoadError, CgoEnabled: compilation.CgoEnabled, LinkingMode: linkingMode.String(), LibcFamily: libcInfo.Family, LibcVersion: libcInfo.Version, } } func createRequestBody(sc *snowflakeConn, sessionParameters map[string]any, clientEnvironment authRequestClientEnvironment, proofKey []byte, samlResponse []byte, ) ([]byte, error) { requestMain := authRequestData{ ClientAppID: clientType, ClientAppVersion: SnowflakeGoDriverVersion, AccountName: sc.cfg.Account, SessionParameters: sessionParameters, ClientEnvironment: clientEnvironment, } switch sc.cfg.Authenticator { case AuthTypeExternalBrowser: if sc.idToken != "" { requestMain.Authenticator = idTokenAuthenticator requestMain.Token = sc.idToken requestMain.LoginName = sc.cfg.User } else { requestMain.ProofKey = string(proofKey) requestMain.Token = string(samlResponse) requestMain.LoginName = sc.cfg.User requestMain.Authenticator = AuthTypeExternalBrowser.String() } case AuthTypeOAuth: requestMain.LoginName = sc.cfg.User requestMain.Authenticator = AuthTypeOAuth.String() var err error if requestMain.Token, err = sfconfig.GetToken(sc.cfg); err != nil { return nil, fmt.Errorf("failed to get OAuth token: %w", err) } case AuthTypeOkta: samlResponse, err := authenticateBySAML( sc.ctx, sc.rest, sc.cfg.OktaURL, sc.cfg.Application, sc.cfg.Account, sc.cfg.User, sc.cfg.Password, sc.cfg.DisableSamlURLCheck) if err != nil { return nil, err } requestMain.RawSAMLResponse = string(samlResponse) case AuthTypeJwt: requestMain.Authenticator = AuthTypeJwt.String() jwtTokenString, err := prepareJWTToken(sc.cfg) if err != nil { return nil, err } requestMain.Token = jwtTokenString case AuthTypePat: logger.WithContext(sc.ctx).Info("Programmatic access token") requestMain.Authenticator = AuthTypePat.String() requestMain.LoginName = sc.cfg.User var err error if requestMain.Token, err = sfconfig.GetToken(sc.cfg); err != nil { return nil, fmt.Errorf("failed to get PAT token: %w", err) } case AuthTypeSnowflake: logger.WithContext(sc.ctx).Debug("Username and password") requestMain.LoginName = sc.cfg.User requestMain.Password = sc.cfg.Password switch { case sc.cfg.PasscodeInPassword: requestMain.ExtAuthnDuoMethod = "passcode" case sc.cfg.Passcode != "": requestMain.Passcode = sc.cfg.Passcode requestMain.ExtAuthnDuoMethod = "passcode" } case AuthTypeUsernamePasswordMFA: logger.WithContext(sc.ctx).Debug("Username and password MFA") requestMain.LoginName = sc.cfg.User requestMain.Password = sc.cfg.Password switch { case sc.mfaToken != "": requestMain.Token = sc.mfaToken case sc.cfg.PasscodeInPassword: requestMain.ExtAuthnDuoMethod = "passcode" case sc.cfg.Passcode != "": requestMain.Passcode = sc.cfg.Passcode requestMain.ExtAuthnDuoMethod = "passcode" } case AuthTypeOAuthAuthorizationCode: logger.WithContext(sc.ctx).Debug("OAuth authorization code") token, err := authenticateByAuthorizationCode(sc) if err != nil { return nil, err } requestMain.LoginName = sc.cfg.User requestMain.Token = token case AuthTypeOAuthClientCredentials: logger.WithContext(sc.ctx).Debug("OAuth client credentials") oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc) if err != nil { return nil, err } token, err := oauthClient.authenticateByOAuthClientCredentials() if err != nil { return nil, err } requestMain.LoginName = sc.cfg.User requestMain.Token = token case AuthTypeWorkloadIdentityFederation: logger.WithContext(sc.ctx).Debug("Workload Identity Federation") wifAttestationProvider := createWifAttestationProvider(sc.ctx, sc.cfg, sc.telemetry) wifAttestation, err := wifAttestationProvider.getAttestation(sc.cfg.WorkloadIdentityProvider) if err != nil { return nil, err } if wifAttestation == nil { return nil, errors.New("workload identity federation attestation is not available, please check your configuration") } requestMain.Authenticator = AuthTypeWorkloadIdentityFederation.String() requestMain.Token = wifAttestation.Credential requestMain.Provider = wifAttestation.ProviderType } logger.WithContext(sc.ctx).Debugf("Request body is created for the authentication. Authenticator: %s, User: %s, Account: %s", sc.cfg.Authenticator.String(), sc.cfg.User, sc.cfg.Account) authRequest := authRequest{ Data: requestMain, } jsonBody, err := json.Marshal(authRequest) if err != nil { logger.WithContext(sc.ctx).Errorf("Failed to marshal JSON. err: %v", err) return nil, err } return jsonBody, nil } type oauthLockKey struct { tokenRequestURL string user string flowType string } func newOAuthAuthorizationCodeLockKey(tokenRequestURL, user string) *oauthLockKey { return &oauthLockKey{ tokenRequestURL: tokenRequestURL, user: user, flowType: "authorization_code", } } func newRefreshTokenLockKey(tokenRequestURL, user string) *oauthLockKey { return &oauthLockKey{ tokenRequestURL: tokenRequestURL, user: user, flowType: "refresh_token", } } func (o *oauthLockKey) lockID() string { return o.tokenRequestURL + "|" + o.user + "|" + o.flowType } func authenticateByAuthorizationCode(sc *snowflakeConn) (string, error) { oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc) if err != nil { return "", err } if !isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) { return oauthClient.authenticateByOAuthAuthorizationCode() } lockKey := newOAuthAuthorizationCodeLockKey(oauthClient.tokenURL(), sc.cfg.User) valueAwaiter := valueAwaitHolder.get(lockKey) defer valueAwaiter.resumeOne() token, err := awaitValue(valueAwaiter, func() (string, error) { return credentialsStorage.getCredential(newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)), nil }, func(s string, err error) bool { return s != "" }, func() string { return "" }) if err != nil || token != "" { return token, err } token, err = oauthClient.authenticateByOAuthAuthorizationCode() if err != nil { return "", err } valueAwaiter.done() return token, err } // Generate a JWT token in string given the configuration func prepareJWTToken(config *Config) (string, error) { if config.PrivateKey == nil { return "", errors.New("trying to use keypair authentication, but PrivateKey was not provided in the driver config") } logger.Debug("preparing JWT for keypair authentication") pubBytes, err := x509.MarshalPKIXPublicKey(config.PrivateKey.Public()) if err != nil { return "", err } hash := sha256.Sum256(pubBytes) accountName := sfconfig.ExtractAccountName(config.Account) userName := strings.ToUpper(config.User) issueAtTime := time.Now().UTC() jwtClaims := jwt.MapClaims{ "iss": fmt.Sprintf("%s.%s.%s", accountName, userName, "SHA256:"+base64.StdEncoding.EncodeToString(hash[:])), "sub": fmt.Sprintf("%s.%s", accountName, userName), "iat": issueAtTime.Unix(), "nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(), "exp": issueAtTime.Add(config.JWTExpireTimeout).Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwtClaims) tokenString, err := token.SignedString(config.PrivateKey) if err != nil { return "", err } logger.Debugf("successfully generated JWT with following claims: %v", jwtClaims) return tokenString, err } type tokenLockKey struct { snowflakeHost string user string tokenType string } func newMfaTokenLockKey(snowflakeHost, user string) *tokenLockKey { return &tokenLockKey{ snowflakeHost: snowflakeHost, user: user, tokenType: "MFA", } } func newIDTokenLockKey(snowflakeHost, user string) *tokenLockKey { return &tokenLockKey{ snowflakeHost: snowflakeHost, user: user, tokenType: "ID", } } func (m *tokenLockKey) lockID() string { return m.snowflakeHost + "|" + m.user + "|" + m.tokenType } func authenticateWithConfig(sc *snowflakeConn) error { var authData *authResponseMain var samlResponse []byte var proofKey []byte var err error mfaTokenLockKey := newMfaTokenLockKey(sc.cfg.Host, sc.cfg.User) idTokenLockKey := newIDTokenLockKey(sc.cfg.Host, sc.cfg.User) if sc.cfg.Authenticator == AuthTypeExternalBrowser || sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode || sc.cfg.Authenticator == AuthTypeOAuthClientCredentials { if (runtime.GOOS == "windows" || runtime.GOOS == "darwin") && sc.cfg.ClientStoreTemporaryCredential == sfconfig.BoolNotSet { sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue } if sc.cfg.Authenticator == AuthTypeExternalBrowser { if isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) { valueAwaiter := valueAwaitHolder.get(idTokenLockKey) defer valueAwaiter.resumeOne() sc.idToken, _ = awaitValue(valueAwaiter, func() (string, error) { credential := credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) return credential, nil }, func(s string, err error) bool { return s != "" }, func() string { return "" }) } else if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { sc.idToken = credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User)) } } // Disable console login by default if sc.cfg.DisableConsoleLogin == sfconfig.BoolNotSet { sc.cfg.DisableConsoleLogin = ConfigBoolTrue } } if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA { if (runtime.GOOS == "windows" || runtime.GOOS == "darwin") && sc.cfg.ClientRequestMfaToken == sfconfig.BoolNotSet { sc.cfg.ClientRequestMfaToken = ConfigBoolTrue } if isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientRequestMfaToken) { valueAwaiter := valueAwaitHolder.get(mfaTokenLockKey) defer valueAwaiter.resumeOne() sc.mfaToken, _ = awaitValue(valueAwaiter, func() (string, error) { credential := credentialsStorage.getCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) return credential, nil }, func(s string, err error) bool { return s != "" }, func() string { return "" }) } else if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue { sc.mfaToken = credentialsStorage.getCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User)) } } logger.WithContext(sc.ctx).Infof("Authenticating via %v", sc.cfg.Authenticator.String()) switch sc.cfg.Authenticator { case AuthTypeExternalBrowser: if sc.idToken == "" { samlResponse, proofKey, err = authenticateByExternalBrowser( sc.ctx, sc.rest, sc.cfg.Authenticator.String(), sc.cfg.Application, sc.cfg.Account, sc.cfg.User, sc.cfg.ExternalBrowserTimeout, sc.cfg.DisableConsoleLogin) if err != nil { sc.cleanup() return err } } } authData, err = authenticate( sc.ctx, sc, samlResponse, proofKey) if err != nil { var se *SnowflakeError if errors.As(err, &se) && slices.Contains(refreshOAuthTokenErrorCodes, strconv.Itoa(se.Number)) { credentialsStorage.deleteCredential(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode { doRefreshTokenWithLock(sc) } // if refreshing succeeds for authorization code, we will take a token from cache // if it fails, we will just run the full flow authData, err = authenticate(sc.ctx, sc, nil, nil) } if err != nil { sc.cleanup() return err } } if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA && isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientRequestMfaToken) { valueAwaiter := valueAwaitHolder.get(mfaTokenLockKey) valueAwaiter.done() } if sc.cfg.Authenticator == AuthTypeExternalBrowser && isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) { valueAwaiter := valueAwaitHolder.get(idTokenLockKey) valueAwaiter.done() } sc.populateSessionParameters(authData.Parameters) sc.configureTelemetry() sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID) return nil } func doRefreshTokenWithLock(sc *snowflakeConn) { if oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc); err != nil { logger.Warnf("failed to create oauth client. %v", err) } else { lockKey := newRefreshTokenLockKey(oauthClient.tokenURL(), sc.cfg.User) if _, err = getValueWithLock(chooseLockerForAuth(sc.cfg), lockKey, func() (string, error) { if err = oauthClient.refreshToken(); err != nil { logger.Warnf("cannot refresh token. %v", err) credentialsStorage.deleteCredential(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User)) return "", err } return "", nil }); err != nil { logger.Warnf("failed to refresh token with lock. %v", err) } } } func chooseLockerForAuth(cfg *Config) locker { if cfg.SingleAuthenticationPrompt == ConfigBoolFalse { return noopLocker } if cfg.User == "" { return noopLocker } return exclusiveLocker } func isEligibleForParallelLogin(cfg *Config, cacheEnabled ConfigBool) bool { return cfg.SingleAuthenticationPrompt != ConfigBoolFalse && cfg.User != "" && cacheEnabled == ConfigBoolTrue } ================================================ FILE: auth_generic_test_methods_test.go ================================================ package gosnowflake import ( "fmt" "os" "testing" ) func getAuthTestConfigFromEnv() (*Config, error) { return GetConfigFromEnv([]*ConfigParam{ {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, {Name: "User", EnvName: "SNOWFLAKE_AUTH_TEST_OKTA_USER", FailOnMissing: true}, {Name: "Password", EnvName: "SNOWFLAKE_AUTH_TEST_OKTA_PASS", FailOnMissing: true}, {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, {Name: "Protocol", EnvName: "SNOWFLAKE_AUTH_TEST_PROTOCOL", FailOnMissing: false}, {Name: "Role", EnvName: "SNOWFLAKE_TEST_ROLE", FailOnMissing: false}, {Name: "Warehouse", EnvName: "SNOWFLAKE_TEST_WAREHOUSE", FailOnMissing: false}, }) } func getAuthTestsConfig(t *testing.T, authMethod AuthType) (*Config, error) { cfg, err := getAuthTestConfigFromEnv() assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) cfg.Authenticator = authMethod return cfg, nil } func isTestRunningInDockerContainer() bool { return os.Getenv("AUTHENTICATION_TESTS_ENV") == "docker" } ================================================ FILE: auth_oauth.go ================================================ package gosnowflake import ( "bufio" "bytes" "cmp" "context" "encoding/json" "errors" "fmt" "html" "io" "net" "net/http" "net/url" "strconv" "strings" "time" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" ) const ( oauthSuccessHTML = ` OAuth for Snowflake OAuth authentication completed successfully. ` localApplicationClientCredentials = "LOCAL_APPLICATION" ) var defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider { return &browserBasedAuthorizationCodeProvider{} } type oauthClient struct { ctx context.Context cfg *Config client *http.Client port int redirectURITemplate string authorizationCodeProviderFactory func() authorizationCodeProvider } func newOauthClient(ctx context.Context, cfg *Config, sc *snowflakeConn) (*oauthClient, error) { port := 0 if cfg.OauthRedirectURI != "" { logger.Debugf("Using oauthRedirectUri from config: %v", cfg.OauthRedirectURI) uri, err := url.Parse(cfg.OauthRedirectURI) if err != nil { return nil, err } portStr := uri.Port() if portStr != "" { if port, err = strconv.Atoi(portStr); err != nil { return nil, err } } } redirectURITemplate := "" if cfg.OauthRedirectURI == "" { redirectURITemplate = "http://127.0.0.1:%v" } logger.Debugf("Redirect URI template: %v, port: %v", redirectURITemplate, port) transport, err := newTransportFactory(cfg, sc.telemetry).createTransport(transportConfigFor(transportTypeOAuth)) if err != nil { return nil, err } client := &http.Client{ Transport: transport, } return &oauthClient{ ctx: context.WithValue(ctx, oauth2.HTTPClient, client), cfg: cfg, client: client, port: port, redirectURITemplate: redirectURITemplate, authorizationCodeProviderFactory: defaultAuthorizationCodeProviderFactory, }, nil } type oauthBrowserResult struct { accessToken string refreshToken string err error } func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode() (string, error) { accessTokenSpec := oauthClient.accessTokenSpec() if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { if accessToken := credentialsStorage.getCredential(accessTokenSpec); accessToken != "" { logger.Debugf("Access token retrieved from cache") return accessToken, nil } if refreshToken := credentialsStorage.getCredential(oauthClient.refreshTokenSpec()); refreshToken != "" { return "", &SnowflakeError{Number: ErrMissingAccessATokenButRefreshTokenPresent} } } logger.Debugf("Access token not present in cache, running full auth code flow") resultChan := make(chan oauthBrowserResult, 1) tcpListener, callbackPort, err := oauthClient.setupListener() if err != nil { return "", err } defer func() { logger.Debug("Closing tcp listener") if err := tcpListener.Close(); err != nil { logger.Warnf("error while closing TCP listener. %v", err) } }() go GoroutineWrapper(oauthClient.ctx, func() { resultChan <- oauthClient.doAuthenticateByOAuthAuthorizationCode(tcpListener, callbackPort) }) select { case <-time.After(oauthClient.cfg.ExternalBrowserTimeout): return "", errors.New("authentication via browser timed out") case result := <-resultChan: if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { logger.Debug("saving oauth access token in cache") credentialsStorage.setCredential(oauthClient.accessTokenSpec(), result.accessToken) credentialsStorage.setCredential(oauthClient.refreshTokenSpec(), result.refreshToken) } return result.accessToken, result.err } } func (oauthClient *oauthClient) doAuthenticateByOAuthAuthorizationCode(tcpListener *net.TCPListener, callbackPort int) oauthBrowserResult { authCodeProvider := oauthClient.authorizationCodeProviderFactory() successChan := make(chan []byte) errChan := make(chan error) responseBodyChan := make(chan string, 2) closeListenerChan := make(chan bool, 2) defer func() { closeListenerChan <- true close(successChan) close(errChan) close(responseBodyChan) close(closeListenerChan) }() logger.Debugf("opening socket on port %v", callbackPort) defer func(tcpListener *net.TCPListener) { <-closeListenerChan }(tcpListener) go handleOAuthSocket(tcpListener, successChan, errChan, responseBodyChan, closeListenerChan) oauth2cfg := oauthClient.buildAuthorizationCodeConfig(callbackPort) codeVerifier := authCodeProvider.createCodeVerifier() state := authCodeProvider.createState() authorizationURL := oauth2cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(codeVerifier)) if err := authCodeProvider.run(authorizationURL); err != nil { responseBodyChan <- err.Error() closeListenerChan <- true return oauthBrowserResult{"", "", err} } err := <-errChan if err != nil { responseBodyChan <- err.Error() return oauthBrowserResult{"", "", err} } codeReqBytes := <-successChan codeReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(codeReqBytes))) if err != nil { responseBodyChan <- err.Error() return oauthBrowserResult{"", "", err} } logger.Debugf("Received authorization code from %v", oauthClient.authorizationURL()) tokenResponse, err := oauthClient.exchangeAccessToken(codeReq, state, oauth2cfg, codeVerifier, responseBodyChan) if err != nil { return oauthBrowserResult{"", "", err} } logger.Debugf("Received token from %v", oauthClient.tokenURL()) return oauthBrowserResult{tokenResponse.AccessToken, tokenResponse.RefreshToken, err} } func (oauthClient *oauthClient) setupListener() (*net.TCPListener, int, error) { tcpListener, err := createLocalTCPListener(oauthClient.port) if err != nil { return nil, 0, err } callbackPort := tcpListener.Addr().(*net.TCPAddr).Port logger.Debugf("oauthClient.port: %v, callbackPort: %v", oauthClient.port, callbackPort) return tcpListener, callbackPort, nil } func (oauthClient *oauthClient) exchangeAccessToken(codeReq *http.Request, state string, oauth2cfg *oauth2.Config, codeVerifier string, responseBodyChan chan string) (*oauth2.Token, error) { queryParams := codeReq.URL.Query() errorMsg := queryParams.Get("error") if errorMsg != "" { errorDesc := queryParams.Get("error_description") errMsg := fmt.Sprintf("error while getting authentication from oauth: %v. Details: %v", errorMsg, errorDesc) responseBodyChan <- html.EscapeString(errMsg) return nil, errors.New(errMsg) } receivedState := queryParams.Get("state") if state != receivedState { errMsg := "invalid oauth state received" responseBodyChan <- errMsg return nil, errors.New(errMsg) } code := queryParams.Get("code") opts := []oauth2.AuthCodeOption{oauth2.VerifierOption(codeVerifier)} if oauthClient.cfg.EnableSingleUseRefreshTokens { opts = append(opts, oauth2.SetAuthURLParam("enable_single_use_refresh_tokens", "true")) } token, err := oauth2cfg.Exchange(oauthClient.ctx, code, opts...) if err != nil { responseBodyChan <- err.Error() return nil, err } responseBodyChan <- oauthSuccessHTML return token, nil } func (oauthClient *oauthClient) buildAuthorizationCodeConfig(callbackPort int) *oauth2.Config { clientID, clientSecret := oauthClient.cfg.OauthClientID, oauthClient.cfg.OauthClientSecret if oauthClient.eligibleForDefaultClientCredentials() { clientID, clientSecret = localApplicationClientCredentials, localApplicationClientCredentials } oauthClient.logIfHTTPInUse(oauthClient.authorizationURL()) oauthClient.logIfHTTPInUse(oauthClient.tokenURL()) return &oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, RedirectURL: oauthClient.buildRedirectURI(callbackPort), Scopes: oauthClient.buildScopes(), Endpoint: oauth2.Endpoint{ AuthURL: oauthClient.authorizationURL(), TokenURL: oauthClient.tokenURL(), AuthStyle: oauth2.AuthStyleInHeader, }, } } func (oauthClient *oauthClient) eligibleForDefaultClientCredentials() bool { return oauthClient.cfg.OauthClientID == "" && oauthClient.cfg.OauthClientSecret == "" && oauthClient.isSnowflakeAsIDP() } func (oauthClient *oauthClient) isSnowflakeAsIDP() bool { return (oauthClient.cfg.OauthAuthorizationURL == "" || strings.Contains(oauthClient.cfg.OauthAuthorizationURL, oauthClient.cfg.Host)) && (oauthClient.cfg.OauthTokenRequestURL == "" || strings.Contains(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.Host)) } func (oauthClient *oauthClient) authorizationURL() string { return cmp.Or(oauthClient.cfg.OauthAuthorizationURL, oauthClient.defaultAuthorizationURL()) } func (oauthClient *oauthClient) defaultAuthorizationURL() string { return fmt.Sprintf("%v://%v:%v/oauth/authorize", oauthClient.cfg.Protocol, oauthClient.cfg.Host, oauthClient.cfg.Port) } func (oauthClient *oauthClient) tokenURL() string { return cmp.Or(oauthClient.cfg.OauthTokenRequestURL, oauthClient.defaultTokenURL()) } func (oauthClient *oauthClient) defaultTokenURL() string { return fmt.Sprintf("%v://%v:%v/oauth/token-request", oauthClient.cfg.Protocol, oauthClient.cfg.Host, oauthClient.cfg.Port) } func (oauthClient *oauthClient) buildRedirectURI(port int) string { if oauthClient.cfg.OauthRedirectURI != "" { return oauthClient.cfg.OauthRedirectURI } return fmt.Sprintf(oauthClient.redirectURITemplate, port) } func (oauthClient *oauthClient) buildScopes() []string { if oauthClient.cfg.OauthScope == "" { return []string{"session:role:" + oauthClient.cfg.Role} } scopes := strings.Split(oauthClient.cfg.OauthScope, " ") for i, scope := range scopes { scopes[i] = strings.TrimSpace(scope) } return scopes } func handleOAuthSocket(tcpListener *net.TCPListener, successChan chan []byte, errChan chan error, responseBodyChan chan string, closeListenerChan chan bool) { conn, err := tcpListener.AcceptTCP() if err != nil { logger.Warnf("error creating socket. %v", err) return } defer func() { if err := conn.Close(); err != nil { logger.Warnf("error while closing connection (%v -> %v). %v", conn.LocalAddr(), conn.RemoteAddr(), err) } }() var buf [bufSize]byte codeResp := bytes.NewBuffer(nil) for { readBytes, err := conn.Read(buf[:]) if err == io.EOF { break } if err != nil { errChan <- err return } codeResp.Write(buf[0:readBytes]) if readBytes < bufSize { break } } errChan <- nil successChan <- codeResp.Bytes() responseBody := <-responseBodyChan respToBrowser, err := buildResponse(responseBody) if err != nil { logger.Warnf("cannot create response to browser. %v", err) } _, err = conn.Write(respToBrowser.Bytes()) if err != nil { logger.Warnf("cannot write response to browser. %v", err) } closeListenerChan <- true } type authorizationCodeProvider interface { run(authorizationURL string) error createState() string createCodeVerifier() string } type browserBasedAuthorizationCodeProvider struct { } func (provider *browserBasedAuthorizationCodeProvider) run(authorizationURL string) error { return openBrowser(authorizationURL) } func (provider *browserBasedAuthorizationCodeProvider) createState() string { return NewUUID().String() } func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() string { return oauth2.GenerateVerifier() } func (oauthClient *oauthClient) authenticateByOAuthClientCredentials() (string, error) { accessTokenSpec := oauthClient.accessTokenSpec() if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { if accessToken := credentialsStorage.getCredential(accessTokenSpec); accessToken != "" { return accessToken, nil } } oauth2Cfg, err := oauthClient.buildClientCredentialsConfig() if err != nil { return "", err } token, err := oauth2Cfg.Token(oauthClient.ctx) if err != nil { return "", err } if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { credentialsStorage.setCredential(accessTokenSpec, token.AccessToken) } return token.AccessToken, nil } func (oauthClient *oauthClient) buildClientCredentialsConfig() (*clientcredentials.Config, error) { if oauthClient.cfg.OauthTokenRequestURL == "" { return nil, errors.New("client credentials flow requires tokenRequestURL") } return &clientcredentials.Config{ ClientID: oauthClient.cfg.OauthClientID, ClientSecret: oauthClient.cfg.OauthClientSecret, TokenURL: oauthClient.cfg.OauthTokenRequestURL, Scopes: oauthClient.buildScopes(), }, nil } func (oauthClient *oauthClient) refreshToken() error { if oauthClient.cfg.ClientStoreTemporaryCredential != ConfigBoolTrue { logger.Debug("credentials storage is disabled, cannot use refresh tokens") return nil } refreshTokenSpec := newOAuthRefreshTokenSpec(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.User) refreshToken := credentialsStorage.getCredential(refreshTokenSpec) if refreshToken == "" { logger.Debug("no refresh token in cache, full flow must be run") return nil } body := url.Values{} body.Add("grant_type", "refresh_token") body.Add("refresh_token", refreshToken) body.Add("scope", strings.Join(oauthClient.buildScopes(), " ")) req, err := http.NewRequest("POST", oauthClient.tokenURL(), strings.NewReader(body.Encode())) if err != nil { return err } req.SetBasicAuth(oauthClient.cfg.OauthClientID, oauthClient.cfg.OauthClientSecret) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") resp, err := oauthClient.client.Do(req) if err != nil { return err } defer func() { if err := resp.Body.Close(); err != nil { logger.Warnf("error while closing response body for %v. %v", req.URL, err) } }() if resp.StatusCode != 200 { respBody, err := io.ReadAll(resp.Body) if err != nil { return err } credentialsStorage.deleteCredential(refreshTokenSpec) return errors.New(string(respBody)) } var tokenResponse tokenExchangeResponseBody if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { return err } accessTokenSpec := oauthClient.accessTokenSpec() credentialsStorage.setCredential(accessTokenSpec, tokenResponse.AccessToken) if tokenResponse.RefreshToken != "" { credentialsStorage.setCredential(refreshTokenSpec, tokenResponse.RefreshToken) } return nil } type tokenExchangeResponseBody struct { AccessToken string `json:"access_token,omitempty"` RefreshToken string `json:"refresh_token"` } func (oauthClient *oauthClient) accessTokenSpec() *secureTokenSpec { return newOAuthAccessTokenSpec(oauthClient.tokenURL(), oauthClient.cfg.User) } func (oauthClient *oauthClient) refreshTokenSpec() *secureTokenSpec { return newOAuthRefreshTokenSpec(oauthClient.tokenURL(), oauthClient.cfg.User) } func (oauthClient *oauthClient) logIfHTTPInUse(u string) { parsed, err := url.Parse(u) if err != nil { logger.Warnf("Cannot parse URL: %v. %v", u, err) return } if parsed.Scheme == "http" { logger.Warnf("OAuth URL uses insecure HTTP protocol: %v", u) } } ================================================ FILE: auth_oauth_test.go ================================================ package gosnowflake import ( "context" "database/sql" "errors" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "io" "net/http" "sync" "testing" "time" "golang.org/x/oauth2" ) func TestUnitOAuthAuthorizationCode(t *testing.T) { skipOnMac(t, "keychain requires password") roundTripper := newCountingRoundTripper(createTestNoRevocationTransport()) httpClient := &http.Client{ Transport: roundTripper, } cfg := &Config{ User: "testUser", Role: "ANALYST", OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", OauthAuthorizationURL: wiremock.baseURL() + "/oauth/authorize", OauthTokenRequestURL: wiremock.baseURL() + "/oauth/token", OauthRedirectURI: "http://localhost:1234/snowflake/oauth-redirect", Transporter: roundTripper, ClientStoreTemporaryCredential: ConfigBoolTrue, ExternalBrowserTimeout: time.Duration(sfconfig.DefaultExternalBrowserTimeout), } client, err := newOauthClient(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient), cfg, &snowflakeConn{}) assertNilF(t, err) accessTokenSpec := newOAuthAccessTokenSpec(wiremock.connectionConfig().OauthTokenRequestURL, wiremock.connectionConfig().User) refreshTokenSpec := newOAuthRefreshTokenSpec(wiremock.connectionConfig().OauthTokenRequestURL, wiremock.connectionConfig().User) t.Run("Success", func(t *testing.T) { credentialsStorage.deleteCredential(accessTokenSpec) credentialsStorage.deleteCredential(refreshTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json")) authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t} client.authorizationCodeProviderFactory = func() authorizationCodeProvider { return authCodeProvider } token, err := client.authenticateByOAuthAuthorizationCode() assertNilF(t, err) assertEqualE(t, token, "access-token-123") time.Sleep(100 * time.Millisecond) authCodeProvider.assertResponseBodyContains("OAuth authentication completed successfully.") }) t.Run("Store access token in cache", func(t *testing.T) { skipOnMissingHome(t) roundTripper.reset() credentialsStorage.deleteCredential(accessTokenSpec) credentialsStorage.deleteCredential(refreshTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json")) authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t} client.authorizationCodeProviderFactory = func() authorizationCodeProvider { return authCodeProvider } _, err = client.authenticateByOAuthAuthorizationCode() assertNilF(t, err) assertEqualE(t, credentialsStorage.getCredential(accessTokenSpec), "access-token-123") }) t.Run("Use cache for consecutive calls", func(t *testing.T) { skipOnMissingHome(t) roundTripper.reset() credentialsStorage.setCredential(accessTokenSpec, "access-token-123") wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json")) authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t} for range 3 { client, err := newOauthClient(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient), cfg, &snowflakeConn{}) assertNilF(t, err) client.authorizationCodeProviderFactory = func() authorizationCodeProvider { return authCodeProvider } _, err = client.authenticateByOAuthAuthorizationCode() assertNilF(t, err) } assertEqualE(t, authCodeProvider.responseBody, "") assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 0) }) t.Run("InvalidState", func(t *testing.T) { credentialsStorage.deleteCredential(accessTokenSpec) credentialsStorage.deleteCredential(refreshTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json")) authCodeProvider := &nonInteractiveAuthorizationCodeProvider{ tamperWithState: true, t: t, } client.authorizationCodeProviderFactory = func() authorizationCodeProvider { return authCodeProvider } _, err = client.authenticateByOAuthAuthorizationCode() assertEqualE(t, err.Error(), "invalid oauth state received") time.Sleep(100 * time.Millisecond) authCodeProvider.assertResponseBodyContains("invalid oauth state received") }) t.Run("ErrorFromIdPWhileGettingCode", func(t *testing.T) { credentialsStorage.deleteCredential(accessTokenSpec) credentialsStorage.deleteCredential(refreshTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/error_from_idp.json")) authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t} client.authorizationCodeProviderFactory = func() authorizationCodeProvider { return authCodeProvider } _, err = client.authenticateByOAuthAuthorizationCode() assertEqualE(t, err.Error(), "error while getting authentication from oauth: some error. Details: some error desc") time.Sleep(100 * time.Millisecond) authCodeProvider.assertResponseBodyContains("error while getting authentication from oauth: some error. Details: some error desc") }) t.Run("ErrorFromProviderWhileGettingCode", func(t *testing.T) { authCodeProvider := &nonInteractiveAuthorizationCodeProvider{ triggerError: "test error", } client.authorizationCodeProviderFactory = func() authorizationCodeProvider { return authCodeProvider } _, err = client.authenticateByOAuthAuthorizationCode() assertEqualE(t, err.Error(), "test error") }) t.Run("InvalidCode", func(t *testing.T) { credentialsStorage.deleteCredential(accessTokenSpec) credentialsStorage.deleteCredential(refreshTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/invalid_code.json")) authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t} client.authorizationCodeProviderFactory = func() authorizationCodeProvider { return authCodeProvider } _, err = client.authenticateByOAuthAuthorizationCode() assertNotNilE(t, err) assertEqualE(t, err.(*oauth2.RetrieveError).ErrorCode, "invalid_grant") assertEqualE(t, err.(*oauth2.RetrieveError).ErrorDescription, "The authorization code is invalid or has expired.") time.Sleep(100 * time.Millisecond) authCodeProvider.assertResponseBodyContains("invalid_grant") }) t.Run("timeout", func(t *testing.T) { credentialsStorage.deleteCredential(accessTokenSpec) credentialsStorage.deleteCredential(refreshTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json")) client.cfg.ExternalBrowserTimeout = 2 * time.Second authCodeProvider := &nonInteractiveAuthorizationCodeProvider{ sleepTime: 3 * time.Second, triggerError: "timed out", t: t, } client.authorizationCodeProviderFactory = func() authorizationCodeProvider { return authCodeProvider } _, err = client.authenticateByOAuthAuthorizationCode() assertNotNilE(t, err) assertStringContainsE(t, err.Error(), "timed out") time.Sleep(2 * time.Second) // awaiting timeout }) } func TestUnitOAuthClientCredentials(t *testing.T) { skipOnMac(t, "keychain requires password") cacheTokenSpec := newOAuthAccessTokenSpec(wiremock.connectionConfig().OauthTokenRequestURL, wiremock.connectionConfig().User) crt := newCountingRoundTripper(createTestNoRevocationTransport()) httpClient := http.Client{ Transport: crt, } cfgFactory := func() *Config { return &Config{ User: "testUser", Role: "ANALYST", OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", OauthTokenRequestURL: wiremock.baseURL() + "/oauth/token", Transporter: crt, ClientStoreTemporaryCredential: ConfigBoolTrue, } } client, err := newOauthClient(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient), cfgFactory(), &snowflakeConn{}) assertNilF(t, err) t.Run("success", func(t *testing.T) { credentialsStorage.deleteCredential(cacheTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json")) token, err := client.authenticateByOAuthClientCredentials() assertNilF(t, err) assertEqualE(t, token, "access-token-123") }) t.Run("should store token in cache", func(t *testing.T) { skipOnMissingHome(t) crt.reset() credentialsStorage.deleteCredential(cacheTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json")) token, err := client.authenticateByOAuthClientCredentials() assertNilF(t, err) assertEqualE(t, token, "access-token-123") client, err := newOauthClient(context.Background(), cfgFactory(), &snowflakeConn{}) assertNilF(t, err) token, err = client.authenticateByOAuthClientCredentials() assertNilF(t, err) assertEqualE(t, token, "access-token-123") assertEqualE(t, crt.postReqCount[cfgFactory().OauthTokenRequestURL], 1) }) t.Run("consecutive calls should take token from cache", func(t *testing.T) { skipOnMissingHome(t) crt.reset() credentialsStorage.setCredential(cacheTokenSpec, "access-token-123") for range 3 { client, err := newOauthClient(context.Background(), cfgFactory(), &snowflakeConn{}) assertNilF(t, err) token, err := client.authenticateByOAuthClientCredentials() assertNilF(t, err) assertEqualE(t, token, "access-token-123") } assertEqualE(t, crt.postReqCount[cfgFactory().OauthTokenRequestURL], 0) }) t.Run("disabling cache", func(t *testing.T) { skipOnMissingHome(t) cfg := cfgFactory() cfg.ClientStoreTemporaryCredential = ConfigBoolFalse credentialsStorage.deleteCredential(cacheTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json")) client, err := newOauthClient(context.Background(), cfg, &snowflakeConn{}) assertNilF(t, err) token, err := client.authenticateByOAuthClientCredentials() assertNilF(t, err) assertEqualE(t, token, "access-token-123") client, err = newOauthClient(context.Background(), cfg, &snowflakeConn{}) assertNilF(t, err) token, err = client.authenticateByOAuthClientCredentials() assertNilF(t, err) assertEqualE(t, token, "access-token-123") assertEqualE(t, crt.postReqCount[cfg.OauthTokenRequestURL], 2) }) t.Run("invalid_client", func(t *testing.T) { credentialsStorage.deleteCredential(cacheTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/invalid_client.json")) _, err = client.authenticateByOAuthClientCredentials() assertNotNilF(t, err) oauth2Err := err.(*oauth2.RetrieveError) assertEqualE(t, oauth2Err.ErrorCode, "invalid_client") assertEqualE(t, oauth2Err.ErrorDescription, "The client secret supplied for a confidential client is invalid.") }) } func TestAuthorizationCodeFlow(t *testing.T) { if runningOnGithubAction() && runningOnLinux() { t.Skip("Github blocks writing to file system") } skipOnMac(t, "keychain requires password") currentDefaultAuthorizationCodeProviderFactory := defaultAuthorizationCodeProviderFactory defer func() { defaultAuthorizationCodeProviderFactory = currentDefaultAuthorizationCodeProviderFactory }() defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider { return &nonInteractiveAuthorizationCodeProvider{ t: t, mu: sync.Mutex{}, } } roundTripper := newCountingRoundTripper(createTestNoRevocationTransport()) t.Run("successful flow", func(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) cfg := wiremock.connectionConfig() cfg.Role = "ANALYST" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.deleteCredential(oauthAccessTokenSpec) credentialsStorage.deleteCredential(oauthRefreshTokenSpec) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) }) t.Run("successful flow with multiple threads", func(t *testing.T) { for _, singleAuthenticationPrompt := range []ConfigBool{ConfigBoolFalse, ConfigBoolTrue, configBoolNotSet} { t.Run("singleAuthenticationPrompt="+singleAuthenticationPrompt.String(), func(t *testing.T) { currentDefaultAuthorizationCodeProviderFactory := defaultAuthorizationCodeProviderFactory defer func() { defaultAuthorizationCodeProviderFactory = currentDefaultAuthorizationCodeProviderFactory }() defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider { return &nonInteractiveAuthorizationCodeProvider{ t: t, mu: sync.Mutex{}, sleepTime: 500 * time.Millisecond, } } roundTripper.reset() wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json"), newWiremockMapping("close_session.json")) cfg := wiremock.connectionConfig() cfg.Role = "ANALYST" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.Transporter = roundTripper cfg.SingleAuthenticationPrompt = singleAuthenticationPrompt oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.deleteCredential(oauthAccessTokenSpec) credentialsStorage.deleteCredential(oauthRefreshTokenSpec) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) initPoolWithSize(t, db, 20) println(roundTripper.postReqCount[cfg.OauthTokenRequestURL]) if singleAuthenticationPrompt == ConfigBoolFalse { assertTrueE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL] > 1) } else { assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) } }) } }) t.Run("successful flow with single-use refresh token enabled", func(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow_with_single_use_refresh_token.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) cfg := wiremock.connectionConfig() cfg.Role = "ANALYST" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper cfg.EnableSingleUseRefreshTokens = true oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.deleteCredential(oauthAccessTokenSpec) credentialsStorage.deleteCredential(oauthRefreshTokenSpec) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) }) t.Run("should use cached access token", func(t *testing.T) { roundTripper.reset() wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) cfg := wiremock.connectionConfig() cfg.Role = "ANALYST" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.deleteCredential(oauthAccessTokenSpec) credentialsStorage.deleteCredential(oauthRefreshTokenSpec) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) conn1, err := db.Conn(context.Background()) assertNilF(t, err) defer conn1.Close() conn2, err := db.Conn(context.Background()) assertNilF(t, err) defer conn2.Close() runSmokeQueryWithConn(t, conn1) runSmokeQueryWithConn(t, conn2) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) }) t.Run("should update cache with new token when the old one expired if refresh token is missing", func(t *testing.T) { roundTripper.reset() wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"), newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) cfg := wiremock.connectionConfig() cfg.Role = "ANALYST" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token") credentialsStorage.deleteCredential(oauthRefreshTokenSpec) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123") }) t.Run("if access token is missing and refresh token is present, should run refresh token flow", func(t *testing.T) { roundTripper.reset() cfg := wiremock.connectionConfig() cfg.OauthScope = "session:role:ANALYST offline_access" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.deleteCredential(oauthAccessTokenSpec) credentialsStorage.setCredential(oauthRefreshTokenSpec, "refresh-token-123") wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"), newWiremockMapping("auth/oauth2/refresh_token/successful_flow.json"), newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) // only refresh token assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123") assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123a") }) t.Run("if access token is expired and refresh token is present, should run refresh token flow", func(t *testing.T) { roundTripper.reset() cfg := wiremock.connectionConfig() cfg.OauthScope = "session:role:ANALYST offline_access" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token") credentialsStorage.setCredential(oauthRefreshTokenSpec, "refresh-token-123") wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"), newWiremockMapping("auth/oauth2/refresh_token/successful_flow.json"), newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) // only refresh token assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123") assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123a") }) t.Run("if new refresh token is not returned, should keep old one", func(t *testing.T) { roundTripper.reset() cfg := wiremock.connectionConfig() cfg.OauthScope = "session:role:ANALYST offline_access" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token") credentialsStorage.setCredential(oauthRefreshTokenSpec, "refresh-token-123") wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"), newWiremockMapping("auth/oauth2/refresh_token/successful_flow_without_new_refresh_token.json"), newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) // only refresh token assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123") assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123") }) t.Run("if refreshing token failed, run normal flow", func(t *testing.T) { roundTripper.reset() cfg := wiremock.connectionConfig() cfg.OauthScope = "session:role:ANALYST offline_access" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token") credentialsStorage.setCredential(oauthRefreshTokenSpec, "expired-refresh-token") wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"), newWiremockMapping("auth/oauth2/refresh_token/invalid_refresh_token.json"), newWiremockMapping("auth/oauth2/authorization_code/successful_flow_with_offline_access.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 2) // only refresh token fails, then authorization code assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123") assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123") }) t.Run("if secure storage is disabled, run normal flow", func(t *testing.T) { roundTripper.reset() cfg := wiremock.connectionConfig() cfg.OauthScope = "session:role:ANALYST offline_access" cfg.Authenticator = AuthTypeOAuthAuthorizationCode cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect" cfg.Transporter = roundTripper cfg.ClientStoreTemporaryCredential = ConfigBoolFalse oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) credentialsStorage.setCredential(oauthAccessTokenSpec, "old-access-token") credentialsStorage.setCredential(oauthRefreshTokenSpec, "old-refresh-token") wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow_with_offline_access.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) // only access token token assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "old-access-token") assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "old-refresh-token") }) } func TestClientCredentialsFlow(t *testing.T) { if runningOnGithubAction() && runningOnLinux() { t.Skip("Github blocks writing to file system") } skipOnMac(t, "keychain requires password") currentDefaultAuthorizationCodeProviderFactory := defaultAuthorizationCodeProviderFactory defer func() { defaultAuthorizationCodeProviderFactory = currentDefaultAuthorizationCodeProviderFactory }() defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider { return &nonInteractiveAuthorizationCodeProvider{ t: t, mu: sync.Mutex{}, } } roundTripper := newCountingRoundTripper(createTestNoRevocationTransport()) cfg := wiremock.connectionConfig() cfg.Role = "ANALYST" cfg.Authenticator = AuthTypeOAuthClientCredentials cfg.Transporter = roundTripper oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User) oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User) t.Run("successful flow", func(t *testing.T) { credentialsStorage.deleteCredential(oauthAccessTokenSpec) wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) }) t.Run("should use cached access token", func(t *testing.T) { roundTripper.reset() wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) credentialsStorage.deleteCredential(oauthAccessTokenSpec) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) conn1, err := db.Conn(context.Background()) assertNilF(t, err) defer conn1.Close() conn2, err := db.Conn(context.Background()) assertNilF(t, err) defer conn2.Close() runSmokeQueryWithConn(t, conn1) runSmokeQueryWithConn(t, conn2) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) }) t.Run("should update cache with new token when the old one expired", func(t *testing.T) { roundTripper.reset() wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"), newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token") connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123") }) t.Run("should not use refresh token, but ask for fresh access token", func(t *testing.T) { roundTripper.reset() wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"), newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token") credentialsStorage.setCredential(oauthRefreshTokenSpec, "refresh-token-123") connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123") assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123") }) t.Run("should not use access token if token cache is disabled", func(t *testing.T) { roundTripper.reset() wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"), newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"), newWiremockMapping("auth/oauth2/login_request.json"), newWiremockMapping("select1.json")) credentialsStorage.setCredential(oauthAccessTokenSpec, "access-token-123") cfg.ClientStoreTemporaryCredential = ConfigBoolFalse connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123") }) } func TestEligibleForDefaultClientCredentials(t *testing.T) { tests := []struct { name string oauthClient *oauthClient expected bool }{ { name: "Client credentials not supplied and Snowflake as IdP", oauthClient: &oauthClient{ cfg: &Config{ Host: "example.snowflakecomputing.com", OauthClientID: "", OauthClientSecret: "", OauthAuthorizationURL: "https://example.snowflakecomputing.com/oauth/authorize", OauthTokenRequestURL: "https://example.snowflakecomputing.com/oauth/token", }, }, expected: true, }, { name: "Client credentials not supplied and empty URLs (defaults to Snowflake)", oauthClient: &oauthClient{ cfg: &Config{ Host: "example.snowflakecomputing.com", OauthClientID: "", OauthClientSecret: "", OauthAuthorizationURL: "", OauthTokenRequestURL: "", }, }, expected: true, }, { name: "Client credentials supplied", oauthClient: &oauthClient{ cfg: &Config{ Host: "example.snowflakecomputing.com", OauthClientID: "testClientID", OauthClientSecret: "testClientSecret", OauthAuthorizationURL: "https://example.snowflakecomputing.com/oauth/authorize", OauthTokenRequestURL: "https://example.snowflakecomputing.com/oauth/token", }, }, expected: false, }, { name: "Only client ID supplied", oauthClient: &oauthClient{ cfg: &Config{ Host: "example.snowflakecomputing.com", OauthClientID: "testClientID", OauthClientSecret: "", OauthAuthorizationURL: "https://example.snowflakecomputing.com/oauth/authorize", OauthTokenRequestURL: "https://example.snowflakecomputing.com/oauth/token", }, }, expected: false, }, { name: "Non-Snowflake IdP", oauthClient: &oauthClient{ cfg: &Config{ Host: "example.snowflakecomputing.com", OauthClientID: "", OauthClientSecret: "", OauthAuthorizationURL: "https://example.com/oauth/authorize", OauthTokenRequestURL: "https://example.com/oauth/token", }, }, expected: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { result := test.oauthClient.eligibleForDefaultClientCredentials() if result != test.expected { t.Errorf("expected %v, got %v", test.expected, result) } }) } } type nonInteractiveAuthorizationCodeProvider struct { t *testing.T tamperWithState bool triggerError string responseBody string mu sync.Mutex sleepTime time.Duration } func (provider *nonInteractiveAuthorizationCodeProvider) run(authorizationURL string) error { if provider.sleepTime != 0 { time.Sleep(provider.sleepTime) if provider.triggerError != "" { return errors.New(provider.triggerError) } } if provider.triggerError != "" { return errors.New(provider.triggerError) } go func() { resp, err := http.Get(authorizationURL) assertNilF(provider.t, err) assertEqualE(provider.t, resp.StatusCode, http.StatusOK) respBody, err := io.ReadAll(resp.Body) assertNilF(provider.t, err) provider.mu.Lock() defer provider.mu.Unlock() provider.responseBody = string(respBody) }() return nil } func (provider *nonInteractiveAuthorizationCodeProvider) createState() string { if provider.tamperWithState { return "invalidState" } return "testState" } func (provider *nonInteractiveAuthorizationCodeProvider) createCodeVerifier() string { return "testCodeVerifier" } func (provider *nonInteractiveAuthorizationCodeProvider) assertResponseBodyContains(str string) { provider.mu.Lock() defer provider.mu.Unlock() assertStringContainsE(provider.t, provider.responseBody, str) } ================================================ FILE: auth_test.go ================================================ package gosnowflake import ( "cmp" "context" "crypto/rand" "crypto/rsa" "database/sql" "encoding/json" "errors" "fmt" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "net/http" "net/url" "os" "path/filepath" "runtime" "testing" "time" "github.com/golang-jwt/jwt/v5" ) func TestUnitPostAuth(t *testing.T) { sr := &snowflakeRestful{ TokenAccessor: getSimpleTokenAccessor(), FuncAuthPost: postAuthTestAfterRenew, } var err error bodyCreator := func() ([]byte, error) { return []byte{0x12, 0x34}, nil } _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err != nil { t.Fatalf("err: %v", err) } sr.FuncAuthPost = postAuthTestError _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppBadGatewayError _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppForbiddenError _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppUnexpectedError _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } } func postAuthFailServiceIssue(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return nil, &SnowflakeError{ Number: ErrCodeServiceUnavailable, } } func postAuthFailWrongAccount(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return nil, &SnowflakeError{ Number: ErrCodeFailedToConnect, } } func postAuthFailUnknown(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return nil, &SnowflakeError{ Number: ErrFailedToAuth, } } func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: false, Code: "98765", Message: "wrong!", }, nil } func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: false, Code: "abcdef", Message: "wrong!", }, nil } func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func postAuthCheckSAMLResponse(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, err := bodyCreator() if err != nil { return nil, err } if err = json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.RawSAMLResponse == "" { return nil, errors.New("SAML response is empty") } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } // Checks that the request body generated when authenticating with OAuth // contains all the necessary values. func postAuthCheckOAuth( _ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration, ) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.Authenticator != AuthTypeOAuth.String() { return nil, errors.New("Authenticator is not OAUTH") } if ar.Data.Token == "" { return nil, errors.New("Token is empty") } if ar.Data.LoginName == "" { return nil, errors.New("Login name is empty") } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.Passcode != "987654321" || ar.Data.ExtAuthnDuoMethod != "passcode" { return nil, fmt.Errorf("passcode didn't match. expected: 987654321, got: %v, duo: %v", ar.Data.Passcode, ar.Data.ExtAuthnDuoMethod) } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.Passcode != "" || ar.Data.ExtAuthnDuoMethod != "passcode" { return nil, fmt.Errorf("passcode must be empty, got: %v, duo: %v", ar.Data.Passcode, ar.Data.ExtAuthnDuoMethod) } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"] != true { return nil, fmt.Errorf("expected client_request_mfa_token to be true but was %v", ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"]) } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", MfaToken: "mockedMfaToken", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.Token != "mockedMfaToken" { return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token) } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", MfaToken: "mockedMfaToken", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.Token != "mockedMfaToken" { return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token) } return &authResponse{ Success: false, Data: authResponseMain{}, Message: "auth failed", Code: "260008", }, nil } func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true { return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"]) } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", IDToken: "mockedIDToken", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.Token != "mockedIDToken" { return nil, fmt.Errorf("unexpected mfatoken: %v", ar.Data.Token) } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", IDToken: "mockedIDToken", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true { return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"]) } return &authResponse{ Success: false, Data: authResponseMain{}, Message: "auth failed", Code: "260008", }, nil } type restfulTestWrapper struct { t *testing.T } func (rtw restfulTestWrapper) postAuthOktaWithNewToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest cfg := &Config{ Authenticator: AuthTypeOkta, } // Retry 3 times and success client := &fakeHTTPClient{ cnt: 3, success: true, statusCode: 429, t: rtw.t, } urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_guid=testguid") if err != nil { return &authResponse{}, err } body := func() ([]byte, error) { jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } return jsonBody, err } _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, cfg).doPost().setBodyCreator(body).execute() if err != nil { return &authResponse{}, err } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", MfaToken: "mockedMfaToken", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } func getDefaultSnowflakeConn() *snowflakeConn { sc := &snowflakeConn{ rest: &snowflakeRestful{ TokenAccessor: getSimpleTokenAccessor(), }, cfg: &Config{ Account: "a", User: "u", Password: "p", Database: "d", Schema: "s", Warehouse: "w", Role: "r", Region: "", PasscodeInPassword: false, Passcode: "", Application: "testapp", }, telemetry: &snowflakeTelemetry{enabled: false}, } return sc } func TestUnitAuthenticateWithTokenAccessor(t *testing.T) { expectedSessionID := int64(123) expectedMasterToken := "master_token" expectedToken := "auth_token" ta := getSimpleTokenAccessor() ta.SetTokens(expectedToken, expectedMasterToken, expectedSessionID) sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeTokenAccessor sc.cfg.TokenAccessor = ta sr := &snowflakeRestful{ FuncPostAuth: postAuthFailServiceIssue, TokenAccessor: ta, } sc.rest = sr // FuncPostAuth is set to fail, but AuthTypeTokenAccessor should not even make a call to FuncPostAuth resp, err := authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("should not have failed, err %v", err) } if resp.SessionID != expectedSessionID { t.Fatalf("Expected session id %v but got %v", expectedSessionID, resp.SessionID) } if resp.Token != expectedToken { t.Fatalf("Expected token %v but got %v", expectedToken, resp.Token) } if resp.MasterToken != expectedMasterToken { t.Fatalf("Expected master token %v but got %v", expectedMasterToken, resp.MasterToken) } if resp.SessionInfo.DatabaseName != sc.cfg.Database { t.Fatalf("Expected database %v but got %v", sc.cfg.Database, resp.SessionInfo.DatabaseName) } if resp.SessionInfo.WarehouseName != sc.cfg.Warehouse { t.Fatalf("Expected warehouse %v but got %v", sc.cfg.Warehouse, resp.SessionInfo.WarehouseName) } if resp.SessionInfo.RoleName != sc.cfg.Role { t.Fatalf("Expected role %v but got %v", sc.cfg.Role, resp.SessionInfo.RoleName) } if resp.SessionInfo.SchemaName != sc.cfg.Schema { t.Fatalf("Expected schema %v but got %v", sc.cfg.Schema, resp.SessionInfo.SchemaName) } } func TestUnitAuthenticate(t *testing.T) { var err error var driverErr *SnowflakeError var ok bool ta := getSimpleTokenAccessor() sc := getDefaultSnowflakeConn() sr := &snowflakeRestful{ FuncPostAuth: postAuthFailServiceIssue, TokenAccessor: ta, } sc.rest = sr _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } driverErr, ok = err.(*SnowflakeError) if !ok || driverErr.Number != ErrCodeServiceUnavailable { t.Fatalf("Snowflake error is expected. err: %v", driverErr) } sr.FuncPostAuth = postAuthFailWrongAccount _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } driverErr, ok = err.(*SnowflakeError) if !ok || driverErr.Number != ErrCodeFailedToConnect { t.Fatalf("Snowflake error is expected. err: %v", driverErr) } sr.FuncPostAuth = postAuthFailUnknown _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } driverErr, ok = err.(*SnowflakeError) if !ok || driverErr.Number != ErrFailedToAuth { t.Fatalf("Snowflake error is expected. err: %v", driverErr) } ta.SetTokens("bad-token", "bad-master-token", 1) sr.FuncPostAuth = postAuthSuccessWithErrorCode _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } newToken, newMasterToken, newSessionID := ta.GetTokens() if newToken != "" || newMasterToken != "" || newSessionID != -1 { t.Fatalf("failed auth should have reset tokens: %v %v %v", newToken, newMasterToken, newSessionID) } driverErr, ok = err.(*SnowflakeError) if !ok || driverErr.Number != 98765 { t.Fatalf("Snowflake error is expected. err: %v", driverErr) } ta.SetTokens("bad-token", "bad-master-token", 1) sr.FuncPostAuth = postAuthSuccessWithInvalidErrorCode _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } oldToken, oldMasterToken, oldSessionID := ta.GetTokens() if oldToken != "" || oldMasterToken != "" || oldSessionID != -1 { t.Fatalf("failed auth should have reset tokens: %v %v %v", oldToken, oldMasterToken, oldSessionID) } sr.FuncPostAuth = postAuthSuccess var resp *authResponseMain resp, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to auth. err: %v", err) } if resp.SessionInfo.DatabaseName != "dbn" { t.Fatalf("failed to get response from auth") } newToken, newMasterToken, newSessionID = ta.GetTokens() if newToken == oldToken { t.Fatalf("new token was not set: %v", newToken) } if newMasterToken == oldMasterToken { t.Fatalf("new master token was not set: %v", newMasterToken) } if newSessionID == oldSessionID { t.Fatalf("new session id was not set: %v", newSessionID) } } func TestUnitAuthenticateSaml(t *testing.T) { var err error sr := &snowflakeRestful{ Protocol: "https", Host: "abc.com", Port: 443, FuncPostAuthSAML: postAuthSAMLAuthSuccess, FuncPostAuthOKTA: postAuthOKTASuccess, FuncGetSSO: getSSOSuccess, FuncPostAuth: postAuthCheckSAMLResponse, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeOkta sc.cfg.OktaURL = &url.URL{ Scheme: "https", Host: "abc.com", } sc.rest = sr _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) assertNilF(t, err, "failed to run.") } // Unit test for OAuth. func TestUnitAuthenticateOAuth(t *testing.T) { var err error sr := &snowflakeRestful{ FuncPostAuth: postAuthCheckOAuth, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Token = "oauthToken" sc.cfg.Authenticator = AuthTypeOAuth sc.rest = sr _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } } func TestUnitAuthenticatePasscode(t *testing.T) { var err error sr := &snowflakeRestful{ FuncPostAuth: postAuthCheckPasscode, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Passcode = "987654321" sc.rest = sr _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckPasscodeInPassword sc.rest = sr sc.cfg.PasscodeInPassword = true _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } } // Test JWT function in the local environment against the validation function in go func TestUnitAuthenticateJWT(t *testing.T) { var err error // Generate a fresh private key for this unit test only localTestKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate test private key: %s", err.Error()) } // Create custom JWT verification function that uses the local key postAuthCheckLocalJWTToken := func(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } if ar.Data.Authenticator != AuthTypeJwt.String() { return nil, errors.New("Authenticator is not JWT") } tokenString := ar.Data.Token // Validate token using the local test key's public key _, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) } return localTestKey.Public(), nil // Use local key for verification }) if err != nil { return nil, err } return &authResponse{ Success: true, Data: authResponseMain{ Token: "t", MasterToken: "m", SessionInfo: authResponseSessionInfo{ DatabaseName: "dbn", }, }, }, nil } sr := &snowflakeRestful{ FuncPostAuth: postAuthCheckLocalJWTToken, // Use local verification function TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeJwt sc.cfg.JWTExpireTimeout = time.Duration(sfconfig.DefaultJWTTimeout) sc.cfg.PrivateKey = localTestKey sc.rest = sr // A valid JWT token should pass if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err != nil { t.Fatalf("failed to run. err: %v", err) } // An invalid JWT token should not pass invalidPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Error(err) } sc.cfg.PrivateKey = invalidPrivateKey if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err == nil { t.Fatalf("invalid token passed") } } func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) { var err error sr := &snowflakeRestful{ FuncPostAuth: postAuthCheckUsernamePasswordMfa, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA sc.cfg.ClientRequestMfaToken = ConfigBoolTrue sc.rest = sr _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaToken sc.mfaToken = "mockedMfaToken" _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaFailed _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed") } } func TestUnitAuthenticateWithConfigMFA(t *testing.T) { var err error sr := &snowflakeRestful{ FuncPostAuth: postAuthCheckUsernamePasswordMfa, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA sc.cfg.ClientRequestMfaToken = ConfigBoolTrue sc.rest = sr sc.ctx = context.Background() err = authenticateWithConfig(sc) if err != nil { t.Fatalf("failed to run. err: %v", err) } } // This test creates two groups of scenarios: // a) singleAuthenticationPrompt=true - in this case, we start authenticating threads at once, // but due to locking mechanism only one should reach wiremock without MFA token. // b) singleAuthenticationPrompt=false - in this case, there is no locking, so all threads should rush, // but on Wiremock only first will be served with correct response (simulating a user confirming MFA only once). // The remaining threads should return error. func TestMfaParallelLogin(t *testing.T) { skipOnMissingHome(t) skipOnMac(t, "interactive keyring access not available on macOS runners") cfg := wiremock.connectionConfig() tokenSpec := newMfaTokenSpec(cfg.Host, cfg.User) for _, singleAuthenticationPrompt := range []ConfigBool{ConfigBoolTrue, ConfigBoolFalse} { t.Run("starts without mfa token, singleAuthenticationPrompt="+singleAuthenticationPrompt.String(), func(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("auth/mfa/parallel_login_successful_flow.json"), newWiremockMapping("select1.json"), newWiremockMapping("close_session.json")) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypeUsernamePasswordMFA cfg.SingleAuthenticationPrompt = singleAuthenticationPrompt cfg.ClientRequestMfaToken = ConfigBoolTrue connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() credentialsStorage.deleteCredential(tokenSpec) errs := initPoolWithSizeAndReturnErrors(db, 20) if singleAuthenticationPrompt == ConfigBoolTrue { assertEqualE(t, len(errs), 0) } else { // most of for the one that actually retrieves MFA token should fail assertEqualE(t, len(errs), 19) } }) t.Run("starts without mfa token, first attempt fails, singleAuthenticationPrompt="+singleAuthenticationPrompt.String(), func(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("auth/mfa/parallel_login_first_fails_then_successful_flow.json"), newWiremockMapping("select1.json"), newWiremockMapping("close_session.json")) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypeUsernamePasswordMFA cfg.SingleAuthenticationPrompt = singleAuthenticationPrompt cfg.ClientRequestMfaToken = ConfigBoolTrue credentialsStorage.deleteCredential(tokenSpec) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() errs := initPoolWithSizeAndReturnErrors(db, 20) if singleAuthenticationPrompt == ConfigBoolTrue { assertEqualF(t, len(errs), 1) assertStringContainsE(t, errs[0].Error(), "MFA with TOTP is required") } else { assertEqualE(t, len(errs), 19) } }) } } func TestUnitAuthenticateWithConfigOkta(t *testing.T) { var err error sr := &snowflakeRestful{ Protocol: "https", Host: "abc.com", Port: 443, FuncPostAuthSAML: postAuthSAMLAuthSuccess, FuncPostAuthOKTA: postAuthOKTASuccess, FuncGetSSO: getSSOSuccess, FuncPostAuth: postAuthCheckSAMLResponse, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeOkta sc.cfg.OktaURL = &url.URL{ Scheme: "https", Host: "abc.com", } sc.rest = sr sc.ctx = context.Background() err = authenticateWithConfig(sc) assertNilE(t, err, "expected to have no error.") sr.FuncPostAuthSAML = postAuthSAMLError err = authenticateWithConfig(sc) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") assertEqualE(t, err.Error(), "failed to get SAML response") } func TestUnitAuthenticateWithExternalBrowserParallel(t *testing.T) { skipOnMissingHome(t) skipOnMac(t, "interactive keyring access not available on macOS runners") t.Run("no ID token cached", func(t *testing.T) { origSamlResponseProvider := defaultSamlResponseProvider defer func() { defaultSamlResponseProvider = origSamlResponseProvider }() defaultSamlResponseProvider = func() samlResponseProvider { return &nonInteractiveSamlResponseProvider{t: t} } wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/successful_flow.json"), newWiremockMapping("select1.json"), newWiremockMapping("close_session.json")) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypeExternalBrowser cfg.ClientStoreTemporaryCredential = ConfigBoolTrue connector := NewConnector(SnowflakeDriver{}, *cfg) credentialsStorage.deleteCredential(newIDTokenSpec(cfg.Host, cfg.User)) db := sql.OpenDB(connector) defer db.Close() runSmokeQuery(t, db) assertEqualE(t, credentialsStorage.getCredential(newIDTokenSpec(cfg.Host, cfg.User)), "test-id-token") }) t.Run("ID token cached", func(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/successful_flow.json"), newWiremockMapping("select1.json"), newWiremockMapping("close_session.json")) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypeExternalBrowser cfg.ClientStoreTemporaryCredential = ConfigBoolTrue connector := NewConnector(SnowflakeDriver{}, *cfg) credentialsStorage.setCredential(newIDTokenSpec(cfg.Host, cfg.User), "test-id-token") db := sql.OpenDB(connector) defer db.Close() runSmokeQuery(t, db) }) t.Run("first connection retrieves ID token, second request uses cached ID token", func(t *testing.T) { origSamlResponseProvider := defaultSamlResponseProvider defer func() { defaultSamlResponseProvider = origSamlResponseProvider }() defaultSamlResponseProvider = func() samlResponseProvider { return &nonInteractiveSamlResponseProvider{t: t} } wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/parallel_login_successful_flow.json"), newWiremockMapping("select1.json"), newWiremockMapping("close_session.json")) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypeExternalBrowser cfg.ClientStoreTemporaryCredential = ConfigBoolTrue connector := NewConnector(SnowflakeDriver{}, *cfg) credentialsStorage.deleteCredential(newIDTokenSpec(cfg.Host, cfg.User)) db := sql.OpenDB(connector) defer db.Close() conn1, err := db.Conn(context.Background()) assertNilF(t, err) defer conn1.Close() runSmokeQueryWithConn(t, conn1) conn2, err := db.Conn(context.Background()) assertNilF(t, err) defer conn2.Close() runSmokeQueryWithConn(t, conn2) }) t.Run("first connection retrieves ID token, remaining ones wait and reuse", func(t *testing.T) { origSamlResponseProvider := defaultSamlResponseProvider defer func() { defaultSamlResponseProvider = origSamlResponseProvider }() defaultSamlResponseProvider = func() samlResponseProvider { return &nonInteractiveSamlResponseProvider{t: t} } wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/parallel_login_successful_flow.json"), newWiremockMapping("select1.json"), newWiremockMapping("close_session.json")) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypeExternalBrowser cfg.ClientStoreTemporaryCredential = ConfigBoolTrue connector := NewConnector(SnowflakeDriver{}, *cfg) credentialsStorage.deleteCredential(newIDTokenSpec(cfg.Host, cfg.User)) db := sql.OpenDB(connector) defer db.Close() errs := initPoolWithSizeAndReturnErrors(db, 20) assertEqualE(t, len(errs), 0) }) t.Run("first connection fails, second retrieves ID token, remaining ones wait and reuse", func(t *testing.T) { origSamlResponseProvider := defaultSamlResponseProvider defer func() { defaultSamlResponseProvider = origSamlResponseProvider }() defaultSamlResponseProvider = func() samlResponseProvider { return &nonInteractiveSamlResponseProvider{t: t} } wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/parallel_login_first_fails_then_successful_flow.json"), newWiremockMapping("select1.json"), newWiremockMapping("close_session.json")) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypeExternalBrowser cfg.ClientStoreTemporaryCredential = ConfigBoolTrue connector := NewConnector(SnowflakeDriver{}, *cfg) credentialsStorage.deleteCredential(newIDTokenSpec(cfg.Host, cfg.User)) db := sql.OpenDB(connector) defer db.Close() errs := initPoolWithSizeAndReturnErrors(db, 20) assertEqualE(t, len(errs), 1) }) } func TestUnitAuthenticateWithConfigExternalBrowserWithFailedSAMLResponse(t *testing.T) { var err error sr := &snowflakeRestful{ FuncPostAuthSAML: postAuthSAMLError, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeExternalBrowser sc.cfg.ExternalBrowserTimeout = time.Duration(sfconfig.DefaultExternalBrowserTimeout) sc.rest = sr sc.ctx = context.Background() err = authenticateWithConfig(sc) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") assertEqualE(t, err.Error(), "failed to get SAML response") } func TestUnitAuthenticateExternalBrowser(t *testing.T) { var err error sr := &snowflakeRestful{ FuncPostAuth: postAuthCheckExternalBrowser, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeExternalBrowser sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue sc.rest = sr _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckExternalBrowserToken sc.idToken = "mockedIDToken" _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckExternalBrowserFailed _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed") } } // To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled // Set any other snowflake_test variables needed for database, schema, role for this user func TestUsernamePasswordMfaCaching(t *testing.T) { t.Skip("manual test for MFA token caching") config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") } // connect with MFA authentication user := os.Getenv("SNOWFLAKE_TEST_MFA_USER") password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD") config.User = user config.Password = password config.Authenticator = AuthTypeUsernamePasswordMFA if runtime.GOOS == "linux" { config.ClientRequestMfaToken = ConfigBoolTrue } connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) for range 3 { // should only be prompted to authenticate first time around. _, err := db.Query("select current_user()") if err != nil { t.Fatal(err) } } } func TestUsernamePasswordMfaCachingWithPasscode(t *testing.T) { t.Skip("manual test for MFA token caching") config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") } // connect with MFA authentication user := os.Getenv("SNOWFLAKE_TEST_MFA_USER") password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD") config.User = user config.Password = password config.Passcode = "" // fill with your passcode from DUO app config.Authenticator = AuthTypeUsernamePasswordMFA if runtime.GOOS == "linux" { config.ClientRequestMfaToken = ConfigBoolTrue } connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) for range 3 { // should only be prompted to authenticate first time around. _, err := db.Query("select current_user()") if err != nil { t.Fatal(err) } } } func TestUsernamePasswordMfaCachingWithPasscodeInPassword(t *testing.T) { t.Skip("manual test for MFA token caching") config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") } // connect with MFA authentication user := os.Getenv("SNOWFLAKE_TEST_MFA_USER") password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD") config.User = user config.Password = password + "" // fill with your passcode from DUO app config.PasscodeInPassword = true connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) for range 3 { // should only be prompted to authenticate first time around. _, err := db.Query("select current_user()") if err != nil { t.Fatal(err) } } } // To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled // Set any other snowflake_test variables needed for database, schema, role for this user func TestDisableUsernamePasswordMfaCaching(t *testing.T) { t.Skip("manual test for disabling MFA token caching") config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") } // connect with MFA authentication user := os.Getenv("SNOWFLAKE_TEST_MFA_USER") password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD") config.User = user config.Password = password config.Authenticator = AuthTypeUsernamePasswordMFA // disable MFA token caching config.ClientRequestMfaToken = ConfigBoolFalse connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) for range 3 { // should be prompted to authenticate 3 times. _, err := db.Query("select current_user()") if err != nil { t.Fatal(err) } } } // To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user // Set any other snowflake_test variables needed for database, schema, role for this user func TestExternalBrowserCaching(t *testing.T) { t.Skip("manual test for external browser token caching") config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") } // connect with external browser authentication user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER") config.User = user config.Authenticator = AuthTypeExternalBrowser if runtime.GOOS == "linux" { config.ClientStoreTemporaryCredential = ConfigBoolTrue } connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) for range 3 { // should only be prompted to authenticate first time around. _, err := db.Query("select current_user()") if err != nil { t.Fatal(err) } } } // To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user // Set any other snowflake_test variables needed for database, schema, role for this user func TestDisableExternalBrowserCaching(t *testing.T) { t.Skip("manual test for disabling external browser token caching") config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") } // connect with external browser authentication user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER") config.User = user config.Authenticator = AuthTypeExternalBrowser // disable external browser token caching config.ClientStoreTemporaryCredential = ConfigBoolFalse connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) for range 3 { // should be prompted to authenticate 3 times. _, err := db.Query("select current_user()") if err != nil { t.Fatal(err) } } } func TestOktaRetryWithNewToken(t *testing.T) { expectedMasterToken := "m" expectedToken := "t" expectedMfaToken := "mockedMfaToken" expectedDatabaseName := "dbn" sr := &snowflakeRestful{ Protocol: "https", Host: "abc.com", Port: 443, FuncPostAuthSAML: postAuthSAMLAuthSuccess, FuncPostAuthOKTA: postAuthOKTASuccess, FuncGetSSO: getSSOSuccess, FuncPostAuth: restfulTestWrapper{t: t}.postAuthOktaWithNewToken, TokenAccessor: getSimpleTokenAccessor(), } sc := getDefaultSnowflakeConn() sc.cfg.Authenticator = AuthTypeOkta sc.cfg.OktaURL = &url.URL{ Scheme: "https", Host: "abc.com", } sc.rest = sr sc.ctx = context.Background() authResponse, err := authenticate(context.Background(), sc, []byte{0x12, 0x34}, []byte{0x56, 0x78}) assertNilF(t, err, "should not have failed to run authenticate()") assertEqualF(t, authResponse.MasterToken, expectedMasterToken) assertEqualF(t, authResponse.Token, expectedToken) assertEqualF(t, authResponse.MfaToken, expectedMfaToken) assertEqualF(t, authResponse.SessionInfo.DatabaseName, expectedDatabaseName) } func TestContextPropagatedToAuthWhenUsingOpen(t *testing.T) { db, err := sql.Open("snowflake", dsn) assertNilF(t, err) defer db.Close() ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) _, err = db.QueryContext(ctx, "SELECT 1") assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "context deadline exceeded") cancel() } func TestContextPropagatedToAuthWhenUsingOpenDB(t *testing.T) { cfg, err := ParseDSN(dsn) assertNilF(t, err) connector := NewConnector(&SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) _, err = db.QueryContext(ctx, "SELECT 1") assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "context deadline exceeded") cancel() } func TestPatSuccessfulFlow(t *testing.T) { cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypePat cfg.Token = "some PAT" wiremock.registerMappings(t, wiremockMapping{filePath: "auth/pat/successful_flow.json"}, wiremockMapping{filePath: "select1.json"}, ) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) rows, err := db.Query("SELECT 1") assertNilF(t, err) var v int assertTrueE(t, rows.Next()) assertNilF(t, rows.Scan(&v)) assertEqualE(t, v, 1) } func TestPatTokenRotation(t *testing.T) { dir := t.TempDir() tokenFilePath := filepath.Join(dir, "tokenFile") assertNilF(t, os.WriteFile(tokenFilePath, []byte("some PAT"), 0644)) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypePat cfg.TokenFilePath = tokenFilePath wiremock.registerMappings(t, wiremockMapping{filePath: "auth/pat/reading_fresh_token.json"}, ) connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) _, err := db.Conn(context.Background()) assertNilF(t, err) assertNilF(t, os.WriteFile(tokenFilePath, []byte("some PAT 2"), 0644)) _, err = db.Conn(context.Background()) assertNilF(t, err) } func TestPatInvalidToken(t *testing.T) { wiremock.registerMappings(t, wiremockMapping{filePath: "auth/pat/invalid_token.json"}, ) cfg := wiremock.connectionConfig() cfg.Authenticator = AuthTypePat cfg.Token = "some PAT" connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) _, err := db.Query("SELECT 1") assertNotNilF(t, err) var se *SnowflakeError assertErrorsAsF(t, err, &se) assertEqualE(t, se.Number, 394400) assertEqualE(t, se.Message, "Programmatic access token is invalid.") } func TestWithOauthAuthorizationCodeFlowManual(t *testing.T) { t.Skip("manual test") for _, provider := range []string{"OKTA", "SNOWFLAKE"} { t.Run(provider, func(t *testing.T) { cfg, err := GetConfigFromEnv([]*ConfigParam{ {Name: "OAuthClientId", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_CLIENT_ID", FailOnMissing: true}, {Name: "OAuthClientSecret", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_CLIENT_SECRET", FailOnMissing: true}, {Name: "OAuthAuthorizationURL", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_AUTHORIZATION_URL", FailOnMissing: false}, {Name: "OAuthTokenRequestURL", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_TOKEN_REQUEST_URL", FailOnMissing: false}, {Name: "OAuthRedirectURI", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_REDIRECT_URI", FailOnMissing: false}, {Name: "OAuthScope", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_SCOPE", FailOnMissing: false}, {Name: "User", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_USER", FailOnMissing: true}, {Name: "Role", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_ROLE", FailOnMissing: true}, {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, }) assertNilF(t, err) cfg.Authenticator = AuthTypeOAuthAuthorizationCode tokenRequestURL := cmp.Or(cfg.OauthTokenRequestURL, fmt.Sprintf("https://%v.snowflakecomputing.com:443/oauth/token-request", cfg.Account)) credentialsStorage.deleteCredential(newOAuthAccessTokenSpec(tokenRequestURL, cfg.User)) credentialsStorage.deleteCredential(newOAuthRefreshTokenSpec(tokenRequestURL, cfg.User)) connector := NewConnector(&SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() conn1, err := db.Conn(context.Background()) assertNilF(t, err) defer conn1.Close() runSmokeQueryWithConn(t, conn1) conn2, err := db.Conn(context.Background()) assertNilF(t, err) defer conn2.Close() runSmokeQueryWithConn(t, conn2) credentialsStorage.setCredential(newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User), "expired-token") conn3, err := db.Conn(context.Background()) assertNilF(t, err) defer conn3.Close() runSmokeQueryWithConn(t, conn3) }) } } func TestWithOAuthClientCredentialsFlowManual(t *testing.T) { t.Skip("manual test") cfg, err := GetConfigFromEnv([]*ConfigParam{ {Name: "OAuthClientId", EnvName: "SNOWFLAKE_TEST_OAUTH_OKTA_CLIENT_ID", FailOnMissing: true}, {Name: "OAuthClientSecret", EnvName: "SNOWFLAKE_TEST_OAUTH_OKTA_CLIENT_SECRET", FailOnMissing: true}, {Name: "OAuthTokenRequestURL", EnvName: "SNOWFLAKE_TEST_OAUTH_OKTA_TOKEN_REQUEST_URL", FailOnMissing: true}, {Name: "Role", EnvName: "SNOWFLAKE_TEST_OAUTH_OKTA_ROLE", FailOnMissing: true}, {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, }) assertNilF(t, err) cfg.Authenticator = AuthTypeOAuthClientCredentials connector := NewConnector(&SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() runSmokeQuery(t, db) } ================================================ FILE: auth_wif.go ================================================ package gosnowflake import ( "bytes" "context" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "os" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/golang-jwt/jwt/v5" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" ) const ( awsWif wifProviderType = "AWS" gcpWif wifProviderType = "GCP" azureWif wifProviderType = "AZURE" oidcWif wifProviderType = "OIDC" gcpMetadataFlavorHeaderName = "Metadata-Flavor" gcpMetadataFlavor = "Google" defaultMetadataServiceBase = "http://169.254.169.254" defaultGcpIamCredentialsBase = "https://iamcredentials.googleapis.com" snowflakeAudience = "snowflakecomputing.com" ) type wifProviderType string type wifAttestation struct { ProviderType string `json:"providerType"` Credential string `json:"credential"` Metadata map[string]string `json:"metadata"` } type wifAttestationCreator interface { createAttestation() (*wifAttestation, error) } type wifAttestationProvider struct { context context.Context cfg *Config awsCreator wifAttestationCreator gcpCreator wifAttestationCreator azureCreator wifAttestationCreator oidcCreator wifAttestationCreator } func createWifAttestationProvider(ctx context.Context, cfg *Config, telemetry *snowflakeTelemetry) *wifAttestationProvider { return &wifAttestationProvider{ context: ctx, cfg: cfg, awsCreator: &awsIdentityAttestationCreator{ cfg: cfg, attestationServiceFactory: createDefaultAwsAttestationMetadataProvider, ctx: ctx, }, gcpCreator: &gcpIdentityAttestationCreator{ cfg: cfg, telemetry: telemetry, metadataServiceBaseURL: defaultMetadataServiceBase, iamCredentialsURL: defaultGcpIamCredentialsBase, }, azureCreator: &azureIdentityAttestationCreator{ azureAttestationMetadataProvider: &defaultAzureAttestationMetadataProvider{}, cfg: cfg, telemetry: telemetry, workloadIdentityEntraResource: determineEntraResource(cfg), azureMetadataServiceBaseURL: defaultMetadataServiceBase, }, oidcCreator: &oidcIdentityAttestationCreator{token: func() (string, error) { return sfconfig.GetToken(cfg) }}, } } func (p *wifAttestationProvider) getAttestation(identityProvider string) (*wifAttestation, error) { switch strings.ToUpper(identityProvider) { case string(awsWif): return p.awsCreator.createAttestation() case string(gcpWif): return p.gcpCreator.createAttestation() case string(azureWif): return p.azureCreator.createAttestation() case string(oidcWif): return p.oidcCreator.createAttestation() default: return nil, fmt.Errorf("unknown WorkloadIdentityProvider specified: %s. Valid values are: %s, %s, %s, %s", identityProvider, awsWif, gcpWif, azureWif, oidcWif) } } type awsAttestastationMetadataProviderFactory func(ctx context.Context, cfg *Config) awsAttestationMetadataProvider type awsIdentityAttestationCreator struct { cfg *Config attestationServiceFactory awsAttestastationMetadataProviderFactory ctx context.Context } type gcpIdentityAttestationCreator struct { cfg *Config telemetry *snowflakeTelemetry metadataServiceBaseURL string iamCredentialsURL string } type oidcIdentityAttestationCreator struct { token func() (string, error) } type awsAttestationMetadataProvider interface { awsCredentials() (aws.Credentials, error) awsCredentialsViaRoleChaining() (aws.Credentials, error) awsRegion() string } type defaultAwsAttestationMetadataProvider struct { ctx context.Context cfg *Config awsCfg aws.Config } func createDefaultAwsAttestationMetadataProvider(ctx context.Context, cfg *Config) awsAttestationMetadataProvider { awsCfg, err := config.LoadDefaultConfig(ctx, config.WithEC2IMDSRegion()) if err != nil { logger.Debugf("Unable to load AWS config: %v", err) return nil } return &defaultAwsAttestationMetadataProvider{ awsCfg: awsCfg, cfg: cfg, ctx: ctx, } } func (s *defaultAwsAttestationMetadataProvider) awsCredentials() (aws.Credentials, error) { return s.awsCfg.Credentials.Retrieve(s.ctx) } func (s *defaultAwsAttestationMetadataProvider) awsCredentialsViaRoleChaining() (aws.Credentials, error) { creds, err := s.awsCredentials() if err != nil { return aws.Credentials{}, err } for _, roleArn := range s.cfg.WorkloadIdentityImpersonationPath { if creds, err = s.assumeRole(creds, roleArn); err != nil { return aws.Credentials{}, err } } return creds, nil } func (s *defaultAwsAttestationMetadataProvider) assumeRole(creds aws.Credentials, roleArn string) (aws.Credentials, error) { logger.Debugf("assuming role %v", roleArn) awsCfg := s.awsCfg awsCfg.Credentials = credentials.StaticCredentialsProvider{Value: creds} awsCfg.Region = s.awsRegion() stsClient := sts.NewFromConfig(awsCfg) role, err := stsClient.AssumeRole(s.ctx, &sts.AssumeRoleInput{ RoleArn: aws.String(roleArn), RoleSessionName: aws.String("identity-federation-session"), }) if err != nil { logger.Debugf("failed to assume role %v: %v", roleArn, err) return aws.Credentials{}, err } return aws.Credentials{ AccessKeyID: *role.Credentials.AccessKeyId, SecretAccessKey: *role.Credentials.SecretAccessKey, SessionToken: *role.Credentials.SessionToken, Expires: *role.Credentials.Expiration, }, nil } func (s *defaultAwsAttestationMetadataProvider) awsRegion() string { return s.awsCfg.Region } func (c *awsIdentityAttestationCreator) createAttestation() (*wifAttestation, error) { logger.Debug("Creating AWS identity attestation...") attestationService := c.attestationServiceFactory(c.ctx, c.cfg) if attestationService == nil { return nil, errors.New("AWS attestation service could not be created") } var creds aws.Credentials var err error if len(c.cfg.WorkloadIdentityImpersonationPath) == 0 { if creds, err = attestationService.awsCredentials(); err != nil { logger.Debugf("error while getting for aws credentials. %v", err) return nil, err } } else { if creds, err = attestationService.awsCredentialsViaRoleChaining(); err != nil { logger.Debugf("error while getting for aws credentials via role chaining. %v", err) return nil, err } } if creds.AccessKeyID == "" || creds.SecretAccessKey == "" { return nil, fmt.Errorf("no AWS credentials were found") } region := attestationService.awsRegion() if region == "" { return nil, fmt.Errorf("no AWS region was found") } stsHostname := stsHostname(region) req, err := c.createStsRequest(stsHostname) if err != nil { return nil, err } err = c.signRequestWithSigV4(c.ctx, req, creds, region) if err != nil { return nil, err } credential, err := c.createBase64EncodedRequestCredential(req) if err != nil { return nil, err } return &wifAttestation{ ProviderType: string(awsWif), Credential: credential, Metadata: map[string]string{}, }, nil } func stsHostname(region string) string { var domain string if strings.HasPrefix(region, "cn-") { domain = "amazonaws.com.cn" } else { domain = "amazonaws.com" } return fmt.Sprintf("sts.%s.%s", region, domain) } func (c *awsIdentityAttestationCreator) createStsRequest(hostname string) (*http.Request, error) { url := fmt.Sprintf("https://%s?Action=GetCallerIdentity&Version=2011-06-15", hostname) req, err := http.NewRequest("POST", url, nil) if err != nil { return nil, err } req.Header.Set("Host", hostname) req.Header.Set("X-Snowflake-Audience", "snowflakecomputing.com") return req, nil } func (c *awsIdentityAttestationCreator) signRequestWithSigV4(ctx context.Context, req *http.Request, creds aws.Credentials, region string) error { signer := v4.NewSigner() // as per docs of SignHTTP, the payload hash must be present even if the payload is empty payloadHash := hex.EncodeToString(sha256.New().Sum(nil)) return signer.SignHTTP(ctx, creds, req, payloadHash, "sts", region, time.Now()) } func (c *awsIdentityAttestationCreator) createBase64EncodedRequestCredential(req *http.Request) (string, error) { headers := make(map[string]string) for key, values := range req.Header { headers[key] = values[0] } assertion := map[string]any{ "url": req.URL.String(), "method": req.Method, "headers": headers, } assertionJSON, err := json.Marshal(assertion) if err != nil { return "", err } return base64.StdEncoding.EncodeToString(assertionJSON), nil } func (c *gcpIdentityAttestationCreator) createAttestation() (*wifAttestation, error) { logger.Debugf("Creating GCP identity attestation...") if len(c.cfg.WorkloadIdentityImpersonationPath) == 0 { return c.createGcpIdentityTokenFromMetadataService() } return c.createGcpIdentityViaImpersonation() } func (c *gcpIdentityAttestationCreator) createGcpIdentityTokenFromMetadataService() (*wifAttestation, error) { req, err := c.createTokenRequest() if err != nil { return nil, fmt.Errorf("failed to create GCP token request: %w", err) } token := fetchTokenFromMetadataService(req, c.cfg, c.telemetry) if token == "" { return nil, fmt.Errorf("no GCP token was found") } sub, _, err := extractSubIssWithoutVerifyingSignature(token) if err != nil { return nil, fmt.Errorf("could not extract claims from token: %v", err) } return &wifAttestation{ ProviderType: string(gcpWif), Credential: token, Metadata: map[string]string{"sub": sub}, }, nil } func (c *gcpIdentityAttestationCreator) createTokenRequest() (*http.Request, error) { uri := fmt.Sprintf("%s/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s", c.metadataServiceBaseURL, snowflakeAudience) req, err := http.NewRequest("GET", uri, nil) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %v", err) } req.Header.Set(gcpMetadataFlavorHeaderName, gcpMetadataFlavor) return req, nil } func (c *gcpIdentityAttestationCreator) createGcpIdentityViaImpersonation() (*wifAttestation, error) { // initialize transport transport, err := newTransportFactory(c.cfg, c.telemetry).createTransport(transportConfigFor(transportTypeWIF)) if err != nil { logger.Debugf("Failed to create HTTP transport: %v", err) return nil, err } client := &http.Client{Transport: transport} // fetch access token for impersonation accessToken, err := c.fetchServiceToken(client) if err != nil { return nil, err } // map paths to full service account paths var fullServiceAccountPaths []string for _, path := range c.cfg.WorkloadIdentityImpersonationPath { fullServiceAccountPaths = append(fullServiceAccountPaths, fmt.Sprintf("projects/-/serviceAccounts/%s", path)) } targetServiceAccount := fullServiceAccountPaths[len(fullServiceAccountPaths)-1] delegates := fullServiceAccountPaths[:len(fullServiceAccountPaths)-1] // fetch impersonated token impersonationToken, err := c.fetchImpersonatedToken(targetServiceAccount, delegates, accessToken, client) if err != nil { return nil, err } // create attestation sub, _, err := extractSubIssWithoutVerifyingSignature(impersonationToken) if err != nil { return nil, fmt.Errorf("could not extract claims from token: %v", err) } return &wifAttestation{ ProviderType: string(gcpWif), Credential: impersonationToken, Metadata: map[string]string{"sub": sub}, }, nil } func (c *gcpIdentityAttestationCreator) fetchServiceToken(client *http.Client) (string, error) { // initialize and do request req, err := http.NewRequest("GET", c.metadataServiceBaseURL+"/computeMetadata/v1/instance/service-accounts/default/token", nil) if err != nil { logger.Debugf("cannot create token request for impersonation. %v", err) return "", err } req.Header.Set(gcpMetadataFlavorHeaderName, gcpMetadataFlavor) resp, err := client.Do(req) if err != nil { logger.Debugf("cannot fetch token for impersonation. %v", err) return "", err } defer func(body io.ReadCloser) { if err = body.Close(); err != nil { logger.Debugf("cannot close token response body for impersonation. %v", err) } }(resp.Body) // if it is not 200, do not parse the response if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("token response status is %v, not parsing", resp.StatusCode) } // parse response and extract access token accessTokenResponse := struct { AccessToken string `json:"access_token"` }{} if err = json.NewDecoder(resp.Body).Decode(&accessTokenResponse); err != nil { logger.Debugf("cannot decode token for impersonation. %v", err) return "", err } accessToken := accessTokenResponse.AccessToken return accessToken, nil } func (c *gcpIdentityAttestationCreator) fetchImpersonatedToken(targetServiceAccount string, delegates []string, accessToken string, client *http.Client) (string, error) { // prepare the request url := fmt.Sprintf("%v/v1/%v:generateIdToken", c.iamCredentialsURL, targetServiceAccount) body := struct { Delegates []string `json:"delegates,omitempty"` Audience string `json:"audience"` }{ Delegates: delegates, Audience: snowflakeAudience, } payload := new(bytes.Buffer) if err := json.NewEncoder(payload).Encode(body); err != nil { logger.Debugf("cannot encode impersonation request body. %v", err) return "", err } req, err := http.NewRequest("POST", url, payload) if err != nil { logger.Debugf("cannot create token request for impersonation. %v", err) return "", err } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Content-Type", "application/json") // send the request resp, err := client.Do(req) if err != nil { logger.Debugf("cannot call impersonation service. %v", err) return "", err } defer func(body io.ReadCloser) { if err = body.Close(); err != nil { logger.Debugf("cannot close token response body for impersonation. %v", err) } }(resp.Body) // handle the response if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("response status is %v, not parsing", resp.StatusCode) } tokenResponse := struct { Token string `json:"token"` }{} if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { logger.Debugf("cannot decode token response. %v", err) return "", err } return tokenResponse.Token, nil } func fetchTokenFromMetadataService(req *http.Request, cfg *Config, telemetry *snowflakeTelemetry) string { transport, err := newTransportFactory(cfg, telemetry).createTransport(transportConfigFor(transportTypeWIF)) if err != nil { logger.Debugf("Failed to create HTTP transport: %v", err) return "" } client := &http.Client{Transport: transport} resp, err := client.Do(req) if err != nil { logger.Debugf("Metadata server request was not successful: %v", err) return "" } defer func() { if err = resp.Body.Close(); err != nil { logger.Debugf("Failed to close response body: %v", err) } }() body, err := io.ReadAll(resp.Body) if err != nil { logger.Debugf("Failed to read response body: %v", err) return "" } return string(body) } func extractSubIssWithoutVerifyingSignature(token string) (subject string, issuer string, err error) { claims, err := extractClaimsMap(token) if err != nil { return "", "", err } issuerClaim, ok := claims["iss"] if !ok { return "", "", errors.New("missing issuer claim in JWT token") } subjectClaim, ok := claims["sub"] if !ok { return "", "", errors.New("missing sub claim in JWT token") } subject, ok = subjectClaim.(string) if !ok { return "", "", errors.New("sub claim is not a string in JWT token") } issuer, ok = issuerClaim.(string) if !ok { return "", "", errors.New("iss claim is not a string in JWT token") } return } // extractClaimsMap parses a JWT token and returns its claims as a map. // It does not verify the token signature. func extractClaimsMap(token string) (map[string]any, error) { parser := jwt.NewParser(jwt.WithoutClaimsValidation()) claims := jwt.MapClaims{} _, _, err := parser.ParseUnverified(token, claims) if err != nil { return nil, fmt.Errorf("unable to extract JWT claims from token: %w", err) } return claims, nil } func (c *oidcIdentityAttestationCreator) createAttestation() (*wifAttestation, error) { logger.Debugf("Creating OIDC identity attestation...") token, err := c.token() if err != nil { return nil, fmt.Errorf("failed to get OIDC token: %w", err) } if token == "" { return nil, fmt.Errorf("no OIDC token was specified") } sub, iss, err := extractSubIssWithoutVerifyingSignature(token) if err != nil { return nil, err } if sub == "" || iss == "" { return nil, errors.New("missing sub or iss claim in JWT token") } return &wifAttestation{ ProviderType: string(oidcWif), Credential: token, Metadata: map[string]string{"sub": sub}, }, nil } // azureAttestationMetadataProvider defines the interface for Azure attestation services type azureAttestationMetadataProvider interface { identityEndpoint() string identityHeader() string clientID() string } type defaultAzureAttestationMetadataProvider struct{} func (p *defaultAzureAttestationMetadataProvider) identityEndpoint() string { return os.Getenv("IDENTITY_ENDPOINT") } func (p *defaultAzureAttestationMetadataProvider) identityHeader() string { return os.Getenv("IDENTITY_HEADER") } func (p *defaultAzureAttestationMetadataProvider) clientID() string { return os.Getenv("MANAGED_IDENTITY_CLIENT_ID") } type azureIdentityAttestationCreator struct { azureAttestationMetadataProvider azureAttestationMetadataProvider cfg *Config telemetry *snowflakeTelemetry workloadIdentityEntraResource string azureMetadataServiceBaseURL string } // createAttestation creates an attestation using Azure identity func (a *azureIdentityAttestationCreator) createAttestation() (*wifAttestation, error) { logger.Debug("Creating Azure identity attestation...") identityEndpoint := a.azureAttestationMetadataProvider.identityEndpoint() var request *http.Request var err error if identityEndpoint == "" { request, err = a.azureVMIdentityRequest() if err != nil { return nil, fmt.Errorf("failed to create Azure VM identity request: %v", err) } } else { identityHeader := a.azureAttestationMetadataProvider.identityHeader() if identityHeader == "" { return nil, fmt.Errorf("managed identity is not enabled on this Azure function") } request, err = a.azureFunctionsIdentityRequest( identityEndpoint, identityHeader, a.azureAttestationMetadataProvider.clientID(), ) if err != nil { return nil, fmt.Errorf("failed to create Azure Functions identity request: %v", err) } } tokenJSON := fetchTokenFromMetadataService(request, a.cfg, a.telemetry) if tokenJSON == "" { return nil, fmt.Errorf("could not fetch Azure token") } token, err := extractTokenFromJSON(tokenJSON) if err != nil { return nil, fmt.Errorf("failed to extract token from JSON: %v", err) } if token == "" { return nil, fmt.Errorf("no access token found in Azure response") } sub, iss, err := extractSubIssWithoutVerifyingSignature(token) if err != nil { return nil, fmt.Errorf("failed to extract sub and iss claims from token: %v", err) } if sub == "" || iss == "" { return nil, fmt.Errorf("missing sub or iss claim in JWT token") } return &wifAttestation{ ProviderType: string(azureWif), Credential: token, Metadata: map[string]string{"sub": sub, "iss": iss}, }, nil } func determineEntraResource(config *Config) string { if config != nil && config.WorkloadIdentityEntraResource != "" { return config.WorkloadIdentityEntraResource } // default resource if none specified return "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } func extractTokenFromJSON(tokenJSON string) (string, error) { var response struct { AccessToken string `json:"access_token"` } err := json.Unmarshal([]byte(tokenJSON), &response) if err != nil { return "", err } return response.AccessToken, nil } func (a *azureIdentityAttestationCreator) azureFunctionsIdentityRequest(identityEndpoint, identityHeader, managedIdentityClientID string) (*http.Request, error) { queryParams := fmt.Sprintf("api-version=2019-08-01&resource=%s", a.workloadIdentityEntraResource) if managedIdentityClientID != "" { queryParams += fmt.Sprintf("&client_id=%s", managedIdentityClientID) } url := fmt.Sprintf("%s?%s", identityEndpoint, queryParams) req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } req.Header.Add("X-IDENTITY-HEADER", identityHeader) return req, nil } func (a *azureIdentityAttestationCreator) azureVMIdentityRequest() (*http.Request, error) { urlWithoutQuery := a.azureMetadataServiceBaseURL + "/metadata/identity/oauth2/token?" queryParams := fmt.Sprintf("api-version=2018-02-01&resource=%s", a.workloadIdentityEntraResource) url := urlWithoutQuery + queryParams req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } req.Header.Add("Metadata", "true") return req, nil } ================================================ FILE: auth_wif_test.go ================================================ package gosnowflake import ( "context" "database/sql" "encoding/base64" "encoding/json" "errors" "fmt" "os" "os/exec" "strings" "testing" "github.com/aws/aws-sdk-go-v2/aws" ) type mockWifAttestationCreator struct { providerType wifProviderType returnError error } func (m *mockWifAttestationCreator) createAttestation() (*wifAttestation, error) { if m.returnError != nil { return nil, m.returnError } return &wifAttestation{ ProviderType: string(m.providerType), }, nil } func TestGetAttestation(t *testing.T) { awsError := errors.New("aws attestation error") gcpError := errors.New("gcp attestation error") azureError := errors.New("azure attestation error") oidcError := errors.New("oidc attestation error") provider := &wifAttestationProvider{ context: context.Background(), awsCreator: &mockWifAttestationCreator{providerType: awsWif}, gcpCreator: &mockWifAttestationCreator{providerType: gcpWif}, azureCreator: &mockWifAttestationCreator{providerType: azureWif}, oidcCreator: &mockWifAttestationCreator{providerType: oidcWif}, } providerWithErrors := &wifAttestationProvider{ context: context.Background(), awsCreator: &mockWifAttestationCreator{providerType: awsWif, returnError: awsError}, gcpCreator: &mockWifAttestationCreator{providerType: gcpWif, returnError: gcpError}, azureCreator: &mockWifAttestationCreator{providerType: azureWif, returnError: azureError}, oidcCreator: &mockWifAttestationCreator{providerType: oidcWif, returnError: oidcError}, } tests := []struct { name string provider *wifAttestationProvider identityProvider string expectedResult *wifAttestation expectedError error }{ { name: "AWS success", provider: provider, identityProvider: "AWS", expectedResult: &wifAttestation{ProviderType: string(awsWif)}, expectedError: nil, }, { name: "AWS error", provider: providerWithErrors, identityProvider: "AWS", expectedResult: nil, expectedError: awsError, }, { name: "GCP success", provider: provider, identityProvider: "GCP", expectedResult: &wifAttestation{ProviderType: string(gcpWif)}, expectedError: nil, }, { name: "GCP error", provider: providerWithErrors, identityProvider: "GCP", expectedResult: nil, expectedError: gcpError, }, { name: "AZURE success", provider: provider, identityProvider: "AZURE", expectedResult: &wifAttestation{ProviderType: string(azureWif)}, expectedError: nil, }, { name: "AZURE error", provider: providerWithErrors, identityProvider: "AZURE", expectedResult: nil, expectedError: azureError, }, { name: "OIDC success", provider: provider, identityProvider: "OIDC", expectedResult: &wifAttestation{ProviderType: string(oidcWif)}, expectedError: nil, }, { name: "OIDC error", provider: providerWithErrors, identityProvider: "OIDC", expectedResult: nil, expectedError: oidcError, }, { name: "Unknown provider", provider: provider, identityProvider: "UNKNOWN", expectedResult: nil, expectedError: errors.New("unknown WorkloadIdentityProvider specified: UNKNOWN. Valid values are: AWS, GCP, AZURE, OIDC"), }, { name: "Empty provider", provider: provider, identityProvider: "", expectedResult: nil, expectedError: errors.New("unknown WorkloadIdentityProvider specified: . Valid values are: AWS, GCP, AZURE, OIDC"), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { attestation, err := test.provider.getAttestation(test.identityProvider) if test.expectedError != nil { assertNilE(t, attestation) assertNotNilF(t, err) assertEqualE(t, test.expectedError.Error(), err.Error()) } else if test.expectedResult != nil { assertNilE(t, err) assertNotNilF(t, attestation) assertEqualE(t, test.expectedResult.ProviderType, attestation.ProviderType) } else { t.Fatal("test case must specify either expectedError or expectedResult") } }) } } func TestAwsIdentityAttestationCreator(t *testing.T) { tests := []struct { name string config Config attestationSvc awsAttestationMetadataProvider expectedError error expectedProvider string expectedStsHost string }{ { name: "No attestation service", attestationSvc: nil, expectedError: fmt.Errorf("AWS attestation service could not be created"), }, { name: "No AWS credentials", attestationSvc: &mockAwsAttestationMetadataProvider{ creds: aws.Credentials{}, region: "us-west-2", }, expectedError: fmt.Errorf("no AWS credentials were found"), }, { name: "No AWS region", attestationSvc: &mockAwsAttestationMetadataProvider{ creds: mockCreds, region: "", }, expectedError: fmt.Errorf("no AWS region was found"), }, { name: "Successful attestation", attestationSvc: &mockAwsAttestationMetadataProvider{ creds: mockCreds, region: "us-west-2", }, expectedProvider: "AWS", expectedStsHost: "sts.us-west-2.amazonaws.com", }, { name: "Successful attestation for CN region", attestationSvc: &mockAwsAttestationMetadataProvider{ creds: mockCreds, region: "cn-northwest-1", }, expectedProvider: "AWS", expectedStsHost: "sts.cn-northwest-1.amazonaws.com.cn", }, { name: "Successful attestation with single role chaining", config: Config{ WorkloadIdentityImpersonationPath: []string{"arn:aws:iam::123456789012:role/test-role"}, }, attestationSvc: &mockAwsAttestationMetadataProvider{ creds: mockCreds, chainingCreds: aws.Credentials{ AccessKeyID: "chainedAccessKey", SecretAccessKey: "chainedSecretKey", SessionToken: "chainedSessionToken", }, region: "us-east-1", useRoleChaining: true, }, expectedProvider: "AWS", expectedStsHost: "sts.us-east-1.amazonaws.com", }, { name: "Successful attestation with multiple role chaining", config: Config{ WorkloadIdentityImpersonationPath: []string{ "arn:aws:iam::123456789012:role/role1", "arn:aws:iam::123456789012:role/role2", "arn:aws:iam::123456789012:role/role3", }, }, attestationSvc: &mockAwsAttestationMetadataProvider{ creds: mockCreds, chainingCreds: aws.Credentials{ AccessKeyID: "finalRoleAccessKey", SecretAccessKey: "finalRoleSecretKey", SessionToken: "finalRoleSessionToken", }, region: "us-west-2", useRoleChaining: true, }, expectedProvider: "AWS", expectedStsHost: "sts.us-west-2.amazonaws.com", }, { name: "Role chaining with no credentials", config: Config{ WorkloadIdentityImpersonationPath: []string{"arn:aws:iam::123456789012:role/test-role"}, }, attestationSvc: &mockAwsAttestationMetadataProvider{ creds: aws.Credentials{}, region: "us-west-2", useRoleChaining: true, }, expectedError: fmt.Errorf("no AWS credentials were found"), }, { name: "Role chaining with no region", config: Config{ WorkloadIdentityImpersonationPath: []string{"arn:aws:iam::123456789012:role/test-role"}, }, attestationSvc: &mockAwsAttestationMetadataProvider{ creds: aws.Credentials{ AccessKeyID: "chainedAccessKey", SecretAccessKey: "chainedSecretKey", SessionToken: "chainedSessionToken", }, region: "", useRoleChaining: true, }, expectedError: fmt.Errorf("no AWS region was found"), }, { name: "Role chaining failure", config: Config{ WorkloadIdentityImpersonationPath: []string{"arn:aws:iam::123456789012:role/test-role"}, }, attestationSvc: &mockAwsAttestationMetadataProvider{ creds: mockCreds, region: "us-west-2", chainingError: fmt.Errorf("failed to assume role: AccessDenied"), useRoleChaining: true, }, expectedError: fmt.Errorf("failed to assume role: AccessDenied"), }, { name: "Cross-account role chaining", config: Config{ WorkloadIdentityImpersonationPath: []string{ "arn:aws:iam::111111111111:role/cross-account-role", "arn:aws:iam::222222222222:role/target-role", }, }, attestationSvc: &mockAwsAttestationMetadataProvider{ creds: mockCreds, chainingCreds: aws.Credentials{ AccessKeyID: "crossAccountAccessKey", SecretAccessKey: "crossAccountSecretKey", SessionToken: "crossAccountSessionToken", }, region: "us-east-1", useRoleChaining: true, }, expectedProvider: "AWS", expectedStsHost: "sts.us-east-1.amazonaws.com", }, { name: "Role chaining in CN region", config: Config{ WorkloadIdentityImpersonationPath: []string{"arn:aws-cn:iam::123456789012:role/cn-role"}, }, attestationSvc: &mockAwsAttestationMetadataProvider{ creds: mockCreds, chainingCreds: aws.Credentials{ AccessKeyID: "cnRoleAccessKey", SecretAccessKey: "cnRoleSecretKey", SessionToken: "cnRoleSessionToken", }, region: "cn-north-1", useRoleChaining: true, }, expectedProvider: "AWS", expectedStsHost: "sts.cn-north-1.amazonaws.com.cn", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { creator := &awsIdentityAttestationCreator{ attestationServiceFactory: func(ctx context.Context, cfg *Config) awsAttestationMetadataProvider { return test.attestationSvc }, ctx: context.Background(), cfg: &test.config, } attestation, err := creator.createAttestation() if test.expectedError != nil { assertNilF(t, attestation) assertNotNilE(t, err) assertEqualE(t, test.expectedError.Error(), err.Error()) } else { assertNilE(t, err) assertNotNilE(t, attestation) assertNotNilE(t, attestation.Credential) assertEqualE(t, test.expectedProvider, attestation.ProviderType) decoded, err := base64.StdEncoding.DecodeString(attestation.Credential) if err != nil { t.Fatalf("Failed to decode credential: %v", err) } var credentialMap map[string]any if err := json.Unmarshal(decoded, &credentialMap); err != nil { t.Fatalf("Failed to unmarshal credential JSON: %v", err) } assertEqualE(t, fmt.Sprintf("https://%s?Action=GetCallerIdentity&Version=2011-06-15", test.expectedStsHost), credentialMap["url"]) } }) } } type mockAwsAttestationMetadataProvider struct { creds aws.Credentials region string chainingCreds aws.Credentials chainingError error useRoleChaining bool } var mockCreds = aws.Credentials{ AccessKeyID: "mockAccessKey", SecretAccessKey: "mockSecretKey", SessionToken: "mockSessionToken", } func (m *mockAwsAttestationMetadataProvider) awsCredentials() (aws.Credentials, error) { return m.creds, nil } func (m *mockAwsAttestationMetadataProvider) awsCredentialsViaRoleChaining() (aws.Credentials, error) { if m.chainingError != nil { return aws.Credentials{}, m.chainingError } if m.chainingCreds.AccessKeyID != "" { return m.chainingCreds, nil } return m.creds, nil } func (m *mockAwsAttestationMetadataProvider) awsRegion() string { return m.region } func TestGcpIdentityAttestationCreator(t *testing.T) { tests := []struct { name string wiremockMappingPath string config Config expectedError error expectedSub string }{ { name: "Successful flow", wiremockMappingPath: "auth/wif/gcp/successful_flow.json", expectedError: nil, expectedSub: "some-subject", }, { name: "Successful impersonation flow", wiremockMappingPath: "auth/wif/gcp/successful_impersionation_flow.json", config: Config{ WorkloadIdentityImpersonationPath: []string{ "delegate1", "delegate2", "targetServiceAccount", }, }, expectedError: nil, expectedSub: "some-impersonated-subject", }, { name: "No GCP credential - http error", wiremockMappingPath: "auth/wif/gcp/http_error.json", expectedError: fmt.Errorf("no GCP token was found"), expectedSub: "", }, { name: "missing issuer claim", wiremockMappingPath: "auth/wif/gcp/missing_issuer_claim.json", expectedError: fmt.Errorf("could not extract claims from token: missing issuer claim in JWT token"), expectedSub: "", }, { name: "missing sub claim", wiremockMappingPath: "auth/wif/gcp/missing_sub_claim.json", expectedError: fmt.Errorf("could not extract claims from token: missing sub claim in JWT token"), expectedSub: "", }, { name: "unparsable token", wiremockMappingPath: "auth/wif/gcp/unparsable_token.json", expectedError: fmt.Errorf("could not extract claims from token: unable to extract JWT claims from token: token is malformed: token contains an invalid number of segments"), expectedSub: "", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { creator := &gcpIdentityAttestationCreator{ cfg: &test.config, metadataServiceBaseURL: wiremock.baseURL(), iamCredentialsURL: wiremock.baseURL(), } wiremock.registerMappings(t, wiremockMapping{filePath: test.wiremockMappingPath}) attestation, err := creator.createAttestation() if test.expectedError != nil { assertNilF(t, attestation) assertNotNilF(t, err) assertEqualE(t, test.expectedError.Error(), err.Error()) } else { assertNilF(t, err) assertNotNilF(t, attestation) assertEqualE(t, string(gcpWif), attestation.ProviderType) assertEqualE(t, test.expectedSub, attestation.Metadata["sub"]) } }) } } func TestOidcIdentityAttestationCreator(t *testing.T) { const ( /* * { * "sub": "some-subject", * "iat": 1743761213, * "exp": 1743764813, * "aud": "www.example.com" * } */ missingIssuerClaimToken = "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImU2M2I5NzA1OTRiY2NmZTAxMDlkOTg4OWM2MDk3OWEwIn0.eyJzdWIiOiJzb21lLXN1YmplY3QiLCJpYXQiOjE3NDM3NjEyMTMsImV4cCI6MTc0Mzc2NDgxMywiYXVkIjoid3d3LmV4YW1wbGUuY29tIn0.H6sN6kjA82EuijFcv-yCJTqau5qvVTCsk0ZQ4gvFQMkB7c71XPs4lkwTa7ZlNNlx9e6TpN1CVGnpCIRDDAZaDw" // pragma: allowlist secret /* * { * "iss": "https://accounts.google.com", * "iat": 1743761213, * "exp": 1743764813, * "aud": "www.example.com" * } */ missingSubClaimToken = "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImU2M2I5NzA1OTRiY2NmZTAxMDlkOTg4OWM2MDk3OWEwIn0.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJpYXQiOjE3NDM3NjEyMTMsImV4cCI6MTc0Mzc2NDgxMywiYXVkIjoid3d3LmV4YW1wbGUuY29tIn0.w0njdpfWFETVK8Ktq9GdvuKRQJjvhOplcSyvQ_zHHwBUSMapqO1bjEWBx5VhGkdECZIGS1VY7db_IOqT45yOMA" // pragma: allowlist secret /* * { * "iss": "https://oidc.eks.us-east-2.amazonaws.com/id/3B869BC5D12CEB5515358621D8085D58", * "iat": 1743692017, * "exp": 1775228014, * "aud": "www.example.com", * "sub": "system:serviceaccount:poc-namespace:oidc-sa" * } */ validToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMi5hbWF6b25hd3MuY29tL2lkLzNCODY5QkM1RDEyQ0VCNTUxNTM1ODYyMUQ4MDg1RDU4IiwiaWF0IjoxNzQ0Mjg3ODc4LCJleHAiOjE3NzU4MjM4NzgsImF1ZCI6Ind3dy5leGFtcGxlLmNvbSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpwb2MtbmFtZXNwYWNlOm9pZGMtc2EifQ.a8H6KRIF1XmM8lkqL6kR8ccInr7wAzQrbKd3ZHFgiEg" // pragma: allowlist secret unparsableToken = "unparsable_token" emptyToken = "" ) type testCase struct { name string token string expectedError error expectedSub string } tests := []testCase{ { name: "no token input", token: emptyToken, expectedError: fmt.Errorf("no OIDC token was specified"), }, { name: "valid token returns proper attestation", token: validToken, expectedError: nil, expectedSub: "system:serviceaccount:poc-namespace:oidc-sa", }, { name: "missing issuer returns error", token: missingIssuerClaimToken, expectedError: errors.New("missing issuer claim in JWT token"), }, { name: "missing sub returns error", token: missingSubClaimToken, expectedError: errors.New("missing sub claim in JWT token"), }, { name: "unparsable token returns error", token: unparsableToken, expectedError: errors.New("unable to extract JWT claims from token: token is malformed: token contains an invalid number of segments"), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { creator := &oidcIdentityAttestationCreator{token: func() (string, error) { return test.token, nil }} attestation, err := creator.createAttestation() if test.expectedError != nil { assertNotNilE(t, err) assertNilF(t, attestation) assertEqualE(t, test.expectedError.Error(), err.Error()) } else { assertNilE(t, err) assertNotNilE(t, attestation) assertEqualE(t, string(oidcWif), attestation.ProviderType) assertEqualE(t, test.expectedSub, attestation.Metadata["sub"]) } }) } } func TestAzureIdentityAttestationCreator(t *testing.T) { tests := []struct { name string wiremockMappingPath string metadataProvider *mockAzureAttestationMetadataProvider cfg *Config expectedIss string expectedError error }{ /* * { * "iss": "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/", * "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD" * } */ { name: "Successful flow", wiremockMappingPath: "auth/wif/azure/successful_flow_basic.json", metadataProvider: azureVMMetadataProvider(), expectedIss: "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/", expectedError: nil, }, /* * { * "iss": "https://login.microsoftonline.com/fa15d692-e9c7-4460-a743-29f29522229/", * "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD" * } */ { name: "Successful flow v2 issuer", wiremockMappingPath: "auth/wif/azure/successful_flow_v2_issuer.json", metadataProvider: azureVMMetadataProvider(), expectedIss: "https://login.microsoftonline.com/fa15d692-e9c7-4460-a743-29f29522229/", expectedError: nil, }, /* * { * "iss": "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/", * "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD" * } */ { name: "Successful flow azure functions", wiremockMappingPath: "auth/wif/azure/successful_flow_azure_functions.json", metadataProvider: azureFunctionsMetadataProvider(), expectedIss: "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/", expectedError: nil, }, /* * { * "iss": "https://login.microsoftonline.com/fa15d692-e9c7-4460-a743-29f29522229/", * "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD" * } */ { name: "Successful flow azure functions v2 issuer", wiremockMappingPath: "auth/wif/azure/successful_flow_azure_functions_v2_issuer.json", metadataProvider: azureFunctionsMetadataProvider(), expectedIss: "https://login.microsoftonline.com/fa15d692-e9c7-4460-a743-29f29522229/", expectedError: nil, }, /* * { * "iss": "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/", * "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD" * } */ { name: "Successful flow azure functions no client ID", wiremockMappingPath: "auth/wif/azure/successful_flow_azure_functions_no_client_id.json", metadataProvider: &mockAzureAttestationMetadataProvider{ identityEndpointValue: wiremock.baseURL() + "/metadata/identity/endpoint/from/env", identityHeaderValue: "some-identity-header-from-env", clientIDValue: "", }, expectedIss: "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/", expectedError: nil, }, /* * { * "iss": "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/", * "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD" * } */ { name: "Successful flow azure functions custom Entra resource", wiremockMappingPath: "auth/wif/azure/successful_flow_azure_functions_custom_entra_resource.json", metadataProvider: azureFunctionsMetadataProvider(), cfg: &Config{WorkloadIdentityEntraResource: "api://1111111-2222-3333-44444-55555555"}, expectedIss: "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/", expectedError: nil, }, { name: "Non-json response", wiremockMappingPath: "auth/wif/azure/non_json_response.json", metadataProvider: azureVMMetadataProvider(), expectedError: fmt.Errorf("failed to extract token from JSON: invalid character 'o' in literal null (expecting 'u')"), }, { name: "Identity endpoint but no identity header", metadataProvider: &mockAzureAttestationMetadataProvider{ identityEndpointValue: wiremock.baseURL() + "/metadata/identity/endpoint/from/env", identityHeaderValue: "", clientIDValue: "managed-client-id-from-env", }, expectedError: fmt.Errorf("managed identity is not enabled on this Azure function"), }, { name: "Unparsable token", wiremockMappingPath: "auth/wif/azure/unparsable_token.json", metadataProvider: azureVMMetadataProvider(), expectedError: fmt.Errorf("failed to extract sub and iss claims from token: unable to extract JWT claims from token: token is malformed: token contains an invalid number of segments"), }, { name: "HTTP error", metadataProvider: azureVMMetadataProvider(), wiremockMappingPath: "auth/wif/azure/http_error.json", expectedError: fmt.Errorf("could not fetch Azure token"), }, { name: "Missing sub or iss claim", wiremockMappingPath: "auth/wif/azure/missing_issuer_claim.json", metadataProvider: azureVMMetadataProvider(), expectedError: fmt.Errorf("failed to extract sub and iss claims from token: missing issuer claim in JWT token"), }, { name: "Missing sub claim", wiremockMappingPath: "auth/wif/azure/missing_sub_claim.json", metadataProvider: azureVMMetadataProvider(), expectedError: fmt.Errorf("failed to extract sub and iss claims from token: missing sub claim in JWT token"), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { if test.wiremockMappingPath != "" { wiremock.registerMappings(t, wiremockMapping{filePath: test.wiremockMappingPath}) } creator := &azureIdentityAttestationCreator{ cfg: test.cfg, azureMetadataServiceBaseURL: wiremock.baseURL(), azureAttestationMetadataProvider: test.metadataProvider, workloadIdentityEntraResource: determineEntraResource(test.cfg), } attestation, err := creator.createAttestation() if test.expectedError != nil { assertNilF(t, attestation) assertNotNilE(t, err) assertEqualE(t, test.expectedError.Error(), err.Error()) } else { assertNilF(t, err) assertNotNilF(t, attestation) assertEqualE(t, string(azureWif), attestation.ProviderType) assertEqualE(t, test.expectedIss, attestation.Metadata["iss"]) assertEqualE(t, "77213E30-E8CB-4595-B1B6-5F050E8308FD", attestation.Metadata["sub"]) } }) } } type mockAzureAttestationMetadataProvider struct { identityEndpointValue string identityHeaderValue string clientIDValue string } func (m *mockAzureAttestationMetadataProvider) identityEndpoint() string { return m.identityEndpointValue } func (m *mockAzureAttestationMetadataProvider) identityHeader() string { return m.identityHeaderValue } func (m *mockAzureAttestationMetadataProvider) clientID() string { return m.clientIDValue } func azureFunctionsMetadataProvider() *mockAzureAttestationMetadataProvider { return &mockAzureAttestationMetadataProvider{ identityEndpointValue: wiremock.baseURL() + "/metadata/identity/endpoint/from/env", identityHeaderValue: "some-identity-header-from-env", clientIDValue: "managed-client-id-from-env", } } func azureVMMetadataProvider() *mockAzureAttestationMetadataProvider { return &mockAzureAttestationMetadataProvider{ identityEndpointValue: "", identityHeaderValue: "", clientIDValue: "", } } // Running this test locally: // * Push branch to repository // * Set PARAMETERS_SECRET // * Run ci/test_wif.sh func TestWorkloadIdentityAuthOnCloudVM(t *testing.T) { account := os.Getenv("SNOWFLAKE_TEST_WIF_ACCOUNT") host := os.Getenv("SNOWFLAKE_TEST_WIF_HOST") provider := os.Getenv("SNOWFLAKE_TEST_WIF_PROVIDER") println("provider = " + provider) if account == "" || host == "" || provider == "" { t.Skip("Test can run only on cloud VM with env variables set") } testCases := []struct { name string skip func() (bool, string) setupCfg func(*testing.T, *Config) expectedUsername string }{ { name: "provider=" + provider, setupCfg: func(_ *testing.T, config *Config) { if provider != "GCP+OIDC" { config.WorkloadIdentityProvider = provider } else { config.WorkloadIdentityProvider = "OIDC" config.Token = func() string { cmd := exec.Command("wget", "-O", "-", "--header=Metadata-Flavor: Google", "http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience=snowflakecomputing.com") output, err := cmd.Output() if err != nil { t.Fatalf("error executing GCP metadata request: %v", err) } token := strings.TrimSpace(string(output)) if token == "" { t.Fatal("failed to retrieve GCP access token: empty response") } return token }() } }, expectedUsername: os.Getenv("SNOWFLAKE_TEST_WIF_USERNAME"), }, { name: "provider=" + provider + ",impersonation", skip: func() (bool, string) { if provider != "AWS" && provider != "GCP" { return true, "Impersonation is supported only on AWS and GCP" } return false, "" }, setupCfg: func(t *testing.T, config *Config) { config.WorkloadIdentityProvider = provider impersonationPath := os.Getenv("SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH") assertNotEqualF(t, impersonationPath, "", "SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH is not set") config.WorkloadIdentityImpersonationPath = strings.Split(impersonationPath, ",") assertNotEqualF(t, os.Getenv("SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION"), "", "SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION is not set") }, expectedUsername: os.Getenv("SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION"), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { if tc.skip != nil { if skip, msg := tc.skip(); skip { t.Skip(msg) } } config := &Config{ Account: account, Host: host, Authenticator: AuthTypeWorkloadIdentityFederation, } tc.setupCfg(t, config) connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) defer db.Close() currentUser := runSelectCurrentUser(t, db) assertEqualE(t, currentUser, tc.expectedUsername) }) } } ================================================ FILE: auth_with_external_browser_test.go ================================================ package gosnowflake import ( "context" "database/sql" "fmt" "log" "os/exec" "sync" "testing" "time" ) func TestExternalBrowserSuccessful(t *testing.T) { cfg := setupExternalBrowserTest(t) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.Success, cfg.User, cfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) }() wg.Wait() } func TestExternalBrowserFailed(t *testing.T) { cfg := setupExternalBrowserTest(t) cfg.ExternalBrowserTimeout = time.Duration(10) * time.Second var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.Fail, "FakeAccount", "NotARealPassword") }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNotNilF(t, err) assertEqualE(t, err.Error(), "authentication timed out") }() wg.Wait() } func TestExternalBrowserTimeout(t *testing.T) { cfg := setupExternalBrowserTest(t) cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.Timeout, cfg.User, cfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNotNilF(t, err) assertEqualE(t, err.Error(), "authentication timed out") }() wg.Wait() } func TestExternalBrowserMismatchUser(t *testing.T) { cfg := setupExternalBrowserTest(t) correctUsername := cfg.User cfg.User = "fakeAccount" var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.Success, correctUsername, cfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390191, fmt.Sprintf("Expected 390191, but got %v", snowflakeErr.Number)) }() wg.Wait() } func TestClientStoreCredentials(t *testing.T) { cfg := setupExternalBrowserTest(t) cfg.ClientStoreTemporaryCredential = 1 cfg.ExternalBrowserTimeout = time.Duration(10) * time.Second t.Run("Obtains the ID token from the server and saves it on the local storage", func(t *testing.T) { cleanupBrowserProcesses(t) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.Success, cfg.User, cfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed: err %v", err)) }() wg.Wait() }) t.Run("Verify validation of ID token if option enabled", func(t *testing.T) { cleanupBrowserProcesses(t) cfg.ClientStoreTemporaryCredential = 1 db := getDbHandlerFromConfig(t, cfg) conn, err := db.Conn(context.Background()) assertNilE(t, err, fmt.Sprintf("Failed to connect to Snowflake. err: %v", err)) defer conn.Close() rows, err := conn.QueryContext(context.Background(), "SELECT 1") assertNilE(t, err, fmt.Sprintf("Failed to run a query. err: %v", err)) rows.Close() }) t.Run("Verify validation of idToken if option disabled", func(t *testing.T) { cleanupBrowserProcesses(t) cfg.ClientStoreTemporaryCredential = 0 db := getDbHandlerFromConfig(t, cfg) _, err := db.Conn(context.Background()) assertNotNilF(t, err) assertEqualE(t, err.Error(), "authentication timed out", fmt.Sprintf("Expected timeout, but got %v", err)) }) } type ExternalBrowserProcessResult struct { Success string Fail string Timeout string OauthOktaSuccess string OauthSnowflakeSuccess string } var externalBrowserType = ExternalBrowserProcessResult{ Success: "success", Fail: "fail", Timeout: "timeout", OauthOktaSuccess: "externalOauthOktaSuccess", OauthSnowflakeSuccess: "internalOauthSnowflakeSuccess", } func cleanupBrowserProcesses(t *testing.T) { if isTestRunningInDockerContainer() { const cleanBrowserProcessesPath = "/externalbrowser/cleanBrowserProcesses.js" _, err := exec.Command("node", cleanBrowserProcessesPath).CombinedOutput() assertNilE(t, err, fmt.Sprintf("failed to execute command: %v", err)) } } func provideExternalBrowserCredentials(t *testing.T, ExternalBrowserProcess string, user string, password string) { if isTestRunningInDockerContainer() { const provideBrowserCredentialsPath = "/externalbrowser/provideBrowserCredentials.js" output, err := exec.Command("node", provideBrowserCredentialsPath, ExternalBrowserProcess, user, password).CombinedOutput() log.Printf("Output: %s\n", output) assertNilE(t, err, fmt.Sprintf("failed to execute command: %v", err)) } } func verifyConnectionToSnowflakeAuthTests(t *testing.T, cfg *Config) (err error) { dsn, err := DSN(cfg) assertNilE(t, err, "failed to create DSN from Config") db, err := sql.Open("snowflake", dsn) assertNilE(t, err, "failed to open Snowflake DB connection") defer db.Close() rows, err := db.Query("SELECT 1") if err != nil { log.Printf("failed to run a query. 'SELECT 1', err: %v", err) return err } defer rows.Close() assertTrueE(t, rows.Next(), "failed to get result", "There were no results for query: ") return err } func setupExternalBrowserTest(t *testing.T) *Config { skipAuthTests(t, "Skipping External Browser tests") cleanupBrowserProcesses(t) cfg, err := getAuthTestsConfig(t, AuthTypeExternalBrowser) assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) return cfg } ================================================ FILE: auth_with_keypair_test.go ================================================ package gosnowflake import ( "crypto/rsa" "fmt" "golang.org/x/crypto/ssh" "os" "testing" ) func TestKeypairSuccessful(t *testing.T) { cfg := setupKeyPairTest(t) cfg.PrivateKey = loadRsaPrivateKeyForKeyPair(t, "SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH") err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err)) } func TestKeypairInvalidKey(t *testing.T) { cfg := setupKeyPairTest(t) cfg.PrivateKey = loadRsaPrivateKeyForKeyPair(t, "SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH") err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390144, fmt.Sprintf("Expected 390144, but got %v", snowflakeErr.Number)) } func setupKeyPairTest(t *testing.T) *Config { skipAuthTests(t, "Skipping KeyPair tests") cfg, err := getAuthTestsConfig(t, AuthTypeJwt) assertEqualE(t, err, nil, fmt.Sprintf("failed to get config: %v", err)) return cfg } func loadRsaPrivateKeyForKeyPair(t *testing.T, envName string) *rsa.PrivateKey { filePath, err := GetFromEnv(envName, true) assertNilF(t, err, fmt.Sprintf("failed to get env: %v", err)) bytes, err := os.ReadFile(filePath) assertNilF(t, err, fmt.Sprintf("failed to read file: %v", err)) key, err := ssh.ParseRawPrivateKey(bytes) assertNilF(t, err, fmt.Sprintf("failed to parse private key: %v", err)) return key.(*rsa.PrivateKey) } ================================================ FILE: auth_with_mfa_test.go ================================================ package gosnowflake import ( "errors" "fmt" "log" "os/exec" "strings" "testing" ) func TestMfaSuccessful(t *testing.T) { cfg := setupMfaTest(t) // Enable MFA token caching cfg.ClientRequestMfaToken = ConfigBoolTrue //Provide your own TOTP code/codes here, to test manually //totpKeys := []string{"222222", "333333", "444444"} totpKeys := getTOPTcodes(t) verifyConnectionToSnowflakeUsingTotpCodes(t, cfg, totpKeys) log.Printf("Testing MFA token caching with second connection...") // Clear the passcode to force use of cached MFA token cfg.Passcode = "" // Attempt to connect using cached MFA token cacheErr := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilF(t, cacheErr, "Failed to connect with cached MFA token") } func setupMfaTest(t *testing.T) *Config { skipAuthTests(t, "Skipping MFA tests") cfg, err := getAuthTestsConfig(t, AuthTypeUsernamePasswordMFA) assertNilF(t, err, "failed to get config") cfg.User, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_MFA_USER", true) assertNilF(t, err, "failed to get MFA user from environment") cfg.Password, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_MFA_PASSWORD", true) assertNilF(t, err, "failed to get MFA password from environment") return cfg } func getTOPTcodes(t *testing.T) []string { if isTestRunningInDockerContainer() { const provideTotpPath = "/externalbrowser/totpGenerator.js" output, err := exec.Command("node", provideTotpPath).CombinedOutput() assertNilF(t, err, fmt.Sprintf("failed to execute command: %v", err)) totpCodes := strings.Fields(string(output)) return totpCodes } return []string{} } func verifyConnectionToSnowflakeUsingTotpCodes(t *testing.T, cfg *Config, totpKeys []string) { if len(totpKeys) == 0 { t.Fatalf("no TOTP codes provided") } var lastError error for i, totpKey := range totpKeys { cfg.Passcode = totpKey err := verifyConnectionToSnowflakeAuthTests(t, cfg) if err == nil { return } lastError = err errorMsg := err.Error() log.Printf("TOTP code %d failed: %v", i+1, errorMsg) var snowflakeErr *SnowflakeError if errors.As(err, &snowflakeErr) && (snowflakeErr.Number == 394633 || snowflakeErr.Number == 394507) { log.Printf("MFA error detected (%d), trying next code...", snowflakeErr.Number) continue } else { log.Printf("Non-MFA error detected: %v", errorMsg) break } } assertNilF(t, lastError, "failed to connect with any TOTP code") } ================================================ FILE: auth_with_oauth_okta_authorization_code_test.go ================================================ package gosnowflake import ( "fmt" "sync" "testing" "time" ) func TestOauthOktaAuthorizationCodeSuccessful(t *testing.T) { cfg := setupOauthOktaAuthorizationCodeTest(t) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthOktaSuccess, cfg.User, cfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) }() wg.Wait() } func TestOauthOktaAuthorizationCodeMismatchedUsername(t *testing.T) { cfg := setupOauthOktaAuthorizationCodeTest(t) user := cfg.User var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthOktaSuccess, user, cfg.Password) }() go func() { defer wg.Done() cfg.User = "fakeUser@snowflake.com" err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number)) }() wg.Wait() } func TestOauthOktaAuthorizationCodeOktaTimeout(t *testing.T) { cfg := setupOauthOktaAuthorizationCodeTest(t) cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNotNilF(t, err, "should failed due to timeout") assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err)) } func TestOauthOktaAuthorizationCodeUsingTokenCache(t *testing.T) { cfg := setupOauthOktaAuthorizationCodeTest(t) cfg.ClientStoreTemporaryCredential = 1 var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthOktaSuccess, cfg.User, cfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) }() wg.Wait() cleanupBrowserProcesses(t) cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) } func setupOauthOktaAuthorizationCodeTest(t *testing.T) *Config { skipAuthTests(t, "Skipping Okta Authorization Code tests") cfg, err := getAuthTestsConfig(t, AuthTypeOAuthAuthorizationCode) assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) cleanupBrowserProcesses(t) cfg.OauthClientID, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthClientSecret, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_SECRET", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthRedirectURI, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_REDIRECT_URI", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthAuthorizationURL, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_AUTH_URL", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthTokenRequestURL, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_TOKEN", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.Role, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_ROLE", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) return cfg } ================================================ FILE: auth_with_oauth_okta_client_credentials_test.go ================================================ package gosnowflake import ( "fmt" "strings" "testing" ) func TestOauthOktaClientCredentialsSuccessful(t *testing.T) { cfg := setupOauthOktaClientCredentialsTest(t) err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err)) } func TestOauthOktaClientCredentialsMismatchedUsername(t *testing.T) { cfg := setupOauthOktaClientCredentialsTest(t) cfg.User = "invalidUser" err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number)) } func TestOauthOktaClientCredentialsUnauthorized(t *testing.T) { cfg := setupOauthOktaClientCredentialsTest(t) cfg.OauthClientID = "invalidClientID" err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNotNilF(t, err, "Expected an error but got nil") assertTrueF(t, strings.Contains(err.Error(), "invalid_client"), fmt.Sprintf("Expected error to contain 'invalid_client', but got: %v", err.Error())) } func setupOauthOktaClientCredentialsTest(t *testing.T) *Config { skipAuthTests(t, "Skipping Okta Client Credentials tests") cfg, err := getAuthTestsConfig(t, AuthTypeOAuthClientCredentials) assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) cfg.OauthClientID, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthClientSecret, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_SECRET", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthTokenRequestURL, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_TOKEN", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.User, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.Role, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_ROLE", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) return cfg } ================================================ FILE: auth_with_oauth_snowflake_authorization_code_test.go ================================================ package gosnowflake import ( "fmt" "sync" "testing" "time" ) func TestOauthSnowflakeAuthorizationCodeSuccessful(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeTest(t) browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials() assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err)) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) }() wg.Wait() } func TestOauthSnowflakeAuthorizationCodeMismatchedUsername(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeTest(t) browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials() assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err)) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password) }() go func() { defer wg.Done() cfg.User = "fakeUser@snowflake.com" err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number)) }() wg.Wait() } func TestOauthSnowflakeAuthorizationCodeTimeout(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeTest(t) cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNotNilF(t, err, "should failed due to timeout") assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err)) } func TestOauthSnowflakeAuthorizationCodeUsingTokenCache(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeTest(t) browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials() assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err)) cfg.ClientStoreTemporaryCredential = 1 var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) }() wg.Wait() cleanupBrowserProcesses(t) cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second err = verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) } func TestOauthSnowflakeAuthorizationCodeWithoutTokenCache(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeTest(t) browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials() assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err)) cfg.ClientStoreTemporaryCredential = 2 var wg sync.WaitGroup cfg.DisableQueryContextCache = true wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) }() wg.Wait() cleanupBrowserProcesses(t) cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second err = verifyConnectionToSnowflakeAuthTests(t, cfg) assertNotNilF(t, err, "Expected an error but got nil") assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err)) } func setupOauthSnowflakeAuthorizationCodeTest(t *testing.T) *Config { skipAuthTests(t, "Skipping Snowflake Authorization Code tests") cfg, err := getAuthTestsConfig(t, AuthTypeOAuthAuthorizationCode) assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) cleanupBrowserProcesses(t) cfg.OauthClientID, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_CLIENT_ID", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthClientSecret, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_CLIENT_SECRET", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthRedirectURI, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_REDIRECT_URI", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.User, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.Role, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_ROLE", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.ClientStoreTemporaryCredential = 2 return cfg } func getOauthSnowflakeAuthorizationCodeTestCredentials() (*Config, error) { return GetConfigFromEnv([]*ConfigParam{ {Name: "User", EnvName: "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", FailOnMissing: true}, {Name: "Password", EnvName: "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_USER_PASSWORD", FailOnMissing: true}, }) } ================================================ FILE: auth_with_oauth_snowflake_authorization_code_wildcards_test.go ================================================ package gosnowflake import ( "fmt" "sync" "testing" "time" ) func TestOauthSnowflakeAuthorizationCodeWildcardsSuccessful(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeWildcardsTest(t) browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials() assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err)) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) }() wg.Wait() } func TestOauthSnowflakeAuthorizationCodeWildcardsMismatchedUsername(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeWildcardsTest(t) browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials() assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err)) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password) }() go func() { defer wg.Done() cfg.User = "fakeUser@snowflake.com" err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number)) }() wg.Wait() } func TestOauthSnowflakeAuthorizationWildcardsCodeTimeout(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeWildcardsTest(t) cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNotNilF(t, err, "should failed due to timeout") assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err)) } func TestOauthSnowflakeAuthorizationCodeWildcardsWithoutTokenCache(t *testing.T) { cfg := setupOauthSnowflakeAuthorizationCodeWildcardsTest(t) browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials() assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err)) cfg.ClientStoreTemporaryCredential = 2 var wg sync.WaitGroup cfg.DisableQueryContextCache = true wg.Add(2) go func() { defer wg.Done() provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password) }() go func() { defer wg.Done() err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err)) }() wg.Wait() cleanupBrowserProcesses(t) cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second err = verifyConnectionToSnowflakeAuthTests(t, cfg) assertNotNilF(t, err, "Expected an error but got nil") assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err)) } func setupOauthSnowflakeAuthorizationCodeWildcardsTest(t *testing.T) *Config { skipAuthTests(t, "Skipping Snowflake Authorization Code tests") cfg, err := getAuthTestsConfig(t, AuthTypeOAuthAuthorizationCode) assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) cleanupBrowserProcesses(t) cfg.OauthClientID, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_WILDCARDS_CLIENT_ID", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.OauthClientSecret, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_WILDCARDS_CLIENT_SECRET", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.User, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.Role, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE", true) assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err)) cfg.ClientStoreTemporaryCredential = 2 return cfg } ================================================ FILE: auth_with_oauth_test.go ================================================ package gosnowflake import ( "encoding/json" "fmt" "net/http" "net/url" "strings" "testing" ) func TestOauthSuccessful(t *testing.T) { cfg := setupOauthTest(t) token, err := getOauthTestToken(t, cfg) assertNilE(t, err, fmt.Sprintf("failed to get token. err: %v", err)) cfg.Token = token err = verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err)) } func TestOauthInvalidToken(t *testing.T) { cfg := setupOauthTest(t) cfg.Token = "invalid_token" err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390303, fmt.Sprintf("Expected 390303, but got %v", snowflakeErr.Number)) } func TestOauthMismatchedUser(t *testing.T) { cfg := setupOauthTest(t) token, err := getOauthTestToken(t, cfg) assertNilE(t, err, fmt.Sprintf("failed to get token. err: %v", err)) cfg.Token = token cfg.User = "fakeaccount" err = verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number)) } func setupOauthTest(t *testing.T) *Config { skipAuthTests(t, "Skipping OAuth tests") cfg, err := getAuthTestsConfig(t, AuthTypeOAuth) assertNilF(t, err, fmt.Sprintf("failed to connect. err: %v", err)) return cfg } func getOauthTestToken(t *testing.T, cfg *Config) (string, error) { client := &http.Client{} authURL, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_URL", true) assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_URL is not set") oauthClientID, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_ID", true) assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_ID is not set") oauthClientSecret, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_SECRET", true) assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_SECRET is not set") inputData := formData(cfg) req, err := http.NewRequest("POST", authURL, strings.NewReader(inputData.Encode())) assertNilF(t, err, fmt.Sprintf("Request failed %v", err)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8") req.SetBasicAuth(oauthClientID, oauthClientSecret) resp, err := client.Do(req) assertNilF(t, err, fmt.Sprintf("Response failed %v", err)) if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("failed to get access token, status code: %d", resp.StatusCode) } defer resp.Body.Close() var response OAuthTokenResponse if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return "", fmt.Errorf("failed to decode response: %v", err) } return response.Token, err } func formData(cfg *Config) url.Values { data := url.Values{} data.Set("username", cfg.User) data.Set("password", cfg.Password) data.Set("grant_type", "password") data.Set("scope", fmt.Sprintf("session:role:%s", strings.ToLower(cfg.Role))) return data } type OAuthTokenResponse struct { Type string `json:"token_type"` Expiration int `json:"expires_in"` Token string `json:"access_token"` Scope string `json:"scope"` } ================================================ FILE: auth_with_okta_test.go ================================================ package gosnowflake import ( "fmt" "net/url" "testing" ) func TestOktaSuccessful(t *testing.T) { cfg := setupOktaTest(t) err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err)) } func TestOktaWrongCredentials(t *testing.T) { cfg := setupOktaTest(t) cfg.Password = "fakePassword" err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 261006, fmt.Sprintf("Expected 261006, but got %v", snowflakeErr.Number)) } func TestOktaWrongAuthenticator(t *testing.T) { cfg := setupOktaTest(t) invalidAddress, err := url.Parse("https://fake-account-0000.okta.com") assertNilF(t, err, fmt.Sprintf("failed to parse: %v", err)) cfg.OktaURL = invalidAddress err = verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 390139, fmt.Sprintf("Expected 390139, but got %v", snowflakeErr.Number)) } func setupOktaTest(t *testing.T) *Config { skipAuthTests(t, "Skipping Okta tests") urlEnv, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OKTA_AUTH", true) assertNilF(t, err, fmt.Sprintf("failed to get env: %v", err)) cfg, err := getAuthTestsConfig(t, AuthTypeOkta) assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err)) cfg.OktaURL, err = url.Parse(urlEnv) assertNilF(t, err, fmt.Sprintf("failed to parse: %v", err)) return cfg } ================================================ FILE: auth_with_pat_test.go ================================================ package gosnowflake import ( "database/sql" "fmt" "log" "strings" "testing" "time" ) type PatToken struct { Name string Value string } func TestEndToEndPatSuccessful(t *testing.T) { cfg := setupEndToEndPatTest(t) patToken := createEndToEndPatToken(t) defer removeEndToEndPatToken(t, patToken.Name) cfg.Token = patToken.Value err := verifyConnectionToSnowflakeAuthTests(t, cfg) assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err)) } func TestEndToEndPatMismatchedUser(t *testing.T) { cfg := setupEndToEndPatTest(t) patToken := createEndToEndPatToken(t) defer removeEndToEndPatToken(t, patToken.Name) cfg.Token = patToken.Value cfg.User = "invalidUser" err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 394400, fmt.Sprintf("Expected 394400, but got %v", snowflakeErr.Number)) } func TestEndToEndPatInvalidToken(t *testing.T) { cfg := setupEndToEndPatTest(t) cfg.Token = "invalidToken" err := verifyConnectionToSnowflakeAuthTests(t, cfg) var snowflakeErr *SnowflakeError assertErrorsAsF(t, err, &snowflakeErr) assertEqualE(t, snowflakeErr.Number, 394400, fmt.Sprintf("Expected 394400, but got %v", snowflakeErr.Number)) } func setupEndToEndPatTest(t *testing.T) *Config { skipAuthTests(t, "Skipping PAT tests") cfg, err := getAuthTestsConfig(t, AuthTypePat) assertNilF(t, err, fmt.Sprintf("failed to parse: %v", err)) return cfg } func getEndToEndPatSetupCommandVariables() (*Config, error) { return GetConfigFromEnv([]*ConfigParam{ {Name: "User", EnvName: "SNOWFLAKE_AUTH_TEST_SNOWFLAKE_USER", FailOnMissing: true}, {Name: "Role", EnvName: "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE", FailOnMissing: true}, }) } func createEndToEndPatToken(t *testing.T) *PatToken { cfg := setupOktaTest(t) patTokenName := fmt.Sprintf("PAT_GOLANG_%s", strings.ReplaceAll(time.Now().Format("20060102150405.000"), ".", "")) patCommandVariables, err := getEndToEndPatSetupCommandVariables() assertNilE(t, err, "failed to get PAT command variables") query := fmt.Sprintf( "alter user %s add programmatic access token %s ROLE_RESTRICTION = '%s' DAYS_TO_EXPIRY=1;", patCommandVariables.User, patTokenName, patCommandVariables.Role, ) patToken, err := connectUsingOktaConnectionAndExecuteCustomCommand(t, cfg, query, true) assertNilE(t, err, "failed to create PAT command variables") return patToken } func removeEndToEndPatToken(t *testing.T, patTokenName string) { cfg := setupOktaTest(t) cfg.Role = "analyst" patCommandVariables, err := getEndToEndPatSetupCommandVariables() assertNilE(t, err, "failed to get PAT command variables") query := fmt.Sprintf( "alter user %s remove programmatic access token %s;", patCommandVariables.User, patTokenName, ) _, err = connectUsingOktaConnectionAndExecuteCustomCommand(t, cfg, query, false) assertNilE(t, err, "failed to remove PAT command variables") } func connectUsingOktaConnectionAndExecuteCustomCommand(t *testing.T, cfg *Config, query string, returnToken bool) (*PatToken, error) { dsn, err := DSN(cfg) assertNilE(t, err, "failed to create DSN from Config") db, err := sql.Open("snowflake", dsn) assertNilE(t, err, "failed to open Snowflake DB connection") defer db.Close() rows, err := db.Query(query) if err != nil { log.Printf("failed to run a query: %v, err: %v", query, err) return nil, err } var patTokenName, patTokenValue string if returnToken && rows.Next() { if err := rows.Scan(&patTokenName, &patTokenValue); err != nil { t.Fatalf("failed to scan token: %v", err) } return &PatToken{Name: patTokenName, Value: patTokenValue}, nil } if returnToken { t.Fatalf("no results found for query: %s", query) } return nil, err } ================================================ FILE: authexternalbrowser.go ================================================ package gosnowflake import ( "bytes" "context" "encoding/base64" "encoding/json" "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "io" "log" "net" "net/http" "net/url" "strconv" "strings" "time" "github.com/pkg/browser" ) const ( samlSuccessHTML = ` SAML Response for Snowflake Your identity was confirmed and propagated to Snowflake %v. You can close this window now and go back where you started from. ` bufSize = 8192 ) // Builds a response to show to the user after successfully // getting a response from Snowflake. func buildResponse(body string) (bytes.Buffer, error) { t := &http.Response{ Status: "200 OK", StatusCode: 200, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Body: io.NopCloser(bytes.NewBufferString(body)), ContentLength: int64(len(body)), Request: nil, Header: make(http.Header), } var b bytes.Buffer err := t.Write(&b) return b, err } // This opens a socket that listens on all available unicast // and any anycast IP addresses locally. By specifying "0", we are // able to bind to a free port. func createLocalTCPListener(port int) (*net.TCPListener, error) { logger.Debugf("creating local TCP listener on port %v", port) allAddressesListener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%v", port)) if err != nil { logger.Warnf("error while setting up 0.0.0.0 listener: %v", err) return nil, err } logger.Debug("Closing 0.0.0.0 tcp listener") if err := allAddressesListener.Close(); err != nil { logger.Errorf("error while closing TCP listener. %v", err) return nil, err } l, err := net.Listen("tcp", fmt.Sprintf("localhost:%v", port)) if err != nil { logger.Warnf("error while setting up listener: %v", err) return nil, err } tcpListener, ok := l.(*net.TCPListener) if !ok { return nil, fmt.Errorf("failed to assert type as *net.TCPListener") } return tcpListener, nil } // Opens a browser window (or new tab) with the configured login Url. // This can / will fail if running inside a shell with no display, ie // ssh'ing into a box attempting to authenticate via external browser. func openBrowser(browserURL string) error { parsedURL, err := url.ParseRequestURI(browserURL) if err != nil { logger.Errorf("error parsing url %v, err: %v", browserURL, err) return err } if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { return fmt.Errorf("invalid browser URL: %v", browserURL) } err = browser.OpenURL(browserURL) if err != nil { logger.Errorf("failed to open a browser. err: %v", err) return err } return nil } // Gets the IDP Url and Proof Key from Snowflake. // Note: FuncPostAuthSaml will return a fully qualified error if // there is something wrong getting data from Snowflake. func getIdpURLProofKey( ctx context.Context, sr *snowflakeRestful, authenticator string, application string, account string, user string, callbackPort int) (string, string, error) { headers := make(map[string]string) headers[httpHeaderContentType] = headerContentTypeApplicationJSON headers[httpHeaderAccept] = headerContentTypeApplicationJSON headers[httpHeaderUserAgent] = userAgent clientEnvironment := newAuthRequestClientEnvironment() clientEnvironment.Application = application requestMain := authRequestData{ ClientAppID: clientType, ClientAppVersion: SnowflakeGoDriverVersion, AccountName: account, LoginName: user, ClientEnvironment: clientEnvironment, Authenticator: authenticator, BrowserModeRedirectPort: strconv.Itoa(callbackPort), } authRequest := authRequest{ Data: requestMain, } jsonBody, err := json.Marshal(authRequest) if err != nil { logger.WithContext(ctx).Errorf("failed to serialize json. err: %v", err) return "", "", err } respd, err := sr.FuncPostAuthSAML(ctx, sr, headers, jsonBody, sr.LoginTimeout) if err != nil { return "", "", err } if !respd.Success { logger.WithContext(ctx).Error("Authentication FAILED") sr.TokenAccessor.SetTokens("", "", -1) code, err := strconv.Atoi(respd.Code) if err != nil { return "", "", err } return "", "", &SnowflakeError{ Number: code, SQLState: SQLStateConnectionRejected, Message: respd.Message, } } return respd.Data.SSOURL, respd.Data.ProofKey, nil } // Gets the login URL for multiple SAML func getLoginURL(sr *snowflakeRestful, user string, callbackPort int) (string, string, error) { proofKey := generateProofKey() params := &url.Values{} params.Add("login_name", user) params.Add("browser_mode_redirect_port", strconv.Itoa(callbackPort)) params.Add("proof_key", proofKey) url := sr.getFullURL(consoleLoginRequestPath, params) return url.String(), proofKey, nil } func generateProofKey() string { randomness := getSecureRandom(32) return base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(randomness) } // The response returned from Snowflake looks like so: // GET /?token=encodedSamlToken // Host: localhost:54001 // Connection: keep-alive // Upgrade-Insecure-Requests: 1 // User-Agent: userAgentStr // Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8 // Referer: https://myaccount.snowflakecomputing.com/fed/login // Accept-Encoding: gzip, deflate, br // Accept-Language: en-US,en;q=0.9 // This extracts the token portion of the response. func getTokenFromResponse(response string) (string, error) { start := "GET /?token=" arr := strings.Split(response, "\r\n") if !strings.HasPrefix(arr[0], start) { logger.Errorf("response is malformed. ") return "", &SnowflakeError{ Number: ErrFailedToParseResponse, SQLState: SQLStateConnectionRejected, Message: errors2.ErrMsgFailedToParseResponse, MessageArgs: []any{response}, } } token := strings.TrimPrefix(arr[0], start) token = strings.Split(token, " ")[0] return token, nil } type authenticateByExternalBrowserResult struct { escapedSamlResponse []byte proofKey []byte err error } func authenticateByExternalBrowser(ctx context.Context, sr *snowflakeRestful, authenticator string, application string, account string, user string, externalBrowserTimeout time.Duration, disableConsoleLogin ConfigBool) ([]byte, []byte, error) { resultChan := make(chan authenticateByExternalBrowserResult, 1) go GoroutineWrapper( ctx, func() { resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, disableConsoleLogin) }, ) select { case <-time.After(externalBrowserTimeout): return nil, nil, errors.New("authentication timed out") case result := <-resultChan: return result.escapedSamlResponse, result.proofKey, result.err } } // Authentication by an external browser takes place via the following: // - the golang snowflake driver communicates to Snowflake that the user wishes to // authenticate via external browser // - snowflake sends back the IDP Url configured at the Snowflake side for the // provided account, or use the multiple SAML way via console login // - the default browser is opened to that URL // - user authenticates at the IDP, and is redirected to Snowflake // - Snowflake directs the user back to the driver // - authenticate is complete! func doAuthenticateByExternalBrowser(ctx context.Context, sr *snowflakeRestful, authenticator string, application string, account string, user string, disableConsoleLogin ConfigBool) authenticateByExternalBrowserResult { l, err := createLocalTCPListener(0) if err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } defer func() { if err = l.Close(); err != nil { logger.Errorf("error while closing TCP listener for external browser (%v). %v", l.Addr().String(), err) } }() callbackPort := l.Addr().(*net.TCPAddr).Port var loginURL string var proofKey string if disableConsoleLogin == ConfigBoolTrue { // Gets the IDP URL and Proof Key from Snowflake loginURL, proofKey, err = getIdpURLProofKey(ctx, sr, authenticator, application, account, user, callbackPort) } else { // Multiple SAML way to do authentication via console login loginURL, proofKey, err = getLoginURL(sr, user, callbackPort) } if err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } if err = defaultSamlResponseProvider().run(loginURL); err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } encodedSamlResponseChan := make(chan string) errChan := make(chan error) var encodedSamlResponse string var errFromGoroutine error conn, err := l.Accept() if err != nil { logger.WithContext(ctx).Errorf("unable to accept connection. err: %v", err) log.Fatal(err) } go func(c net.Conn) { var buf bytes.Buffer total := 0 encodedSamlResponse := "" var errAccept error for { b := make([]byte, bufSize) n, err := c.Read(b) if err != nil { if err != io.EOF { logger.WithContext(ctx).Infof("error reading from socket. err: %v", err) errAccept = &SnowflakeError{ Number: ErrFailedToGetExternalBrowserResponse, SQLState: SQLStateConnectionRejected, Message: errors2.ErrMsgFailedToGetExternalBrowserResponse, MessageArgs: []any{err}, } } break } total += n buf.Write(b) if n < bufSize { // We successfully read all data s := string(buf.Bytes()[:total]) encodedSamlResponse, errAccept = getTokenFromResponse(s) break } buf.Grow(bufSize) } if encodedSamlResponse != "" { body := fmt.Sprintf(samlSuccessHTML, application) httpResponse, err := buildResponse(body) if err != nil && errAccept == nil { errAccept = err } if _, err = c.Write(httpResponse.Bytes()); err != nil && errAccept == nil { errAccept = err } } if err := c.Close(); err != nil { logger.Warnf("error while closing browser connection. %v", err) } encodedSamlResponseChan <- encodedSamlResponse errChan <- errAccept }(conn) encodedSamlResponse = <-encodedSamlResponseChan errFromGoroutine = <-errChan if errFromGoroutine != nil { return authenticateByExternalBrowserResult{nil, nil, errFromGoroutine} } escapedSamlResponse, err := url.QueryUnescape(encodedSamlResponse) if err != nil { logger.WithContext(ctx).Errorf("unable to unescape saml response. err: %v", err) return authenticateByExternalBrowserResult{nil, nil, err} } return authenticateByExternalBrowserResult{[]byte(escapedSamlResponse), []byte(proofKey), nil} } type samlResponseProvider interface { run(url string) error } type externalBrowserSamlResponseProvider struct { } func (e externalBrowserSamlResponseProvider) run(url string) error { return openBrowser(url) } var defaultSamlResponseProvider = func() samlResponseProvider { return &externalBrowserSamlResponseProvider{} } ================================================ FILE: authexternalbrowser_test.go ================================================ package gosnowflake import ( "context" "errors" "fmt" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "net/http" "net/url" "strings" "testing" "time" ) func TestGetTokenFromResponseFail(t *testing.T) { response := "GET /?fakeToken=fakeEncodedSamlToken HTTP/1.1\r\n" + "Host: localhost:54001\r\n" + "Connection: keep-alive\r\n" + "Upgrade-Insecure-Requests: 1\r\n" + "User-Agent: userAgentStr\r\n" + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\r\n" + "Referer: https://myaccount.snowflakecomputing.com/fed/login\r\n" + "Accept-Encoding: gzip, deflate, br\r\n" + "Accept-Language: en-US,en;q=0.9\r\n\r\n" _, err := getTokenFromResponse(response) if err == nil { t.Errorf("Should have failed parsing the malformed response.") } } func TestGetTokenFromResponse(t *testing.T) { response := "GET /?token=GETtokenFromResponse HTTP/1.1\r\n" + "Host: localhost:54001\r\n" + "Connection: keep-alive\r\n" + "Upgrade-Insecure-Requests: 1\r\n" + "User-Agent: userAgentStr\r\n" + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\r\n" + "Referer: https://myaccount.snowflakecomputing.com/fed/login\r\n" + "Accept-Encoding: gzip, deflate, br\r\n" + "Accept-Language: en-US,en;q=0.9\r\n\r\n" expected := "GETtokenFromResponse" token, err := getTokenFromResponse(response) if err != nil { t.Errorf("Failed to get the token. Err: %#v", err) } if token != expected { t.Errorf("Expected: %s, found: %s", expected, token) } } func TestBuildResponse(t *testing.T) { resp, err := buildResponse(fmt.Sprintf(samlSuccessHTML, "Go")) assertNilF(t, err) bytes := resp.Bytes() respStr := string(bytes[:]) if !strings.Contains(respStr, "Your identity was confirmed and propagated to Snowflake Go.\nYou can close this window now and go back where you started from.") { t.Fatalf("failed to build response") } } func postAuthExternalBrowserError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{}, errors.New("failed to get SAML response") } func postAuthExternalBrowserErrorDelayed(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { time.Sleep(2 * time.Second) return &authResponse{}, errors.New("failed to get SAML response") } func postAuthExternalBrowserFail(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: false, Message: "external browser auth failed", }, nil } func postAuthExternalBrowserFailWithCode(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: false, Message: "failed to connect to db", Code: "260008", }, nil } func TestUnitAuthenticateByExternalBrowser(t *testing.T) { authenticator := "externalbrowser" application := "testapp" account := "testaccount" user := "u" timeout := sfconfig.DefaultExternalBrowserTimeout sr := &snowflakeRestful{ Protocol: "https", Host: "abc.com", Port: 443, FuncPostAuthSAML: postAuthExternalBrowserError, TokenAccessor: getSimpleTokenAccessor(), } _, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, timeout, ConfigBoolTrue) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthExternalBrowserFail _, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, timeout, ConfigBoolTrue) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode _, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, timeout, ConfigBoolTrue) if err == nil { t.Fatal("should have failed.") } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } if driverErr.Number != ErrCodeFailedToConnect { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeFailedToConnect, driverErr.Number) } } func TestAuthenticationTimeout(t *testing.T) { authenticator := "externalbrowser" application := "testapp" account := "testaccount" user := "u" timeout := 1 * time.Second sr := &snowflakeRestful{ Protocol: "https", Host: "abc.com", Port: 443, FuncPostAuthSAML: postAuthExternalBrowserErrorDelayed, TokenAccessor: getSimpleTokenAccessor(), } _, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, timeout, ConfigBoolTrue) assertEqualE(t, err.Error(), "authentication timed out", err.Error()) } func Test_createLocalTCPListener(t *testing.T) { listener, err := createLocalTCPListener(0) if err != nil { t.Fatalf("createLocalTCPListener() failed: %v", err) } if listener == nil { t.Fatal("createLocalTCPListener() returned nil listener") } // Close the listener after the test. defer listener.Close() } func TestUnitGetLoginURL(t *testing.T) { expectedScheme := "https" expectedHost := "abc.com:443" user := "u" callbackPort := 123 sr := &snowflakeRestful{ Protocol: "https", Host: "abc.com", Port: 443, TokenAccessor: getSimpleTokenAccessor(), } loginURL, proofKey, err := getLoginURL(sr, user, callbackPort) assertNilF(t, err, "failed to get login URL") assertNotNilF(t, len(proofKey), "proofKey should be non-empty string") urlPtr, err := url.Parse(loginURL) assertNilF(t, err, "failed to parse the login URL") assertEqualF(t, urlPtr.Scheme, expectedScheme) assertEqualF(t, urlPtr.Host, expectedHost) assertEqualF(t, urlPtr.Path, consoleLoginRequestPath) assertStringContainsF(t, urlPtr.RawQuery, "login_name") assertStringContainsF(t, urlPtr.RawQuery, "browser_mode_redirect_port") assertStringContainsF(t, urlPtr.RawQuery, "proof_key") } type nonInteractiveSamlResponseProvider struct { t *testing.T } func (provider *nonInteractiveSamlResponseProvider) run(url string) error { go func() { resp, err := http.Get(url) assertNilF(provider.t, err) assertEqualE(provider.t, resp.StatusCode, http.StatusOK) }() return nil } ================================================ FILE: authokta.go ================================================ package gosnowflake import ( "bytes" "context" "encoding/json" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "html" "io" "net/http" "net/url" "strconv" "time" ) type authOKTARequest struct { Username string `json:"username"` Password string `json:"password"` } type authOKTAResponse struct { CookieToken string `json:"cookieToken"` SessionToken string `json:"sessionToken"` } /* authenticateBySAML authenticates a user by SAML SAML Authentication 1. query GS to obtain IDP token and SSO url 2. IMPORTANT Client side validation: validate both token url and sso url contains same prefix (protocol + host + port) as the given authenticator url. Explanation: This provides a way for the user to 'authenticate' the IDP it is sending his/her credentials to. Without such a check, the user could be coerced to provide credentials to an IDP impersonator. 3. query IDP token url to authenticate and retrieve access token 4. given access token, query IDP URL snowflake app to get SAML response 5. IMPORTANT Client side validation: validate the post back url come back with the SAML response contains the same prefix as the Snowflake's server url, which is the intended destination url to Snowflake. Explanation: This emulates the behavior of IDP initiated login flow in the user browser where the IDP instructs the browser to POST the SAML assertion to the specific SP endpoint. This is critical in preventing a SAML assertion issued to one SP from being sent to another SP. */ func authenticateBySAML( ctx context.Context, sr *snowflakeRestful, oktaURL *url.URL, application string, account string, user string, password string, disableSamlURLCheck ConfigBool, ) (samlResponse []byte, err error) { logger.WithContext(ctx).Info("step 1: query GS to obtain IDP token and SSO url") headers := make(map[string]string) headers[httpHeaderContentType] = headerContentTypeApplicationJSON headers[httpHeaderAccept] = headerContentTypeApplicationJSON headers[httpHeaderUserAgent] = userAgent clientEnvironment := newAuthRequestClientEnvironment() clientEnvironment.Application = application requestMain := authRequestData{ ClientAppID: clientType, ClientAppVersion: SnowflakeGoDriverVersion, AccountName: account, ClientEnvironment: clientEnvironment, Authenticator: oktaURL.String(), } authRequest := authRequest{ Data: requestMain, } params := &url.Values{} jsonBody, err := json.Marshal(authRequest) if err != nil { return nil, err } logger.WithContext(ctx).Infof("PARAMS for Auth: %v, %v", params, sr) respd, err := sr.FuncPostAuthSAML(ctx, sr, headers, jsonBody, sr.LoginTimeout) if err != nil { return nil, err } if !respd.Success { logger.WithContext(ctx).Error("Authentication FAILED") sr.TokenAccessor.SetTokens("", "", -1) code, err := strconv.Atoi(respd.Code) if err != nil { return nil, err } return nil, &SnowflakeError{ Number: code, SQLState: SQLStateConnectionRejected, Message: respd.Message, } } logger.WithContext(ctx).Info("step 2: validate Token and SSO URL has the same prefix as oktaURL") var tokenURL *url.URL var ssoURL *url.URL if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil { return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL) } if ssoURL, err = url.Parse(respd.Data.SSOURL); err != nil { return nil, fmt.Errorf("failed to parse SSO URL. %v", respd.Data.SSOURL) } if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) { return nil, &SnowflakeError{ Number: ErrCodeIdpConnectionError, SQLState: SQLStateConnectionRejected, Message: errors.ErrMsgIdpConnectionError, MessageArgs: []any{oktaURL, respd.Data.TokenURL, respd.Data.SSOURL}, } } logger.WithContext(ctx).Info("step 3: query IDP token url to authenticate and retrieve access token") jsonBody, err = json.Marshal(authOKTARequest{ Username: user, Password: password, }) if err != nil { return nil, err } respa, err := sr.FuncPostAuthOKTA(ctx, sr, headers, jsonBody, respd.Data.TokenURL, sr.LoginTimeout) if err != nil { return nil, err } logger.WithContext(ctx).Info("step 4: query IDP URL snowflake app to get SAML response") params = &url.Values{} params.Add("RelayState", "/some/deep/link") var oneTimeToken string if respa.SessionToken != "" { oneTimeToken = respa.SessionToken } else { oneTimeToken = respa.CookieToken } params.Add("onetimetoken", oneTimeToken) headers = make(map[string]string) headers[httpHeaderAccept] = "*/*" bd, err := sr.FuncGetSSO(ctx, sr, params, headers, respd.Data.SSOURL, sr.LoginTimeout) if err != nil { return nil, err } if disableSamlURLCheck == ConfigBoolFalse { logger.WithContext(ctx).Info("step 5: validate post_back_url matches Snowflake URL") tgtURL, err := postBackURL(bd) if err != nil { return nil, err } fullURL := sr.getURL() logger.WithContext(ctx).Infof("tgtURL: %v, origURL: %v", tgtURL, fullURL) if !isPrefixEqual(tgtURL, fullURL) { return nil, &SnowflakeError{ Number: ErrCodeSSOURLNotMatch, SQLState: SQLStateConnectionRejected, Message: errors.ErrMsgSSOURLNotMatch, MessageArgs: []any{tgtURL, fullURL}, } } } return bd, nil } func postBackURL(htmlData []byte) (url *url.URL, err error) { idx0 := bytes.Index(htmlData, []byte("
` pbURL, err := postBackURL([]byte(c)) if err != nil { t.Fatalf("failed to get URL. err: %v, %v", err, c) } if pbURL.String() != "https://abc.com/" { t.Errorf("failed to get URL. got: %v, %v", pbURL, c) } c = `` _, err = postBackURL([]byte(c)) if err == nil { t.Fatalf("should have failed") } c = `
` _, err = postBackURL([]byte(c)) if err == nil { t.Fatalf("should have failed") } c = `")}, }, nil } func TestUnitPostAuthSAML(t *testing.T) { sr := &snowflakeRestful{ FuncPost: postTestError, TokenAccessor: getSimpleTokenAccessor(), } var err error _, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{}, 0) if err == nil { t.Fatal("should have failed.") } sr.FuncPost = postTestAppBadGatewayError _, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{}, 0) if err == nil { t.Fatal("should have failed.") } sr.FuncPost = postTestSuccessButInvalidJSON _, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{0x12, 0x34}, 0) if err == nil { t.Fatalf("should have failed to post") } } func TestUnitPostAuthOKTA(t *testing.T) { sr := &snowflakeRestful{ FuncPost: postTestError, TokenAccessor: getSimpleTokenAccessor(), } var err error _, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{}, "hahah", 0) if err == nil { t.Fatal("should have failed.") } sr.FuncPost = postTestAppBadGatewayError _, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{}, "hahah", 0) if err == nil { t.Fatal("should have failed.") } sr.FuncPost = postTestSuccessButInvalidJSON _, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{0x12, 0x34}, "haha", 0) if err == nil { t.Fatal("should have failed to run post request after the renewal") } } func TestUnitGetSSO(t *testing.T) { sr := &snowflakeRestful{ FuncGet: getTestError, TokenAccessor: getSimpleTokenAccessor(), } var err error _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0) if err == nil { t.Fatal("should have failed.") } sr.FuncGet = getTestAppBadGatewayError _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0) if err == nil { t.Fatal("should have failed.") } sr.FuncGet = getTestHTMLSuccess _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0) if err != nil { t.Fatalf("failed to get HTML content. err: %v", err) } _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "invalid!@url$%^", 0) if err == nil { t.Fatal("should have failed to parse URL.") } } func postAuthSAMLError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{}, errors.New("failed to get SAML response") } func postAuthSAMLAuthFail(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: false, Message: "SAML auth failed", }, nil } func postAuthSAMLAuthFailWithCode(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: false, Code: strconv.Itoa(ErrCodeIdpConnectionError), Message: "SAML auth failed", }, nil } func postAuthSAMLAuthSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, Message: "", Data: authResponseMain{ TokenURL: "https://1abc.com/token", SSOURL: "https://2abc.com/sso", }, }, nil } func postAuthSAMLAuthSuccessButInvalidTokenURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, Message: "", Data: authResponseMain{ TokenURL: "invalid!@url$%^", SSOURL: "https://abc.com/sso", }, }, nil } func postAuthSAMLAuthSuccessButInvalidSSOURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, Message: "", Data: authResponseMain{ TokenURL: "https://abc.com/token", SSOURL: "invalid!@url$%^", }, }, nil } func postAuthSAMLAuthSuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, Message: "", Data: authResponseMain{ TokenURL: "https://abc.com/token", SSOURL: "https://abc.com/sso", }, }, nil } func postAuthOKTAError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ string, _ time.Duration) (*authOKTAResponse, error) { return &authOKTAResponse{}, errors.New("failed to get SAML response") } func postAuthOKTASuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ string, _ time.Duration) (*authOKTAResponse, error) { return &authOKTAResponse{}, nil } func getSSOError(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { return []byte{}, errors.New("failed to get SSO html") } func getSSOSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { return []byte(``), nil } func getSSOSuccess(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { return []byte(`
`), nil } func getSSOSuccessButWrongPrefixURL(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { return []byte(`
`), nil } func TestUnitAuthenticateBySAML(t *testing.T) { authenticator := &url.URL{ Scheme: "https", Host: "abc.com", } application := "testapp" account := "testaccount" user := "u" password := "p" sr := &snowflakeRestful{ Protocol: "https", Host: "abc.com", Port: 443, FuncPostAuthSAML: postAuthSAMLError, TokenAccessor: getSimpleTokenAccessor(), } var err error _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") assertEqualE(t, err.Error(), "failed to get SAML response") sr.FuncPostAuthSAML = postAuthSAMLAuthFail _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") assertEqualE(t, err.Error(), "strconv.Atoi: parsing \"\": invalid syntax") sr.FuncPostAuthSAML = postAuthSAMLAuthFailWithCode _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") driverErr, ok := err.(*SnowflakeError) assertTrueF(t, ok, "should be a SnowflakeError") assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError) sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") driverErr, ok = err.(*SnowflakeError) assertTrueF(t, ok, "should be a SnowflakeError") assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError) sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidTokenURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") assertEqualE(t, err.Error(), "failed to parse token URL. invalid!@url$%^") sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidSSOURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") assertEqualE(t, err.Error(), "failed to parse SSO URL. invalid!@url$%^") sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess sr.FuncPostAuthOKTA = postAuthOKTAError _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncPostAuthOKTA.") assertEqualE(t, err.Error(), "failed to get SAML response") sr.FuncPostAuthOKTA = postAuthOKTASuccess sr.FuncGetSSO = getSSOError _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncGetSSO.") assertEqualE(t, err.Error(), "failed to get SSO html") sr.FuncGetSSO = getSSOSuccessButInvalidURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncGetSSO.") assertHasPrefixE(t, err.Error(), "failed to find action field in HTML response") sr.FuncGetSSO = getSSOSuccess _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNilF(t, err, "should have succeeded at FuncGetSSO.") sr.FuncGetSSO = getSSOSuccessButWrongPrefixURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncGetSSO.") driverErr, ok = err.(*SnowflakeError) assertTrueF(t, ok, "should be a SnowflakeError") assertEqualE(t, driverErr.Number, ErrCodeSSOURLNotMatch) } func TestDisableSamlURLCheck(t *testing.T) { authenticator := &url.URL{ Scheme: "https", Host: "abc.com", } application := "testapp" account := "testaccount" user := "u" password := "p" sr := &snowflakeRestful{ Protocol: "https", Host: "abc.com", Port: 443, FuncPostAuthSAML: postAuthSAMLAuthSuccess, FuncPostAuthOKTA: postAuthOKTASuccess, FuncGetSSO: getSSOSuccessButWrongPrefixURL, TokenAccessor: getSimpleTokenAccessor(), } var err error // Test for disabled SAML URL check _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolTrue) assertNilF(t, err, "SAML URL check should have disabled.") // Test for enabled SAML URL check _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse) assertNotNilF(t, err, "should have failed at FuncGetSSO.") driverErr, ok := err.(*SnowflakeError) assertTrueF(t, ok, "should be a SnowflakeError") assertEqualE(t, driverErr.Number, ErrCodeSSOURLNotMatch) } ================================================ FILE: azure_storage_client.go ================================================ package gosnowflake import ( "bytes" "cmp" "context" "crypto/md5" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "strings" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" ) type snowflakeAzureClient struct { cfg *Config telemetry *snowflakeTelemetry } type azureLocation struct { containerName string path string } type azureAPI interface { UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) DownloadStream(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error) GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) } func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool, telemetry *snowflakeTelemetry) (cloudClient, error) { sasToken := info.Creds.AzureSasToken u, err := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken)) if err != nil { return nil, err } transport, err := newTransportFactory(util.cfg, telemetry).createTransport(transportConfigFor(transportTypeCloudProvider)) if err != nil { return nil, err } client, err := azblob.NewClientWithNoCredential(u.String(), &azblob.ClientOptions{ ClientOptions: azcore.ClientOptions{ Retry: policy.RetryOptions{ MaxRetries: 60, RetryDelay: 2 * time.Second, }, Transport: &http.Client{ Transport: transport, }, }, }) if err != nil { return nil, err } return client, nil } // cloudUtil implementation func (util *snowflakeAzureClient) getFileHeader(ctx context.Context, meta *fileMetadata, filename string) (*fileHeader, error) { client, ok := meta.client.(*azblob.Client) if !ok { return nil, errors.New("failed to parse client to azblob.Client") } azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location) if err != nil { return nil, err } path := azureLoc.path + strings.TrimLeft(filename, "/") containerClient, err := createContainerClient(client.URL(), util.cfg, util.telemetry) if err != nil { return nil, &SnowflakeError{ Message: "failed to create container client", } } var blobClient azureAPI blobClient = containerClient.NewBlockBlobClient(path) // for testing only if meta.mockAzureClient != nil { blobClient = meta.mockAzureClient } resp, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (blob.GetPropertiesResponse, error) { return blobClient.GetProperties(ctx, &blob.GetPropertiesOptions{ AccessConditions: &blob.AccessConditions{}, CPKInfo: &blob.CPKInfo{}, }) }) if err != nil { var se *azcore.ResponseError if errors.As(err, &se) { if se.ErrorCode == string(bloberror.BlobNotFound) { meta.resStatus = notFoundFile return nil, errors.New("could not find file") } else if se.StatusCode == 403 { meta.resStatus = renewToken return nil, errors.New("received 403, attempting to renew") } } meta.resStatus = errStatus meta.lastError = err return nil, fmt.Errorf("unexpected error while retrieving file header from azure. %w", err) } meta.resStatus = uploaded metadata := withLowerKeys(resp.Metadata) var encData encryptionData _, ok = metadata["encryptiondata"] if ok { if err = json.Unmarshal([]byte(*metadata["encryptiondata"]), &encData); err != nil { return nil, err } } matdesc, ok := metadata["matdesc"] if !ok { // matdesc is not in response, use empty string matdesc = new(string) } encryptionMetadata := encryptMetadata{ encData.WrappedContentKey.EncryptionKey, encData.ContentEncryptionIV, *matdesc, } digest, ok := metadata["sfcdigest"] if !ok { // sfcdigest is not in response, use empty string digest = new(string) } return &fileHeader{ *digest, int64(len(metadata)), &encryptionMetadata, }, nil } // cloudUtil implementation func (util *snowflakeAzureClient) uploadFile( ctx context.Context, dataFile string, meta *fileMetadata, maxConcurrency int, multiPartThreshold int64) error { azureMeta := map[string]*string{ "sfcdigest": &meta.sha256Digest, } if meta.encryptMeta != nil { ed := &encryptionData{ EncryptionMode: "FullBlob", WrappedContentKey: contentKey{ "symmKey1", meta.encryptMeta.key, "AES_CBC_256", }, EncryptionAgent: encryptionAgent{ "1.0", "AES_CBC_128", }, ContentEncryptionIV: meta.encryptMeta.iv, KeyWrappingMetadata: keyMetadata{ "Java 5.3.0", }, } metadata, err := json.Marshal(ed) if err != nil { return err } encryptionMetadata := string(metadata) azureMeta["encryptiondata"] = &encryptionMetadata azureMeta["matdesc"] = &meta.encryptMeta.matdesc } azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location) if err != nil { return err } path := azureLoc.path + strings.TrimLeft(meta.dstFileName, "/") client, ok := meta.client.(*azblob.Client) if !ok { return &SnowflakeError{ Message: "failed to cast to azure client", } } containerClient, err := createContainerClient(client.URL(), util.cfg, util.telemetry) if err != nil { return &SnowflakeError{ Message: "failed to create container client", } } var blobClient azureAPI blobClient = containerClient.NewBlockBlobClient(path) // for testing only if meta.mockAzureClient != nil { blobClient = meta.mockAzureClient } if meta.srcStream != nil { uploadSrc := cmp.Or(meta.realSrcStream, meta.srcStream) data := uploadSrc.Bytes() contentMD5 := md5.Sum(data) _, err = withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (azblob.UploadStreamResponse, error) { return blobClient.UploadStream(ctx, bytes.NewReader(data), &azblob.UploadStreamOptions{ BlockSize: int64(len(data)), Metadata: azureMeta, HTTPHeaders: &blob.HTTPHeaders{ BlobContentMD5: contentMD5[:], }, }) }) } else { var f *os.File f, err = os.Open(dataFile) if err != nil { return fmt.Errorf("failed to open file: %w", err) } defer func() { if err = f.Close(); err != nil { logger.Warnf("Failed to close the %v file: %v", dataFile, err) } }() var contentMD5 []byte contentMD5, err = computeMD5ForFile(f) if err != nil { return fmt.Errorf("failed to compute MD5: %w", err) } contentType := "application/octet-stream" contentEncoding := "utf-8" blobOptions := &azblob.UploadFileOptions{ HTTPHeaders: &blob.HTTPHeaders{ BlobContentType: &contentType, BlobContentEncoding: &contentEncoding, BlobContentMD5: contentMD5, }, Metadata: azureMeta, Concurrency: uint16(maxConcurrency), } if meta.options.putAzureCallback != nil { blobOptions.Progress = meta.options.putAzureCallback.call } _, err = withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (azblob.UploadFileResponse, error) { return blobClient.UploadFile(ctx, f, blobOptions) }) } if err != nil { var se *azcore.ResponseError if errors.As(err, &se) { if se.StatusCode == 403 && util.detectAzureTokenExpireError(se.RawResponse) { meta.resStatus = renewToken } else { meta.resStatus = needRetry meta.lastError = err } return err } meta.resStatus = errStatus return err } meta.dstFileSize = meta.uploadSize meta.resStatus = uploaded return nil } // cloudUtil implementation func (util *snowflakeAzureClient) nativeDownloadFile( ctx context.Context, meta *fileMetadata, fullDstFileName string, maxConcurrency int64, partSize int64) error { azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location) if err != nil { return err } path := azureLoc.path + strings.TrimLeft(meta.srcFileName, "/") logger.Debugf("AZURE CLIENT: Send Get Request to the bucket: %v, file: %v", meta.stageInfo.Location, meta.srcFileName) client, ok := meta.client.(*azblob.Client) if !ok { return &SnowflakeError{ Message: "failed to cast to azure client", } } containerClient, err := createContainerClient(client.URL(), util.cfg, util.telemetry) if err != nil { return &SnowflakeError{ Message: "failed to create container client", } } var blobClient azureAPI blobClient = containerClient.NewBlockBlobClient(path) // for testing only if meta.mockAzureClient != nil { blobClient = meta.mockAzureClient } if isFileGetStream(ctx) { blobDownloadResponse, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (azblob.DownloadStreamResponse, error) { return blobClient.DownloadStream(ctx, &azblob.DownloadStreamOptions{}) }) if err != nil { return err } retryReader := blobDownloadResponse.NewRetryReader(context.Background(), &azblob.RetryReaderOptions{}) defer func() { if err = retryReader.Close(); err != nil { logger.Warnf("failed to close the Azure reader: %v", err) } }() _, err = meta.dstStream.ReadFrom(retryReader) if err != nil { return err } } else { f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode) if err != nil { return fmt.Errorf("failed to open file: %w", err) } defer func() { if err = f.Close(); err != nil { logger.Warnf("failed to close the %v file: %v", fullDstFileName, err) } }() _, err = withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (any, error) { return blobClient.DownloadFile( ctx, f, &azblob.DownloadFileOptions{ Concurrency: uint16(maxConcurrency), BlockSize: int64Max(partSize, blob.DefaultDownloadBlockSize), }) }) if err != nil { return err } } meta.resStatus = downloaded return nil } func (util *snowflakeAzureClient) extractContainerNameAndPath(location string) (*azureLocation, error) { stageLocation, err := expandUser(location) if err != nil { return nil, err } containerName := stageLocation path := "" if strings.Contains(stageLocation, "/") { containerName = stageLocation[:strings.Index(stageLocation, "/")] path = stageLocation[strings.Index(stageLocation, "/")+1:] if path != "" && !strings.HasSuffix(path, "/") { path += "/" } } return &azureLocation{containerName, path}, nil } func (util *snowflakeAzureClient) detectAzureTokenExpireError(resp *http.Response) bool { if resp.StatusCode != 403 { return false } azureErr, err := io.ReadAll(resp.Body) if err != nil { return false } errStr := string(azureErr) return strings.Contains(errStr, "Signature not valid in the specified time frame") || strings.Contains(errStr, "Server failed to authenticate the request") } // computeMD5ForFile reads a file to compute its MD5 digest, then seeks back to // the start so the file can be read again for upload. Azure does not compute // Content-MD5 for multi-part (block blob) uploads, so we must provide it. func computeMD5ForFile(f *os.File) ([]byte, error) { h := md5.New() if _, err := io.Copy(h, f); err != nil { return nil, err } if _, err := f.Seek(0, io.SeekStart); err != nil { return nil, err } return h.Sum(nil), nil } func createContainerClient(clientURL string, cfg *Config, telemetry *snowflakeTelemetry) (*container.Client, error) { transport, err := newTransportFactory(cfg, telemetry).createTransport(transportConfigFor(transportTypeCloudProvider)) if err != nil { return nil, err } return container.NewClientWithNoCredential(clientURL, &container.ClientOptions{ClientOptions: azcore.ClientOptions{ Transport: &http.Client{ Transport: transport, }, }}) } ================================================ FILE: azure_storage_client_test.go ================================================ package gosnowflake import ( "bytes" "context" "crypto/md5" "encoding/json" "errors" "io" "net/http" "os" "path" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" ) func TestExtractContainerNameAndPath(t *testing.T) { azureUtil := new(snowflakeAzureClient) testcases := []tcBucketPath{ {"sfc-eng-regression/test_sub_dir/", "sfc-eng-regression", "test_sub_dir/"}, {"sfc-eng-regression/dir/test_stg/test_sub_dir/", "sfc-eng-regression", "dir/test_stg/test_sub_dir/"}, {"sfc-eng-regression/", "sfc-eng-regression", ""}, {"sfc-eng-regression//", "sfc-eng-regression", "/"}, {"sfc-eng-regression///", "sfc-eng-regression", "//"}, } for _, test := range testcases { t.Run(test.in, func(t *testing.T) { azureLoc, err := azureUtil.extractContainerNameAndPath(test.in) if err != nil { t.Error(err) } if azureLoc.containerName != test.bucket { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.bucket, azureLoc.containerName) } if azureLoc.path != test.path { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.path, azureLoc.path) } }) } } func TestUnitDetectAzureTokenExpireError(t *testing.T) { azureUtil := new(snowflakeAzureClient) dd := &execResponseData{} invalidSig := &execResponse{ Data: *dd, Message: "Signature not valid in the specified time frame", Code: "403", Success: true, } ba, err := json.Marshal(invalidSig) if err != nil { panic(err) } resp := &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}} if !azureUtil.detectAzureTokenExpireError(resp) { t.Fatal("expected token expired") } invalidAuth := &execResponse{ Data: *dd, Message: "Server failed to authenticate the request", Code: "403", Success: true, } ba, err = json.Marshal(invalidAuth) if err != nil { panic(err) } resp = &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}} if !azureUtil.detectAzureTokenExpireError(resp) { t.Fatal("expected token expired") } resp = &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, } if azureUtil.detectAzureTokenExpireError(resp) { t.Fatal("invalid body") } invalidMessage := &execResponse{ Data: *dd, Message: "unauthorized", Code: "403", Success: true, } ba, err = json.Marshal(invalidMessage) if err != nil { panic(err) } resp = &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}} if azureUtil.detectAzureTokenExpireError(resp) { t.Fatal("incorrect message") } resp = &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}} if azureUtil.detectAzureTokenExpireError(resp) { t.Fatal("status code is success. expected false.") } } type azureObjectAPIMock struct { UploadStreamFunc func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) UploadFileFunc func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) DownloadFileFunc func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) DownloadStreamFunc func(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error) GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) } func (c *azureObjectAPIMock) UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) { return c.UploadStreamFunc(ctx, body, o) } func (c *azureObjectAPIMock) UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { return c.UploadFileFunc(ctx, file, o) } func (c *azureObjectAPIMock) GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { return c.GetPropertiesFunc(ctx, o) } func (c *azureObjectAPIMock) DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) { return c.DownloadFileFunc(ctx, file, o) } func (c *azureObjectAPIMock) DownloadStream(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error) { return c.DownloadStreamFunc(ctx, o) } func TestUploadFileWithAzureUploadFailedError(t *testing.T) { info := execResponseStageInfo{ Location: "azblob/storage/users/456/", LocationType: "AZURE", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } encMat := snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "AZURE", noSleepingTime: true, parallel: initialParallel, client: azureCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptionMaterial: &encMat, encryptMeta: testEncryptionMeta(), overwrite: true, dstCompressionType: compressionTypes["GZIP"], options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockAzureClient: &azureObjectAPIMock{ UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { return azblob.UploadFileResponse{}, errors.New("unexpected error uploading file") }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Fatal("should have failed") } } func TestUploadStreamWithAzureUploadFailedError(t *testing.T) { info := execResponseStageInfo{ Location: "azblob/storage/users/456/", LocationType: "AZURE", } initialParallel := int64(100) src := []byte{65, 66, 67} encMat := snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "AZURE", noSleepingTime: true, parallel: initialParallel, client: azureCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcStream: bytes.NewBuffer(src), encryptionMaterial: &encMat, encryptMeta: testEncryptionMeta(), overwrite: true, dstCompressionType: compressionTypes["GZIP"], options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockAzureClient: &azureObjectAPIMock{ UploadStreamFunc: func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) { return azblob.UploadStreamResponse{}, errors.New("unexpected error uploading file") }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcStream = uploadMeta.srcStream err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Fatal("should have failed") } } func TestUploadFileWithAzureUploadTokenExpired(t *testing.T) { info := execResponseStageInfo{ Location: "azblob/storage/users/456/", LocationType: "AZURE", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } dd := &execResponseData{} invalidSig := &execResponse{ Data: *dd, Message: "Signature not valid in the specified time frame", Code: "403", Success: true, } ba, err := json.Marshal(invalidSig) if err != nil { panic(err) } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "AZURE", noSleepingTime: true, parallel: initialParallel, client: azureCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptMeta: testEncryptionMeta(), overwrite: true, dstCompressionType: compressionTypes["GZIP"], options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockAzureClient: &azureObjectAPIMock{ UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { return azblob.UploadFileResponse{}, &azcore.ResponseError{ ErrorCode: "12345", StatusCode: 403, RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}, } }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err != nil { t.Fatal(err) } if uploadMeta.resStatus != renewToken { t.Fatalf("expected %v result status, got: %v", renewToken, uploadMeta.resStatus) } } func TestUploadFileWithAzureUploadNeedsRetry(t *testing.T) { info := execResponseStageInfo{ Location: "azblob/storage/users/456/", LocationType: "AZURE", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } dd := &execResponseData{} invalidSig := &execResponse{ Data: *dd, Message: "Server Error", Code: "500", Success: true, } ba, err := json.Marshal(invalidSig) if err != nil { panic(err) } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "AZURE", noSleepingTime: false, parallel: initialParallel, client: azureCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptMeta: testEncryptionMeta(), overwrite: true, dstCompressionType: compressionTypes["GZIP"], options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockAzureClient: &azureObjectAPIMock{ UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { return azblob.UploadFileResponse{}, &azcore.ResponseError{ ErrorCode: "12345", StatusCode: 500, RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}, } }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Fatal("should have raised an error") } if uploadMeta.resStatus != needRetry { t.Fatalf("expected %v result status, got: %v", needRetry, uploadMeta.resStatus) } } func TestDownloadOneFileToAzureFailed(t *testing.T) { info := execResponseStageInfo{ Location: "azblob/rwyitestacco/users/1234/", LocationType: "AZURE", } dir, err := os.Getwd() if err != nil { t.Error(err) } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "AZURE", noSleepingTime: true, client: azureCli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockAzureClient: &azureObjectAPIMock{ DownloadFileFunc: func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) { return 0, errors.New("unexpected error uploading file") }, GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { return blob.GetPropertiesResponse{}, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } } func TestGetFileHeaderErrorStatus(t *testing.T) { ctx := context.Background() info := execResponseStageInfo{ Location: "azblob/teststage/users/34/", LocationType: "AZURE", } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } meta := fileMetadata{ client: azureCli, stageInfo: &info, mockAzureClient: &azureObjectAPIMock{ GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { return blob.GetPropertiesResponse{}, errors.New("failed to retrieve headers") }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(ctx, &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != errStatus { t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus) } dd := &execResponseData{} invalidSig := &execResponse{ Data: *dd, Message: "Not Found", Code: "404", Success: true, } ba, err := json.Marshal(invalidSig) if err != nil { panic(err) } meta = fileMetadata{ client: azureCli, stageInfo: &info, mockAzureClient: &azureObjectAPIMock{ GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { return blob.GetPropertiesResponse{}, &azcore.ResponseError{ ErrorCode: "BlobNotFound", StatusCode: 404, RawResponse: &http.Response{StatusCode: http.StatusNotFound, Body: &fakeResponseBody{body: ba}}, } }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(ctx, &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != notFoundFile { t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus) } invalidSig = &execResponse{ Data: *dd, Message: "Unauthorized", Code: "403", Success: true, } ba, err = json.Marshal(invalidSig) if err != nil { panic(err) } meta.mockAzureClient = &azureObjectAPIMock{ GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { return blob.GetPropertiesResponse{}, &azcore.ResponseError{ StatusCode: 403, RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}, } }, } if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(ctx, &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != renewToken { t.Fatalf("expected %v result status, got: %v", renewToken, meta.resStatus) } } func TestUploadFileToAzureClientCastFail(t *testing.T) { info := execResponseStageInfo{ Location: "azblob/rwyi-testacco/users/9220/", LocationType: "AZURE", } dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "AZURE", noSleepingTime: false, client: s3Cli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptMeta: testEncryptionMeta(), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Fatal("should have failed") } } func TestUploadFileToAzureSetsBlobContentMD5(t *testing.T) { info := execResponseStageInfo{ Location: "azblob/storage/users/456/", LocationType: "AZURE", } dir, err := os.Getwd() if err != nil { t.Fatal(err) } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Fatal(err) } srcFile := path.Join(dir, "/test_data/put_get_1.txt") srcContent, err := os.ReadFile(srcFile) if err != nil { t.Fatal(err) } expectedMD5 := md5.Sum(srcContent) var capturedMD5 []byte uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "AZURE", noSleepingTime: true, parallel: 1, client: azureCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: srcFile, encryptionMaterial: &snowflakeFileEncryption{QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886}, encryptMeta: testEncryptionMeta(), overwrite: true, dstCompressionType: compressionTypes["GZIP"], options: &SnowflakeFileTransferOptions{MultiPartThreshold: multiPartThreshold}, mockAzureClient: &azureObjectAPIMock{ UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { if o.HTTPHeaders != nil { capturedMD5 = o.HTTPHeaders.BlobContentMD5 } return azblob.UploadFileResponse{}, nil }, }, sfa: &snowflakeFileTransferAgent{sc: &snowflakeConn{cfg: &Config{}}}, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Fatal(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err != nil { t.Fatal(err) } if capturedMD5 == nil { t.Fatal("expected BlobContentMD5 to be set, got nil") } if !bytes.Equal(capturedMD5, expectedMD5[:]) { t.Fatalf("BlobContentMD5 mismatch: got %x, want %x", capturedMD5, expectedMD5[:]) } } func TestUploadStreamToAzureSetsBlobContentMD5(t *testing.T) { info := execResponseStageInfo{ Location: "azblob/storage/users/456/", LocationType: "AZURE", } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Fatal(err) } src := []byte{65, 66, 67} expectedMD5 := md5.Sum(src) var capturedMD5 []byte uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "AZURE", noSleepingTime: true, parallel: 1, client: azureCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcStream: bytes.NewBuffer(src), encryptionMaterial: &snowflakeFileEncryption{QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886}, encryptMeta: testEncryptionMeta(), overwrite: true, dstCompressionType: compressionTypes["GZIP"], options: &SnowflakeFileTransferOptions{MultiPartThreshold: multiPartThreshold}, mockAzureClient: &azureObjectAPIMock{ UploadStreamFunc: func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) { if o.HTTPHeaders != nil { capturedMD5 = o.HTTPHeaders.BlobContentMD5 } return azblob.UploadStreamResponse{}, nil }, }, sfa: &snowflakeFileTransferAgent{sc: &snowflakeConn{cfg: &Config{}}}, } uploadMeta.realSrcStream = uploadMeta.srcStream err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err != nil { t.Fatal(err) } if capturedMD5 == nil { t.Fatal("expected BlobContentMD5 to be set, got nil") } if !bytes.Equal(capturedMD5, expectedMD5[:]) { t.Fatalf("BlobContentMD5 mismatch: got %x, want %x", capturedMD5, expectedMD5[:]) } } func TestAzureGetHeaderClientCastFail(t *testing.T) { ctx := context.Background() info := execResponseStageInfo{ Location: "azblob/rwyi-testacco/users/9220/", LocationType: "AZURE", } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } meta := fileMetadata{ client: s3Cli, stageInfo: &execResponseStageInfo{Location: ""}, mockAzureClient: &azureObjectAPIMock{ GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { return blob.GetPropertiesResponse{}, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } _, err = new(snowflakeAzureClient).getFileHeader(ctx, &meta, "file.txt") if err == nil { t.Fatal("should have failed") } } ================================================ FILE: bind_uploader.go ================================================ package gosnowflake import ( "bytes" "context" "database/sql" "database/sql/driver" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" "math/big" "reflect" "strconv" "strings" ) const ( bindStageName = "SYSTEM$BIND" createTemporaryStageStmt = "CREATE OR REPLACE TEMPORARY STAGE " + bindStageName + " file_format=" + "(type=csv field_optionally_enclosed_by='\"')" // size (in bytes) of max input stream (10MB default) as per JDBC specs inputStreamBufferSize = 1024 * 1024 * 10 ) type bindUploader struct { ctx context.Context sc *snowflakeConn stagePath string fileCount int arrayBindStage string } type bindingSchema struct { Typ string `json:"type"` Nullable bool `json:"nullable"` Fields []query.FieldMetadata `json:"fields"` } type bindingValue struct { value *string format string schema *bindingSchema } func (bu *bindUploader) upload(bindings []driver.NamedValue) (*execResponse, error) { bindingRows, err := bu.buildRowsAsBytes(bindings) if err != nil { return nil, err } startIdx, numBytes, rowNum := 0, 0, 0 bu.fileCount = 0 var data *execResponse for rowNum < len(bindingRows) { for numBytes < inputStreamBufferSize && rowNum < len(bindingRows) { numBytes += len(bindingRows[rowNum]) rowNum++ } // concatenate all byte arrays into 1 and put into input stream var b bytes.Buffer b.Grow(numBytes) for i := startIdx; i < rowNum; i++ { b.Write(bindingRows[i]) } bu.fileCount++ data, err = bu.uploadStreamInternal(&b, bu.fileCount, true) if err != nil { return nil, err } startIdx = rowNum numBytes = 0 } return data, nil } func (bu *bindUploader) uploadStreamInternal( inputStream *bytes.Buffer, dstFileName int, compressData bool) ( *execResponse, error) { if err := bu.createStageIfNeeded(); err != nil { return nil, err } stageName := bu.stagePath if stageName == "" { return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrBindUpload, Message: "stage name is null", }, bu.sc) } // use a placeholder for source file putCommand := fmt.Sprintf("put 'file:///tmp/placeholder/%v' '%v' overwrite=true", dstFileName, stageName) // for Windows queries putCommand = strings.ReplaceAll(putCommand, "\\", "\\\\") // prepare context for PUT command ctx := WithFilePutStream(bu.ctx, inputStream) ctx = WithFileTransferOptions(ctx, &SnowflakeFileTransferOptions{ compressSourceFromStream: compressData}) return bu.sc.exec(ctx, putCommand, false, true, false, []driver.NamedValue{}) } func (bu *bindUploader) createStageIfNeeded() error { if bu.arrayBindStage != "" { return nil } data, err := bu.sc.exec(bu.ctx, createTemporaryStageStmt, false, false, false, []driver.NamedValue{}) if err != nil { newThreshold := "0" bu.sc.syncParams.set(sessionArrayBindStageThreshold, &newThreshold) return err } if !data.Success { code, err := strconv.Atoi(data.Code) if err != nil { return err } return exceptionTelemetry(&SnowflakeError{ Number: code, SQLState: data.Data.SQLState, Message: data.Message, QueryID: data.Data.QueryID, }, bu.sc) } bu.arrayBindStage = bindStageName return nil } // transpose the columns to rows and write them to a list of bytes func (bu *bindUploader) buildRowsAsBytes(columns []driver.NamedValue) ([][]byte, error) { numColumns := len(columns) if columns[0].Value == nil { return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrBindSerialization, Message: "no binds found in the first column", }, bu.sc) } _, column, err := snowflakeArrayToString(&columns[0], true) if err != nil { return nil, err } numRows := len(column) csvRows := make([][]byte, 0) rows := make([][]any, 0) for range numRows { rows = append(rows, make([]any, numColumns)) } for rowIdx := range numRows { if column[rowIdx] == nil { rows[rowIdx][0] = column[rowIdx] } else { rows[rowIdx][0] = *column[rowIdx] } } for colIdx := 1; colIdx < numColumns; colIdx++ { _, column, err = snowflakeArrayToString(&columns[colIdx], true) if err != nil { return nil, err } iNumRows := len(column) if iNumRows != numRows { return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrBindSerialization, Message: errors.ErrMsgBindColumnMismatch, MessageArgs: []any{colIdx, iNumRows, numRows}, }, bu.sc) } for rowIdx := range numRows { // length of column = number of rows if column[rowIdx] == nil { rows[rowIdx][colIdx] = column[rowIdx] } else { rows[rowIdx][colIdx] = *column[rowIdx] } } } for _, row := range rows { csvRows = append(csvRows, bu.createCSVRecord(row)) } return csvRows, nil } func (bu *bindUploader) createCSVRecord(data []any) []byte { var b strings.Builder b.Grow(1024) for i := range data { if i > 0 { b.WriteString(",") } value, ok := data[i].(string) if ok { b.WriteString(escapeForCSV(value)) } else if !reflect.ValueOf(data[i]).IsNil() { logger.WithContext(bu.ctx).Debugf("Cannot convert value to string in createCSVRecord. value: %v", data[i]) } } b.WriteString("\n") return []byte(b.String()) } func (sc *snowflakeConn) processBindings( ctx context.Context, bindings []driver.NamedValue, describeOnly bool, requestID UUID, req *execRequest) error { arrayBindThreshold := sc.getArrayBindStageThreshold() numBinds, err := arrayBindValueCount(bindings) if err != nil { return err } if 0 < arrayBindThreshold && arrayBindThreshold <= numBinds && !describeOnly && isArrayBind(bindings) { uploader := bindUploader{ sc: sc, ctx: ctx, stagePath: "@" + bindStageName + "/" + requestID.String(), } _, err := uploader.upload(bindings) if err != nil { return err } req.Bindings = nil req.BindStage = uploader.stagePath } else { req.Bindings, err = getBindValues(bindings, &sc.syncParams) if err != nil { return err } req.BindStage = "" } return nil } func getBindValues(bindings []driver.NamedValue, params *syncParams) (map[string]execBindParameter, error) { tsmode := types.TimestampNtzType idx := 1 var err error bindValues := make(map[string]execBindParameter, len(bindings)) for _, binding := range bindings { if tnt, ok := binding.Value.(TypedNullTime); ok { tsmode = convertTzTypeToSnowflakeType(tnt.TzType) binding.Value = tnt.Time } t := goTypeToSnowflake(binding.Value, tsmode) if t == types.ChangeType { tsmode, err = dataTypeMode(binding.Value) if err != nil { return nil, err } } else { var val any var bv bindingValue if t == types.SliceType { // retrieve array binding data t, val, err = snowflakeArrayToString(&binding, false) if err != nil { return nil, err } } else { bv, err = valueToString(binding.Value, tsmode, params) val = bv.value if err != nil { return nil, err } } switch t { case types.NullType, types.UnSupportedType: t = types.TextType case types.NilObjectType, types.MapType, types.NilMapType: t = types.ObjectType case types.NilArrayType: t = types.ArrayType } bindValues[bindingName(binding, idx)] = execBindParameter{ Type: t.String(), Value: val, Format: bv.format, Schema: bv.schema, } idx++ } } return bindValues, nil } func bindingName(nv driver.NamedValue, idx int) string { if nv.Name != "" { return nv.Name } return strconv.Itoa(idx) } func arrayBindValueCount(bindValues []driver.NamedValue) (int, error) { if !isArrayBind(bindValues) { return 0, nil } _, arr, err := snowflakeArrayToString(&bindValues[0], false) if err != nil { return 0, err } return len(bindValues) * len(arr), nil } func isArrayBind(bindings []driver.NamedValue) bool { if len(bindings) == 0 { return false } for _, binding := range bindings { if supported := supportedArrayBind(&binding); !supported { return false } } return true } func supportedArrayBind(nv *driver.NamedValue) bool { switch reflect.TypeOf(nv.Value) { case reflect.TypeFor[*intArray](), reflect.TypeFor[*int32Array](), reflect.TypeFor[*int64Array](), reflect.TypeFor[*float64Array](), reflect.TypeFor[*float32Array](), reflect.TypeFor[*decfloatArray](), reflect.TypeFor[*boolArray](), reflect.TypeFor[*stringArray](), reflect.TypeFor[*byteArray](), reflect.TypeFor[*timestampNtzArray](), reflect.TypeFor[*timestampLtzArray](), reflect.TypeFor[*timestampTzArray](), reflect.TypeFor[*dateArray](), reflect.TypeFor[*timeArray](): return true case reflect.TypeFor[[]uint8](): // internal binding ts mode val, ok := nv.Value.([]uint8) if !ok { return ok } if len(val) == 0 { return true // for null binds } if types.FixedType <= types.SnowflakeType(val[0]) && types.SnowflakeType(val[0]) <= types.UnSupportedType { return true } return false default: // Support for bulk array binding insertion using []interface{} if isInterfaceArrayBinding(nv.Value) { return true } return false } } func supportedDecfloatBind(nv *driver.NamedValue) bool { if nv.Value == nil { return false } val := reflect.Indirect(reflect.ValueOf(nv.Value)) if !val.IsValid() { return false } return val.Type() == reflect.TypeFor[big.Float]() } func supportedNullBind(nv *driver.NamedValue) bool { switch reflect.TypeOf(nv.Value) { case reflect.TypeFor[sql.NullString](), reflect.TypeFor[sql.NullInt64](), reflect.TypeFor[sql.NullBool](), reflect.TypeFor[sql.NullFloat64](), reflect.TypeFor[TypedNullTime](): return true } return false } func supportedStructuredObjectWriterBind(nv *driver.NamedValue) bool { if _, ok := nv.Value.(StructuredObjectWriter); ok { return true } _, ok := nv.Value.(reflect.Type) return ok } func supportedStructuredArrayBind(nv *driver.NamedValue) bool { typ := reflect.TypeOf(nv.Value) return typ != nil && (typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice) } func supportedStructuredMapBind(nv *driver.NamedValue) bool { typ := reflect.TypeOf(nv.Value) return typ != nil && (typ.Kind() == reflect.Map || typ == reflect.TypeFor[NilMapTypes]()) } ================================================ FILE: bindings_test.go ================================================ package gosnowflake import ( "bytes" "context" "database/sql" "database/sql/driver" "fmt" "log" "math" "math/big" "math/rand" "reflect" "slices" "strconv" "strings" "testing" "time" ) const ( createTableSQL = `create or replace table test_prep_statement(c1 INTEGER, c2 FLOAT, c3 BOOLEAN, c4 STRING, C5 BINARY, C6 TIMESTAMP_NTZ, C7 TIMESTAMP_LTZ, C8 TIMESTAMP_TZ, C9 DATE, C10 TIME)` deleteTableSQL = "drop table if exists TEST_PREP_STATEMENT" insertSQL = "insert into TEST_PREP_STATEMENT values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" selectAllSQL = "select * from TEST_PREP_STATEMENT ORDER BY 1" createTableSQLBulkArray = `create or replace table test_bulk_array(c1 INTEGER, c2 FLOAT, c3 BOOLEAN, c4 STRING, C5 BINARY, C6 INTEGER)` deleteTableSQLBulkArray = "drop table if exists test_bulk_array" insertSQLBulkArray = "insert into test_bulk_array values(?, ?, ?, ?, ?, ?)" selectAllSQLBulkArray = "select * from test_bulk_array ORDER BY 1" createTableSQLBulkArrayDateTimeTimestamp = `create or replace table test_bulk_array_DateTimeTimestamp( C1 TIMESTAMP_NTZ, C2 TIMESTAMP_LTZ, C3 TIMESTAMP_TZ, C4 DATE, C5 TIME)` deleteTableSQLBulkArrayDateTimeTimestamp = "drop table if exists test_bulk_array_DateTimeTimestamp" insertSQLBulkArrayDateTimeTimestamp = "insert into test_bulk_array_DateTimeTimestamp values(?, ?, ?, ?, ?)" selectAllSQLBulkArrayDateTimeTimestamp = "select * from test_bulk_array_DateTimeTimestamp ORDER BY 1" enableFeatureMaxLOBSize = "ALTER SESSION SET FEATURE_INCREASED_MAX_LOB_SIZE_IN_MEMORY='ENABLED'" unsetFeatureMaxLOBSize = "ALTER SESSION UNSET FEATURE_INCREASED_MAX_LOB_SIZE_IN_MEMORY" enableLargeVarcharAndBinary = "ALTER SESSION SET ENABLE_LARGE_VARCHAR_AND_BINARY_IN_RESULT=TRUE" disableLargeVarcharAndBinary = "ALTER SESSION SET ENABLE_LARGE_VARCHAR_AND_BINARY_IN_RESULT=FALSE" unsetLargeVarcharAndBinary = "ALTER SESSION UNSET ENABLE_LARGE_VARCHAR_AND_BINARY_IN_RESULT" smallSize = 16 * 1024 * 1024 // 16 MB - right at LOB threshold largeSize = 64 * 1024 * 1024 // 64 MB - well above LOB threshold // range to use for generating random numbers lobRandomRange = 100000 ) func TestBindingFloat64(t *testing.T) { runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} expected := 42.23 var out float64 var rows *RowsExtended for _, v := range types { t.Run(v, func(t *testing.T) { dbt.mustExec(fmt.Sprintf("CREATE OR REPLACE TABLE test (id int, value %v)", v)) dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected) rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { assertNilF(t, rows.Scan(&out)) if expected != out { dbt.Errorf("%s: %g != %g", v, expected, out) } } else { dbt.Errorf("%s: no data", v) } }) } dbt.mustExec("DROP TABLE IF EXISTS test") }) } // TestBindingUint64 tests uint64 binding. Should fail as unit64 is not a // supported binding value by Go's sql package. func TestBindingUint64(t *testing.T) { runDBTest(t, func(dbt *DBTest) { expected := uint64(18446744073709551615) dbt.mustExec("CREATE OR REPLACE TABLE test (id int, value INTEGER)") if _, err := dbt.exec("INSERT INTO test VALUES (1, ?)", expected); err == nil { dbt.Fatal("should fail as uint64 values with high bit set are not supported.") } else { logger.Infof("expected err: %v", err) } dbt.mustExec("DROP TABLE IF EXISTS test") }) } func TestBindingDateTimeTimestamp(t *testing.T) { createDSN(PSTLocation) runDBTest(t, func(dbt *DBTest) { expected := time.Now() dbt.mustExec( "CREATE OR REPLACE TABLE tztest (id int, ntz timestamp_ntz, ltz timestamp_ltz, dt date, tm time)") stmt, err := dbt.prepare("INSERT INTO tztest(id,ntz,ltz,dt,tm) VALUES(1,?,?,?,?)") if err != nil { dbt.Fatal(err.Error()) } defer stmt.Close() if _, err = stmt.Exec( DataTypeTimestampNtz, expected, DataTypeTimestampLtz, expected, DataTypeDate, expected, DataTypeTime, expected); err != nil { dbt.Fatal(err) } rows := dbt.mustQuery("SELECT ntz,ltz,dt,tm FROM tztest WHERE id=?", 1) defer rows.Close() var ntz, vltz, dt, tm time.Time columnTypes, err := rows.ColumnTypes() if err != nil { dbt.Errorf("column type error. err: %v", err) } if columnTypes[0].Name() != "NTZ" { dbt.Errorf("expected column name: %v, got: %v", "TEST", columnTypes[0]) } canNull := dbt.mustNullable(columnTypes[0]) if !canNull { dbt.Errorf("expected nullable: %v, got: %v", true, canNull) } if columnTypes[0].DatabaseTypeName() != "TIMESTAMP_NTZ" { dbt.Errorf("expected database type: %v, got: %v", "TIMESTAMP_NTZ", columnTypes[0].DatabaseTypeName()) } dbt.mustFailDecimalSize(columnTypes[0]) dbt.mustFailLength(columnTypes[0]) cols, err := rows.Columns() if err != nil { dbt.Errorf("failed to get columns. err: %v", err) } if len(cols) != 4 || cols[0] != "NTZ" || cols[1] != "LTZ" || cols[2] != "DT" || cols[3] != "TM" { dbt.Errorf("failed to get columns. got: %v", cols) } if rows.Next() { assertNilF(t, rows.Scan(&ntz, &vltz, &dt, &tm)) if expected.UnixNano() != ntz.UnixNano() { dbt.Errorf("returned TIMESTAMP_NTZ value didn't match. expected: %v:%v, got: %v:%v", expected.UnixNano(), expected, ntz.UnixNano(), ntz) } if expected.UnixNano() != vltz.UnixNano() { dbt.Errorf("returned TIMESTAMP_LTZ value didn't match. expected: %v:%v, got: %v:%v", expected.UnixNano(), expected, vltz.UnixNano(), vltz) } if expected.Year() != dt.Year() || expected.Month() != dt.Month() || expected.Day() != dt.Day() { dbt.Errorf("returned DATE value didn't match. expected: %v:%v, got: %v:%v", expected.Unix()*1000, expected, dt.Unix()*1000, dt) } if expected.Hour() != tm.Hour() || expected.Minute() != tm.Minute() || expected.Second() != tm.Second() || expected.Nanosecond() != tm.Nanosecond() { dbt.Errorf("returned TIME value didn't match. expected: %v:%v, got: %v:%v", expected.UnixNano(), expected, tm.UnixNano(), tm) } } else { dbt.Error("no data") } dbt.mustExec("DROP TABLE tztest") }) createDSN("UTC") } func TestBindingBinary(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE OR REPLACE TABLE bintest (id int, b binary)") var b = []byte{0x01, 0x02, 0x03} dbt.mustExec("INSERT INTO bintest(id,b) VALUES(1, ?)", DataTypeBinary, b) rows := dbt.mustQuery("SELECT b FROM bintest WHERE id=?", 1) defer rows.Close() if rows.Next() { var rb []byte if err := rows.Scan(&rb); err != nil { dbt.Errorf("failed to scan data. err: %v", err) } if !bytes.Equal(b, rb) { dbt.Errorf("failed to match data. expected: %v, got: %v", b, rb) } } else { dbt.Errorf("no data") } dbt.mustExec("DROP TABLE bintest") }) } func TestBindingTimestampTZ(t *testing.T) { runDBTest(t, func(dbt *DBTest) { expected := time.Now() dbt.mustExec("CREATE OR REPLACE TABLE tztest (id int, tz timestamp_tz)") stmt, err := dbt.prepare("INSERT INTO tztest(id,tz) VALUES(1, ?)") if err != nil { dbt.Fatal(err.Error()) } defer func() { assertNilF(t, stmt.Close()) }() if _, err = stmt.Exec(DataTypeTimestampTz, expected); err != nil { dbt.Fatal(err) } rows := dbt.mustQuery("SELECT tz FROM tztest WHERE id=?", 1) defer func() { assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { assertNilF(t, rows.Scan(&v)) if expected.UnixNano() != v.UnixNano() { dbt.Errorf("returned value didn't match. expected: %v:%v, got: %v:%v", expected.UnixNano(), expected, v.UnixNano(), v) } } else { dbt.Error("no data") } dbt.mustExec("DROP TABLE tztest") }) } // SNOW-755844: Test the use of a pointer *time.Time type in user-defined structures to perform updates/inserts func TestBindingTimePtrInStruct(t *testing.T) { runDBTest(t, func(dbt *DBTest) { type timePtrStruct struct { id *int timeVal *time.Time } expectedID := 1 expectedTime := time.Now() testStruct := timePtrStruct{id: &expectedID, timeVal: &expectedTime} dbt.mustExec("CREATE OR REPLACE TABLE timeStructTest (id int, tz timestamp_tz)") runInsertQuery := false for range 2 { if !runInsertQuery { _, err := dbt.exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) if err != nil { dbt.Fatal(err.Error()) } runInsertQuery = true } else { // Update row with a new time value expectedTime = time.Now().Add(1) testStruct.timeVal = &expectedTime _, err := dbt.exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) if err != nil { dbt.Fatal(err.Error()) } } rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID) defer func() { assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { assertNilF(t, rows.Scan(&v)) if expectedTime.UnixNano() != v.UnixNano() { dbt.Errorf("returned value didn't match. expected: %v:%v, got: %v:%v", expectedTime.UnixNano(), expectedTime, v.UnixNano(), v) } } else { dbt.Error("no data") } } dbt.mustExec("DROP TABLE timeStructTest") }) } // SNOW-755844: Test the use of a time.Time type in user-defined structures to perform updates/inserts func TestBindingTimeInStruct(t *testing.T) { runDBTest(t, func(dbt *DBTest) { type timeStruct struct { id int timeVal time.Time } expectedID := 1 expectedTime := time.Now() testStruct := timeStruct{id: expectedID, timeVal: expectedTime} dbt.mustExec("CREATE OR REPLACE TABLE timeStructTest (id int, tz timestamp_tz)") runInsertQuery := false for range 2 { if !runInsertQuery { _, err := dbt.exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) if err != nil { dbt.Fatal(err.Error()) } runInsertQuery = true } else { // Update row with a new time value expectedTime = time.Now().Add(1) testStruct.timeVal = expectedTime _, err := dbt.exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) if err != nil { dbt.Fatal(err.Error()) } } rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID) defer func() { assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { assertNilF(t, rows.Scan(&v)) if expectedTime.UnixNano() != v.UnixNano() { dbt.Errorf("returned value didn't match. expected: %v:%v, got: %v:%v", expectedTime.UnixNano(), expectedTime, v.UnixNano(), v) } } else { dbt.Error("no data") } } dbt.mustExec("DROP TABLE timeStructTest") }) } func TestBindingInterface(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext( WithHigherPrecision(context.Background()), selectVariousTypes) defer func() { assertNilF(t, rows.Close()) }() if !rows.Next() { dbt.Error("failed to query") } var v1, v2, v2a, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v2a, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } if s1, ok := v1.(*big.Float); !ok || s1.Cmp(big.NewFloat(1.0)) != 0 { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v1) } if s2, ok := v2.(int64); !ok || s2 != 2 { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v2) } if s2a, ok := v2a.(*big.Int); !ok || big.NewInt(22).Cmp(s2a) != 0 { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v2a) } if s3, ok := v3.(string); !ok || s3 != "t3" { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v3) } if s4, ok := v4.(float64); !ok || s4 != 4.2 { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v4) } }) } func TestBindingInterfaceString(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery(selectVariousTypes) defer func() { assertNilF(t, rows.Close()) }() if !rows.Next() { dbt.Error("failed to query") } var v1, v2, v2a, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v2a, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } if s, ok := v1.(string); !ok { dbt.Error("failed to convert to string") } else if d, err := strconv.ParseFloat(s, 64); err != nil { dbt.Errorf("failed to convert to float. value: %v, err: %v", v1, err) } else if d != 1.00 { dbt.Errorf("failed to fetch. expected: 1.00, value: %v", v1) } if s, ok := v2.(string); !ok || s != "2" { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v2) } if s, ok := v2a.(string); !ok || s != "22" { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v2a) } if s, ok := v3.(string); !ok || s != "t3" { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v3) } }) } func TestBulkArrayBindingUUID(t *testing.T) { max := math.Pow10(5) // 100K because my power is maximum expectedUuids := make([]any, int(max)) createTable := "CREATE OR REPLACE TABLE TEST_PREP_STATEMENT (uuid VARCHAR)" insert := "INSERT INTO TEST_PREP_STATEMENT (uuid) VALUES (?)" for i := range expectedUuids { expectedUuids[i] = newTestUUID() } slices.SortStableFunc(expectedUuids, func(i, j any) int { return strings.Compare(i.(testUUID).String(), j.(testUUID).String()) }) runDBTest(t, func(dbt *DBTest) { var rows *RowsExtended t.Cleanup(func() { if rows != nil { assertNilF(t, rows.Close()) } _, err := dbt.exec(deleteTableSQL) if err != nil { t.Logf("failed to drop table. err: %s", err) } }) dbt.mustExec(createTable) array, err := Array(&expectedUuids) assertNilF(t, err) res := dbt.mustExec(insert, array) affected, err := res.RowsAffected() if err != nil { t.Fatalf("failed to get affected rows. err: %s", err) } else if affected != int64(max) { t.Fatalf("failed to insert all rows. expected: %f.0, got: %v", max, affected) } rows = dbt.mustQuery("SELECT * FROM TEST_PREP_STATEMENT ORDER BY uuid") if rows == nil { t.Fatal("failed to query") } if rows.Err() != nil { t.Fatalf("failed to query. err: %s", rows.Err()) } var actual = make([]testUUID, len(expectedUuids)) for i := 0; rows.Next(); i++ { var ( out testUUID ) if err := rows.Scan(&out); err != nil { t.Fatal(err) } actual[i] = out } for i := range expectedUuids { assertEqualE(t, actual[i], expectedUuids[i]) } }) } func TestBulkArrayBindingInterfaceNil(t *testing.T) { nilArray := make([]any, 1) runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQL) defer dbt.mustExec(deleteTableSQL) dbt.mustExec(insertSQL, mustArray(&nilArray), mustArray(&nilArray), mustArray(&nilArray), mustArray(&nilArray), mustArray(&nilArray), mustArray(&nilArray, TimestampNTZType), mustArray(&nilArray, TimestampTZType), mustArray(&nilArray, TimestampTZType), mustArray(&nilArray, DateType), mustArray(&nilArray, TimeType)) rows := dbt.mustQuery(selectAllSQL) defer func() { assertNilF(t, rows.Close()) }() var v0 sql.NullInt32 var v1 sql.NullFloat64 var v2 sql.NullBool var v3 sql.NullString var v4 []byte var v5, v6, v7, v8, v9 sql.NullTime cnt := 0 for i := 0; rows.Next(); i++ { if err := rows.Scan(&v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8, &v9); err != nil { t.Fatal(err) } if v0.Valid { t.Fatalf("failed to fetch the sql.NullInt32 column v0. expected %v, got: %v", nilArray[i], v0) } if v1.Valid { t.Fatalf("failed to fetch the sql.NullFloat64 column v1. expected %v, got: %v", nilArray[i], v1) } if v2.Valid { t.Fatalf("failed to fetch the sql.NullBool column v2. expected %v, got: %v", nilArray[i], v2) } if v3.Valid { t.Fatalf("failed to fetch the sql.NullString column v3. expected %v, got: %v", nilArray[i], v3) } if v4 != nil { t.Fatalf("failed to fetch the []byte column v4. expected %v, got: %v", nilArray[i], v4) } if v5.Valid { t.Fatalf("failed to fetch the sql.NullTime column v5. expected %v, got: %v", nilArray[i], v5) } if v6.Valid { t.Fatalf("failed to fetch the sql.NullTime column v6. expected %v, got: %v", nilArray[i], v6) } if v7.Valid { t.Fatalf("failed to fetch the sql.NullTime column v7. expected %v, got: %v", nilArray[i], v7) } if v8.Valid { t.Fatalf("failed to fetch the sql.NullTime column v8. expected %v, got: %v", nilArray[i], v8) } if v9.Valid { t.Fatalf("failed to fetch the sql.NullTime column v9. expected %v, got: %v", nilArray[i], v9) } cnt++ } if cnt != len(nilArray) { t.Fatal("failed to query") } }) } func TestBulkArrayBindingInterface(t *testing.T) { intArray := make([]any, 3) intArray[0] = int32(100) intArray[1] = int32(200) fltArray := make([]any, 3) fltArray[0] = float64(0.1) fltArray[2] = float64(5.678) boolArray := make([]any, 3) boolArray[1] = false boolArray[2] = true strArray := make([]any, 3) strArray[2] = "test3" byteArray := make([]any, 3) byteArray[0] = []byte{0x01, 0x02, 0x03} byteArray[2] = []byte{0x07, 0x08, 0x09} int64Array := make([]any, 3) int64Array[0] = int64(100) int64Array[1] = int64(200) runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQLBulkArray) defer dbt.mustExec(deleteTableSQLBulkArray) dbt.mustExec(insertSQLBulkArray, mustArray(&intArray), mustArray(&fltArray), mustArray(&boolArray), mustArray(&strArray), mustArray(&byteArray), mustArray(&int64Array)) rows := dbt.mustQuery(selectAllSQLBulkArray) defer func() { assertNilF(t, rows.Close()) }() var v0 sql.NullInt32 var v1 sql.NullFloat64 var v2 sql.NullBool var v3 sql.NullString var v4 []byte var v5 sql.NullInt64 cnt := 0 for i := 0; rows.Next(); i++ { if err := rows.Scan(&v0, &v1, &v2, &v3, &v4, &v5); err != nil { t.Fatal(err) } if v0.Valid { if v0.Int32 != intArray[i] { t.Fatalf("failed to fetch the sql.NullInt32 column v0. expected %v, got: %v", intArray[i], v0.Int32) } } else if intArray[i] != nil { t.Fatalf("failed to fetch the sql.NullInt32 column v0. expected %v, got: %v", intArray[i], v0) } if v1.Valid { if v1.Float64 != fltArray[i] { t.Fatalf("failed to fetch the sql.NullFloat64 column v1. expected %v, got: %v", fltArray[i], v1.Float64) } } else if fltArray[i] != nil { t.Fatalf("failed to fetch the sql.NullFloat64 column v1. expected %v, got: %v", fltArray[i], v1) } if v2.Valid { if v2.Bool != boolArray[i] { t.Fatalf("failed to fetch the sql.NullBool column v2. expected %v, got: %v", boolArray[i], v2.Bool) } } else if boolArray[i] != nil { t.Fatalf("failed to fetch the sql.NullBool column v2. expected %v, got: %v", boolArray[i], v2) } if v3.Valid { if v3.String != strArray[i] { t.Fatalf("failed to fetch the sql.NullString column v3. expected %v, got: %v", strArray[i], v3.String) } } else if strArray[i] != nil { t.Fatalf("failed to fetch the sql.NullString column v3. expected %v, got: %v", strArray[i], v3) } if byteArray[i] != nil { if !bytes.Equal(v4, byteArray[i].([]byte)) { t.Fatalf("failed to fetch the []byte column v4. expected %v, got: %v", byteArray[i], v4) } } else if v4 != nil { t.Fatalf("failed to fetch the []byte column v4. expected %v, got: %v", byteArray[i], v4) } if v5.Valid { if v5.Int64 != int64Array[i] { t.Fatalf("failed to fetch the sql.NullInt64 column v5. expected %v, got: %v", int64Array[i], v5.Int64) } } else if int64Array[i] != nil { t.Fatalf("failed to fetch the sql.NullInt64 column v5. expected %v, got: %v", int64Array[i], v5) } cnt++ } if cnt != len(intArray) { t.Fatal("failed to query") } }) } func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) { tz := time.Now() createDSN(PSTLocation) now := time.Now() loc, err := time.LoadLocation(PSTLocation) if err != nil { t.Error(err) } ntzArray := make([]any, 3) ntzArray[0] = now ntzArray[1] = now.Add(1) ltzArray := make([]any, 3) ltzArray[1] = now.Add(2).In(loc) ltzArray[2] = now.Add(3).In(loc) tzArray := make([]any, 3) tzArray[0] = tz.Add(4).In(loc) tzArray[2] = tz.Add(5).In(loc) dtArray := make([]any, 3) dtArray[0] = tz.Add(6).In(loc) dtArray[1] = now.Add(7).In(loc) tmArray := make([]any, 3) tmArray[1] = now.Add(8).In(loc) tmArray[2] = now.Add(9).In(loc) runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQLBulkArrayDateTimeTimestamp) defer dbt.mustExec(deleteTableSQLBulkArrayDateTimeTimestamp) dbt.mustExec(insertSQLBulkArrayDateTimeTimestamp, mustArray(&ntzArray, TimestampNTZType), mustArray(<zArray, TimestampLTZType), mustArray(&tzArray, TimestampTZType), mustArray(&dtArray, DateType), mustArray(&tmArray, TimeType)) rows := dbt.mustQuery(selectAllSQLBulkArrayDateTimeTimestamp) defer func() { assertNilF(t, rows.Close()) }() var v0, v1, v2, v3, v4 sql.NullTime cnt := 0 for i := 0; rows.Next(); i++ { if err := rows.Scan(&v0, &v1, &v2, &v3, &v4); err != nil { t.Fatal(err) } if v0.Valid { if v0.Time.UnixNano() != ntzArray[i].(time.Time).UnixNano() { t.Fatalf("failed to fetch the column v0. expected %v, got: %v", ntzArray[i], v0) } } else if ntzArray[i] != nil { t.Fatalf("failed to fetch the column v0. expected %v, got: %v", ntzArray[i], v0) } if v1.Valid { if v1.Time.UnixNano() != ltzArray[i].(time.Time).UnixNano() { t.Fatalf("failed to fetch the column v1. expected %v, got: %v", ltzArray[i], v1) } } else if ltzArray[i] != nil { t.Fatalf("failed to fetch the column v1. expected %v, got: %v", ltzArray[i], v1) } if v2.Valid { if v2.Time.UnixNano() != tzArray[i].(time.Time).UnixNano() { t.Fatalf("failed to fetch the column v2. expected %v, got: %v", tzArray[i], v2) } } else if tzArray[i] != nil { t.Fatalf("failed to fetch the column v2. expected %v, got: %v", tzArray[i], v2) } if v3.Valid { if v3.Time.Year() != dtArray[i].(time.Time).Year() || v3.Time.Month() != dtArray[i].(time.Time).Month() || v3.Time.Day() != dtArray[i].(time.Time).Day() { t.Fatalf("failed to fetch the column v3. expected %v, got: %v", dtArray[i], v3) } } else if dtArray[i] != nil { t.Fatalf("failed to fetch the column v3. expected %v, got: %v", dtArray[i], v3) } if v4.Valid { if v4.Time.Hour() != tmArray[i].(time.Time).Hour() || v4.Time.Minute() != tmArray[i].(time.Time).Minute() || v4.Time.Second() != tmArray[i].(time.Time).Second() { t.Fatalf("failed to fetch the column v4. expected %v, got: %v", tmArray[i], v4) } } else if tmArray[i] != nil { t.Fatalf("failed to fetch the column v4. expected %v, got: %v", tmArray[i], v4) } cnt++ } if cnt != len(ntzArray) { t.Fatal("failed to query") } }) createDSN("UTC") } // TestBindingArray tests basic array binding via the usage of the Array // function that converts the passed Golang slice to a Snowflake array type func TestBindingArray(t *testing.T) { testBindingArray(t, false) } // TestBindingBulkArray tests bulk array binding via the usage of the Array // function that converts the passed Golang slice to a Snowflake array type func TestBindingBulkArray(t *testing.T) { if runningOnGithubAction() { t.Skip("client_stage_array_binding_threshold value is internal") } testBindingArray(t, true) } func testBindingArray(t *testing.T, bulk bool) { tz := time.Now() createDSN(PSTLocation) intArray := []int{1, 2, 3} fltArray := []float64{0.1, 2.34, 5.678} boolArray := []bool{true, false, true} strArray := []string{"test1", "test2", "test3"} byteArray := [][]byte{{0x01, 0x02, 0x03}, {0x04, 0x05, 0x06}, {0x07, 0x08, 0x09}} now := time.Now() loc, err := time.LoadLocation(PSTLocation) if err != nil { t.Error(err) } ntzArray := []time.Time{now, now.Add(1), now.Add(2)} ltzArray := []time.Time{now.Add(3).In(loc), now.Add(4).In(loc), now.Add(5).In(loc)} tzArray := []time.Time{tz.Add(6).In(loc), tz.Add(7).In(loc), tz.Add(8).In(loc)} dtArray := []time.Time{now.Add(9), now.Add(10), now.Add(11)} tmArray := []time.Time{now.Add(12), now.Add(13), now.Add(14)} runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQL) defer dbt.mustExec(deleteTableSQL) if bulk { if _, err := dbt.exec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1"); err != nil { t.Error(err) } } dbt.mustExec(insertSQL, mustArray(&intArray), mustArray(&fltArray), mustArray(&boolArray), mustArray(&strArray), mustArray(&byteArray), mustArray(&ntzArray, TimestampNTZType), mustArray(<zArray, TimestampLTZType), mustArray(&tzArray, TimestampTZType), mustArray(&dtArray, DateType), mustArray(&tmArray, TimeType)) rows := dbt.mustQuery(selectAllSQL) defer func() { assertNilF(t, rows.Close()) }() var v0 int var v1 float64 var v2 bool var v3 string var v4 []byte var v5, v6, v7, v8, v9 time.Time cnt := 0 for i := 0; rows.Next(); i++ { if err := rows.Scan(&v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8, &v9); err != nil { t.Fatal(err) } if v0 != intArray[i] { t.Fatalf("failed to fetch. expected %v, got: %v", intArray[i], v0) } if v1 != fltArray[i] { t.Fatalf("failed to fetch. expected %v, got: %v", fltArray[i], v1) } if v2 != boolArray[i] { t.Fatalf("failed to fetch. expected %v, got: %v", boolArray[i], v2) } if v3 != strArray[i] { t.Fatalf("failed to fetch. expected %v, got: %v", strArray[i], v3) } if !bytes.Equal(v4, byteArray[i]) { t.Fatalf("failed to fetch. expected %v, got: %v", byteArray[i], v4) } if v5.UnixNano() != ntzArray[i].UnixNano() { t.Fatalf("failed to fetch. expected %v, got: %v", ntzArray[i], v5) } if v6.UnixNano() != ltzArray[i].UnixNano() { t.Fatalf("failed to fetch. expected %v, got: %v", ltzArray[i], v6) } if v7.UnixNano() != tzArray[i].UnixNano() { t.Fatalf("failed to fetch. expected %v, got: %v", tzArray[i], v7) } if v8.Year() != dtArray[i].Year() || v8.Month() != dtArray[i].Month() || v8.Day() != dtArray[i].Day() { t.Fatalf("failed to fetch. expected %v, got: %v", dtArray[i], v8) } if v9.Hour() != tmArray[i].Hour() || v9.Minute() != tmArray[i].Minute() || v9.Second() != tmArray[i].Second() { t.Fatalf("failed to fetch. expected %v, got: %v", tmArray[i], v9) } cnt++ } if cnt != len(intArray) { t.Fatal("failed to query") } }) createDSN("UTC") } func TestBulkArrayBinding(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c8 binary)", dbname)) now := time.Now() someTime := time.Date(1, time.January, 1, 12, 34, 56, 123456789, time.UTC) someDate := time.Date(2024, time.March, 18, 0, 0, 0, 0, time.UTC) someBinary := []byte{0x01, 0x02, 0x03} numRows := 100000 intArr := make([]int, numRows) strArr := make([]string, numRows) ltzArr := make([]time.Time, numRows) tzArr := make([]time.Time, numRows) ntzArr := make([]time.Time, numRows) dateArr := make([]time.Time, numRows) timeArr := make([]time.Time, numRows) binArr := make([][]byte, numRows) for i := range numRows { intArr[i] = i strArr[i] = "test" + strconv.Itoa(i) ltzArr[i] = now tzArr[i] = now.Add(time.Hour).UTC() ntzArr[i] = now.Add(2 * time.Hour) dateArr[i] = someDate timeArr[i] = someTime binArr[i] = someBinary } dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?)", dbname), mustArray(&intArr), mustArray(&strArr), mustArray(<zArr, TimestampLTZType), mustArray(&tzArr, TimestampTZType), mustArray(&ntzArr, TimestampNTZType), mustArray(&dateArr, DateType), mustArray(&timeArr, TimeType), mustArray(&binArr)) rows := dbt.mustQuery("select * from " + dbname + " order by c1") defer func() { assertNilF(t, rows.Close()) }() cnt := 0 var i int var s string var ltz, tz, ntz, date, tt time.Time var b []byte for rows.Next() { if err := rows.Scan(&i, &s, <z, &tz, &ntz, &date, &tt, &b); err != nil { t.Fatal(err) } assertEqualE(t, i, cnt) assertEqualE(t, "test"+strconv.Itoa(cnt), s) assertEqualE(t, ltz.UTC(), now.UTC()) assertEqualE(t, tz.UTC(), now.Add(time.Hour).UTC()) assertEqualE(t, ntz.UTC(), now.Add(2*time.Hour).UTC()) assertEqualE(t, date, someDate) assertEqualE(t, tt, someTime) assertBytesEqualE(t, b, someBinary) cnt++ } if cnt != numRows { t.Fatalf("expected %v rows, got %v", numRows, cnt) } }) } func TestSupportedDecfloatBind(t *testing.T) { t.Run("dont panic on nil UUID", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("expected not to panic, but did panic") } }() var nilUUID *UUID nv := driver.NamedValue{Value: nilUUID} shouldBind := supportedDecfloatBind(&nv) // should not panic and return false assertFalseE(t, shouldBind, "expected not to support binding nil *UUID") }) t.Run("dont panic on nil pointer array", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("expected not to panic, but did panic") } }() var nilArray *[]string nv := driver.NamedValue{Value: nilArray} shouldBind := supportedDecfloatBind(&nv) // should not panic and return false assertFalseE(t, shouldBind, "expected not to support binding nil []string") }) t.Run("dont panic on nil pointer", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("expected not to panic, but did panic") } }() var nilTime *time.Time nv := driver.NamedValue{Value: nilTime} shouldBind := supportedDecfloatBind(&nv) // should not panic and return false assertFalseE(t, shouldBind, "expected not to support binding nil *time.Time") }) t.Run("dont panic on nil *big.Float", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("expected not to panic, but did panic") } }() var nilBigFloat *big.Float nv := driver.NamedValue{Value: nilBigFloat} shouldBind := supportedDecfloatBind(&nv) // should not panic and return false assertFalseE(t, shouldBind, "expected not to support binding nil *big.Float") }) t.Run("Is Valid for big.Float", func(t *testing.T) { val := big.NewFloat(123.456) nv := driver.NamedValue{Value: val} shouldBind := supportedDecfloatBind(&nv) assertTrueE(t, shouldBind, "expected to support binding big.Float") }) t.Run("Is Not Valid for other types", func(t *testing.T) { val := 123.456 // float64 nv := driver.NamedValue{Value: val} shouldBind := supportedDecfloatBind(&nv) assertFalseE(t, shouldBind, "expected not to support binding float64") }) } func TestBindingsWithSameValue(t *testing.T) { arrayInsertTable := "test_array_binding_insert" stageBindingTable := "test_stage_binding_insert" interfaceArrayTable := "test_interface_binding_insert" runDBTest(t, func(dbt *DBTest) { dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", arrayInsertTable)) dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", stageBindingTable)) dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", interfaceArrayTable)) defer func() { dbt.mustExec(fmt.Sprintf("drop table if exists %v", arrayInsertTable)) dbt.mustExec(fmt.Sprintf("drop table if exists %v", stageBindingTable)) dbt.mustExec(fmt.Sprintf("drop table if exists %v", interfaceArrayTable)) }() numRows := 5 intArr := make([]int, numRows) strArr := make([]string, numRows) timeArr := make([]time.Time, numRows) boolArr := make([]bool, numRows) doubleArr := make([]float64, numRows) intAnyArr := make([]any, numRows) strAnyArr := make([]any, numRows) timeAnyArr := make([]any, numRows) boolAnyArr := make([]bool, numRows) doubleAnyArr := make([]float64, numRows) for i := range numRows { intArr[i] = i intAnyArr[i] = i double := rand.Float64() doubleArr[i] = double doubleAnyArr[i] = double strArr[i] = "test" + strconv.Itoa(i) strAnyArr[i] = "test" + strconv.Itoa(i) b := getRandomBool() boolArr[i] = b boolAnyArr[i] = b date := getRandomDate() timeArr[i] = date timeAnyArr[i] = date } dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", interfaceArrayTable), mustArray(&intAnyArr), mustArray(&strAnyArr), mustArray(&timeAnyArr, TimestampLTZType), mustArray(&timeAnyArr, TimestampTZType), mustArray(&timeAnyArr, TimestampNTZType), mustArray(&timeAnyArr, DateType), mustArray(&timeAnyArr, TimeType), mustArray(&boolArr), mustArray(&doubleArr)) dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", arrayInsertTable), mustArray(&intArr), mustArray(&strArr), mustArray(&timeArr, TimestampLTZType), mustArray(&timeArr, TimestampTZType), mustArray(&timeArr, TimestampNTZType), mustArray(&timeArr, DateType), mustArray(&timeArr, TimeType), mustArray(&boolArr), mustArray(&doubleArr)) dbt.mustExec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1") dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", stageBindingTable), mustArray(&intArr), mustArray(&strArr), mustArray(&timeArr, TimestampLTZType), mustArray(&timeArr, TimestampTZType), mustArray(&timeArr, TimestampNTZType), mustArray(&timeArr, DateType), mustArray(&timeArr, TimeType), mustArray(&boolArr), mustArray(&doubleArr)) insertRows := dbt.mustQuery("select * from " + arrayInsertTable + " order by c1") bindingRows := dbt.mustQuery("select * from " + stageBindingTable + " order by c1") interfaceRows := dbt.mustQuery("select * from " + interfaceArrayTable + " order by c1") defer func() { assertNilF(t, insertRows.Close()) assertNilF(t, bindingRows.Close()) assertNilF(t, interfaceRows.Close()) }() var i, bi, ii int var s, bs, is string var ltz, bltz, iltz, itz, btz, tz, intz, ntz, bntz, iDate, date, bDate, itt, tt, btt time.Time var b, bb, ib bool var d, bd, id float64 timeFormat := "15:04:05" for k := range numRows { assertTrueF(t, insertRows.Next()) assertNilF(t, insertRows.Scan(&i, &s, <z, &tz, &ntz, &date, &tt, &b, &d)) assertTrueF(t, bindingRows.Next()) assertNilF(t, bindingRows.Scan(&bi, &bs, &bltz, &btz, &bntz, &bDate, &btt, &bb, &bd)) assertTrueF(t, interfaceRows.Next()) assertNilF(t, interfaceRows.Scan(&ii, &is, &iltz, &itz, &intz, &iDate, &itt, &ib, &id)) assertEqualE(t, k, i) assertEqualE(t, k, bi) assertEqualE(t, k, ii) assertEqualE(t, "test"+strconv.Itoa(k), s) assertEqualE(t, "test"+strconv.Itoa(k), bs) assertEqualE(t, "test"+strconv.Itoa(k), is) utcTime := timeArr[k].UTC() assertEqualE(t, ltz.UTC(), utcTime) assertEqualE(t, bltz.UTC(), utcTime) assertEqualE(t, iltz.UTC(), utcTime) assertEqualE(t, tz.UTC(), utcTime) assertEqualE(t, btz.UTC(), utcTime) assertEqualE(t, itz.UTC(), utcTime) assertEqualE(t, ntz.UTC(), utcTime) assertEqualE(t, bntz.UTC(), utcTime) assertEqualE(t, intz.UTC(), utcTime) testingDate := timeArr[k].Truncate(24 * time.Hour) assertEqualE(t, date, testingDate) assertEqualE(t, bDate, testingDate) assertEqualE(t, iDate, testingDate) testingTime := timeArr[k].Format(timeFormat) assertEqualE(t, tt.Format(timeFormat), testingTime) assertEqualE(t, btt.Format(timeFormat), testingTime) assertEqualE(t, itt.Format(timeFormat), testingTime) assertEqualE(t, b, boolArr[k]) assertEqualE(t, bb, boolArr[k]) assertEqualE(t, ib, boolArr[k]) assertEqualE(t, d, doubleArr[k]) assertEqualE(t, bd, doubleArr[k]) assertEqualE(t, id, doubleArr[k]) } }) } func TestBulkArrayBindingTimeWithPrecision(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec(fmt.Sprintf("create or replace table %v (s time(0), ms time(3), us time(6), ns time(9))", dbname)) someTimeWithSeconds := time.Date(1, time.January, 1, 1, 1, 1, 0, time.UTC) someTimeWithMilliseconds := time.Date(1, time.January, 1, 2, 2, 2, 123000000, time.UTC) someTimeWithMicroseconds := time.Date(1, time.January, 1, 3, 3, 3, 123456000, time.UTC) someTimeWithNanoseconds := time.Date(1, time.January, 1, 4, 4, 4, 123456789, time.UTC) numRows := 100000 secondsArr := make([]time.Time, numRows) millisecondsArr := make([]time.Time, numRows) microsecondsArr := make([]time.Time, numRows) nanosecondsArr := make([]time.Time, numRows) for i := range numRows { secondsArr[i] = someTimeWithSeconds millisecondsArr[i] = someTimeWithMilliseconds microsecondsArr[i] = someTimeWithMicroseconds nanosecondsArr[i] = someTimeWithNanoseconds } dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?)", dbname), mustArray(&secondsArr, TimeType), mustArray(&millisecondsArr, TimeType), mustArray(µsecondsArr, TimeType), mustArray(&nanosecondsArr, TimeType)) rows := dbt.mustQuery("select * from " + dbname) defer func() { assertNilF(t, rows.Close()) }() cnt := 0 var s, ms, us, ns time.Time for rows.Next() { if err := rows.Scan(&s, &ms, &us, &ns); err != nil { t.Fatal(err) } assertEqualE(t, s, someTimeWithSeconds) assertEqualE(t, ms, someTimeWithMilliseconds) assertEqualE(t, us, someTimeWithMicroseconds) assertEqualE(t, ns, someTimeWithNanoseconds) cnt++ } if cnt != numRows { t.Fatalf("expected %v rows, got %v", numRows, cnt) } }) } func TestBulkArrayMultiPartBinding(t *testing.T) { rowCount := 1000000 // large enough to be partitioned into multiple files randomIter := rand.Intn(3) + 2 randomStrings := make([]string, rowCount) str := randomString(30) for i := range rowCount { randomStrings[i] = str } tempTableName := fmt.Sprintf("test_table_%v", randomString(5)) ctx := context.Background() runDBTest(t, func(dbt *DBTest) { dbt.mustExec(fmt.Sprintf("CREATE TABLE %s (C VARCHAR(64) NOT NULL)", tempTableName)) defer dbt.mustExec("drop table " + tempTableName) for range randomIter { dbt.mustExecContext(ctx, fmt.Sprintf("INSERT INTO %s VALUES (?)", tempTableName), mustArray(&randomStrings)) rows := dbt.mustQuery("select count(*) from " + tempTableName) defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { t.Error(err) } } } rows := dbt.mustQuery("select count(*) from " + tempTableName) defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { t.Error(err) } if count != randomIter*rowCount { t.Errorf("expected %v rows, got %v rows intead", randomIter*rowCount, count) } } }) } func TestBulkArrayMultiPartBindingInt(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table binding_test (c1 integer)") startNum := 1000000 endNum := 3000000 numRows := endNum - startNum intArr := make([]int, numRows) for i := startNum; i < endNum; i++ { intArr[i-startNum] = i } _, err := dbt.exec("insert into binding_test values (?)", mustArray(&intArr)) if err != nil { t.Errorf("Should have succeeded to insert. err: %v", err) } rows := dbt.mustQuery("select * from binding_test order by c1") defer func() { assertNilF(t, rows.Close()) }() cnt := startNum var i int for rows.Next() { if err := rows.Scan(&i); err != nil { t.Fatal(err) } if i != cnt { t.Errorf("expected: %v, got: %v", cnt, i) } cnt++ } if cnt != endNum { t.Fatalf("expected %v rows, got %v", numRows, cnt-startNum) } dbt.mustExec("DROP TABLE binding_test") }) } func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table binding_test (c1 integer, c2 string)") startNum := 1000000 endNum := 2000000 numRows := endNum - startNum // Define the integer and string arrays intArr := make([]any, numRows) stringArr := make([]any, numRows) for i := startNum; i < endNum; i++ { intArr[i-startNum] = i stringArr[i-startNum] = fmt.Sprint(i) } // Set some of the rows to NULL intArr[numRows-1] = nil intArr[numRows-2] = nil intArr[numRows-3] = nil stringArr[1] = nil stringArr[2] = nil stringArr[3] = nil _, err := dbt.exec("insert into binding_test values (?, ?)", mustArray(&intArr), mustArray(&stringArr)) if err != nil { t.Errorf("Should have succeeded to insert. err: %v", err) } rows := dbt.mustQuery("select * from binding_test order by c1,c2") defer func() { assertNilF(t, rows.Close()) }() cnt := startNum var i sql.NullInt32 var s sql.NullString for rows.Next() { if err := rows.Scan(&i, &s); err != nil { t.Fatal(err) } // Verify integer column c1 if i.Valid { if int(i.Int32) != intArr[cnt-startNum] { t.Fatalf("expected: %v, got: %v", cnt, int(i.Int32)) } } else if !(cnt == startNum+numRows-1 || cnt == startNum+numRows-2 || cnt == startNum+numRows-3) { t.Fatalf("expected NULL in column c1 at index: %v", cnt-startNum) } // Verify string column c2 if s.Valid { if s.String != stringArr[cnt-startNum] { t.Fatalf("expected: %v, got: %v", cnt, s.String) } } else if !(cnt == startNum+1 || cnt == startNum+2 || cnt == startNum+3) { t.Fatalf("expected NULL in column c2 at index: %v", cnt-startNum) } cnt++ } if cnt != endNum { t.Fatalf("expected %v rows, got %v", numRows, cnt-startNum) } dbt.mustExec("DROP TABLE binding_test") }) } func TestFunctionParameters(t *testing.T) { testcases := []struct { testDesc string paramType string input any nullResult bool }{ {"textAndNullStringResultInNull", "text", sql.NullString{}, true}, {"numberAndNullInt64ResultInNull", "number", sql.NullInt64{}, true}, {"floatAndNullFloat64ResultInNull", "float", sql.NullFloat64{}, true}, {"booleanAndAndNullBoolResultInNull", "boolean", sql.NullBool{}, true}, {"dateAndTypedNullTimeResultInNull", "date", TypedNullTime{sql.NullTime{}, DateType}, true}, {"datetimeAndTypedNullTimeResultInNull", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, {"timeAndTypedNullTimeResultInNull", "time", TypedNullTime{sql.NullTime{}, TimeType}, true}, {"timestampAndTypedNullTimeResultInNull", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, {"timestamp_ntzAndTypedNullTimeResultInNull", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, {"timestamp_ltzAndTypedNullTimeResultInNull", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true}, {"timestamp_tzAndTypedNullTimeResultInNull", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true}, {"textAndStringResultInNotNull", "text", "string", false}, {"numberAndIntegerResultInNotNull", "number", 123, false}, {"floatAndFloatResultInNotNull", "float", 123.01, false}, {"booleanAndBooleanResultInNotNull", "boolean", true, false}, {"dateAndTimeResultInNotNull", "date", time.Now(), false}, {"datetimeAndTimeResultInNotNull", "datetime", time.Now(), false}, {"timeAndTimeResultInNotNull", "time", time.Now(), false}, {"timestampAndTimeResultInNotNull", "timestamp", time.Now(), false}, {"timestamp_ntzAndTimeResultInNotNull", "timestamp_ntz", time.Now(), false}, {"timestamp_ltzAndTimeResultInNotNull", "timestamp_ltz", time.Now(), false}, {"timestamp_tzAndTimeResultInNotNull", "timestamp_tz", time.Now(), false}, } runDBTest(t, func(dbt *DBTest) { _, err := dbt.exec("ALTER SESSION SET BIND_NULL_VALUE_USE_NULL_DATATYPE=false") if err != nil { log.Println(err) } for _, tc := range testcases { t.Run(tc.testDesc, func(t *testing.T) { query := fmt.Sprintf(` CREATE OR REPLACE FUNCTION NULLPARAMFUNCTION("param1" %v) RETURNS TABLE("r1" %v) LANGUAGE SQL AS 'select param1';`, tc.paramType, tc.paramType) dbt.mustExec(query) rows, err := dbt.query("select * from table(NULLPARAMFUNCTION(?))", tc.input) if err != nil { t.Fatal(err) } defer func() { assertNilF(t, rows.Close()) }() if rows.Err() != nil { t.Fatal(err) } if !rows.Next() { t.Fatal("no rows fetched") } var r1 any err = rows.Scan(&r1) if err != nil { t.Fatal(err) } if tc.nullResult && r1 != nil { t.Fatalf("the result for %v is of type %v but should be null", tc.paramType, reflect.TypeOf(r1)) } if !tc.nullResult && r1 == nil { t.Fatalf("the result for %v should not be null", tc.paramType) } }) } }) } // TestVariousBindingModes tests 24 parameter types × 3 binding modes. // Subtests share a hardcoded table name (BINDING_MODES) via CREATE OR REPLACE, // so they CANNOT run in parallel — concurrent subtests would overwrite each // other's tables. Making this parallel-safe would require unique table names // per subtest. func TestVariousBindingModes(t *testing.T) { testcases := []struct { testDesc string paramType string input any isNil bool }{ {"textAndString", "text", "string", false}, {"numberAndInteger", "number", 123, false}, {"floatAndFloat", "float", 123.01, false}, {"booleanAndBoolean", "boolean", true, false}, {"dateAndTime", "date", time.Now().Truncate(24 * time.Hour), false}, {"datetimeAndTime", "datetime", time.Now(), false}, {"timeAndTime", "time", "12:34:56", false}, {"timestampAndTime", "timestamp", time.Now(), false}, {"timestamp_ntzAndTime", "timestamp_ntz", time.Now(), false}, {"timestamp_ltzAndTime", "timestamp_ltz", time.Now(), false}, {"timestamp_tzAndTime", "timestamp_tz", time.Now(), false}, {"textAndNullString", "text", sql.NullString{}, true}, {"numberAndNullInt64", "number", sql.NullInt64{}, true}, {"floatAndNullFloat64", "float", sql.NullFloat64{}, true}, {"booleanAndAndNullBool", "boolean", sql.NullBool{}, true}, {"dateAndTypedNullTime", "date", TypedNullTime{sql.NullTime{}, DateType}, true}, {"datetimeAndTypedNullTime", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, {"timeAndTypedNullTime", "time", TypedNullTime{sql.NullTime{}, TimeType}, true}, {"timestampAndTypedNullTime", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, {"timestamp_ntzAndTypedNullTime", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, {"timestamp_ltzAndTypedNullTime", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true}, {"timestamp_tzAndTypedNullTime", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true}, {"LOBSmallSize", fmt.Sprintf("varchar(%v)", smallSize), fastStringGeneration(smallSize), false}, {"LOBLargeSize", fmt.Sprintf("varchar(%v)", largeSize), fastStringGeneration(largeSize), false}, } bindingModes := []struct { param string query string transform func(any) any }{ { param: "?", transform: func(v any) any { return v }, }, { param: ":1", transform: func(v any) any { return v }, }, { param: ":param", transform: func(v any) any { return sql.Named("param", v) }, }, } runDBTest(t, func(dbt *DBTest) { for _, tc := range testcases { // TODO SNOW-1264687 if strings.Contains(tc.testDesc, "LOB") { skipOnJenkins(t, "skipped until SNOW-1264687 is fixed") } for _, bindingMode := range bindingModes { t.Run(tc.testDesc+" "+bindingMode.param, func(t *testing.T) { query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType) dbt.mustExec(query) if _, err := dbt.exec(fmt.Sprintf("INSERT INTO BINDING_MODES VALUES (%v)", bindingMode.param), bindingMode.transform(tc.input)); err != nil { t.Fatal(err) } if tc.isNil { query = "SELECT * FROM BINDING_MODES WHERE param1 IS NULL" } else { query = fmt.Sprintf("SELECT * FROM BINDING_MODES WHERE param1 = %v", bindingMode.param) } rows, err := dbt.query(query, bindingMode.transform(tc.input)) if err != nil { t.Fatal(err) } defer func() { assertNilF(t, rows.Close()) }() if !rows.Next() { t.Fatal("Expected to return a row") } }) } } }) } func skipMaxLobSizeTestOnGithubActions(t *testing.T) { if runningOnGithubAction() { t.Skip("Max Lob Size parameters are not available on GH Actions") } } func TestLOBRetrievalWithArrow(t *testing.T) { testLOBRetrieval(t, true) } func TestLOBRetrievalWithJSON(t *testing.T) { testLOBRetrieval(t, false) } func testLOBRetrieval(t *testing.T, useArrowFormat bool) { runDBTest(t, func(dbt *DBTest) { if useArrowFormat { dbt.mustExec(forceARROW) } else { dbt.mustExec(forceJSON) } var res string testSizes := [2]int{smallSize, largeSize} for _, testSize := range testSizes { t.Run(fmt.Sprintf("testLOB_%v_useArrowFormat=%v", strconv.Itoa(testSize), strconv.FormatBool(useArrowFormat)), func(t *testing.T) { rows, err := dbt.query(fmt.Sprintf("SELECT randstr(%v, 124)", testSize)) assertNilF(t, err) defer func() { assertNilF(t, rows.Close()) }() assertTrueF(t, rows.Next(), fmt.Sprintf("no rows returned for the LOB size %v", testSize)) // retrieve the result err = rows.Scan(&res) assertNilF(t, err) // verify the length of the result assertEqualF(t, len(res), testSize) }) } }) } func TestMaxLobSize(t *testing.T) { skipMaxLobSizeTestOnGithubActions(t) runDBTest(t, func(dbt *DBTest) { dbt.mustExec(enableFeatureMaxLOBSize) defer dbt.mustExec(unsetLargeVarcharAndBinary) t.Run("Max Lob Size disabled", func(t *testing.T) { dbt.mustExec(disableLargeVarcharAndBinary) _, err := dbt.query("select randstr(20000000, random())") assertNotNilF(t, err) assertStringContainsF(t, err.Error(), "Actual length 20000000 exceeds supported length") }) t.Run("Max Lob Size enabled", func(t *testing.T) { dbt.mustExec(enableLargeVarcharAndBinary) rows, err := dbt.query("select randstr(20000000, random())") assertNilF(t, err) defer func() { assertNilF(t, rows.Close()) }() }) }) } func TestInsertLobDataWithLiteralArrow(t *testing.T) { // TODO SNOW-1264687 skipOnJenkins(t, "skipped until SNOW-1264687 is fixed") testInsertLOBData(t, true, true) } func TestInsertLobDataWithLiteralJSON(t *testing.T) { // TODO SNOW-1264687 skipOnJenkins(t, "skipped until SNOW-1264687 is fixed") testInsertLOBData(t, false, true) } func TestInsertLobDataWithBindingsArrow(t *testing.T) { // TODO SNOW-1264687 skipOnJenkins(t, "skipped until SNOW-1264687 is fixed") testInsertLOBData(t, true, false) } func TestInsertLobDataWithBindingsJSON(t *testing.T) { // TODO SNOW-1264687 skipOnJenkins(t, "skipped until SNOW-1264687 is fixed") testInsertLOBData(t, false, false) } func testInsertLOBData(t *testing.T, useArrowFormat bool, isLiteral bool) { expectedNumCols := 3 columnMeta := []struct { columnName string columnType reflect.Type }{ {"C1", reflect.TypeFor[string]()}, {"C2", reflect.TypeFor[string]()}, {"C3", reflect.TypeFor[string]()}, } testCases := []struct { testDesc string c1Size int c2Size int c3Size int }{ {"testLOBInsertSmallSize", smallSize, smallSize, lobRandomRange}, {"testLOBInsertLargeSize", largeSize, smallSize, lobRandomRange}, } runDBTest(t, func(dbt *DBTest) { var c1 string var c2 string var c3 int dbt.mustExec(enableFeatureMaxLOBSize) if useArrowFormat { dbt.mustExec(forceARROW) } else { dbt.mustExec(forceJSON) } for _, tc := range testCases { t.Run(tc.testDesc, func(t *testing.T) { c1Data := fastStringGeneration(tc.c1Size) c2Data := fastStringGeneration(tc.c2Size) c3Data := rand.Intn(tc.c3Size) dbt.mustExec(fmt.Sprintf("CREATE OR REPLACE TABLE lob_test_table (c1 varchar(%v), c2 varchar(%v), c3 int)", tc.c1Size, tc.c2Size)) if isLiteral { dbt.mustExec(fmt.Sprintf("INSERT INTO lob_test_table VALUES ('%s', '%s', %v)", c1Data, c2Data, c3Data)) } else { dbt.mustExec("INSERT INTO lob_test_table VALUES (?, ?, ?)", c1Data, c2Data, c3Data) } rows, err := dbt.query("SELECT * FROM lob_test_table") assertNilF(t, err) defer func() { assertNilF(t, rows.Close()) }() assertTrueF(t, rows.Next(), fmt.Sprintf("%s: no rows returned", tc.testDesc)) err = rows.Scan(&c1, &c2, &c3) assertNilF(t, err) // check the number of columns columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualF(t, len(columnTypes), expectedNumCols) // verify the column metadata: name, type and length for colIdx := range expectedNumCols { colName := columnTypes[colIdx].Name() assertEqualF(t, colName, columnMeta[colIdx].columnName) colType := columnTypes[colIdx].ScanType() assertEqualF(t, colType, columnMeta[colIdx].columnType) colLength, ok := columnTypes[colIdx].Length() switch colIdx { case 0: assertTrueF(t, ok) assertEqualF(t, colLength, int64(tc.c1Size)) // verify the data assertEqualF(t, c1, c1Data) case 1: assertTrueF(t, ok) assertEqualF(t, colLength, int64(tc.c2Size)) // verify the data assertEqualF(t, c2, c2Data) case 2: assertFalseF(t, ok) // verify the data assertEqualF(t, c3, c3Data) } } }) dbt.mustExec("DROP TABLE IF EXISTS lob_test_table") } dbt.mustExec(unsetFeatureMaxLOBSize) }) } func fastStringGeneration(size int) string { if size <= 0 { return "" } pattern := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" patternLen := len(pattern) if size <= patternLen { return pattern[:size] } fullRepeats := size / patternLen remainder := size % patternLen var result strings.Builder result.Grow(size) fullPattern := strings.Repeat(pattern, fullRepeats) result.WriteString(fullPattern) if remainder > 0 { result.WriteString(pattern[:remainder]) } return result.String() } func getRandomDate() time.Time { return time.Date(rand.Intn(1582)+1, time.January, rand.Intn(40), rand.Intn(40), rand.Intn(40), rand.Intn(40), rand.Intn(40), time.UTC) } func getRandomBool() bool { return rand.Int63n(time.Now().Unix())%2 == 0 } ================================================ FILE: chunk.go ================================================ package gosnowflake import ( "bytes" "fmt" "io" "unicode" "unicode/utf16" "unicode/utf8" ) const ( defaultChunkBufferSize int64 = 8 << 10 // 8k defaultStringBufferSize int64 = 512 ) type largeChunkDecoder struct { r io.Reader rows int // hint for number of rows cells int // hint for number of cells/row rem int // bytes remaining in rbuf ptr int // position in rbuf rbuf []byte sbuf *bytes.Buffer // buffer for decodeString ioError error } func decodeLargeChunk(r io.Reader, rowCount int, cellCount int) ([][]*string, error) { logger.Info("custom JSON Decoder") lcd := largeChunkDecoder{ r, rowCount, cellCount, 0, 0, make([]byte, defaultChunkBufferSize), bytes.NewBuffer(make([]byte, defaultStringBufferSize)), nil, } rows, err := lcd.decode() if lcd.ioError != nil && lcd.ioError != io.EOF { return nil, lcd.ioError } else if err != nil { return nil, err } return rows, nil } func (lcd *largeChunkDecoder) mkError(s string) error { return fmt.Errorf("corrupt chunk: %s", s) } func (lcd *largeChunkDecoder) decode() ([][]*string, error) { if lcd.nextByteNonWhitespace() != '[' { return nil, lcd.mkError("expected chunk to begin with '['") } rows := make([][]*string, 0, lcd.rows) if lcd.nextByteNonWhitespace() == ']' { return rows, nil // special case of an empty chunk } lcd.rewind(1) OuterLoop: for { row, err := lcd.decodeRow() if err != nil { return nil, err } rows = append(rows, row) switch lcd.nextByteNonWhitespace() { case ',': continue // more elements in the array case ']': return rows, nil // we've scanned the whole chunk default: break OuterLoop } } return nil, lcd.mkError("invalid row boundary") } func (lcd *largeChunkDecoder) decodeRow() ([]*string, error) { if lcd.nextByteNonWhitespace() != '[' { return nil, lcd.mkError("expected row to begin with '['") } row := make([]*string, 0, lcd.cells) if lcd.nextByteNonWhitespace() == ']' { return row, nil // special case of an empty row } lcd.rewind(1) OuterLoop: for { cell, err := lcd.decodeCell() if err != nil { return nil, err } row = append(row, cell) switch lcd.nextByteNonWhitespace() { case ',': continue // more elements in the array case ']': return row, nil // we've scanned the whole row default: break OuterLoop } } return nil, lcd.mkError("invalid cell boundary") } func (lcd *largeChunkDecoder) decodeCell() (*string, error) { c := lcd.nextByteNonWhitespace() switch c { case '"': s, err := lcd.decodeString() return &s, err case 'n': if lcd.nextByte() == 'u' && lcd.nextByte() == 'l' && lcd.nextByte() == 'l' { return nil, nil } } return nil, lcd.mkError("cell begins with unexpected byte") } // TODO we can optimize this further by optimistically searching // the read buffer for the next string. If it's short enough and // doesn't contain any escaped characters, we can construct the // return string directly without writing to the sbuf func (lcd *largeChunkDecoder) decodeString() (string, error) { lcd.sbuf.Reset() for { // NOTE if you make changes here, ensure this // variable does not escape to the heap c := lcd.nextByte() if c == '"' { break } else if c == '\\' { if err := lcd.decodeEscaped(); err != nil { return "", err } } else if c < ' ' { return "", lcd.mkError("unexpected control character") } else if c < utf8.RuneSelf { lcd.sbuf.WriteByte(c) } else { lcd.rewind(1) lcd.sbuf.WriteRune(lcd.readRune()) } } return lcd.sbuf.String(), nil } func (lcd *largeChunkDecoder) decodeEscaped() error { // NOTE if you make changes here, ensure this // variable does not escape to the heap c := lcd.nextByte() switch c { case '"', '\\', '/', '\'': lcd.sbuf.WriteByte(c) case 'b': lcd.sbuf.WriteByte('\b') case 'f': lcd.sbuf.WriteByte('\f') case 'n': lcd.sbuf.WriteByte('\n') case 'r': lcd.sbuf.WriteByte('\r') case 't': lcd.sbuf.WriteByte('\t') case 'u': rr := lcd.getu4() if rr < 0 { return lcd.mkError("invalid escape sequence") } if utf16.IsSurrogate(rr) { rr1, size := lcd.getu4WithPrefix() if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { // A valid pair; consume. lcd.sbuf.WriteRune(dec) break } // Invalid surrogate; fall back to replacement rune. lcd.rewind(size) rr = unicode.ReplacementChar } lcd.sbuf.WriteRune(rr) default: return lcd.mkError("invalid escape sequence: " + string(c)) } return nil } func (lcd *largeChunkDecoder) readRune() rune { lcd.ensureBytes(4) r, size := utf8.DecodeRune(lcd.rbuf[lcd.ptr:]) lcd.ptr += size lcd.rem -= size return r } func (lcd *largeChunkDecoder) getu4WithPrefix() (rune, int) { lcd.ensureBytes(6) // NOTE take a snapshot of the cursor state. If this // is not a valid rune, then we need to roll back to // where we were before we began consuming bytes ptr := lcd.ptr if lcd.nextByte() != '\\' { return -1, lcd.ptr - ptr } if lcd.nextByte() != 'u' { return -1, lcd.ptr - ptr } r := lcd.getu4() return r, lcd.ptr - ptr } func (lcd *largeChunkDecoder) getu4() rune { var r rune for range 4 { c := lcd.nextByte() switch { case '0' <= c && c <= '9': c = c - '0' case 'a' <= c && c <= 'f': c = c - 'a' + 10 case 'A' <= c && c <= 'F': c = c - 'A' + 10 default: return -1 } r = r*16 + rune(c) } return r } func (lcd *largeChunkDecoder) nextByteNonWhitespace() byte { for { c := lcd.nextByte() switch c { case ' ', '\t', '\n', '\r': continue default: return c } } } func (lcd *largeChunkDecoder) rewind(n int) { lcd.ptr -= n lcd.rem += n } func (lcd *largeChunkDecoder) nextByte() byte { if lcd.rem == 0 { if lcd.ioError != nil { return 0 } lcd.ptr = 0 lcd.rem = lcd.fillBuffer(lcd.rbuf) if lcd.rem == 0 { return 0 } } b := lcd.rbuf[lcd.ptr] lcd.ptr++ lcd.rem-- return b } func (lcd *largeChunkDecoder) ensureBytes(n int) { if lcd.rem <= n { rbuf := make([]byte, defaultChunkBufferSize) // NOTE when the buffer reads from the stream, there's no // guarantee that it will actually be filled. As such we // must use (ptr+rem) to compute the end of the slice. off := copy(rbuf, lcd.rbuf[lcd.ptr:lcd.ptr+lcd.rem]) add := lcd.fillBuffer(rbuf[off:]) lcd.ptr = 0 lcd.rem += add lcd.rbuf = rbuf } } func (lcd *largeChunkDecoder) fillBuffer(b []byte) int { n, err := lcd.r.Read(b) if err != nil && err != io.EOF { lcd.ioError = err return 0 } else if n <= 0 { lcd.ioError = io.EOF return 0 } return n } ================================================ FILE: chunk_downloader.go ================================================ package gosnowflake import ( "bufio" "compress/gzip" "context" "encoding/json" "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "github.com/snowflakedb/gosnowflake/v2/internal/query" "io" "net/http" "net/url" "strconv" "sync" "time" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" ) var ( errNoConnection = errors.New("failed to retrieve connection") ) type chunkDownloader interface { totalUncompressedSize() (acc int64) start() error next() (chunkRowType, error) reset() getChunkMetas() []query.ExecResponseChunk getQueryResultFormat() resultFormat getRowType() []query.ExecResponseRowType setNextChunkDownloader(downloader chunkDownloader) getNextChunkDownloader() chunkDownloader getRawArrowBatches() []*rawArrowBatchData } type snowflakeChunkDownloader struct { sc *snowflakeConn ctx context.Context pool memory.Allocator Total int64 TotalRowIndex int64 CellCount int CurrentChunk []chunkRowType CurrentChunkIndex int CurrentChunkSize int CurrentIndex int ChunkHeader map[string]string ChunkMetas []query.ExecResponseChunk Chunks map[int][]chunkRowType ChunksChan chan int ChunksError chan *chunkError ChunksErrorCounter int ChunksFinalErrors []*chunkError ChunksMutex *sync.Mutex DoneDownloadCond *sync.Cond firstBatchRaw *rawArrowBatchData NextDownloader chunkDownloader Qrmk string QueryResultFormat string rawBatches []*rawArrowBatchData RowSet rowSetType FuncDownload func(context.Context, *snowflakeChunkDownloader, int) FuncDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error FuncGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) } func (scd *snowflakeChunkDownloader) totalUncompressedSize() (acc int64) { for _, c := range scd.ChunkMetas { acc += c.UncompressedSize } return } func (scd *snowflakeChunkDownloader) start() error { if usesArrowBatches(scd.ctx) && scd.getQueryResultFormat() == arrowFormat { return scd.startArrowBatches() } scd.CurrentChunkSize = len(scd.RowSet.JSON) // cache the size scd.CurrentIndex = -1 // initial chunks idx scd.CurrentChunkIndex = -1 // initial chunk scd.CurrentChunk = make([]chunkRowType, scd.CurrentChunkSize) populateJSONRowSet(scd.CurrentChunk, scd.RowSet.JSON) if scd.getQueryResultFormat() == arrowFormat && scd.RowSet.RowSetBase64 != "" { params, err := scd.getConfigParams() if err != nil { return fmt.Errorf("getting config params: %w", err) } // if the rowsetbase64 retrieved from the server is empty, move on to downloading chunks loc := getCurrentLocation(params) firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) if err != nil { return fmt.Errorf("building first arrow chunk: %w", err) } higherPrecision := higherPrecisionEnabled(scd.ctx) scd.CurrentChunk, err = firstArrowChunk.decodeArrowChunk(scd.ctx, scd.RowSet.RowType, higherPrecision, params) scd.CurrentChunkSize = firstArrowChunk.rowCount if err != nil { return fmt.Errorf("decoding arrow chunk: %w", err) } } // start downloading chunks if exists chunkMetaLen := len(scd.ChunkMetas) if chunkMetaLen > 0 { chunkDownloadWorkers := defaultMaxChunkDownloadWorkers chunkDownloadWorkersStr, ok := scd.sc.syncParams.get(clientPrefetchThreadsKey) if ok { var err error chunkDownloadWorkers, err = strconv.Atoi(*chunkDownloadWorkersStr) if err != nil { logger.Warnf("invalid value for CLIENT_PREFETCH_THREADS: %v", *chunkDownloadWorkersStr) chunkDownloadWorkers = defaultMaxChunkDownloadWorkers } } if chunkDownloadWorkers <= 0 { logger.Warnf("invalid value for CLIENT_PREFETCH_THREADS: %v. It should be a positive integer. Defaulting to %v", chunkDownloadWorkers, defaultMaxChunkDownloadWorkers) chunkDownloadWorkers = defaultMaxChunkDownloadWorkers } logger.WithContext(scd.ctx).Debugf("chunkDownloadWorkers: %v", chunkDownloadWorkers) logger.WithContext(scd.ctx).Debugf("chunks: %v, total bytes: %d", chunkMetaLen, scd.totalUncompressedSize()) scd.ChunksMutex = &sync.Mutex{} scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex) scd.Chunks = make(map[int][]chunkRowType) scd.ChunksChan = make(chan int, chunkMetaLen) scd.ChunksError = make(chan *chunkError, chunkDownloadWorkers) for i := range chunkMetaLen { chunk := scd.ChunkMetas[i] logger.WithContext(scd.ctx).Debugf("Result Format: %v, add chunk to channel ChunksChan: %v, URL: %v, RowCount: %v, UncompressedSize: %v, ChunkResultFormat: %v", scd.getQueryResultFormat(), i+1, chunk.URL, chunk.RowCount, chunk.UncompressedSize, scd.QueryResultFormat) scd.ChunksChan <- i } for i := 0; i < intMin(chunkDownloadWorkers, chunkMetaLen); i++ { scd.schedule() } } return nil } func (scd *snowflakeChunkDownloader) schedule() { timer := time.Now() select { case nextIdx := <-scd.ChunksChan: logger.WithContext(scd.ctx).Infof("schedule chunk: %v", nextIdx+1) go GoroutineWrapper( scd.ctx, func() { scd.FuncDownload(scd.ctx, scd, nextIdx) }, ) default: // no more download chunkCount := len(scd.ChunkMetas) avgTime := 0.0 if chunkCount > 0 { avgTime = float64(time.Since(timer)) / float64(chunkCount) } logger.WithContext(scd.ctx).Infof("Processed %v chunks. It took %v ms, average chunk processing time: %v ms", len(scd.ChunkMetas), time.Since(timer).String(), avgTime) } } func (scd *snowflakeChunkDownloader) checkErrorRetry() error { select { case errc := <-scd.ChunksError: if scd.ChunksErrorCounter >= maxChunkDownloaderErrorCounter || errors.Is(errc.Error, context.Canceled) || errors.Is(errc.Error, context.DeadlineExceeded) { scd.ChunksFinalErrors = append(scd.ChunksFinalErrors, errc) logger.WithContext(scd.ctx).Warnf("chunk idx: %v, err: %v. no further retry", errc.Index, errc.Error) return errc.Error } // add the index to the chunks channel so that the download will be retried. go GoroutineWrapper( scd.ctx, func() { scd.FuncDownload(scd.ctx, scd, errc.Index) }, ) scd.ChunksErrorCounter++ logger.WithContext(scd.ctx).Warnf("chunk idx: %v, err: %v. retrying (%v/%v)...", errc.Index, errc.Error, scd.ChunksErrorCounter, maxChunkDownloaderErrorCounter) return nil default: logger.WithContext(scd.ctx).Info("no error is detected.") return nil } } func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) { for { scd.CurrentIndex++ if scd.CurrentIndex < scd.CurrentChunkSize { return scd.CurrentChunk[scd.CurrentIndex], nil } scd.CurrentChunkIndex++ // next chunk scd.CurrentIndex = -1 // reset if scd.CurrentChunkIndex >= len(scd.ChunkMetas) { break } scd.ChunksMutex.Lock() if scd.CurrentChunkIndex > 0 { scd.Chunks[scd.CurrentChunkIndex-1] = nil // detach the previously used chunk } for scd.Chunks[scd.CurrentChunkIndex] == nil { logger.WithContext(scd.ctx).Debugf("waiting for chunk idx: %v/%v", scd.CurrentChunkIndex+1, len(scd.ChunkMetas)) if err := scd.checkErrorRetry(); err != nil { scd.ChunksMutex.Unlock() return chunkRowType{}, fmt.Errorf("checking for error: %w", err) } // wait for chunk downloader goroutine to broadcast the event, // 1) one chunk download finishes or 2) an error occurs. scd.DoneDownloadCond.Wait() } logger.WithContext(scd.ctx).Debugf("ready: chunk %v", scd.CurrentChunkIndex+1) scd.CurrentChunk = scd.Chunks[scd.CurrentChunkIndex] scd.ChunksMutex.Unlock() scd.CurrentChunkSize = len(scd.CurrentChunk) // kick off the next download scd.schedule() } logger.WithContext(scd.ctx).Debugf("no more data") if len(scd.ChunkMetas) > 0 { close(scd.ChunksError) close(scd.ChunksChan) } return chunkRowType{}, io.EOF } func (scd *snowflakeChunkDownloader) reset() { scd.Chunks = nil // detach all chunks. No way to go backward without reinitialize it. } func (scd *snowflakeChunkDownloader) getChunkMetas() []query.ExecResponseChunk { return scd.ChunkMetas } func (scd *snowflakeChunkDownloader) getQueryResultFormat() resultFormat { return resultFormat(scd.QueryResultFormat) } func (scd *snowflakeChunkDownloader) setNextChunkDownloader(nextDownloader chunkDownloader) { scd.NextDownloader = nextDownloader } func (scd *snowflakeChunkDownloader) getNextChunkDownloader() chunkDownloader { return scd.NextDownloader } func (scd *snowflakeChunkDownloader) getRowType() []query.ExecResponseRowType { return scd.RowSet.RowType } // rawArrowBatchData holds raw (untransformed) arrow records for a single batch. type rawArrowBatchData struct { records *[]arrow.Record rowCount int loc *time.Location } func (scd *snowflakeChunkDownloader) getRawArrowBatches() []*rawArrowBatchData { if scd.firstBatchRaw == nil || scd.firstBatchRaw.records == nil { return scd.rawBatches } return append([]*rawArrowBatchData{scd.firstBatchRaw}, scd.rawBatches...) } // releaseRawArrowBatches releases any raw arrow records still owned by the // chunk downloader. Records whose ownership was transferred to BatchRaw // (via GetArrowBatches) will already have been nilled out and are skipped. func (scd *snowflakeChunkDownloader) releaseRawArrowBatches() { releaseRecords := func(raw *rawArrowBatchData) { if raw == nil || raw.records == nil { return } for _, rec := range *raw.records { rec.Release() } raw.records = nil } releaseRecords(scd.firstBatchRaw) for _, raw := range scd.rawBatches { releaseRecords(raw) } } func (scd *snowflakeChunkDownloader) getConfigParams() (*syncParams, error) { if scd.sc == nil || scd.sc.cfg == nil { return nil, errNoConnection } return &scd.sc.syncParams, nil } func getChunk( ctx context.Context, sc *snowflakeConn, fullURL string, headers map[string]string, timeout time.Duration) ( *http.Response, error, ) { u, err := url.Parse(fullURL) if err != nil { return nil, fmt.Errorf("failed to parse URL: %w", err) } return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.rest.MaxRetryCount, sc.currentTimeProvider, sc.cfg).execute() } func (scd *snowflakeChunkDownloader) startArrowBatches() error { var loc *time.Location params, err := scd.getConfigParams() if err != nil { return fmt.Errorf("getting config params: %w", err) } loc = getCurrentLocation(params) if scd.RowSet.RowSetBase64 != "" { firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) if err != nil { return fmt.Errorf("building first arrow chunk: %w", err) } scd.firstBatchRaw = &rawArrowBatchData{ loc: loc, } if firstArrowChunk.allocator != nil { scd.firstBatchRaw.records, err = firstArrowChunk.decodeArrowBatchRaw() if err != nil { return fmt.Errorf("decoding arrow batch: %w", err) } scd.firstBatchRaw.rowCount = countRawArrowBatchRows(scd.firstBatchRaw.records) } } chunkMetaLen := len(scd.ChunkMetas) scd.rawBatches = make([]*rawArrowBatchData, chunkMetaLen) for i := range scd.rawBatches { scd.rawBatches[i] = &rawArrowBatchData{ loc: loc, rowCount: scd.ChunkMetas[i].RowCount, } scd.CurrentChunkIndex++ } return nil } /* largeResultSetReader is a reader that wraps the large result set with leading and tailing brackets. */ type largeResultSetReader struct { status int body io.Reader } func (r *largeResultSetReader) Read(p []byte) (n int, err error) { if r.status == 0 { p[0] = 0x5b // initial 0x5b ([) r.status = 1 return 1, nil } if r.status == 1 { var len int len, err = r.body.Read(p) if err == io.EOF { r.status = 2 return len, nil } if err != nil { return 0, fmt.Errorf("reading body: %w", err) } return len, nil } if r.status == 2 { p[0] = 0x5d // tail 0x5d (]) r.status = 3 return 1, nil } // ensure no data and EOF return 0, io.EOF } func downloadChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { logger.WithContext(ctx).Infof("download start chunk: %v", idx+1) defer scd.DoneDownloadCond.Broadcast() timer := time.Now() if err := scd.FuncDownloadHelper(ctx, scd, idx); err != nil { logger.WithContext(ctx).Errorf( "failed to extract HTTP response body. URL: %v, err: %v", scd.ChunkMetas[idx].URL, err) scd.ChunksError <- &chunkError{Index: idx, Error: err} } else if errors.Is(scd.ctx.Err(), context.Canceled) || errors.Is(scd.ctx.Err(), context.DeadlineExceeded) { scd.ChunksError <- &chunkError{Index: idx, Error: scd.ctx.Err()} } elapsedTime := time.Since(timer).String() logger.Debugf("“Processed %v chunk %v out of %v. It took %v ms. Chunk size: %v, rows: %v”.", scd.getQueryResultFormat(), idx+1, len(scd.ChunkMetas), elapsedTime, scd.ChunkMetas[idx].UncompressedSize, scd.ChunkMetas[idx].RowCount) } func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx int) error { headers := make(map[string]string) if len(scd.ChunkHeader) > 0 { logger.WithContext(ctx).Debug("chunk header is provided.") for k, v := range scd.ChunkHeader { logger.WithContext(ctx).Debugf("adding header: %v, value: %v", k, v) headers[k] = v } } else { headers[headerSseCAlgorithm] = headerSseCAes headers[headerSseCKey] = scd.Qrmk } resp, err := scd.FuncGet(ctx, scd.sc, scd.ChunkMetas[idx].URL, headers, scd.sc.rest.RequestTimeout) if err != nil { return fmt.Errorf("getting chunk: %w", err) } defer func() { if err = resp.Body.Close(); err != nil { logger.Warnf("downloadChunkHelper: closing response body %v: %v", scd.ChunkMetas[idx].URL, err) } }() logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL) if resp.StatusCode != http.StatusOK { b, err := io.ReadAll(resp.Body) if err != nil { logger.WithContext(ctx).Errorf("reading response body: %v", err) } logger.WithContext(ctx).Debugf("HTTP: %v, URL: %v, Header: %v, Body: %v", resp.StatusCode, scd.ChunkMetas[idx].URL, resp.Header, b) return &SnowflakeError{ Number: ErrFailedToGetChunk, SQLState: SQLStateConnectionFailure, Message: errors2.ErrMsgFailedToGetChunk, MessageArgs: []any{idx}, } } bufStream := bufio.NewReader(resp.Body) return decodeChunk(ctx, scd, idx, bufStream) } func decodeChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int, bufStream *bufio.Reader) error { gzipMagic, err := bufStream.Peek(2) if err != nil { return fmt.Errorf("peeking for gzip magic bytes: %w", err) } start := time.Now() var source io.Reader if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { // detects and uncompresses Gzip format data bufStream0, err := gzip.NewReader(bufStream) if err != nil { return fmt.Errorf("creating gzip reader: %w", err) } defer func() { if err = bufStream0.Close(); err != nil { logger.Warnf("decodeChunk: closing gzip reader: %v", err) } }() source = bufStream0 } else { source = bufStream } st := &largeResultSetReader{ status: 0, body: source, } var respd []chunkRowType if scd.getQueryResultFormat() != arrowFormat { var decRespd [][]*string if !customJSONDecoderEnabled { dec := json.NewDecoder(st) for { if err := dec.Decode(&decRespd); err == io.EOF { break } else if err != nil { return fmt.Errorf("decoding json: %w", err) } } } else { decRespd, err = decodeLargeChunk(st, scd.ChunkMetas[idx].RowCount, scd.CellCount) if err != nil { return fmt.Errorf("decoding large chunk: %w", err) } } respd = make([]chunkRowType, len(decRespd)) populateJSONRowSet(respd, decRespd) } else { ipcReader, err := ipc.NewReader(source, ipc.WithAllocator(scd.pool)) if err != nil { return fmt.Errorf("creating ipc reader: %w", err) } var loc *time.Location params, err := scd.getConfigParams() if err != nil { return fmt.Errorf("getting config params: %w", err) } loc = getCurrentLocation(params) arc := arrowResultChunk{ ipcReader, 0, loc, scd.pool, } if usesArrowBatches(scd.ctx) { var err error scd.rawBatches[idx].records, err = arc.decodeArrowBatchRaw() if err != nil { return fmt.Errorf("decoding Arrow batch: %w", err) } scd.rawBatches[idx].rowCount = countRawArrowBatchRows(scd.rawBatches[idx].records) return nil } highPrec := higherPrecisionEnabled(scd.ctx) respd, err = arc.decodeArrowChunk(ctx, scd.RowSet.RowType, highPrec, params) if err != nil { return fmt.Errorf("decoding arrow chunk: %w", err) } } logger.WithContext(scd.ctx).Debugf( "decoded %d rows w/ %d bytes in %s (chunk %v)", scd.ChunkMetas[idx].RowCount, scd.ChunkMetas[idx].UncompressedSize, time.Since(start), idx+1, ) scd.ChunksMutex.Lock() defer scd.ChunksMutex.Unlock() scd.Chunks[idx] = respd return nil } func populateJSONRowSet(dst []chunkRowType, src [][]*string) { // populate string rowset from src to dst's chunkRowType struct's RowSet field for i, row := range src { dst[i].RowSet = row } } func countRawArrowBatchRows(recs *[]arrow.Record) (cnt int) { if recs == nil { return 0 } for _, r := range *recs { cnt += int(r.NumRows()) } return } func getAllocator(ctx context.Context) memory.Allocator { pool, ok := ctx.Value(arrowAlloc).(memory.Allocator) if !ok { return memory.DefaultAllocator } return pool } func usesArrowBatches(ctx context.Context) bool { return ia.BatchesEnabled(ctx) } ================================================ FILE: chunk_downloader_test.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "testing" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" ) func TestChunkDownloaderDoesNotStartWhenArrowParsingCausesError(t *testing.T) { tcs := []string{ "invalid base64", "aW52YWxpZCBhcnJvdw==", // valid base64, but invalid arrow } for _, tc := range tcs { t.Run(tc, func(t *testing.T) { scd := snowflakeChunkDownloader{ ctx: context.Background(), QueryResultFormat: "arrow", RowSet: rowSetType{ RowSetBase64: tc, }, } err := scd.start() assertNotNilF(t, err) }) } } func TestWithArrowBatchesWhenQueryReturnsNoRowsWhenUsingNativeGoSQLInterface(t *testing.T) { runDBTest(t, func(dbt *DBTest) { var rows driver.Rows var err error err = dbt.conn.Raw(func(x any) error { rows, err = x.(driver.QueryerContext).QueryContext(ia.EnableArrowBatches(context.Background()), "SELECT 1 WHERE 0 = 1", nil) return err }) assertNilF(t, err) rows.Close() }) } func TestWithArrowBatchesWhenQueryReturnsRowsAndReadingRows(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ia.EnableArrowBatches(context.Background()), "SELECT 1") defer rows.Close() assertFalseF(t, rows.Next()) }) } func TestWithArrowBatchesWhenQueryReturnsNoRowsAndReadingRows(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ia.EnableArrowBatches(context.Background()), "SELECT 1 WHERE 1 = 0") defer rows.Close() assertFalseF(t, rows.Next()) }) } func TestWithArrowBatchesWhenQueryReturnsNoRowsAndReadingArrowBatchData(t *testing.T) { runDBTest(t, func(dbt *DBTest) { var rows driver.Rows var err error err = dbt.conn.Raw(func(x any) error { rows, err = x.(driver.QueryerContext).QueryContext(ia.EnableArrowBatches(context.Background()), "SELECT 1 WHERE 1 = 0", nil) return err }) assertNilF(t, err) defer rows.Close() provider := rows.(SnowflakeRows).(ia.BatchDataProvider) info, err := provider.GetArrowBatches() assertNilF(t, err) assertEmptyE(t, info.Batches) }) } ================================================ FILE: chunk_test.go ================================================ package gosnowflake import ( "bytes" "context" "database/sql/driver" "encoding/json" "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "io" "strings" "sync" "testing" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" ) func TestBadChunkData(t *testing.T) { testDecodeErr(t, "") testDecodeErr(t, "null") testDecodeErr(t, "42") testDecodeErr(t, "\"null\"") testDecodeErr(t, "{}") testDecodeErr(t, "[[]") testDecodeErr(t, "[null]") testDecodeErr(t, `[[hello world]]`) testDecodeErr(t, `[[""hello world""]]`) testDecodeErr(t, `[["\"hello world""]]`) testDecodeErr(t, `[[""hello world\""]]`) testDecodeErr(t, `[["hello world`) testDecodeErr(t, `[["hello world"`) testDecodeErr(t, `[["hello world"]`) testDecodeErr(t, `[["\uQQQQ"]]`) for b := range byte(' ') { testDecodeErr(t, string([]byte{ '[', '[', '"', b, '"', ']', ']', })) } } func TestValidChunkData(t *testing.T) { testDecodeOk(t, "[]") testDecodeOk(t, "[ ]") testDecodeOk(t, "[[]]") testDecodeOk(t, "[ [ ] ]") testDecodeOk(t, "[[],[],[],[]]") testDecodeOk(t, "[[] , [] , [], [] ]") testDecodeOk(t, "[[null]]") testDecodeOk(t, "[[\n\t\r null]]") testDecodeOk(t, "[[null,null]]") testDecodeOk(t, "[[ null , null ]]") testDecodeOk(t, "[[null],[null],[null]]") testDecodeOk(t, "[[null],[ null ] , [null]]") testDecodeOk(t, `[[""]]`) testDecodeOk(t, `[["false"]]`) testDecodeOk(t, `[["true"]]`) testDecodeOk(t, `[["42"]]`) testDecodeOk(t, `[[""]]`) testDecodeOk(t, `[["hello"]]`) testDecodeOk(t, `[["hello world"]]`) testDecodeOk(t, `[["/ ' \\ \b \t \n \f \r \""]]`) testDecodeOk(t, `[["❄"]]`) testDecodeOk(t, `[["\u2744"]]`) testDecodeOk(t, `[["\uFfFc"]]`) // consume replacement chars testDecodeOk(t, `[["\ufffd"]]`) // consume replacement chars testDecodeOk(t, `[["\u0000"]]`) // yes, this is valid testDecodeOk(t, `[["\uD834\uDD1E"]]`) // surrogate pair testDecodeOk(t, `[["\uD834\u0000"]]`) // corrupt surrogate pair testDecodeOk(t, `[["$"]]`) // "$" testDecodeOk(t, `[["\u0024"]]`) // "$" testDecodeOk(t, `[["\uC2A2"]]`) // "¢" testDecodeOk(t, `[["¢"]]`) // "¢" testDecodeOk(t, `[["\u00E2\u82AC"]]`) // "€" testDecodeOk(t, `[["€"]]`) // "€" testDecodeOk(t, `[["\uF090\u8D88"]]`) // "𐍈" testDecodeOk(t, `[["𐍈"]]`) // "𐍈" } func TestSmallBufferChunkData(t *testing.T) { r := strings.NewReader(`[ [null,"hello world"], ["foo bar", null], [null, null] , ["foo bar", "hello world" ] ]`) lcd := largeChunkDecoder{ r, 0, 0, 0, 0, make([]byte, 1), bytes.NewBuffer(make([]byte, defaultStringBufferSize)), nil, } if _, err := lcd.decode(); err != nil { t.Fatalf("failed with small buffer: %s", err) } } func TestEnsureBytes(t *testing.T) { // the content here doesn't matter r := strings.NewReader("0123456789") lcd := largeChunkDecoder{ r, 0, 0, 3, 8189, make([]byte, 8192), bytes.NewBuffer(make([]byte, defaultStringBufferSize)), nil, } lcd.ensureBytes(4) // we expect the new remainder to be 3 + 10 (length of r) if lcd.rem != 13 { t.Fatalf("buffer was not refilled correctly") } } func testDecodeOk(t *testing.T, s string) { var rows [][]*string if err := json.Unmarshal([]byte(s), &rows); err != nil { t.Fatalf("test case is not valid json / [][]*string: %s", s) } // NOTE we parse and stringify the expected result to // remove superficial differences, like whitespace expect, err := json.Marshal(rows) if err != nil { t.Fatalf("unreachable: %s", err) } rows, err = decodeLargeChunk(strings.NewReader(s), 0, 0) if err != nil { t.Fatalf("expected decode to succeed: %s", err) } actual, err := json.Marshal(rows) if err != nil { t.Fatalf("json marshal failed: %s", err) } if string(actual) != string(expect) { t.Fatalf(` result did not match expected result expect=%s bytes=(%v) acutal=%s bytes=(%v)`, string(expect), expect, string(actual), actual, ) } } func testDecodeErr(t *testing.T, s string) { if _, err := decodeLargeChunk(strings.NewReader(s), 0, 0); err == nil { t.Fatalf("expected decode to fail for input: %s", s) } } func TestEnableArrowBatches(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { ctx := ia.EnableArrowBatches(sct.sc.ctx) numrows := 3000 // approximately 6 ArrowBatch objects pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx = WithArrowAllocator(ctx, pool) query := fmt.Sprintf(selectRandomGenerator, numrows) rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) defer rows.Close() // getting result batches via raw bridge info, err := rows.(*snowflakeRows).GetArrowBatches() if err != nil { t.Error(err) } batches := info.Batches numBatches := len(batches) maxWorkers := 10 // enough for 3000 rows type count struct { m sync.Mutex recVal int metaVal int } cnt := count{recVal: 0} var wg sync.WaitGroup chunks := make(chan int, numBatches) for w := 1; w <= maxWorkers; w++ { wg.Add(1) go func(wg *sync.WaitGroup, chunks <-chan int) { defer wg.Done() for i := range chunks { batch := batches[i] var recs *[]arrow.Record if batch.Records != nil { recs = batch.Records } else if batch.Download != nil { var downloadErr error recs, _, downloadErr = batch.Download(context.Background()) if downloadErr != nil { t.Error(downloadErr) } } if recs != nil { for _, r := range *recs { cnt.m.Lock() cnt.recVal += int(r.NumRows()) cnt.m.Unlock() r.Release() } } cnt.m.Lock() cnt.metaVal += batch.RowCount cnt.m.Unlock() } }(&wg, chunks) } for j := range numBatches { chunks <- j } close(chunks) wg.Wait() if cnt.recVal != numrows { t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) } if cnt.metaVal != numrows { t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) } }) } func TestWithArrowBatchesAsync(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { ctx := WithAsyncMode(sct.sc.ctx) ctx = ia.EnableArrowBatches(ctx) numrows := 50000 pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx = WithArrowAllocator(ctx, pool) query := fmt.Sprintf(selectRandomGenerator, numrows) rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) defer rows.Close() info, err := rows.(*snowflakeRows).GetArrowBatches() if err != nil { t.Error(err) } batches := info.Batches numBatches := len(batches) maxWorkers := 10 type count struct { m sync.Mutex recVal int metaVal int } cnt := count{recVal: 0} var wg sync.WaitGroup chunks := make(chan int, numBatches) for w := 1; w <= maxWorkers; w++ { wg.Add(1) go func(wg *sync.WaitGroup, chunks <-chan int) { defer wg.Done() for i := range chunks { batch := batches[i] var recs *[]arrow.Record if batch.Records != nil { recs = batch.Records } else if batch.Download != nil { var downloadErr error recs, _, downloadErr = batch.Download(context.Background()) if downloadErr != nil { t.Error(downloadErr) } } if recs != nil { for _, r := range *recs { cnt.m.Lock() cnt.recVal += int(r.NumRows()) cnt.m.Unlock() r.Release() } } cnt.m.Lock() cnt.metaVal += batch.RowCount cnt.m.Unlock() } }(&wg, chunks) } for j := range numBatches { chunks <- j } close(chunks) wg.Wait() if cnt.recVal != numrows { t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) } if cnt.metaVal != numrows { t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) } }) } func TestWithArrowBatchesButReturningJSON(t *testing.T) { testWithArrowBatchesButReturningJSON(t, false) } func TestWithArrowBatchesButReturningJSONAsync(t *testing.T) { testWithArrowBatchesButReturningJSON(t, true) } func testWithArrowBatchesButReturningJSON(t *testing.T, async bool) { runSnowflakeConnTest(t, func(sct *SCTest) { requestID := NewUUID() pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := WithArrowAllocator(context.Background(), pool) ctx = ia.EnableArrowBatches(ctx) ctx = WithRequestID(ctx, requestID) if async { ctx = WithAsyncMode(ctx) } sct.mustExec(forceJSON, nil) rows := sct.mustQueryContext(ctx, "SELECT 'hello'", nil) defer rows.Close() _, err := rows.(ia.BatchDataProvider).GetArrowBatches() assertNotNilF(t, err) var se *SnowflakeError assertTrueE(t, errors.As(err, &se)) assertEqualE(t, se.Message, errors2.ErrMsgNonArrowResponseInArrowBatches) assertEqualE(t, se.Number, ErrNonArrowResponseInArrowBatches) v := make([]driver.Value, 1) assertNilE(t, rows.Next(v)) assertEqualE(t, v[0], "hello") }) } func TestWithArrowBatchesMultistatement(t *testing.T) { testWithArrowBatchesMultistatement(t, false) } func TestWithArrowBatchesMultistatementAsync(t *testing.T) { testWithArrowBatchesMultistatement(t, true) } func testWithArrowBatchesMultistatement(t *testing.T, async bool) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.mustExec("ALTER SESSION SET ENABLE_FIX_1758055_ADD_ARROW_SUPPORT_FOR_MULTI_STMTS = true", nil) pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := WithMultiStatement(ia.EnableArrowBatches(WithArrowAllocator(context.Background(), pool)), 2) if async { ctx = WithAsyncMode(ctx) } driverRows := sct.mustQueryContext(ctx, "SELECT 'abc' UNION SELECT 'def' ORDER BY 1; SELECT 'ghi' UNION SELECT 'jkl' ORDER BY 1", nil) defer driverRows.Close() sfRows := driverRows.(SnowflakeRows) expectedResults := [][]string{{"abc", "def"}, {"ghi", "jkl"}} resultSetIdx := 0 for hasNextResultSet := true; hasNextResultSet; hasNextResultSet = sfRows.NextResultSet() != io.EOF { info, err := driverRows.(ia.BatchDataProvider).GetArrowBatches() assertNilF(t, err) assertEqualF(t, len(info.Batches), 1) batch := info.Batches[0] assertNotNilF(t, batch.Records) records := *batch.Records assertEqualF(t, len(records), 1) record := records[0] defer record.Release() assertEqualF(t, record.Column(0).(*array.String).Value(0), expectedResults[resultSetIdx][0]) assertEqualF(t, record.Column(0).(*array.String).Value(1), expectedResults[resultSetIdx][1]) resultSetIdx++ } assertEqualF(t, resultSetIdx, len(expectedResults)) err := sfRows.NextResultSet() assertErrIsE(t, err, io.EOF) }) } func TestWithArrowBatchesMultistatementWithJSONResponse(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.mustExec(forceJSON, nil) pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := WithMultiStatement(ia.EnableArrowBatches(WithArrowAllocator(context.Background(), pool)), 2) driverRows := sct.mustQueryContext(ctx, "SELECT 'abc' UNION SELECT 'def' ORDER BY 1; SELECT 'ghi' UNION SELECT 'jkl' ORDER BY 1", nil) defer driverRows.Close() sfRows := driverRows.(SnowflakeRows) resultSetIdx := 0 for hasNextResultSet := true; hasNextResultSet; hasNextResultSet = sfRows.NextResultSet() != io.EOF { _, err := driverRows.(ia.BatchDataProvider).GetArrowBatches() assertNotNilF(t, err) var se *SnowflakeError assertTrueF(t, errors.As(err, &se)) assertEqualE(t, se.Number, ErrNonArrowResponseInArrowBatches) assertEqualE(t, se.Message, errors2.ErrMsgNonArrowResponseInArrowBatches) resultSetIdx++ } assertEqualF(t, resultSetIdx, 2) err := sfRows.NextResultSet() assertErrIsE(t, err, io.EOF) }) } func TestWithArrowBatchesMultistatementWithLargeResultSet(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.mustExec("ALTER SESSION SET ENABLE_FIX_1758055_ADD_ARROW_SUPPORT_FOR_MULTI_STMTS = true", nil) pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := WithMultiStatement(ia.EnableArrowBatches(WithArrowAllocator(context.Background(), pool)), 2) driverRows := sct.mustQueryContext(ctx, "SELECT 'abc' FROM TABLE(GENERATOR(ROWCOUNT => 1000000)); SELECT 'abc' FROM TABLE(GENERATOR(ROWCOUNT => 1000000))", nil) defer driverRows.Close() sfRows := driverRows.(SnowflakeRows) rowCount := 0 for hasNextResultSet := true; hasNextResultSet; hasNextResultSet = sfRows.NextResultSet() != io.EOF { info, err := driverRows.(ia.BatchDataProvider).GetArrowBatches() assertNilF(t, err) assertTrueF(t, len(info.Batches) > 1) for _, batch := range info.Batches { var recs *[]arrow.Record if batch.Records != nil { recs = batch.Records } else if batch.Download != nil { recs, _, err = batch.Download(context.Background()) assertNilF(t, err) } if recs != nil { for _, record := range *recs { defer record.Release() for i := 0; i < int(record.NumRows()); i++ { assertEqualF(t, record.Column(0).(*array.String).Value(i), "abc") rowCount++ } } } } } err := sfRows.NextResultSet() assertErrIsE(t, err, io.EOF) }) } func TestQueryArrowStream(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { numrows := 50000 query := fmt.Sprintf(selectRandomGenerator, numrows) loader, err := sct.sc.QueryArrowStream(sct.sc.ctx, query) assertNilF(t, err) if loader.TotalRows() != int64(numrows) { t.Errorf("total numrows did not match expected, wanted %v, got %v", numrows, loader.TotalRows()) } batches, err := loader.GetBatches() assertNilF(t, err) assertTrueF(t, len(batches) > 0, "should have at least one batch") assertTrueF(t, len(loader.RowTypes()) > 0, "should have row types") }) } func TestQueryArrowStreamDescribeOnly(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { numrows := 50000 query := fmt.Sprintf(selectRandomGenerator, numrows) loader, err := sct.sc.QueryArrowStream(WithDescribeOnly(sct.sc.ctx), query) assertNilF(t, err, "failed to run query") if loader.TotalRows() != 0 { t.Errorf("total numrows did not match expected, wanted 0, got %v", loader.TotalRows()) } if len(loader.RowTypes()) != 2 { t.Errorf("rowTypes length did not match expected, wanted 2, got %v", len(loader.RowTypes())) } }) } func TestRetainChunkWOHighPrecision(t *testing.T) { runDBTest(t, func(dbt *DBTest) { var rows driver.Rows var err error err = dbt.conn.Raw(func(connection any) error { rows, err = connection.(driver.QueryerContext).QueryContext(ia.EnableArrowBatches(context.Background()), "select 0", nil) return err }) assertNilF(t, err, "error running select 0 query") info, err := rows.(ia.BatchDataProvider).GetArrowBatches() assertNilF(t, err, "error getting arrow batch data") assertEqualF(t, len(info.Batches), 1, "should have one batch") records := info.Batches[0].Records assertNotNilF(t, records, "records should not be nil") numRecords := len(*records) assertEqualF(t, numRecords, 1, "should have exactly one record") record := (*records)[0] assertEqualF(t, len(record.Columns()), 1, "should have exactly one column") column := record.Column(0).(*array.Int8) row := column.Len() assertEqualF(t, row, 1, "should have exactly one row") int8Val := column.Value(0) assertEqualF(t, int8Val, int8(0), "value of cell should be 0") }) } func TestQueryArrowStreamMultiStatement(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.mustExec("ALTER SESSION SET ENABLE_FIX_1758055_ADD_ARROW_SUPPORT_FOR_MULTI_STMTS = true", nil) ctx := WithMultiStatement(ia.EnableArrowBatches(sct.sc.ctx), 2) loader, err := sct.sc.QueryArrowStream(ctx, "SELECT 'abc'; SELECT 'abc' UNION SELECT 'def' ORDER BY 1") assertNilF(t, err) assertTrueF(t, len(loader.RowTypes()) > 0, "should have row types") assertTrueF(t, loader.TotalRows() > 0, "should have total rows") }) } func TestQueryArrowStreamMultiStatementForJSONData(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { ctx := WithMultiStatement(ia.EnableArrowBatches(sct.sc.ctx), 2) loader, err := sct.sc.QueryArrowStream(ctx, "SELECT 'abc'; SELECT 'abc'") assertNilF(t, err) assertTrueF(t, loader.TotalRows() > 0, "should return data") }) } ================================================ FILE: ci/_init.sh ================================================ #!/usr/bin/env -e export PLATFORM=$(echo $(uname) | tr '[:upper:]' '[:lower:]') # Use the internal Docker Registry export INTERNAL_REPO=artifactory.ci1.us-west-2.aws-dev.app.snowflake.com/internal-production-docker-snowflake-virtual export DOCKER_REGISTRY_NAME=$INTERNAL_REPO/docker export WORKSPACE=${WORKSPACE:-/tmp} export DRIVER_NAME=go TEST_IMAGE_VERSION=1 declare -A TEST_IMAGE_NAMES=( [$DRIVER_NAME-chainguard-go1_24]=$DOCKER_REGISTRY_NAME/client-$DRIVER_NAME-chainguard-go1.24-test:$TEST_IMAGE_VERSION ) export TEST_IMAGE_NAMES ================================================ FILE: ci/build.bat ================================================ REM Format and Lint Golang driver @echo off setlocal EnableDelayedExpansion echo [INFO] Download tools where golint IF !ERRORLEVEL! NEQ 0 go install golang.org/x/lint/golint@latest where make2help IF !ERRORLEVEL! NEQ 0 go install github.com/Songmu/make2help/cmd/make2help@latest echo [INFO] Go mod go mod tidy go mod vendor FOR /F "tokens=1" %%a IN ('go list ./...') DO ( echo [INFO] Verifying %%a go vet %%a golint -set_exit_status %%a ) ================================================ FILE: ci/build.sh ================================================ #!/bin/bash # # Format, lint and WhiteSource scan Golang driver # set -e set -o pipefail CI_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" cd $CI_DIR/.. make fmt lint ================================================ FILE: ci/container/test_authentication.sh ================================================ #!/bin/bash -e set -o pipefail export AUTH_PARAMETER_FILE=./.github/workflows/parameters_aws_auth_tests.json eval $(jq -r '.authtestparams | to_entries | map("export \(.key)=\(.value|tostring)")|.[]' $AUTH_PARAMETER_FILE) export SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH=./.github/workflows/rsa_keys/rsa_key.p8 export SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH=./.github/workflows/rsa_keys/rsa_key_invalid.p8 export RUN_AUTH_TESTS=true export AUTHENTICATION_TESTS_ENV="docker" export RUN_AUTH_TESTS=true export AUTHENTICATION_TESTS_ENV="docker" go test -v -run TestExternalBrowser* go test -v -run TestClientStoreCredentials go test -v -run TestOkta* go test -v -run TestOauth* go test -v -run TestKeypair* go test -v -run TestEndToEndPat* go test -v -run TestMfaSuccessful ================================================ FILE: ci/container/test_component.sh ================================================ #!/bin/bash set -e set -o pipefail CI_SCRIPTS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" TOPDIR=$(cd $CI_SCRIPTS_DIR/../.. && pwd) cd $TOPDIR cp parameters.json.local parameters.json make test ================================================ FILE: ci/docker/rockylinux9/Dockerfile ================================================ ARG BASE_IMAGE=rockylinux:9 FROM $BASE_IMAGE ARG TARGETARCH # Update all packages first (including glibc) to get latest versions RUN dnf update -y && dnf clean all # Install glibc-devel - it should match the updated glibc version # If there's still a mismatch, try installing an older compatible version RUN dnf install -y --allowerasing --nobest glibc-devel || \ (echo "Direct install failed, checking available versions..." && \ dnf list available glibc-devel | head -5 && \ CURRENT_GLIBC=$(rpm -q glibc --qf '%{VERSION}-%{RELEASE}\n') && \ echo "Current glibc: $CURRENT_GLIBC" && \ dnf install -y --allowerasing --nobest glibc-devel || true) && \ dnf clean all # Install minimal required packages + gcc for CGO (race detection) RUN dnf install -y --allowerasing --nobest \ gcc \ java-11-openjdk \ python3 \ curl \ wget \ jq \ tar \ gzip \ procps-ng \ && dnf clean all # Set Java 11 as the default using environment variables ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk ENV PATH="${JAVA_HOME}/bin:${PATH}" # Accept full Go version as build argument (e.g., GO_VERSION=1.24.2) ARG GO_VERSION # Download and install Go version RUN GOARCH=${TARGETARCH} && \ GO_VERSION_SHORT=$(echo ${GO_VERSION} | cut -d. -f1,2) && \ echo "Installing Go ${GO_VERSION} for ${GOARCH}..." && \ wget -q https://golang.org/dl/go${GO_VERSION}.linux-${GOARCH}.tar.gz -O /tmp/go.tar.gz && \ mkdir -p /usr/local/go${GO_VERSION_SHORT} && \ tar -C /usr/local/go${GO_VERSION_SHORT} --strip-components=1 -xzf /tmp/go.tar.gz && \ rm /tmp/go.tar.gz && \ # Create wrapper script for short version (e.g., go1.24) \ echo "#!/bin/bash" > /usr/local/bin/go${GO_VERSION_SHORT} && \ echo "export GOROOT=/usr/local/go${GO_VERSION_SHORT}" >> /usr/local/bin/go${GO_VERSION_SHORT} && \ echo 'exec $GOROOT/bin/go "$@"' >> /usr/local/bin/go${GO_VERSION_SHORT} && \ chmod +x /usr/local/bin/go${GO_VERSION_SHORT} # Ensure /usr/local/bin is in PATH (should be by default, but making sure) ENV PATH="/usr/local/bin:${PATH}" # Accept user ID as build argument to match host permissions ARG USER_ID=1001 ARG GROUP_ID=1001 # Create user for proper permission testing # Always create "user" user - use requested IDs if available, otherwise auto-assign RUN if ! getent group user >/dev/null 2>&1; then \ (groupadd -g ${GROUP_ID} user 2>/dev/null || groupadd user); \ fi && \ if ! getent passwd user >/dev/null 2>&1; then \ (useradd -u ${USER_ID} -g user -m -s /bin/bash user 2>/dev/null || useradd -g user -m -s /bin/bash user); \ fi && \ mkdir -p /home/user/go && \ chown -R user:user /home/user USER user WORKDIR /home/user/gosnowflake ================================================ FILE: ci/gofix.sh ================================================ #!/usr/bin/env bash set -euo pipefail CI_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" cd "$CI_DIR/.." GOOS_LIST=(linux darwin windows) GOARCH_LIST=(amd64 arm64) # Standard GOOS/GOARCH values — handled by the matrix, not via -tags. # Version tags (go1.X) and toolchain tags (gc, gccgo, ignore) are also excluded. STANDARD_TAGS=( linux darwin windows freebsd openbsd netbsd plan9 solaris aix js wasip1 android ios amd64 arm64 386 arm mips mips64 mipsle mips64le ppc64 ppc64le riscv64 s390x wasm cgo gc gccgo ignore ) ensure_clean_worktree() { if ! git diff --quiet --ignore-submodules -- || \ ! git diff --cached --quiet --ignore-submodules --; then echo "ERROR: working tree is dirty before go fix runs." echo "Run this check from a clean checkout so failures only reflect go fix changes." exit 1 fi } # Automatically discover custom build tags from //go:build lines. # Strips boolean operators and negations, deduplicates, then removes # standard tags and go1.X version constraints. discover_custom_tags() { while IFS= read -r tag; do # Skip go1.X version tags [[ "$tag" =~ ^go[0-9] ]] && continue # Skip standard GOOS/GOARCH/toolchain tags skip=false for std in "${STANDARD_TAGS[@]}"; do [[ "$tag" == "$std" ]] && skip=true && break done $skip || echo "$tag" done < <( git grep -h '//go:build' -- '*.go' \ | sed 's|//go:build||g' \ | tr '!&|() \t' '\n' \ | grep -v '^$' \ | sort -u ) } ensure_clean_worktree CUSTOM_TAGS=() while IFS= read -r tag; do CUSTOM_TAGS+=("$tag") done < <(discover_custom_tags) TAGS_LIST=("" "${CUSTOM_TAGS[@]}") TOTAL=$(( ${#GOOS_LIST[@]} * ${#GOARCH_LIST[@]} * ${#TAGS_LIST[@]} + ${#TAGS_LIST[@]} )) RUN=0 echo "Discovered custom build tags: ${CUSTOM_TAGS[*]:-none}" echo "Running go fix across all OS/arch/tag combinations (CGO_ENABLED=0)..." for os in "${GOOS_LIST[@]}"; do for arch in "${GOARCH_LIST[@]}"; do for tags in "${TAGS_LIST[@]}"; do RUN=$(( RUN + 1 )) tag_flag="" tag_label="(no tags)" if [[ -n "$tags" ]]; then tag_flag="-tags=$tags" tag_label="tags=$tags" fi echo " [$RUN/$TOTAL] CGO_ENABLED=0 GOOS=$os GOARCH=$arch $tag_label" # "no cgo types" is a harmless warning from go/packages when it cannot # invoke the cgo preprocessor (cross-compilation, no C toolchain, etc.). # No go fix fixer depends on cgo type information, so suppress the noise. CGO_ENABLED=0 GOOS="$os" GOARCH="$arch" go fix $tag_flag ./... \ 2> >(grep -v "^go fix: warning: no cgo types:" >&2) done done done # Run cgo-enabled passes on the native target so that files with # `import "C"` (excluded when CGO_ENABLED=0) are also checked. # Cross-GOOS/GOARCH is not needed here because cgo requires a # C cross-compiler that is not generally available. echo "Running go fix with CGO_ENABLED=1 (native target)..." for tags in "${TAGS_LIST[@]}"; do RUN=$(( RUN + 1 )) tag_flag="" tag_label="(no tags)" if [[ -n "$tags" ]]; then tag_flag="-tags=$tags" tag_label="tags=$tags" fi echo " [$RUN/$TOTAL] CGO_ENABLED=1 (native) $tag_label" CGO_ENABLED=1 go fix $tag_flag ./... done echo "Checking for uncommitted changes..." if ! git diff --exit-code; then echo "" echo "ERROR: go fix produced changes." echo "Run 'ci/gofix.sh' locally and commit the result." exit 1 fi echo "All files are up to date." ================================================ FILE: ci/image/Dockerfile ================================================ FROM artifactory.int.snowflakecomputing.com/development-chainguard-virtual/snowflake.com/go:1.24.0-dev USER root RUN apk update && apk add python3 python3-dev jq aws-cli gosu py3-pip RUN python3 -m ensurepip RUN pip install -U snowflake-connector-python # workspace RUN mkdir -p /home/user && \ chmod 777 /home/user WORKDIR /mnt/host # entry point COPY scripts/entrypoint.sh /usr/local/bin/entrypoint.sh RUN chmod +x /usr/local/bin/entrypoint.sh ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] ================================================ FILE: ci/image/build.sh ================================================ #!/usr/bin/env bash -e # # Build Docker images # set -o pipefail THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" source $THIS_DIR/../_init.sh for name in "${!TEST_IMAGE_NAMES[@]}"; do docker build \ --platform linux/amd64 \ --file $THIS_DIR/Dockerfile \ --label snowflake \ --label $DRIVER_NAME \ --tag ${TEST_IMAGE_NAMES[$name]} . done ================================================ FILE: ci/image/scripts/entrypoint.sh ================================================ #!/bin/bash -ex # Add local user # Either use the LOCAL_USER_ID if passed in at runtime or # fallback USER_ID=${LOCAL_USER_ID:-9001} echo "Starting with UID : $USER_ID" adduser -s /bin/bash -u $USER_ID -h /home/user -D user export HOME=/home/user mkdir -p /home/user/.cache chown user:user /home/user/.cache exec gosu user "$@" ================================================ FILE: ci/image/update.sh ================================================ #!/usr/bin/env bash -e # # Build Docker images # set -o pipefail THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" source $THIS_DIR/../_init.sh for image in $(docker images --format "{{.ID}},{{.Repository}}:{{.Tag}}" | grep "artifactory.ci1.us-west-2.aws-dev.app.snowflake.com" | grep "client-$DRIVER_NAME"); do target_id=$(echo $image | awk -F, '{print $1}') target_name=$(echo $image | awk -F, '{print $2}') for name in "${!TEST_IMAGE_NAMES[@]}"; do if [[ "$target_name" == "${TEST_IMAGE_NAMES[$name]}" ]]; then echo $name docker_hub_image_name=$(echo ${TEST_IMAGE_NAMES[$name]/$DOCKER_REGISTRY_NAME/snowflakedb}) set -x docker tag $target_id $docker_hub_image_name set +x docker push "${TEST_IMAGE_NAMES[$name]}" fi done done ================================================ FILE: ci/scripts/.gitignore ================================================ wiremock-standalone-*.jar ================================================ FILE: ci/scripts/README.md ================================================ # Refreshing wiremock test cert Password for CA is `password`. ```bash openssl x509 -req -in wiremock.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out wiremock.crt -days 365 -sha256 -extfile wiremock.v3.ext openssl pkcs12 -export -out wiremock.p12 -inkey wiremock.key -in wiremock.crt ``` # Refreshing ECDSA cert ```bash openssl x509 -req -in wiremock-ecdsa.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out wiremock-ecdsa.crt -days 365 -sha256 -extfile wiremock.v3.ext openssl pkcs12 -export -inkey wiremock-ecdsa.key -in wiremock-ecdsa.crt -out wiremock-ecdsa.p12 ``` ================================================ FILE: ci/scripts/ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIF1zCCA7+gAwIBAgIUXh8f8hI5mKqCrUJaDn0zF6qGmw0wDQYJKoZIhvcNAQEL BQAwezELMAkGA1UEBhMCUEwxFDASBgNVBAgMC01hem93aWVja2llMQ8wDQYDVQQH DAZXYXJzYXcxEjAQBgNVBAoMCVNub3dmbGFrZTEQMA4GA1UECwwHRHJpdmVyczEf MB0GA1UEAwwWU25vd2ZsYWtlIHRlc3QgUm9vdCBDQTAeFw0yNTAzMDUwOTQ0MTha Fw0zNTAzMDMwOTQ0MThaMHsxCzAJBgNVBAYTAlBMMRQwEgYDVQQIDAtNYXpvd2ll Y2tpZTEPMA0GA1UEBwwGV2Fyc2F3MRIwEAYDVQQKDAlTbm93Zmxha2UxEDAOBgNV BAsMB0RyaXZlcnMxHzAdBgNVBAMMFlNub3dmbGFrZSB0ZXN0IFJvb3QgQ0EwggIi MA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCW0bhevdDp+6S3eIqEAWvFJ66M ST3WcvYUwdEILGRHyjYT34R2dM2HsmJ8NUA17NFpnWIRbv+f8oKFec90dDfKOdzQ vZmiHHun0zYLOf/QE0wj6rtB9zcn8Skwio7f9BQAed9Krovb6/f5tfRMzhDqsk6u Ut+ra2INrA4apAEaw1hZVMN8htkH+M7GSha4hLIM+HOSmBt8pulxlwVFaqpvwZR6 8ettpR9lX3PXFP2s09rY3Pq2PfB6JNF9qmMZzqlgr4qI0HKu5VTTSL3eWmJiZmVb mplISSzL7kKjPoBXLeNJTRtkfO1XKBvDXrNfnfexIlv8lJ9eCVaHaHLw+qgJNq3v TR/BbmrfroLfdpzW2DlF9PDNEookrri2oZyky2DwGklyH5DsUU5T5xTk+eOHsSvB JQEBrl9JCEhWNgVCgzPcQ9Ma7PaIaKw9SQAXWDFd5DLzAZ7Q5dHXy82k942Cp6kZ O6/s9SnhHPQQZg4H4ruqGuy1CdsOvd9ZpCRYUKXZoZYcEidqLRAb+rYCsf8dWiMn Qvru0/V18upRsK9BCgRAQcP0R//HXBH199nqGuCnPCGgRIiRfwawyp/C5rXCb0BN eYfBhdvdnd144CgvHq5tsAHjdw7yhP87zF6Wa+bKThfihfK/LKpIwVLRnN/e6Nea uWSu1Ns+6aywd5MBNwIDAQABo1MwUTAdBgNVHQ4EFgQU0GVyoh2s3w5Ka8ynllvA pHtFh6AwHwYDVR0jBBgwFoAU0GVyoh2s3w5Ka8ynllvApHtFh6AwDwYDVR0TAQH/ BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAgEANPpuw7bno0cTkaY0CA+0YsHf5r8T lSSNUtvREGudH09gPUmVFnU7MMNe+q6gOFkPIl+Mdbj/loaN8eNeZ3OO84VjbVvR 2MtuQti7OcxhptUG9YkS6BeW/ZIp4QGYthDByg5Kc0Wf8mkNqCWuXYnQK7zyTqIM 37TmPZMfD0+ck5Nc5r3S1n2xH0sTTwKjhw54OUpDxxfXARkdCg0u7wJlm/kxiUiA rhw9fXVVkeLh1J8sRIyXsLdJBDjDhVOoz/lBCgEUYJ0R/icUxl7jGt7XEXqUY4ER xYb8oVdEmUPYRR5m7Q5076HKCLXNY/Jn5BvtfaPCs288jXWSidY9B71baaBzeN6C Y+1Yh9m/+SVz+g+5/PAm0kdzvWytewi53GDnG6P1peJi3TZOMhL+WU1gv3JSNiZ5 +JbmQIM3jM22QJeElMA+tavB+Hm1PDIqgfVsvOOmpd/npKUc8AlNDA9/sNA9h0V7 0ldbQoPXVh81+7O+uDMrN3x8naCOAdsAaz4mHEBlhSn55snvbeXSkEw2oVtzt9fB qscc02cN/9gf0UdIXsyDpL0ZL/rkjbmauE5QC45WKRc87cZYH8OhnROg+A2Dr3bk 0LIZdOSbsZmVoyKWDO5P2p3l3z4x3D1P+KBWIxx/fCtdIvHg1EFHmn0SHuyoxzsO gVB+n3ggLTYRR0s= -----END CERTIFICATE----- ================================================ FILE: ci/scripts/ca.key ================================================ -----BEGIN ENCRYPTED PRIVATE KEY----- MIIJtTBfBgkqhkiG9w0BBQ0wUjAxBgkqhkiG9w0BBQwwJAQQR+n/YtOhd0h7AmwV GU9glAICCAAwDAYIKoZIhvcNAgkFADAdBglghkgBZQMEASoEED/jsTIXZ/aJZt1B 0sr61w4EgglQKHQFRJpHyd4I84WNu87VANMHLwxApMnsag9ccKEIDCiMOpESkiE8 vh7gE+MkeZUCjXTsFslz00u0ZTeSGsE6BlasS0FnITzkM+3y4HiW4ezC8EU15hd9 Acs4n7cNMPPPFvnUKtE4gye9DqdxUnEYEcT+fasHhIhmzpn/WaBaBmdv7pZPz2MX AblwJ860qO1W33+nn6bGcCokvNC1GIePbh1DdsaSJvLy3zljOeqO2jyp74n1DRnH 2XWR9e+IYa68kpuHNrosHNSkOmkxb+zTQeL4rFgeQi6gdnJdMzKSyrGKg2/feVUF K9QlqtJuest2SDKwmECO/nTKdMTicv3CnMuwXURaggceFLHE0ea7AdoZc2gSx2Zr ePjqKlKMF0lYirA6ZTpL1FLptFju4IS2rxI6uKf21eMSM8sQ7ui97IELZhKdykwo PEmj7d0aO5J7OaatGtNreVpSarYdSO4rfZW/iGbRda74NJnH3Wiy958UHcMob+45 MEQtww3NoLZbSbdfvn4+xoLZIzqm6uu4avsb952imq4UxwgEBcVaDGjeGJF34yuC uYXQqQRTjSjD9Cru579gW6wZXzW3G9hsuC66f686CvaE3nJK2+OkRtSYogSfk2lq O8G9UFQ7tGtUrsXWIt6+iUWRv1PA7OFIwXjxumoMFsMK2xxI9UNXuIUeC7qWAeOB tlXCygdrYBoZekfjM5yeWRCC4KZSEnD4DDXR+f40GJU9cHIjSTiBbWFHDIgLm49y 8JdtzRZKMnxUt2jetEPoTMCIzsbHYK4D5+SkQQ2S4ti9qdmqFTW+E9vDDOHMrmfZ cvbMTOCBrr2AP5itXcNs2m0tyXYl4cWR/3c8owFvivZljav+TARxhYzZRUXX7Ozv Ht20/tJNtofWp4vd9QyrWYo06krgSl+P1EWpHQlpc9zb8AMjuCN8k80/eK5uF2Dd uTQa3+6PIeL/jf0vstDSbhAu5C2cFOF1REifaBtgsXDgnAUaemMgNBcA211frzcT Fbp7p1qoQ7jwcYyq1khdk3W2qLpNTJILgdQaeLEGFUzDGmKBlbloBiW+43bCbTII mm7SuY7rLcQQc1REfcLEkZo+KFRfZkLt8gd1bUMTZ2XdGw22P2BfQFFTvSCm9hrJ GMmUnT7W9fb6vPl1QoGlqrG+6o+LAGaPx/wlrd6Ut19YqRaZmYY8n/kqEGllo2eH 5wA4sO9OjXcIK3BHoeZDvdvEueqq4ynEWohW21M9w8HptxaeguiSIWaXpxeNSKOx +H0dfGG+s1MkQVMxpFT/WzmQXWM15ESy7SLYbj4qKj5M1cnfSTk5e67rrZYaXDoL qKx1Ta3ol8KtqJmHs2wPSlrg5hi7iwl+mz1Q+er1NmUitm3+9nDHfBCqKPIA3Nsn ffGaaRvRp/nkidgDewjCh5QxbzeHeqYqgwn6MA+ybKbVmLeceS/8djVoRCBlUH4u s94lcruWkEfhx0dflOjbNqctfGIkDDX5OBwab+eaPswFgg99ijJ2TcuvAxNTSTrs efd3KXyD0wWvLvJfBRxfenLzrEt3zbN2tNah+guR48D6dM3T/g+U1W9MzmvToo5L pPtOjL7xvb6lrkzfemmI4yVex8/otNcpLfVMlY16twAjaybaBR2Aoq9rEr8j5Kqa TP+6H3krV36Vbed+6aVFfF4CsraxVzUUHXyGaV9B9pwpubvaxHjqMuUdHm1LfsNL VcDXow4HMzOdnOXQ7CA/5d5VNG0bxqnhjPor3sL1mvdBz/JdLmlxnn56q9v+09d1 CnSQoAPyj2ZFMLbTJgiBY23ovfoV2PU7fQwtZOKG4xuJDgRIabrJchsRqwjw7Niu ucKCEFYPIc+MZCAQg1CxZ7/JofEgbiBAE6xwDwbycSbyLhRnEafEo76KPwp87Uck rzxrgeDEhPviXSmguidsrxMjJnkOeTS1ZoskbOdfQ0npdqTIscS80u705RuVzc7P M6OPLsuLuxII/lciKlDo3DuoqvRSrlTPkF1Kmp7lwN0AyqSkUcgXdNRQBeE5fGh3 m+Jdj2WMX5Rj0TVMos66uImvB3/b0MrOtZivmJ6Ed9oNQZg5msYCpxhzrd2A+AOQ sE2alhC3HtPPHjiXVev2i7CcGyvlBTApFT5qfOg605zT3h3ObT1fXR10a2SqwiHC KWfQAQPe+fs6OMSJNHgi8DjEa4YtJ498zW93vLvHu+X7I2mnQLbf+eJ2DBiB39eo 2oWj4R2SBK5JD6cc+Uq2pmdhTxLj/9KQ2MmWA6HIYv15qBPwYUh9bIjZ0/H3gDHH +BfLmfe2MSDaWKx3z+KhTH05fLI14QFY5uSogTvlUIWIR24FMU1SV6J00lQ8dujG cE/ayVRVLGvZN7VUynZ0mcmB7eowZBjblJQMwmxeUdmbjc/g5otAvBx5V+Xlio+4 z8uPUc/8D9A8+ja5NzXNeZhiPcvzU81L1LOva2hvB24w/E+2qt8TLs9Bc4FO/dVP BClriniw9CTbHFki4OFUVdvJkvXEnOWJGJzk/l2IuTs3Nm7ghyU9Z4ZV64q8um/A tn/aAvIt+v++IJPaT6/aHLVyJyLK45xP6mTdKNQkn3P2c2CsxsdXz8dT9nhoGBin c/WlbSQCrRADKYYJgpc8irZfZoy6gKldT441enzz+C8jUb4btaDh6dZftb90CHsl BplPKvHeu6kld4mVaQadaEZrfmX21SS7RJpbaNIZ5+HgRMwTSU51uie0iUj1mmZ0 Tyk7YI+PHzuwGEFRHPw3StwNy79ihmamq2ef2UKK4QjDzW/4SCbRDy3WI0AzYQOR oLgbcSB9dRYPdW8sHay8EQ+8jrnklvc6iWsu4zE1+ptZnCMgSvv2iqCKW5MELx3x Z9NlfcmbQIaZCN6LKZaim/L8rK/bocB5yM4teApYKvPOiTXh/9csmrrZccSXg7Ct sIiZnqA0VW3fWN8EtFhhZUGv/q7VEyi/Iz+j9RrFaZDZ/pM1uQvvGWqR2AdTppDj RuUPEma0xt5SpGnttERsH5MUV8YGgRVuoiLg0P15yJR7mNy/VpPJhxWWKG2x8R+u 75QzlRR12rg0HfbqA+d1ADNbKWTJEAY1hks2tA+DOWPK/4/cEOF7bJIxZY2MJgAz 8RhXbAQaxpX+cbPbHdMdvYKWpFi+GBNYXCIOoj3l79ATkCDguynmjrk= -----END ENCRYPTED PRIVATE KEY----- ================================================ FILE: ci/scripts/ca.srl ================================================ 54587BDD05D4BE6A6D8852CA7FDB421189EA1C6D ================================================ FILE: ci/scripts/execute_tests.sh ================================================ #!/bin/bash # # Build and Test Golang driver # set -e set -o pipefail CI_SCRIPTS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" TOPDIR=$(cd $CI_SCRIPTS_DIR/../.. && pwd) eval $(jq -r '.testconnection | to_entries | map("export \(.key)=\(.value|tostring)")|.[]' $TOPDIR/parameters.json) env | grep SNOWFLAKE | grep -v PASS | grep -v SECRET | sort cd $TOPDIR go install github.com/jstemmer/go-junit-report/v2@latest if [[ "$HOME_EMPTY" == "yes" ]] ; then export GOCACHE=$HOME/go-build export GOMODCACHE=$HOME/go-modules export HOME= fi COVPKGS=$(go list ./... | grep -v '/cmd/' | tr '\n' ',' | sed 's/,$//') if [[ "$SEQUENTIAL_TESTS" == "true" ]] ; then # Test each package separately to avoid buffering (slower but real-time output) PACKAGES=$(go list ./...) if [[ -n "$JENKINS_HOME" ]]; then export WORKSPACE=${WORKSPACE:-/mnt/workspace} ( for pkg in $PACKAGES; do # Convert full package path to relative path pkg_path=$(echo $pkg | sed "s|^github.com/snowflakedb/gosnowflake/v2||" | sed "s|^/||") if [[ -z "$pkg_path" ]]; then pkg_path="." else pkg_path="./$pkg_path" fi echo "=== Testing package: $pkg_path ===" >&2 GODEBUG=$TEST_GO_DEBUG go test $GO_TEST_PARAMS -timeout 90m -race -v "$pkg_path" done ) | /home/user/go/bin/go-junit-report -iocopy -out $WORKSPACE/junit-go.xml else set +e FAILED=0 ( for pkg in $PACKAGES; do pkg_path=$(echo $pkg | sed "s|^github.com/snowflakedb/gosnowflake/v2||" | sed "s|^/||") if [[ -z "$pkg_path" ]]; then pkg_path="." else pkg_path="./$pkg_path" fi echo "=== Testing package: $pkg_path ===" >&2 # Note: -coverprofile only works with single package, use -coverpkg for multiple GODEBUG=$TEST_GO_DEBUG go test $GO_TEST_PARAMS -timeout 90m -race -coverpkg="$COVPKGS" -coverprofile="${pkg_path//\//_}_coverage.txt" -covermode=atomic -v "$pkg_path" if [[ $? -ne 0 ]]; then FAILED=1 echo "[ERROR] Package $pkg_path tests failed" >&2 fi done # Merge coverage files go install github.com/wadey/gocovmerge@latest gocovmerge *_coverage.txt > coverage.txt rm -f *_coverage.txt exit $FAILED ) | tee test-output.txt TEST_EXIT_CODE=${PIPESTATUS[0]} cat test-output.txt | go-junit-report > test-report.junit.xml exit $TEST_EXIT_CODE fi else # Test all packages with ./... (parallel, faster, but buffered per package) if [[ -n "$JENKINS_HOME" ]]; then export WORKSPACE=${WORKSPACE:-/mnt/workspace} GODEBUG=$TEST_GO_DEBUG go test $GO_TEST_PARAMS -timeout 90m -race -v ./... | /home/user/go/bin/go-junit-report -iocopy -out $WORKSPACE/junit-go.xml else set +e GODEBUG=$TEST_GO_DEBUG go test $GO_TEST_PARAMS -timeout 90m -race -coverpkg="$COVPKGS" -coverprofile=coverage.txt -covermode=atomic -v ./... | tee test-output.txt TEST_EXIT_CODE=${PIPESTATUS[0]} cat test-output.txt | go-junit-report > test-report.junit.xml exit $TEST_EXIT_CODE fi fi ================================================ FILE: ci/scripts/hang_webserver.py ================================================ #!/usr/bin/env python3 import sys from http.server import BaseHTTPRequestHandler,HTTPServer from socketserver import ThreadingMixIn import threading import time import json class HTTPRequestHandler(BaseHTTPRequestHandler): invocations = 0 def do_POST(self): if self.path.startswith('/reset'): print("Resetting HTTP mocks") HTTPRequestHandler.invocations = 0 self.__respond(200) elif self.path.startswith('/invocations'): self.__respond(200, body=str(HTTPRequestHandler.invocations)) elif self.path.startswith('/ocsp'): print("ocsp") self.ocspMocks() elif self.path.startswith('/session/v1/login-request'): self.authMocks() def ocspMocks(self): if self.path.startswith('/ocsp/403'): self.send_response(403) self.send_header('Content-Type', 'text/plain') self.end_headers() elif self.path.startswith('/ocsp/404'): self.send_response(404) self.send_header('Content-Type', 'text/plain') self.end_headers() elif self.path.startswith('/ocsp/hang'): print("Hanging") time.sleep(300) self.send_response(200, 'OK') self.send_header('Content-Type', 'text/plain') self.end_headers() else: self.send_response(200, 'OK') self.send_header('Content-Type', 'text/plain') self.end_headers() def authMocks(self): content_length = int(self.headers.get('content-length', 0)) body = self.rfile.read(content_length) jsonBody = json.loads(body) if jsonBody['data']['ACCOUNT_NAME'] == "jwtAuthTokenTimeout": HTTPRequestHandler.invocations += 1 if HTTPRequestHandler.invocations >= 3: self.__respond(200, body='''{ "data": { "token": "someToken" }, "success": true }''') else: time.sleep(2000) self.send_response(200) else: print("Unknown auth request") self.send_response(500) def __respond(self, http_code, content_type='application/json', body=None): print("responding:", body) self.send_response(http_code) self.send_header('Content-Type', content_type) self.end_headers() if body != None: responseBody = bytes(body, "utf-8") self.wfile.write(responseBody) do_GET = do_POST class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): allow_reuse_address = True def shutdown(self): self.socket.close() HTTPServer.shutdown(self) class SimpleHttpServer(): def __init__(self, ip, port): self.server = ThreadedHTTPServer((ip,port), HTTPRequestHandler) def start(self): self.server_thread = threading.Thread(target=self.server.serve_forever) self.server_thread.daemon = True self.server_thread.start() def waitForThread(self): self.server_thread.join() def stop(self): self.server.shutdown() self.waitForThread() if __name__=='__main__': if len(sys.argv) != 2: print("Usage: python3 {} PORT".format(sys.argv[0])) sys.exit(2) PORT = int(sys.argv[1]) server = SimpleHttpServer('localhost', PORT) print('HTTP Server Running on PORT {}..........'.format(PORT)) server.start() server.waitForThread() ================================================ FILE: ci/scripts/login_internal_docker.sh ================================================ #!/bin/bash -e # # Login the Internal Docker Registry # if [[ -z "$GITHUB_ACTIONS" ]]; then echo "[INFO] Login the internal Docker Registry" if ! docker login $INTERNAL_REPO; then echo "[ERROR] Failed to connect to the Artifactory server. Ensure 'sf artifact oci auth' has been run." exit 1 fi else echo "[INFO] No login the internal Docker Registry" fi ================================================ FILE: ci/scripts/run_wiremock.sh ================================================ #!/usr/bin/env bash SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" cd $SCRIPT_DIR if [[ "$1" == "--ecdsa" || "$WIREMOCK_ENABLE_ECDSA" == "true" ]] ; then echo "Using ecliptic curves" pfxFile="$SCRIPT_DIR/wiremock-ecdsa.p12" else echo "Using RSA" pfxFile="$SCRIPT_DIR/wiremock.p12" fi if [ ! -f "$SCRIPT_DIR/wiremock-standalone-3.11.0.jar" ]; then curl -O https://repo1.maven.org/maven2/org/wiremock/wiremock-standalone/3.11.0/wiremock-standalone-3.11.0.jar fi java -jar "$SCRIPT_DIR/wiremock-standalone-3.11.0.jar" --port ${WIREMOCK_PORT:=14355} --https-port ${WIREMOCK_HTTPS_PORT:=13567} --https-keystore "$pfxFile" --keystore-type PKCS12 --keystore-password password ================================================ FILE: ci/scripts/setup_connection_parameters.sh ================================================ #!/bin/bash -e # # Set connection parameters # CI_SCRIPTS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" if [[ "$CLOUD_PROVIDER" == "AZURE" ]]; then PARAMETER_FILE=parameters_azure_golang.json.gpg PRIVATE_KEY=rsa_key_golang_azure.p8.gpg elif [[ "$CLOUD_PROVIDER" == "GCP" ]]; then PARAMETER_FILE=parameters_gcp_golang.json.gpg PRIVATE_KEY=rsa_key_golang_gcp.p8.gpg else PARAMETER_FILE=parameters_aws_golang.json.gpg PRIVATE_KEY=rsa_key_golang_aws.p8.gpg fi gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $CI_SCRIPTS_DIR/../../parameters.json $CI_SCRIPTS_DIR/../../.github/workflows/$PARAMETER_FILE gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $CI_SCRIPTS_DIR/../../rsa-2048-private-key.p8 $CI_SCRIPTS_DIR/../../.github/workflows/rsa-2048-private-key.p8.gpg gpg --quiet --batch --yes --decrypt --passphrase="$GOLANG_PRIVATE_KEY_SECRET" --output $CI_SCRIPTS_DIR/../../.github/workflows/parameters/public/rsa_key_golang.p8 $CI_SCRIPTS_DIR/../../.github/workflows/parameters/public/$PRIVATE_KEY ================================================ FILE: ci/scripts/setup_gpg.sh ================================================ #!/bin/bash # GPG setup script for creating unique GPG home directory setup_gpg_home() { # Create unique GPG home directory export GNUPGHOME="${THIS_DIR}/.gnupg_$$_$(date +%s%N)_${BUILD_NUMBER:-}" mkdir -p "$GNUPGHOME" chmod 700 "$GNUPGHOME" cleanup_gpg() { if [[ -n "$GNUPGHOME" && -d "$GNUPGHOME" ]]; then rm -rf "$GNUPGHOME" fi } trap cleanup_gpg EXIT } setup_gpg_home ================================================ FILE: ci/scripts/wiremock-ecdsa-pub.key ================================================ -----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEX3j37DbAKoO6Cwn0TsoMcsVXEF52 lDa2tEHX2kMoxLExE4cgBipPyHgwNEblfAbaA1eC03fytJZw0wd08GvA+Q== -----END PUBLIC KEY----- ================================================ FILE: ci/scripts/wiremock-ecdsa.crt ================================================ -----BEGIN CERTIFICATE----- MIID/jCCAeagAwIBAgIUVFh73QXUvmptiFLKf9tCEYnqHG0wDQYJKoZIhvcNAQEL BQAwezELMAkGA1UEBhMCUEwxFDASBgNVBAgMC01hem93aWVja2llMQ8wDQYDVQQH DAZXYXJzYXcxEjAQBgNVBAoMCVNub3dmbGFrZTEQMA4GA1UECwwHRHJpdmVyczEf MB0GA1UEAwwWU25vd2ZsYWtlIHRlc3QgUm9vdCBDQTAeFw0yNjAzMDYxODQ4MjJa Fw0yNzAzMDYxODQ4MjJaMHkxCzAJBgNVBAYTAlBMMRQwEgYDVQQIDAtNYXpvd2ll Y2tpZTEPMA0GA1UEBwwGV2Fyc2F3MRIwEAYDVQQKDAlTbm93Zmxha2UxGzAZBgNV BAsMEkRldmVsb3BlciBwbGF0Zm9ybTESMBAGA1UEAwwJbG9jYWxob3N0MCowBQYD K2VwAyEAGLQr+l2G3bxeA8oXH6epvuZ1ZLY381WEwehREgaYpTyjdjB0MB8GA1Ud IwQYMBaAFNBlcqIdrN8OSmvMp5ZbwKR7RYegMAkGA1UdEwQCMAAwCwYDVR0PBAQD AgTwMBoGA1UdEQQTMBGHBH8AAAGCCWxvY2FsaG9zdDAdBgNVHQ4EFgQU/9pFFL7e 4Fr4IzzELxg3Y3nWns4wDQYJKoZIhvcNAQELBQADggIBAIE6g+wbA5JIWaU+atNL Qr62D+a1IlB4kE+Ysaz5iMCDNKIfbNe5/Mrgzbuc8iiRCz2QicPHEtS5OC39jeKM tX1JQGfA9G8P+IEX6POPgSYbBjO2uj9qdATFF3bjHtB9KPe/lF34rWD5v8ajMoOY oosRM+wOMT/H08AOmPRe3T1qVVCk9G87qGRw2cvpyoOh46dzcsaJ/4QNAMzp7PY1 yn8h8VRJoqkSHf/du1ACoqcmsfF26fMmVRjGmiMoIteIr/8CAFzc9yMXXTq4/F2P DT1XoWeQopdmWTkxS2DCiStxYWEYAVURzg4C1zeq3/KC48oZrNhNylkaHpsHx5x6 MxC8RoVN2zA8GZEsIVdRXi/gl8DjAwLieTwIErtczaMgNwmX1qU+qBoXAzZ4bEJT UuwfO/LcUywX6TZ91bO/tVsLOH2vNWjeQI/ewqUjpnPxqx9WG1QLaQ2wu2oqKBQQ YPZzpezG10tThgTkNyPlFyV0pT2YjfruDovC7EBGkaO1/ZheNvSbsZbuXJKDecr6 LhrAPh95V8mVUYCjI8bQK+K+u5feBN3pXtY9hfltcJ2611Xfv7Tm8R7JLgGQSlim 7D9i4/XWLKVfRtCbQLabgGzc46Kk8W5Ae8Ie1UrdhetehJPAMO/v8rOKnWqR3HxR i+s79C6kuYGYmRblr9LJ82pn -----END CERTIFICATE----- ================================================ FILE: ci/scripts/wiremock-ecdsa.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIH5MIGsAgEAMHkxCzAJBgNVBAYTAlBMMRQwEgYDVQQIDAtNYXpvd2llY2tpZTEP MA0GA1UEBwwGV2Fyc2F3MRIwEAYDVQQKDAlTbm93Zmxha2UxGzAZBgNVBAsMEkRl dmVsb3BlciBwbGF0Zm9ybTESMBAGA1UEAwwJbG9jYWxob3N0MCowBQYDK2VwAyEA GLQr+l2G3bxeA8oXH6epvuZ1ZLY381WEwehREgaYpTygADAFBgMrZXADQQAQX4XJ I6PxjoC2RofZayHk+ud2oyXdLE1M9NarUY6+2lKntFIIhn/s1F+4UK0cnDB40vJp MXV6quLOTF06azUM -----END CERTIFICATE REQUEST----- ================================================ FILE: ci/scripts/wiremock-ecdsa.key ================================================ -----BEGIN PRIVATE KEY----- MC4CAQAwBQYDK2VwBCIEICQI1T3B7DZ45py/Oa4fEjhdz3kMDlRFXvY8vv9DA5Io -----END PRIVATE KEY----- ================================================ FILE: ci/scripts/wiremock.crt ================================================ -----BEGIN CERTIFICATE----- MIIF7TCCA9WgAwIBAgIUVFh73QXUvmptiFLKf9tCEYnqHGwwDQYJKoZIhvcNAQEL BQAwezELMAkGA1UEBhMCUEwxFDASBgNVBAgMC01hem93aWVja2llMQ8wDQYDVQQH DAZXYXJzYXcxEjAQBgNVBAoMCVNub3dmbGFrZTEQMA4GA1UECwwHRHJpdmVyczEf MB0GA1UEAwwWU25vd2ZsYWtlIHRlc3QgUm9vdCBDQTAeFw0yNjAzMDYxMzE0MDZa Fw0yNzAzMDYxMzE0MDZaMG4xCzAJBgNVBAYTAlBMMRQwEgYDVQQIDAtNYXpvd2ll Y2tpZTEPMA0GA1UEBwwGV2Fyc2F3MRIwEAYDVQQKDAlTbm93Zmxha2UxEDAOBgNV BAsMB0RyaXZlcnMxEjAQBgNVBAMMCWxvY2FsaG9zdDCCAiIwDQYJKoZIhvcNAQEB BQADggIPADCCAgoCggIBAMMpVsRRrW7/UFzfb/WfkjF5tKIJBNze/90qC2xheSsq h3yQPPgfQXnSPLTCR0Z0ZEhV5NbiZPlSS5Nl9zD/JwSryFuFAtTrYhOcqBpnzz46 n3bZUHNfC/sD6qNVL43LsyvfKWWBVyxlSpCMmEdgyqvPTRHJ3l3EW8uCBUxHQM35 FxUNpTdc/tFCXVDZgRGUwQ23yRmwGx2HbXN1PEsmJ/yZ/mZg9oIWNUqTWGj6DY8R 8gmf5oXgkjPlu2G6xxb6lo6cAToAWhjBuCVzo7ciCXpaGVxXv4IyksB+xJxjYFll 1CBeYKXw5+UdCjzA04MA8Q+E0TNRRiv74sHYq2egS80+6NByjmHolzd/6nOUo5ed e96Mj5rfOojGn0Omwf8r1B/+aYZcYtOHyN44ZskZnDMv1NGlyn5o0lcn+RJyMi4D +MgwgOEYvDcByp9YG5y6MxAUo3Gexl8cifCGbBRZaL2PNWKhHVB0IKZwvY5WLPMD 0d8pDl5+LrMq/1ra5ObhPhiOdgjpaPuH5lnyTkx0YG9adNsaczPFzzXARHIj3Il7 WuEqBbf5a/iZcKlPOTNhlxhWIYUJ+1qunKXt3mhZx3IVX1pqionSGJkYwNTkWtJl tCzJquaPWmdMBfdtDNoavH5pRnbCtI/DB37gJ3u4VHfqZU2R7hXBkwW22IOiSKjv AgMBAAGjdjB0MB8GA1UdIwQYMBaAFNBlcqIdrN8OSmvMp5ZbwKR7RYegMAkGA1Ud EwQCMAAwCwYDVR0PBAQDAgTwMBoGA1UdEQQTMBGHBH8AAAGCCWxvY2FsaG9zdDAd BgNVHQ4EFgQUn/a/Lb80EZf1PHprSyO+qvRv3y4wDQYJKoZIhvcNAQELBQADggIB AFxnpTGBeUmdeef8N04X0LoUNiDTrhgPnJy5DYhFwfK27wsFHH4uWTf4Fg61VblG QJOhVYkZshvltdVRDr/Y1iAfCvwRlweA43QrXtMnDy+326ig277E2Z1C7K3f7lHS t/vUFR3fmSRdOAzFoJQISgzwL4tFw0wS36lwh6bOYHp/pm7BG4g+Z3ftWw8eUjmv udpupYXG36SflfZWasy4I1fl0mDWIS6eKkR76DqqugBMH1QMprTwr0OjXaWiku6r z3IsMPVnVXeejNNoP/67AfGzEb3FeFGVMl+qg7lL155blga1ph8upWo4k6qsZF7S 4ZlscEaYSZj20ZR5ZN/n8F8d43uqzL0RUbaNyvYS12nnun5XnkfVFa2QJdq/EOV7 dEyp9/GCIazqMf3cNUnQWUaQ/ow6zzL6+2bc5GnjRYps8z2+zyFFUgfINxrcg3K1 T3C2ZNV3lSOwuzlyMD236HgM+Kt7mq2nmiDTlcp7JqrsLr6qzidL8jfnqjG9Jyg4 y6cJzWPKTfVmqsJtfx1YBnIkddh4NYtpUgBGjYkYIRIonZ7eu9fapKKiRguckD4T P1BTd3BzwYqTmNXlxVV2uVhh7mPZo+jghK2HtuUcjsZPbWm2ju8kPmRo83fpBvk7 6OYjoXKwQZxnQSqJ9rPf1fqGepn4kQR6qvM6phVSBs5x -----END CERTIFICATE----- ================================================ FILE: ci/scripts/wiremock.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIIEszCCApsCAQAwbjELMAkGA1UEBhMCUEwxFDASBgNVBAgMC01hem93aWVja2ll MQ8wDQYDVQQHDAZXYXJzYXcxEjAQBgNVBAoMCVNub3dmbGFrZTEQMA4GA1UECwwH RHJpdmVyczESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOC Ag8AMIICCgKCAgEAwylWxFGtbv9QXN9v9Z+SMXm0ogkE3N7/3SoLbGF5KyqHfJA8 +B9BedI8tMJHRnRkSFXk1uJk+VJLk2X3MP8nBKvIW4UC1OtiE5yoGmfPPjqfdtlQ c18L+wPqo1UvjcuzK98pZYFXLGVKkIyYR2DKq89NEcneXcRby4IFTEdAzfkXFQ2l N1z+0UJdUNmBEZTBDbfJGbAbHYdtc3U8SyYn/Jn+ZmD2ghY1SpNYaPoNjxHyCZ/m heCSM+W7YbrHFvqWjpwBOgBaGMG4JXOjtyIJeloZXFe/gjKSwH7EnGNgWWXUIF5g pfDn5R0KPMDTgwDxD4TRM1FGK/viwdirZ6BLzT7o0HKOYeiXN3/qc5Sjl5173oyP mt86iMafQ6bB/yvUH/5phlxi04fI3jhmyRmcMy/U0aXKfmjSVyf5EnIyLgP4yDCA 4Ri8NwHKn1gbnLozEBSjcZ7GXxyJ8IZsFFlovY81YqEdUHQgpnC9jlYs8wPR3ykO Xn4usyr/Wtrk5uE+GI52COlo+4fmWfJOTHRgb1p02xpzM8XPNcBEciPciXta4SoF t/lr+JlwqU85M2GXGFYhhQn7Wq6cpe3eaFnHchVfWmqKidIYmRjA1ORa0mW0LMmq 5o9aZ0wF920M2hq8fmlGdsK0j8MHfuAne7hUd+plTZHuFcGTBbbYg6JIqO8CAwEA AaAAMA0GCSqGSIb3DQEBCwUAA4ICAQBHoiHRzxkLHkWfgq1wbFrVnsHrnALSY+Nl 994fFykF4fDA5eLvfIWmuU5YZwyz+9Bw0SGoefb9RfFxZbQByBglhFbHPEvID1Sw 3ByJPMLccep7lkLd/BfIgyZ7vSyIK3mKY4wSnGqf3eiQeMU57ViP3AL6Q0Uos3Jm jmUWIeEHrSE2HfHREK8ar0xGKTimQymW6P+ecRKQKs7I7aEJL5t3/zp2w+EyxIGC ezP+rtH8QdfDJN3nui+2ljgonvbwrYMJTBJYZ/oOx/msKUF4EO2FT/VJKQsOZnyL s0HXMEEJ9AKlFo9gagZ6ZqxnVYCPoeW8Nfb56YwZ9im2wbo2yaNAFTMaKoH1/2g0 LHZd1vq1sU6xT3V3R+5Iiw4k7u8mx6ietSbwuyOkHkQ+RZf5hZKvdHSymKTuN/e4 40XzGBhcTqs57KHbsiWFBnRFiIZgFq5kbC0G+c927g8XRB9j3xiMjBBwUR0Kp78q bTvAzod0ZhYeltFw63TkNe/yH4RZefseub0eice6Fjmpv0BgjYNP2guCnd3u7KaG H0zYSFHzN00jtDNNs1Jx1drsHZcr6fAOeeUmI9ExsDkt8vyMmpshd+w3LEh/ZVL2 pvvtcut0s24OszF5HCRScxSXv3SSUDX1asRyUHY5STLdK74o+dfqXT+ja+MRJEEh IiE2ITiP8Q== -----END CERTIFICATE REQUEST----- ================================================ FILE: ci/scripts/wiremock.key ================================================ -----BEGIN PRIVATE KEY----- MIIJQQIBADANBgkqhkiG9w0BAQEFAASCCSswggknAgEAAoICAQDDKVbEUa1u/1Bc 32/1n5IxebSiCQTc3v/dKgtsYXkrKod8kDz4H0F50jy0wkdGdGRIVeTW4mT5UkuT Zfcw/ycEq8hbhQLU62ITnKgaZ88+Op922VBzXwv7A+qjVS+Ny7Mr3yllgVcsZUqQ jJhHYMqrz00Ryd5dxFvLggVMR0DN+RcVDaU3XP7RQl1Q2YERlMENt8kZsBsdh21z dTxLJif8mf5mYPaCFjVKk1ho+g2PEfIJn+aF4JIz5bthuscW+paOnAE6AFoYwbgl c6O3Igl6WhlcV7+CMpLAfsScY2BZZdQgXmCl8OflHQo8wNODAPEPhNEzUUYr++LB 2KtnoEvNPujQco5h6Jc3f+pzlKOXnXvejI+a3zqIxp9DpsH/K9Qf/mmGXGLTh8je OGbJGZwzL9TRpcp+aNJXJ/kScjIuA/jIMIDhGLw3AcqfWBucujMQFKNxnsZfHInw hmwUWWi9jzVioR1QdCCmcL2OVizzA9HfKQ5efi6zKv9a2uTm4T4YjnYI6Wj7h+ZZ 8k5MdGBvWnTbGnMzxc81wERyI9yJe1rhKgW3+Wv4mXCpTzkzYZcYViGFCftarpyl 7d5oWcdyFV9aaoqJ0hiZGMDU5FrSZbQsyarmj1pnTAX3bQzaGrx+aUZ2wrSPwwd+ 4Cd7uFR36mVNke4VwZMFttiDokio7wIDAQABAoICAAgrmeCm1A5FOAsQpkeagkH5 /hBD37qTchNt6C6Ft3nm0jyVGUhV8/rH92yl2YVfPWIzM7JfUKozbMs4m0Gnh5hQ IheFblnq73SHZsORkavhmRLJBETgN3MvIHVCuAvv+Ynzp3BYGtsr877bc/XrsnBr lvwQqcjefe1Q0yyfVbI0eb09kKt3BDVPLvLsjX+77N0d0u3Ktp06MeCB3vVScp1w 9k/jl/kC5FZBQZPw1qfPsNoATLlRboLSXPw5bTj5YrDeYnAYMFgVpsJCoMRQ83lL flZPAiB5l4qMLr+mqr5ItLm/hGejZJdDQPjMJc634l+rnXUliOeHKGDEfmCHOxpu N2C8iXJysQJhDGfHvLmNeKdaXgJt+T37W8M8t02oHDECpMwMSOHMlVpxut8DBhpa hz9olGxwp7c2fSemJGiWNUXCfMtkhUl4VLRAqZ7pD91VtmQAi8gAIg15MHIjlGAh EVQZZE1qd0SUxy4nCNYt9L3AhU2I/I8k7cQMKBX0vOrQQvaZmBo5FI3uSejMeNgn MQWQvzR1XIzBeMCv8c5kgRr6C6RPGYzycxO3fP93TfpwY/vehuBwAh+38qYY6Azn zVYqjn5hTnxhH3pCG3ugoqiLSnfrptw/TUVR9GOwMPNwD3QR6Hv57EljLyaaDQho byLkPdKXEQUmFHEoLTWRAoIBAQDmh/yS0gmoWBDDA8/xIzB2a8NM9VfrutoI2HNM cnrQXWDdgjcLM/AAuV3ESyP0+1PFFv5gxCg35fPX+uj24dydsyxCAbFBxCBPvBUC 3Mc2PskEDmFyuYDwxbLItxDgjMZX1kWhCGONV2LOHfxy1itkZ6aWhP4p77/+9oaU 26Uq5mcWMMUV0wWX6IS7ttpK6xmXY3LauEzqmwgQITfrLBMdpDyJZjGYYpYLOWvg hGIkkEH+ACyrU1SZOYl3tCYmteXSfJeuwLP4g2vcLaj0j+z7fvJ1YAVByeuOHKV/ JgHv1XE3tRZH1ZZ1QoeHHlaizjzjCCic/ld93SHzYwgFDyN5AoIBAQDYuQGKIEbS KlZpaZAvyU9XYEXDSRLGnkKLOo0A54IsM/2YueYPgJ3ovMyVU5coMcXC3AACo0Zy OREHXdmNmKe+PcZArbn/BvTMihXChLKeGc/MFyCBqbniDM6/LSkqT1mU2jL8AEKz xwU9kHX4NZrq4CfYoqA3b8x/dVCgV/8L0o5+mubHm7NUyFYlOHPWQ0u+auEKEdAB dVtv3VuPUwgkmE4OgsDv3q165jQ0Yr5cxXwlNUHd0yJlo9QklN8ua6rxCLU+ylbB RgU+tALD7pBPF2pa+m5G3efOUOTFhwWFsQ/mABZscz9emiQXNHVuwj7feLOOE/Yq PkhecmmsPm2nAoIBAEG4BKXqYLxwFp8xqAcLTBaGVA/NZXobM2sQIZZqkF50MFgV dhGohcP/FB8QeLivKUtnaa82XGzLDj/FFMLE0rrWSEis6NZhzgBNEwRU4imxrmaM nvUwsvRwt64GmjYZi7WgrQriNFcn0VAHNl+adJZUAiao1TgpU+egae9nymc3da3a y2SUWuTacXR+BS8UZKBGxohZv/ulpJ/MiH9veieaGXPmAT9642FhxkIkG0JnKZj6 fcF9qQFhaLIKVlH0ywa9ZBR6dRPki0wibCcEHL/5ia8yZ21A3fkOa5OaxzSS+Yqz Ah4KYrEc/Tvkxzf0aWEjg0h2LYUBFFupILEohqkCggEAUhLGHXwZte+op+T9YMt5 C5r+8HTU8njutHFpAsWpy3mo+VS1ZnuL0Z7mT0rHvMYUobXVHyqcPBeWdla5U+FS 7T3RvZ7NCGKnBGrq0K6WQj9+LUk420HejlfRWB8PLuG8CB4WHs8uc4zUVDtIIcaT M43OKUF1MWlaZY6VCRQqF10W76VT7pXtdRclYJUfcS4tGiC5tqmGP3clOJj42q9U Lx+qt94WmQCYbCmP7aLTeqijWifwGMSjiyBe77edSaQmqX9lvDC+aBVPWS6suWy4 I+u3MFsUtivFZKHH8XIvyjCC19SCqXF/tyDiuBL6wgY370NzpEO0/sx1dacYk81U kwKCAQB6g2V31JRn6CjkCTDG9Lf71AQwW1ZaLB71rhKyepAVV3vnYBRjuApGx3cN WFVIU9Cc010xDjeBmlbkqsfDujZRKTdU8aq9U8N26UWNkiwQjD1kCQR7KrvatZaU wglJ04BXZhVW/qT5Q2j/bgBmEjjbes83ZNWwWbx9x+h/YUVcCJ+n6OQCmDRBEvz6 1XkRpWt1HR9yEpH8kIuwWBqe/+afmASaLCK19jQcQ80QDvEcn8cy8A0UHM3FToWf R3OBlkcHYlUMZbj0VpiDktEUxl/ycPVWesH7WOhsB4HxSqtLpjebJBffzU/e+k+u Q39oXb8n1ljeCNi/Ksj8e/KstwzI -----END PRIVATE KEY----- ================================================ FILE: ci/scripts/wiremock.v3.ext ================================================ authorityKeyIdentifier=keyid,issuer basicConstraints=CA:FALSE keyUsage = digitalSignature, nonRepudiation, keyEncipherment, dataEncipherment subjectAltName = @alt_names [alt_names] IP.1 = 127.0.0.1 DNS.1 = localhost ================================================ FILE: ci/test.bat ================================================ REM Test Golang driver setlocal EnableDelayedExpansion start /b python ci\scripts\hang_webserver.py 12345 curl -O https://repo1.maven.org/maven2/org/wiremock/wiremock-standalone/3.11.0/wiremock-standalone-3.11.0.jar START /B java -jar wiremock-standalone-3.11.0.jar --port %WIREMOCK_PORT% -https-port %WIREMOCK_HTTPS_PORT% --https-keystore ci/scripts/wiremock.p12 --keystore-type PKCS12 --keystore-password password if "%CLOUD_PROVIDER%"=="AWS" ( set PARAMETER_FILENAME=parameters_aws_golang.json.gpg set PRIVATE_KEY=rsa_key_golang_aws.p8.gpg ) else if "%CLOUD_PROVIDER%"=="AZURE" ( set PARAMETER_FILENAME=parameters_azure_golang.json.gpg set PRIVATE_KEY=rsa_key_golang_azure.p8.gpg ) else if "%CLOUD_PROVIDER%"=="GCP" ( set PARAMETER_FILENAME=parameters_gcp_golang.json.gpg set PRIVATE_KEY=rsa_key_golang_gcp.p8.gpg ) if not defined PARAMETER_FILENAME ( echo [ERROR] failed to detect CLOUD_PROVIDER: %CLOUD_PROVIDER% exit /b 1 ) gpg --quiet --batch --yes --decrypt --passphrase="%PARAMETERS_SECRET%" --output parameters.json .github/workflows/%PARAMETER_FILENAME% if %ERRORLEVEL% NEQ 0 ( echo [ERROR] failed to decrypt the test parameters exit /b 1 ) gpg --quiet --batch --yes --decrypt --passphrase="%PARAMETERS_SECRET%" --output rsa-2048-private-key.p8 .github/workflows/rsa-2048-private-key.p8.gpg if %ERRORLEVEL% NEQ 0 ( echo [ERROR] failed to decrypt the rsa-2048 private key exit /b 1 ) REM Create directory structure for golang private key if not exist ".github\workflows\parameters\public" mkdir ".github\workflows\parameters\public" gpg --quiet --batch --yes --decrypt --passphrase="%GOLANG_PRIVATE_KEY_SECRET%" --output .github\workflows\parameters\public\rsa_key_golang.p8 .github\workflows\parameters\public\%PRIVATE_KEY% if %ERRORLEVEL% NEQ 0 ( echo [ERROR] failed to decrypt the golang private key exit /b 1 ) echo @echo off>parameters.bat jq -r ".testconnection | to_entries | map(\"set \(.key)=\(.value)\") | .[]" parameters.json >> parameters.bat call parameters.bat if %ERRORLEVEL% NEQ 0 ( echo [ERROR] failed to set the test parameters exit /b 1 ) echo [INFO] Account: %SNOWFLAKE_TEST_ACCOUNT% echo [INFO] User : %SNOWFLAKE_TEST_USER% echo [INFO] Database: %SNOWFLAKE_TEST_DATABASE% echo [INFO] Warehouse: %SNOWFLAKE_TEST_WAREHOUSE% echo [INFO] Role: %SNOWFLAKE_TEST_ROLE% go install github.com/jstemmer/go-junit-report/v2@latest REM Build coverpkg list excluding cmd/ packages set COVPKGS= for /f "usebackq delims=" %%p in (`go list ./...`) do ( echo %%p | findstr /C:"/cmd/" >nul if !ERRORLEVEL! NEQ 0 ( if "!COVPKGS!"=="" ( set COVPKGS=%%p ) else ( set COVPKGS=!COVPKGS!,%%p ) ) ) REM Test based on SEQUENTIAL_TESTS setting if "%SEQUENTIAL_TESTS%"=="true" ( REM Test each package separately to avoid buffering - real-time output but slower echo [INFO] Running tests sequentially for real-time output REM Clear any existing output file if exist test-output.txt del test-output.txt REM Track if any test failed set TEST_FAILED=0 REM Loop through each package and test separately for /f "usebackq delims=" %%p in (`go list ./...`) do ( set PKG=%%p REM Convert full package path to relative path set PKG_PATH=!PKG:github.com/snowflakedb/gosnowflake/v2=! if "!PKG_PATH!"=="" ( set PKG_PATH=. ) else ( set PKG_PATH=.!PKG_PATH! ) echo === Testing package: !PKG_PATH! === echo === Testing package: !PKG_PATH! === >> test-output.txt REM Test package and append to output (no -race on Windows ARM) REM Replace / with _ for coverage filename set COV_FILE=!PKG_PATH:/=_!_coverage.txt go test %GO_TEST_PARAMS% --timeout 90m -coverpkg=!COVPKGS! -coverprofile=!COV_FILE! -covermode=atomic -v !PKG_PATH! >> test-output.txt 2>&1 REM Track failure but continue testing other packages if !ERRORLEVEL! NEQ 0 ( echo [ERROR] Package !PKG_PATH! tests failed set TEST_FAILED=1 ) ) REM Merge coverage files go install github.com/wadey/gocovmerge@latest gocovmerge *_coverage.txt > coverage.txt del *_coverage.txt REM Set exit code based on whether any test failed set TEST_EXIT=!TEST_FAILED! ) else ( REM Test all packages with ./... - parallel, faster, but buffered echo [INFO] Running tests in parallel go test %GO_TEST_PARAMS% --timeout 90m -coverpkg=!COVPKGS! -coverprofile=coverage.txt -covermode=atomic -v ./... > test-output.txt 2>&1 set TEST_EXIT=!ERRORLEVEL! ) REM Display the test output type test-output.txt REM Generate JUnit report from the saved output type test-output.txt | go-junit-report > test-report.junit.xml REM End local scope and exit with the test exit code endlocal & exit /b %TEST_EXIT% ================================================ FILE: ci/test.sh ================================================ #!/bin/bash # # Test Golang driver # set -e set -o pipefail CI_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" $CI_DIR/scripts/run_wiremock.sh & if [[ -n "$JENKINS_HOME" ]]; then ROOT_DIR="$(cd "${CI_DIR}/.." && pwd)" export WORKSPACE=${WORKSPACE:-/tmp} source $CI_DIR/_init.sh declare -A TARGET_TEST_IMAGES if [[ -n "$TARGET_DOCKER_TEST_IMAGE" ]]; then echo "[INFO] TARGET_DOCKER_TEST_IMAGE: $TARGET_DOCKER_TEST_IMAGE" IMAGE_NAME=${TEST_IMAGE_NAMES[$TARGET_DOCKER_TEST_IMAGE]} if [[ -z "$IMAGE_NAME" ]]; then echo "[ERROR] The target platform $TARGET_DOCKER_TEST_IMAGE doesn't exist. Check $CI_DIR/_init.sh" exit 1 fi TARGET_TEST_IMAGES=([$TARGET_DOCKER_TEST_IMAGE]=$IMAGE_NAME) else echo "[ERROR] Set TARGET_DOCKER_TEST_IMAGE to the docker image name to run the test" for name in "${!TEST_IMAGE_NAMES[@]}"; do echo " " $name done exit 2 fi for name in "${!TARGET_TEST_IMAGES[@]}"; do echo "[INFO] Testing $DRIVER_NAME on $name" docker container run \ --rm \ --network=host \ -v $ROOT_DIR:/mnt/host \ -v $WORKSPACE:/mnt/workspace \ -e LOCAL_USER_ID=$(id -u ${USER}) \ -e GIT_COMMIT \ -e GIT_BRANCH \ -e GIT_URL \ -e AWS_ACCESS_KEY_ID \ -e AWS_SECRET_ACCESS_KEY \ -e GITHUB_ACTIONS \ -e GITHUB_SHA \ -e GITHUB_REF \ -e RUNNER_TRACKING_ID \ -e JOB_NAME \ -e BUILD_NUMBER \ -e JENKINS_HOME \ ${TEST_IMAGE_NAMES[$name]} \ /mnt/host/ci/container/test_component.sh echo "[INFO] Test Results: $WORKSPACE/junit.xml" done else source $CI_DIR/scripts/setup_connection_parameters.sh cd $CI_DIR/.. make test fi ================================================ FILE: ci/test_authentication.sh ================================================ #!/bin/bash -e set -o pipefail export THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" source "$THIS_DIR/scripts/setup_gpg.sh" export WORKSPACE=${WORKSPACE:-/tmp} CI_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" if [[ -n "$JENKINS_HOME" ]]; then ROOT_DIR="$(cd "${CI_DIR}/.." && pwd)" export WORKSPACE=${WORKSPACE:-/tmp} source $CI_DIR/_init.sh echo "Use /sbin/ip" IP_ADDR=$(/sbin/ip -4 addr show scope global dev eth0 | grep inet | awk '{print $2}' | cut -d / -f 1) fi gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters_aws_auth_tests.json "$THIS_DIR/../.github/workflows/parameters_aws_auth_tests.json.gpg" gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/rsa_keys/rsa_key.p8 "$THIS_DIR/../.github/workflows/rsa_keys/rsa_key.p8.gpg" gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/rsa_keys/rsa_key_invalid.p8 "$THIS_DIR/../.github/workflows/rsa_keys/rsa_key_invalid.p8.gpg" docker run \ -v $(cd $THIS_DIR/.. && pwd):/mnt/host \ -v $WORKSPACE:/mnt/workspace \ --rm \ artifactory.ci1.us-west-2.aws-dev.app.snowflake.com/internal-production-docker-snowflake-virtual/docker/snowdrivers-test-external-browser-golang:8 \ "/mnt/host/ci/container/test_authentication.sh" ================================================ FILE: ci/test_revocation.sh ================================================ #!/bin/bash # # Test certificate revocation validation using the revocation-validation framework. # set -o pipefail THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" DRIVER_DIR="$( dirname "${THIS_DIR}")" WORKSPACE=${WORKSPACE:-${DRIVER_DIR}} echo "[Info] Starting revocation validation tests" echo "[Info] Go driver path: $DRIVER_DIR" set -e # Clone revocation-validation framework REVOCATION_DIR="/tmp/revocation-validation" REVOCATION_BRANCH="${REVOCATION_BRANCH:-main}" rm -rf "$REVOCATION_DIR" if [ -n "$GITHUB_USER" ] && [ -n "$GITHUB_TOKEN" ]; then git clone --depth 1 --branch "$REVOCATION_BRANCH" "https://${GITHUB_USER}:${GITHUB_TOKEN}@github.com/snowflake-eng/revocation-validation.git" "$REVOCATION_DIR" else git clone --depth 1 --branch "$REVOCATION_BRANCH" "https://github.com/snowflake-eng/revocation-validation.git" "$REVOCATION_DIR" fi cd "$REVOCATION_DIR" # Point the framework at the local Go driver checkout go mod edit -replace "github.com/snowflakedb/gosnowflake/v2=${DRIVER_DIR}" go mod tidy echo "[Info] Replaced gosnowflake module with local checkout: $DRIVER_DIR" echo "[Info] Running tests with Go $(go version | grep -oE 'go[0-9]+\.[0-9]+')..." go run . \ --client snowflake \ --output "${WORKSPACE}/revocation-results.json" \ --output-html "${WORKSPACE}/revocation-report.html" \ --log-level debug EXIT_CODE=$? if [ -f "${WORKSPACE}/revocation-results.json" ]; then echo "[Info] Results: ${WORKSPACE}/revocation-results.json" fi if [ -f "${WORKSPACE}/revocation-report.html" ]; then echo "[Info] Report: ${WORKSPACE}/revocation-report.html" fi exit $EXIT_CODE ================================================ FILE: ci/test_rockylinux9.sh ================================================ #!/bin/bash -e # # Test GoSnowflake driver in Rocky Linux 9 # NOTES: # - Go version MUST be passed in as the first argument, e.g: "1.24.2" # - This is the script that test_rockylinux9_docker.sh runs inside of the docker container if [[ -z "${1}" ]]; then echo "[ERROR] Go version is required as first argument (e.g., '1.24.2')" echo "Usage: $0 " exit 1 fi GO_VERSION="${1}" THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" CONNECTOR_DIR="$( dirname "${THIS_DIR}")" # Validate prerequisites if [[ ! -f "${CONNECTOR_DIR}/parameters.json" ]]; then echo "[ERROR] parameters.json not found - connection parameters must be decrypted first" exit 1 fi if [[ ! -f "${CONNECTOR_DIR}/.github/workflows/parameters/public/rsa_key_golang.p8" ]]; then echo "[ERROR] Private key not found - must be decrypted first" exit 1 fi # Setup Go environment echo "[Info] Using Go ${GO_VERSION}" # Extract short version for wrapper script GO_VERSION_SHORT=$(echo ${GO_VERSION} | cut -d. -f1,2) if ! command -v go${GO_VERSION_SHORT} &> /dev/null; then echo "[ERROR] Go ${GO_VERSION_SHORT} not found!" exit 1 fi # Set GOROOT to short version directory (e.g., /usr/local/go1.24) export GOROOT="/usr/local/go${GO_VERSION_SHORT}" export PATH="${GOROOT}/bin:$PATH" export GOPATH="/home/user/go" export PATH="$GOPATH/bin:$PATH" echo "[Info] Go ${GO_VERSION} version: $(go version)" cd $CONNECTOR_DIR echo "[Info] Downloading Go modules" go mod download # Load connection parameters eval $(jq -r '.testconnection | to_entries | map("export \(.key)=\(.value|tostring)")|.[]' ${CONNECTOR_DIR}/parameters.json) export SNOWFLAKE_TEST_PRIVATE_KEY="${CONNECTOR_DIR}/.github/workflows/parameters/public/rsa_key_golang.p8" # Start WireMock ${CONNECTOR_DIR}/ci/scripts/run_wiremock.sh & # Run tests using make test cd ${CONNECTOR_DIR} make test ================================================ FILE: ci/test_rockylinux9_docker.sh ================================================ #!/bin/bash -e # Test GoSnowflake driver in Rocky Linux 9 Docker # NOTES: # - Go version MUST be specified as first argument # - Usage: ./test_rockylinux9_docker.sh "1.24.2" set -o pipefail if [[ -z "${1}" ]]; then echo "[ERROR] Go version is required as first argument (e.g., '1.24.2')" echo "Usage: $0 " exit 1 fi GO_ENV=${1} # Set constants THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" CONNECTOR_DIR="$( dirname "${THIS_DIR}")" WORKSPACE=${WORKSPACE:-${CONNECTOR_DIR}} # TODO: Uncomment when set_base_image.sh is created for Go # source $THIS_DIR/set_base_image.sh cd $THIS_DIR/docker/rockylinux9 CONTAINER_NAME=test_gosnowflake_rockylinux9 echo "[Info] Building docker image for Rocky Linux 9 with Go ${GO_ENV}" # Get current user/group IDs to match host permissions USER_ID=$(id -u) GROUP_ID=$(id -g) docker build --pull -t ${CONTAINER_NAME}:1.0 \ --build-arg BASE_IMAGE=rockylinux:9 \ --build-arg GO_VERSION=$GO_ENV \ --build-arg USER_ID=$USER_ID \ --build-arg GROUP_ID=$GROUP_ID \ . -f Dockerfile # Use setup_connection_parameters.sh like native jobs (outside container) if [[ "$GITHUB_ACTIONS" == "true" ]]; then source ${CONNECTOR_DIR}/ci/scripts/setup_connection_parameters.sh fi docker run --network=host \ -e TERM=vt102 \ -e JENKINS_HOME \ -e GITHUB_ACTIONS \ -e CLOUD_PROVIDER \ -e GO_TEST_PARAMS \ -e WIREMOCK_PORT \ -e WIREMOCK_HTTPS_PORT \ --mount type=bind,source="${CONNECTOR_DIR}",target=/home/user/gosnowflake \ ${CONTAINER_NAME}:1.0 \ ci/test_rockylinux9.sh ${GO_ENV} ================================================ FILE: ci/test_wif.sh ================================================ #!/bin/bash -e set -o pipefail export THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" export RSA_KEY_PATH_AWS_AZURE="$THIS_DIR/wif/parameters/rsa_wif_aws_azure" export RSA_KEY_PATH_GCP="$THIS_DIR/wif/parameters/rsa_wif_gcp" export PARAMETERS_FILE_PATH="$THIS_DIR/wif/parameters/parameters_wif.json" run_tests_and_set_result() { local provider="$1" local host="$2" local snowflake_host="$3" local rsa_key_path="$4" local snowflake_user="$5" local impersonation_path="$6" local snowflake_user_for_impersonation="$7" # NOTE: /home/user is the only dir we can write to (SNOW-2231498 to improve WORKDIR) ssh -i "$rsa_key_path" -o IdentitiesOnly=yes -p 443 "$host" env BRANCH="$BRANCH" SNOWFLAKE_TEST_WIF_HOST="$snowflake_host" SNOWFLAKE_TEST_WIF_PROVIDER="$provider" SNOWFLAKE_TEST_WIF_ACCOUNT="$SNOWFLAKE_TEST_WIF_ACCOUNT SNOWFLAKE_TEST_WIF_USERNAME="$snowflake_user" SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH="$impersonation_path" SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION="$snowflake_user_for_impersonation"" bash << EOF set -e set -o pipefail docker run \ --rm \ --cpus=1 \ -m 2g \ -e BRANCH \ -e SNOWFLAKE_TEST_WIF_PROVIDER \ -e SNOWFLAKE_TEST_WIF_HOST \ -e SNOWFLAKE_TEST_WIF_ACCOUNT \ -e SNOWFLAKE_TEST_WIF_USERNAME \ -e SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH \ -e SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION \ snowflakedb/client-go-chainguard-go1.24-test:1 \ bash -c " cd /home/user echo 'Running tests on branch: \$BRANCH, provider: \$SNOWFLAKE_TEST_WIF_PROVIDER' if [[ \"\$BRANCH\" =~ ^PR-[0-9]+\$ ]]; then wget -O - https://github.com/snowflakedb/gosnowflake/archive/refs/pull/\$(echo \$BRANCH | cut -d- -f2)/head.tar.gz | tar -xz else wget -O - https://github.com/snowflakedb/gosnowflake/archive/refs/heads/$BRANCH.tar.gz | tar -xz fi mv gosnowflake-* gosnowflake cd gosnowflake SKIP_SETUP=true go test -v -run TestWorkloadIdentityAuthOnCloudVM " EOF local status=$? if [[ $status -ne 0 ]]; then echo "$provider tests failed with exit status: $status" EXIT_STATUS=1 else echo "$provider tests passed" fi } get_branch() { local branch if [[ -n "${GIT_BRANCH}" ]]; then # Jenkins branch="${GIT_BRANCH}" else # Local branch=$(git rev-parse --abbrev-ref HEAD) fi echo "${branch}" } setup_parameters() { source "$THIS_DIR/scripts/setup_gpg.sh" gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output "$RSA_KEY_PATH_AWS_AZURE" "${RSA_KEY_PATH_AWS_AZURE}.gpg" gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output "$RSA_KEY_PATH_GCP" "${RSA_KEY_PATH_GCP}.gpg" chmod 600 "$RSA_KEY_PATH_AWS_AZURE" chmod 600 "$RSA_KEY_PATH_GCP" gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output "$PARAMETERS_FILE_PATH" "${PARAMETERS_FILE_PATH}.gpg" eval $(jq -r '.wif | to_entries | map("export \(.key)=\(.value|tostring)")|.[]' $PARAMETERS_FILE_PATH) } BRANCH=$(get_branch) export BRANCH setup_parameters # Run tests for all cloud providers EXIT_STATUS=0 set +e # Don't exit on first failure run_tests_and_set_result "AZURE" "$HOST_AZURE" "$SNOWFLAKE_TEST_WIF_HOST_AZURE" "$RSA_KEY_PATH_AWS_AZURE" "$SNOWFLAKE_TEST_WIF_USERNAME_AZURE" run_tests_and_set_result "AWS" "$HOST_AWS" "$SNOWFLAKE_TEST_WIF_HOST_AWS" "$RSA_KEY_PATH_AWS_AZURE" "$SNOWFLAKE_TEST_WIF_USERNAME_AWS" "$SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH_AWS" "$SNOWFLAKE_TEST_WIF_USERNAME_AWS_IMPERSONATION" run_tests_and_set_result "GCP" "$HOST_GCP" "$SNOWFLAKE_TEST_WIF_HOST_GCP" "$RSA_KEY_PATH_GCP" "$SNOWFLAKE_TEST_WIF_USERNAME_GCP" "$SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH_GCP" "$SNOWFLAKE_TEST_WIF_USERNAME_GCP_IMPERSONATION" run_tests_and_set_result "GCP+OIDC" "$HOST_GCP" "$SNOWFLAKE_TEST_WIF_HOST_GCP" "$RSA_KEY_PATH_GCP" "$SNOWFLAKE_TEST_WIF_USERNAME_GCP_OIDC" set -e # Re-enable exit on error echo "Exit status: $EXIT_STATUS" exit $EXIT_STATUS ================================================ FILE: ci/wif/parameters/parameters_wif.json.gpg ================================================  'QW-qd YêőTkv5F2яyD`mwGLWݽd_\'q6T*'9_֮t %?wļHbZvfwӘ].h\Θ_&uzT[&1G0=)}V;j ==X;E(,k7I&@ŕZקЗ$>-@Ʈ9y0IF-;U']x)A5'D+$>3ܒAƯ~9?):}*m}]7^,e@!\Cl6ąUi-@9!k&VgN G{h'bw3/ >QXZjZub 'D.# {Dj'̪Tŋ,%QH5 ================================================ FILE: ci/wif/parameters/rsa_wif_aws_azure.gpg ================================================  髃6K5%܇ټ飐|eRk]nc-TloB,ܐ͒R7B] 0 { for val := range unknownValues { logger.Warnf("Unknown configuration entry: %s with value: %s", val, unknownValues[val]) } } err = validateClientConfiguration(&clientConfig) if err != nil { return nil, parsingClientConfigError(err) } return &clientConfig, nil } func getUnknownValues(fileContents []byte) map[string]any { var values map[string]any err := json.Unmarshal(fileContents, &values) if err != nil { return nil } if values["common"] == nil { return nil } commonValues := values["common"].(map[string]any) lowercaseCommonValues := make(map[string]any, len(commonValues)) for k, v := range commonValues { lowercaseCommonValues[strings.ToLower(k)] = v } delete(lowercaseCommonValues, "log_level") delete(lowercaseCommonValues, "log_path") return lowercaseCommonValues } func parsingClientConfigError(err error) error { return fmt.Errorf("parsing client config failed: %w", err) } func validateClientConfiguration(clientConfig *ClientConfig) error { if clientConfig == nil { return errors.New("client config not found") } if clientConfig.Common == nil { return errors.New("common section in client config not found") } return validateLogLevel(*clientConfig) } func validateLogLevel(clientConfig ClientConfig) error { var logLevel = clientConfig.Common.LogLevel if logLevel != "" { _, err := toLogLevel(logLevel) if err != nil { return err } } return nil } func toLogLevel(logLevelString string) (string, error) { var logLevel = strings.ToUpper(logLevelString) switch logLevel { case levelOff, levelError, levelWarn, levelInfo, levelDebug, levelTrace: return logLevel, nil default: return "", errors.New("unknown log level: " + logLevelString) } } ================================================ FILE: client_configuration_test.go ================================================ package gosnowflake import ( "fmt" "os" "path" "path/filepath" "strings" "testing" ) func TestFindConfigFileFromConnectionParameters(t *testing.T) { dirs := createTestDirectories(t) connParameterConfigPath := createFile(t, "conn_parameters_config.json", "random content", dirs.dir) envConfigPath := createFile(t, "env_var_config.json", "random content", dirs.dir) t.Setenv(clientConfEnvName, envConfigPath) createFile(t, defaultConfigName, "random content", dirs.predefinedDir1) createFile(t, defaultConfigName, "random content", dirs.predefinedDir2) clientConfigFilePath, err := findClientConfigFilePath(connParameterConfigPath, predefinedTestDirs(dirs)) assertEqualE(t, err, nil) assertEqualE(t, clientConfigFilePath, connParameterConfigPath, "config file path") } func TestFindConfigFileFromEnvVariable(t *testing.T) { dirs := createTestDirectories(t) envConfigPath := createFile(t, "env_var_config.json", "random content", dirs.dir) t.Setenv(clientConfEnvName, envConfigPath) createFile(t, defaultConfigName, "random content", dirs.predefinedDir1) createFile(t, defaultConfigName, "random content", dirs.predefinedDir2) clientConfigFilePath, err := findClientConfigFilePath("", predefinedTestDirs(dirs)) assertEqualE(t, err, nil) assertEqualE(t, clientConfigFilePath, envConfigPath, "config file path") } func TestFindConfigFileFromFirstPredefinedDir(t *testing.T) { dirs := createTestDirectories(t) configPath := createFile(t, defaultConfigName, "random content", dirs.predefinedDir1) createFile(t, defaultConfigName, "random content", dirs.predefinedDir2) clientConfigFilePath, err := findClientConfigFilePath("", predefinedTestDirs(dirs)) assertEqualE(t, err, nil) assertEqualE(t, clientConfigFilePath, configPath, "config file path") } func TestFindConfigFileFromSubsequentDirectoryIfNotFoundInPreviousOne(t *testing.T) { dirs := createTestDirectories(t) createFile(t, "wrong_file_name.json", "random content", dirs.predefinedDir1) configPath := createFile(t, defaultConfigName, "random content", dirs.predefinedDir2) clientConfigFilePath, err := findClientConfigFilePath("", predefinedTestDirs(dirs)) assertEqualE(t, err, nil) assertEqualE(t, clientConfigFilePath, configPath, "config file path") } func TestNotFindConfigFileWhenNotDefined(t *testing.T) { dirs := createTestDirectories(t) createFile(t, "wrong_file_name.json", "random content", dirs.predefinedDir1) createFile(t, "wrong_file_name.json", "random content", dirs.predefinedDir2) clientConfigFilePath, err := findClientConfigFilePath("", predefinedTestDirs(dirs)) assertEqualE(t, err, nil) assertEqualE(t, clientConfigFilePath, "", "config file path") } func TestCreatePredefinedDirs(t *testing.T) { skipOnMissingHome(t) exeDir, _ := os.Executable() appDir := filepath.Dir(exeDir) homeDir, err := os.UserHomeDir() assertNilF(t, err, "get home dir error") locations := clientConfigPredefinedDirs() assertEqualF(t, len(locations), 2, "size") assertEqualE(t, locations[0], appDir, "driver directory") assertEqualE(t, locations[1], homeDir, "home directory") } func TestGetClientConfig(t *testing.T) { dir := t.TempDir() fileName := "config.json" configContents := createClientConfigContent("INFO", "/some-path/some-directory") createFile(t, fileName, configContents, dir) filePath := path.Join(dir, fileName) clientConfigFilePath, _, err := getClientConfig(filePath) assertNilF(t, err) assertNotNilF(t, clientConfigFilePath) assertEqualE(t, clientConfigFilePath.Common.LogLevel, "INFO", "log level") assertEqualE(t, clientConfigFilePath.Common.LogPath, "/some-path/some-directory", "log path") } func TestNoResultForGetClientConfigWhenNoFileFound(t *testing.T) { clientConfigFilePath, _, err := getClientConfig("") assertNilF(t, err) assertNilF(t, clientConfigFilePath) } func TestParseConfiguration(t *testing.T) { dir := t.TempDir() testCases := []struct { testName string fileName string fileContents string expectedLogLevel string expectedLogPath string }{ { testName: "TestWithLogLevelUpperCase", fileName: "config_1.json", fileContents: createClientConfigContent("INFO", "/some-path/some-directory"), expectedLogLevel: "INFO", expectedLogPath: "/some-path/some-directory", }, { testName: "TestWithLogLevelLowerCase", fileName: "config_2.json", fileContents: createClientConfigContent("info", "/some-path/some-directory"), expectedLogLevel: "info", expectedLogPath: "/some-path/some-directory", }, { testName: "TestWithMissingValues", fileName: "config_3.json", fileContents: `{ "common": {} }`, expectedLogLevel: "", expectedLogPath: "", }, } for _, tc := range testCases { t.Run(tc.testName, func(t *testing.T) { fileName := createFile(t, tc.fileName, tc.fileContents, dir) config, err := parseClientConfiguration(fileName) assertNilF(t, err, "parse client configuration error") assertEqualE(t, config.Common.LogLevel, tc.expectedLogLevel, "log level") assertEqualE(t, config.Common.LogPath, tc.expectedLogPath, "log path") }) } } func TestParseAllLogLevels(t *testing.T) { dir := t.TempDir() for _, logLevel := range []string{"OFF", "ERROR", "WARN", "INFO", "DEBUG", "TRACE"} { t.Run(logLevel, func(t *testing.T) { fileContents := fmt.Sprintf(`{ "common": { "log_level" : "%s", "log_path" : "/some-path/some-directory" } }`, logLevel) fileName := createFile(t, fmt.Sprintf("config_%s.json", logLevel), fileContents, dir) config, err := parseClientConfiguration(fileName) assertNilF(t, err, "parse client config error") assertEqualE(t, config.Common.LogLevel, logLevel, "log level") }) } } func TestParseConfigurationFails(t *testing.T) { dir := t.TempDir() testCases := []struct { testName string fileName string FileContents string expectedErrorMessageToContain string }{ { testName: "TestWithWrongLogLevel", fileName: "config_1.json", FileContents: createClientConfigContent("something weird", "/some-path/some-directory"), expectedErrorMessageToContain: "unknown log level", }, { testName: "TestWithWrongTypeOfLogLevel", fileName: "config_2.json", FileContents: `{ "common": { "log_level" : 15, "log_path" : "/some-path/some-directory" } }`, expectedErrorMessageToContain: "ClientConfigCommonProps.common.log_level", }, { testName: "TestWithWrongTypeOfLogPath", fileName: "config_3.json", FileContents: `{ "common": { "log_level" : "INFO", "log_path" : true } }`, expectedErrorMessageToContain: "ClientConfigCommonProps.common.log_path", }, { testName: "TestWithoutCommon", fileName: "config_4.json", FileContents: "{}", expectedErrorMessageToContain: "common section in client config not found", }, } for _, tc := range testCases { t.Run(tc.testName, func(t *testing.T) { fileName := createFile(t, tc.fileName, tc.FileContents, dir) _, err := parseClientConfiguration(fileName) assertNotNilF(t, err, "parse client configuration error") errMessage := fmt.Sprint(err) expectedPrefix := "parsing client config failed" assertHasPrefixE(t, errMessage, expectedPrefix, "error message") assertStringContainsE(t, errMessage, tc.expectedErrorMessageToContain, "error message") }) } } func TestUnknownValues(t *testing.T) { testCases := []struct { testName string inputString string expectedOutput map[string]string }{ { testName: "EmptyCommon", inputString: `{ "common": {} }`, expectedOutput: map[string]string{}, }, { testName: "CommonMissing", inputString: `{ }`, expectedOutput: map[string]string{}, }, { testName: "UnknownProperty", inputString: `{ "common": { "unknown_key": "unknown_value" } }`, expectedOutput: map[string]string{ "unknown_key": "unknown_value", }, }, { testName: "KnownAndUnknownProperty", inputString: `{ "common": { "lOg_level": "level", "log_PATH": "path", "unknown_key": "unknown_value" } }`, expectedOutput: map[string]string{ "unknown_key": "unknown_value", }, }, { testName: "KnownProperties", inputString: `{ "common": { "log_level": "level", "log_path": "path" } }`, expectedOutput: map[string]string{}, }, { testName: "EmptyInput", inputString: "", expectedOutput: map[string]string{}, }, } for _, tc := range testCases { t.Run(tc.testName, func(t *testing.T) { inputBytes := []byte(tc.inputString) result := getUnknownValues(inputBytes) assertEqualE(t, fmt.Sprint(result), fmt.Sprint(tc.expectedOutput)) }) } } func TestConfigFileOpenSymlinkFail(t *testing.T) { skipOnWindows(t, "file permission is different") dir := t.TempDir() configFilePath := createFile(t, defaultConfigName, "random content", dir) symlinkFile := path.Join(dir, "test_symlink") expectedErrMsg := "too many levels of symbolic links" err := os.Symlink(configFilePath, symlinkFile) assertNilF(t, err, "failed to create symlink") _, err = getFileContents(symlinkFile, os.FileMode(1<<4|1<<1)) assertNotNilF(t, err, "should have blocked opening symlink") assertTrueF(t, strings.Contains(err.Error(), expectedErrMsg)) } func createFile(t *testing.T, fileName string, fileContents string, directory string) string { fullFileName := path.Join(directory, fileName) err := os.WriteFile(fullFileName, []byte(fileContents), 0644) assertNilF(t, err, "create file error") return fullFileName } func createTestDirectories(t *testing.T) struct { dir string predefinedDir1 string predefinedDir2 string } { dir := t.TempDir() predefinedDir1 := path.Join(dir, "dir1") err := os.Mkdir(predefinedDir1, 0700) assertNilF(t, err, "predefined dir1 error") predefinedDir2 := path.Join(dir, "dir2") err = os.Mkdir(predefinedDir2, 0700) assertNilF(t, err, "predefined dir2 error") return struct { dir string predefinedDir1 string predefinedDir2 string }{ dir: dir, predefinedDir1: predefinedDir1, predefinedDir2: predefinedDir2, } } func predefinedTestDirs(dirs struct { dir string predefinedDir1 string predefinedDir2 string }) []string { return []string{dirs.predefinedDir1, dirs.predefinedDir2} } func createClientConfigContent(logLevel string, logPath string) string { return fmt.Sprintf(`{ "common": { "log_level" : "%s", "log_path" : "%s" } }`, logLevel, strings.ReplaceAll(logPath, "\\", "\\\\"), ) } ================================================ FILE: client_test.go ================================================ package gosnowflake import ( "context" "net/http" "net/url" "testing" ) type DummyTransport struct { postRequests int getRequests int } func (t *DummyTransport) RoundTrip(r *http.Request) (*http.Response, error) { if r.URL.Path == "" { switch r.Method { case http.MethodGet: t.getRequests++ case http.MethodPost: t.postRequests++ } return &http.Response{StatusCode: 200}, nil } return createTestNoRevocationTransport().RoundTrip(r) } func TestInternalClient(t *testing.T) { config, err := ParseDSN(dsn) assertNilF(t, err, "failed to parse dsn") transport := DummyTransport{} config.Transporter = &transport driver := SnowflakeDriver{} db, err := driver.OpenWithConfig(context.Background(), *config) assertNilF(t, err, "failed to open with config") internalClient := (db.(*snowflakeConn)).internal resp, err := internalClient.Get(context.Background(), &url.URL{}, make(map[string]string), 0) assertNilF(t, err, "GET request should succeed") assertEqualF(t, resp.StatusCode, 200, "GET response status code should be 200") assertEqualF(t, transport.getRequests, 1, "Expected exactly one GET request") resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, defaultTimeProvider) assertNilF(t, err, "POST request should succeed") assertEqualF(t, resp.StatusCode, 200, "POST response status code should be 200") assertEqualF(t, transport.postRequests, 1, "Expected exactly one POST request") db.Close() } ================================================ FILE: cmd/arrow/.gitignore ================================================ arrow_batches transform_batches_to_rows/transform_batches_to_rows ================================================ FILE: cmd/arrow/Makefile ================================================ SUBDIRS := batches transform_batches_to_rows TARGETS := all install run lint fmt $(TARGETS): subdirs subdirs: $(SUBDIRS) $(SUBDIRS): @$(MAKE) -C $@ $(filter $(TARGETS),$(MAKECMDGOALS)) .PHONY: subdirs $(TARGETS) $(SUBDIRS) ================================================ FILE: cmd/arrow/transform_batches_to_rows/Makefile ================================================ include ../../../gosnowflake.mak CMD_TARGET=transform_batches_to_rows ## Install install: cinstall ## Run run: crun ## Lint lint: clint ## Format source codes fmt: cfmt .PHONY: install run lint fmt ================================================ FILE: cmd/arrow/transform_batches_to_rows/transform_batches_to_rows.go ================================================ package main import ( "context" "database/sql" "database/sql/driver" "errors" "flag" "io" "log" sf "github.com/snowflakedb/gosnowflake/v2" "github.com/snowflakedb/gosnowflake/v2/arrowbatches" ) func main() { if !flag.Parsed() { flag.Parse() } cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{ {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, {Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true}, {Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true}, {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, {Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false}, }) if err != nil { log.Fatalf("failed to create Config, err: %v", err) } connector := sf.NewConnector(sf.SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() conn, err := db.Conn(context.Background()) if err != nil { log.Fatalf("cannot create a connection. %v", err) } defer conn.Close() _, err = conn.ExecContext(context.Background(), "ALTER SESSION SET GO_QUERY_RESULT_FORMAT = json") if err != nil { log.Fatalf("cannot force JSON as result format. %v", err) } var rows driver.Rows err = conn.Raw(func(x any) error { rows, err = x.(driver.QueryerContext).QueryContext(arrowbatches.WithArrowBatches(context.Background()), "SELECT 1, 'hello' UNION SELECT 2, 'hi' UNION SELECT 3, 'howdy'", nil) return err }) if err != nil { log.Fatalf("cannot run a query. %v", err) } defer rows.Close() _, err = arrowbatches.GetArrowBatches(rows.(sf.SnowflakeRows)) var se *sf.SnowflakeError if !errors.As(err, &se) || se.Number != sf.ErrNonArrowResponseInArrowBatches { log.Fatalf("expected to fail while retrieving arrow batches") } res := make([]driver.Value, 2) for { err = rows.Next(res) if err == io.EOF { break } println(res[0].(string), res[1].(string)) } } ================================================ FILE: cmd/logger/Makefile ================================================ include ../../gosnowflake.mak CMD_TARGET=logger ## Install install: cinstall ## Run run: crun ## Lint lint: clint ## Format source codes fmt: cfmt .PHONY: install run lint fmt ================================================ FILE: cmd/logger/logger.go ================================================ package main import ( "bytes" sf "github.com/snowflakedb/gosnowflake/v2" "log" "strings" ) func main() { buf := &bytes.Buffer{} buf2 := &bytes.Buffer{} var mylog = sf.GetLogger() mylog.SetOutput(buf) mylog.Info("Hello I am default") mylog.Info("Hello II amm default") mylog.Debug("Default I am debug NOT SHOWN") _ = mylog.SetLogLevel("debug") mylog.Debug("Default II amm debug TO SHOW") var testlog = sf.CreateDefaultLogger() _ = testlog.SetLogLevel("debug") testlog.SetOutput(buf) testlog.SetOutput(buf2) sf.SetLogger(testlog) var mylog2 = sf.GetLogger() mylog2.Debug("test debug log is shown") _ = mylog2.SetLogLevel("info") mylog2.Debug("test debug log is not shownII") log.Print("Expect all true values:") // verify logger switch var strbuf = buf.String() log.Printf("%t:%t:%t:%t", strings.Contains(strbuf, "I am default"), strings.Contains(strbuf, "II amm default"), !strings.Contains(strbuf, "test debug log is shown"), strings.Contains(buf2.String(), "test debug log is shown")) // verify log level switch log.Printf("%t:%t:%t:%t", !strings.Contains(strbuf, "Default I am debug NOT SHOWN"), strings.Contains(strbuf, "Default II amm debug TO SHOW"), strings.Contains(buf2.String(), "test debug log is shown"), !strings.Contains(buf2.String(), "test debug log is not shownII")) } ================================================ FILE: cmd/mfa/Makefile ================================================ include ../../gosnowflake.mak CMD_TARGET=mfa ## Install install: cinstall ## Run run: crun ## Lint lint: clint ## Format source codes fmt: cfmt .PHONY: install run lint fmt ================================================ FILE: cmd/mfa/mfa.go ================================================ package main import ( "database/sql" "flag" "fmt" "log" sf "github.com/snowflakedb/gosnowflake/v2" ) func main() { if !flag.Parsed() { flag.Parse() } cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{ {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, {Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true}, {Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true}, {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, {Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false}, }) if err != nil { log.Fatalf("failed to create Config, err: %v", err) } cfg.Authenticator = sf.AuthTypeUsernamePasswordMFA dsn, err := sf.DSN(cfg) if err != nil { log.Fatalf("failed to create DSN from Config. err: %v", err) } // The external browser flow should start with the call to Open db, err := sql.Open("snowflake", dsn) if err != nil { log.Fatalf("failed to connect. err: %v", err) } defer db.Close() query := "SELECT 1" rows, err := db.Query(query) if err != nil { log.Fatalf("failed to run a query. %v, err: %v", query, err) } defer rows.Close() var v int for rows.Next() { err := rows.Scan(&v) if err != nil { log.Fatalf("failed to get result. err: %v", err) } if v != 1 { log.Fatalf("failed to get 1. got: %v", v) } fmt.Printf("Congrats! You have successfully run %v with Snowflake DB!", query) } } ================================================ FILE: cmd/programmatic_access_token/.gitignore ================================================ pat ================================================ FILE: cmd/programmatic_access_token/Makefile ================================================ include ../../gosnowflake.mak CMD_TARGET=pat ## Install install: cinstall ## Run run: crun ## Lint lint: clint ## Format source codes fmt: cfmt .PHONY: install run lint fmt ================================================ FILE: cmd/programmatic_access_token/pat.go ================================================ // you have to configure PAT on your user package main import ( "database/sql" "flag" "fmt" sf "github.com/snowflakedb/gosnowflake/v2" "log" ) func main() { if !flag.Parsed() { flag.Parse() } cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{ {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, {Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true}, {Name: "Token", EnvName: "SNOWFLAKE_TEST_PAT", FailOnMissing: true}, {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, {Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false}, }) cfg.Authenticator = sf.AuthTypePat if err != nil { log.Fatalf("cannot build config. %v", err) } connector := sf.NewConnector(sf.SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() query := "SELECT 1" rows, err := db.Query(query) if err != nil { log.Fatalf("failed to run a query. %v, err: %v", query, err) } defer rows.Close() var v int if !rows.Next() { log.Fatalf("no rows returned") } if err = rows.Scan(&v); err != nil { log.Fatalf("failed to scan rows. %v", err) } if v != 1 { log.Fatalf("unexpected result, expected 1, got %v", v) } fmt.Printf("Congrats! You have successfully run %v with Snowflake DB!\n", query) } ================================================ FILE: cmd/tomlfileconnection/.gitignore ================================================ tomlfileconnection.go ================================================ FILE: cmd/tomlfileconnection/Makefile ================================================ include ../../gosnowflake.mak CMD_TARGET=tomlfileconnection ## Install install: cinstall ## Run run: crun ## Lint lint: clint ## Format source codes fmt: cfmt .PHONY: install run lint fmt ================================================ FILE: cmd/variant/Makefile ================================================ include ../../gosnowflake.mak CMD_TARGET=variant ## Install install: cinstall ## Run run: crun ## Lint lint: clint ## Format source codes fmt: cfmt .PHONY: install run lint fmt ================================================ FILE: cmd/variant/insertvariantobject.go ================================================ package main import ( "context" "database/sql" "encoding/json" "flag" "fmt" "log" "strconv" "time" sf "github.com/snowflakedb/gosnowflake/v2" ) func main() { if !flag.Parsed() { flag.Parse() } cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{ {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, {Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true}, {Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true}, {Name: "Warehouse", EnvName: "SNOWFLAKE_TEST_WAREHOUSE", FailOnMissing: true}, {Name: "Database", EnvName: "SNOWFLAKE_TEST_DATABASE", FailOnMissing: true}, {Name: "Schema", EnvName: "SNOWFLAKE_TEST_SCHEMA", FailOnMissing: true}, {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, {Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false}, }) if err != nil { log.Fatalf("failed to create Config, err: %v", err) } dsn, err := sf.DSN(cfg) if err != nil { log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) } db, err := sql.Open("snowflake", dsn) if err != nil { log.Fatalf("failed to connect. %v, err: %v", dsn, err) } defer db.Close() ctx := context.Background() conn, err := db.Conn(ctx) if err != nil { log.Fatalf("Failed to acquire connection. err: %v", err) } defer conn.Close() tablename := "insert_variant_object_" + strconv.FormatInt(time.Now().UnixNano(), 10) param := map[string]string{"key": "value"} jsonStr, err := json.Marshal(param) if err != nil { log.Fatalf("failed to marshal json. err: %v", err) } createTableQuery := "CREATE TABLE " + tablename + " (c1 VARIANT, c2 OBJECT)" // https://docs.snowflake.com/en/sql-reference/functions/parse_json // can do with TO_VARIANT(PARSE_JSON(..)) as well, but PARSE_JSON already produces VARIANT insertQuery := "INSERT INTO " + tablename + " (c1, c2) SELECT PARSE_JSON(?), TO_OBJECT(PARSE_JSON(?))" // https://docs.snowflake.com/en/sql-reference/data-types-semistructured#object insertOnlyObject := "INSERT INTO " + tablename + " (c2) SELECT OBJECT_CONSTRUCT('name', 'Jones'::VARIANT, 'age', 42::VARIANT)" selectQuery := "SELECT c1, c2 FROM " + tablename dropQuery := "DROP TABLE " + tablename fmt.Printf("Creating table: %v\n", createTableQuery) _, err = conn.ExecContext(ctx, createTableQuery) if err != nil { log.Fatalf("failed to run the query. %v, err: %v", createTableQuery, err) } defer func() { fmt.Printf("Dropping the table: %v\n", dropQuery) _, err = conn.ExecContext(ctx, dropQuery) if err != nil { log.Fatalf("failed to run the query. %v, err: %v", dropQuery, err) } }() fmt.Printf("Inserting VARIANT and OBJECT data into table: %v\n", insertQuery) _, err = conn.ExecContext(ctx, insertQuery, string(jsonStr), string(jsonStr), ) if err != nil { log.Fatalf("failed to run the query. %v, err: %v", insertQuery, err) } fmt.Printf("Now for another approach: %v\n", insertOnlyObject) _, err = conn.ExecContext(ctx, insertOnlyObject) if err != nil { log.Fatalf("failed to run the query. %v, err: %v", insertOnlyObject, err) } fmt.Printf("Querying the table into which we just inserted the data: %v\n", selectQuery) rows, err := conn.QueryContext(ctx, selectQuery) if err != nil { log.Fatalf("failed to run the query. %v, err: %v", selectQuery, err) } defer rows.Close() var c1, c2 any for rows.Next() { err := rows.Scan(&c1, &c2) if err != nil { log.Fatalf("failed to get result. err: %v", err) } fmt.Printf("%v (type: %T), %v (type: %T)\n", c1, c1, c2, c2) } if rows.Err() != nil { fmt.Printf("ERROR: %v\n", rows.Err()) return } } ================================================ FILE: codecov.yml ================================================ parsers: go: partials_as_hits: true ignore: - "cmd/" ================================================ FILE: connection.go ================================================ package gosnowflake import ( "context" "database/sql" "database/sql/driver" "encoding/json" "fmt" "net/http" "net/url" "regexp" "strconv" "strings" "sync/atomic" "time" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "github.com/snowflakedb/gosnowflake/v2/internal/errors" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "go.opentelemetry.io/otel/propagation" ) const ( httpHeaderContentType = "Content-Type" httpHeaderAccept = "accept" httpHeaderUserAgent = "User-Agent" httpHeaderServiceName = "X-Snowflake-Service" httpHeaderContentLength = "Content-Length" httpHeaderHost = "Host" httpHeaderValueOctetStream = "application/octet-stream" httpHeaderContentEncoding = "Content-Encoding" httpClientAppID = "CLIENT_APP_ID" httpClientAppVersion = "CLIENT_APP_VERSION" ) const ( statementTypeIDSelect = int64(0x1000) statementTypeIDDml = int64(0x3000) statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500) statementTypeIDMultistatement = int64(0xA000) ) const ( sessionClientSessionKeepAlive = "client_session_keep_alive" sessionClientSessionKeepAliveHeartbeatFrequency = "client_session_keep_alive_heartbeat_frequency" sessionClientValidateDefaultParameters = "CLIENT_VALIDATE_DEFAULT_PARAMETERS" sessionArrayBindStageThreshold = "client_stage_array_binding_threshold" serviceName = "service_name" ) type resultType string const ( snowflakeResultType ContextKey = "snowflakeResultType" execResultType resultType = "exec" queryResultType resultType = "query" ) type execKey string const ( executionType execKey = "executionType" executionTypeStatement string = "statement" ) // snowflakeConn manages its own context. // External cancellation should not be supported because the connection // may be reused after the original query/request has completed. type snowflakeConn struct { ctx context.Context cfg *Config rest *snowflakeRestful sequenceCounter uint64 telemetry *snowflakeTelemetry internal InternalClient queryContextCache queryContextCache currentTimeProvider currentTimeProvider syncParams syncParams idToken string mfaToken string } var ( queryIDPattern = `[\w\-_]+` queryIDRegexp = regexp.MustCompile(queryIDPattern) ) func (sc *snowflakeConn) exec( ctx context.Context, query string, noResult bool, isInternal bool, describeOnly bool, bindings []driver.NamedValue) ( *execResponse, error) { if sc.cfg.LogQueryText || isLogQueryTextEnabled(ctx) { if len(bindings) > 0 && (sc.cfg.LogQueryParameters || isLogQueryParametersEnabled(ctx)) { logger.WithContext(ctx).Infof("Executing query: %v with bindings: %v", query, bindings) } else { logger.WithContext(ctx).Infof("Executing query: %v", query) } } else { logger.WithContext(ctx).Infof("Executing query") } var err error counter := atomic.AddUint64(&sc.sequenceCounter, 1) // query sequence counter _, _, sessionID := safeGetTokens(sc.rest) ctx = context.WithValue(ctx, SFSessionIDKey, sessionID) queryContext, err := buildQueryContext(&sc.queryContextCache) if err != nil { logger.WithContext(ctx).Errorf("error while building query context: %v", err) } req := execRequest{ SQLText: query, AsyncExec: noResult, Parameters: map[string]any{}, IsInternal: isInternal, DescribeOnly: describeOnly, SequenceID: counter, QueryContext: queryContext, } if key := ctx.Value(multiStatementCount); key != nil { req.Parameters[string(multiStatementCount)] = key } if tag := ctx.Value(queryTag); tag != nil { req.Parameters[string(queryTag)] = tag } logger.WithContext(ctx).Debugf("parameters: %v", req.Parameters) // handle bindings, if required requestID := getOrGenerateRequestIDFromContext(ctx) if len(bindings) > 0 { if err = sc.processBindings(ctx, bindings, describeOnly, requestID, &req); err != nil { return nil, err } } logger.WithContext(ctx).Debugf("bindings: %v", req.Bindings) // populate headers headers := getHeaders() if isFileTransfer(query) { headers[httpHeaderAccept] = headerContentTypeApplicationJSON } // propagate traceID and spanID via traceparent header. this is a no-op if invalid IDs propagator := propagation.TraceContext{} propagator.Inject(ctx, propagation.MapCarrier(headers)) if sn, ok := sc.syncParams.get(serviceName); ok { headers[httpHeaderServiceName] = *sn } jsonBody, err := json.Marshal(req) if err != nil { return nil, err } data, err := sc.rest.FuncPostQuery(ctx, sc.rest, &url.Values{}, headers, jsonBody, sc.rest.RequestTimeout, requestID, sc.cfg) if err != nil { return data, err } code := -1 if data.Code != "" { code, err = strconv.Atoi(data.Code) if err != nil { return data, err } } logger.WithContext(ctx).Debugf("Success: %v, Code: %v", data.Success, code) if !sc.cfg.DisableQueryContextCache && data.Data.QueryContext != nil { queryContext, err := extractQueryContext(data) if err != nil { logger.WithContext(ctx).Errorf("error while decoding query context: %v", err) } else { sc.queryContextCache.add(sc, queryContext.Entries...) } } if !data.Success { err = exceptionTelemetry(populateErrorFields(code, data), sc) return nil, err } // handle PUT/GET commands fileTransferChan := make(chan error, 1) if isFileTransfer(query) { go func() { data, err = sc.processFileTransfer(ctx, data, query, isInternal) fileTransferChan <- err }() select { case <-ctx.Done(): logger.WithContext(ctx).Debugf("File transfer has been cancelled") return nil, ctx.Err() case err := <-fileTransferChan: if err != nil { return nil, err } } } logger.WithContext(ctx).Debugf("Exec/Query: queryId=%v SUCCESS with total=%v, returned=%v ", data.Data.QueryID, data.Data.Total, data.Data.Returned) if data.Data.FinalDatabaseName != "" { sc.cfg.Database = data.Data.FinalDatabaseName } if data.Data.FinalSchemaName != "" { sc.cfg.Schema = data.Data.FinalSchemaName } if data.Data.FinalWarehouseName != "" { sc.cfg.Warehouse = data.Data.FinalWarehouseName } if data.Data.FinalRoleName != "" { sc.cfg.Role = data.Data.FinalRoleName } sc.populateSessionParameters(data.Data.Parameters) return data, err } func extractQueryContext(data *execResponse) (queryContext, error) { var queryContext queryContext err := json.Unmarshal(data.Data.QueryContext, &queryContext) return queryContext, err } func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) { rqc := requestQueryContext{} if qcc == nil || len(qcc.entries) == 0 { logger.Debugf("empty qcc") return rqc, nil } for _, qce := range qcc.entries { contextData := contextData{} if qce.Context == "" { contextData.Base64Data = qce.Context } rqc.Entries = append(rqc.Entries, requestQueryContextEntry{ ID: qce.ID, Priority: qce.Priority, Timestamp: qce.Timestamp, Context: contextData, }) } return rqc, nil } func (sc *snowflakeConn) Begin() (driver.Tx, error) { return sc.BeginTx(context.Background(), driver.TxOptions{}) } func (sc *snowflakeConn) BeginTx( ctx context.Context, opts driver.TxOptions) ( driver.Tx, error) { logger.WithContext(ctx).Debug("BeginTx") if opts.ReadOnly { return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrNoReadOnlyTransaction, SQLState: SQLStateFeatureNotSupported, Message: errors.ErrMsgNoReadOnlyTransaction, }, sc) } if int(opts.Isolation) != int(sql.LevelDefault) { return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrNoDefaultTransactionIsolationLevel, SQLState: SQLStateFeatureNotSupported, Message: errors.ErrMsgNoDefaultTransactionIsolationLevel, }, sc) } if sc.rest == nil { return nil, driver.ErrBadConn } isDesc := isDescribeOnly(ctx) isInternal := isInternal(ctx) if _, err := sc.exec(ctx, "BEGIN", false, /* noResult */ isInternal, isDesc, nil); err != nil { return nil, err } return &snowflakeTx{sc, ctx}, nil } func (sc *snowflakeConn) cleanup() { // must flush log buffer while the process is running. logger.WithContext(sc.ctx).Debug("Snowflake connection closing.") if sc.rest != nil && sc.rest.Client != nil { sc.rest.Client.CloseIdleConnections() } } func (sc *snowflakeConn) Close() (err error) { logger.WithContext(sc.ctx).Info("Closing connection") if err := sc.telemetry.sendBatch(); err != nil { logger.WithContext(sc.ctx).Warnf("error while sending telemetry. %v", err) } sc.stopHeartBeat() sc.rest.HeartBeat = nil defer sc.cleanup() if sc.cfg != nil && !sc.cfg.ServerSessionKeepAlive { logger.WithContext(sc.ctx).Debug("Closing session since ServerSessionKeepAlive is false") // we have to replace context with background, otherwise we can use a one that is cancelled or timed out if err = sc.rest.FuncCloseSession(context.Background(), sc.rest, sc.rest.RequestTimeout); err != nil { logger.WithContext(sc.ctx).Errorf("error while closing session: %v", err) } } else { logger.WithContext(sc.ctx).Info("Skipping session close since ServerSessionKeepAlive is true") } return nil } func (sc *snowflakeConn) PrepareContext( ctx context.Context, query string) ( driver.Stmt, error) { logger.WithContext(sc.ctx).Debugf("Prepare Context") if sc.rest == nil { return nil, driver.ErrBadConn } stmt := &snowflakeStmt{ sc: sc, query: query, } return stmt, nil } func (sc *snowflakeConn) ExecContext( ctx context.Context, query string, args []driver.NamedValue) ( driver.Result, error) { if sc.rest == nil { return nil, driver.ErrBadConn } _, _, sessionID := safeGetTokens(sc.rest) ctx = context.WithValue(ctx, SFSessionIDKey, sessionID) logger.WithContext(ctx).Debug("ExecContext:") noResult := isAsyncMode(ctx) isDesc := isDescribeOnly(ctx) isInternal := isInternal(ctx) ctx = setResultType(ctx, execResultType) data, err := sc.exec(ctx, query, noResult, isInternal, isDesc, args) if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) if data != nil { code, e := strconv.Atoi(data.Code) if e != nil { return nil, e } return nil, exceptionTelemetry(&SnowflakeError{ Number: code, SQLState: data.Data.SQLState, Message: err.Error(), QueryID: data.Data.QueryID, }, sc) } return nil, err } // if async exec, return result object right away if noResult { return data.Data.AsyncResult, nil } if isDml(data.Data.StatementTypeID) { // collects all values from the returned row sets updatedRows, err := updateRows(data.Data) if err != nil { return nil, err } logger.WithContext(ctx).Debugf("number of updated rows: %#v", updatedRows) return &snowflakeResult{ affectedRows: updatedRows, insertID: -1, queryID: data.Data.QueryID, }, nil // last insert id is not supported by Snowflake } else if isMultiStmt(&data.Data) { return sc.handleMultiExec(ctx, data.Data) } else if isDql(&data.Data) { logger.WithContext(ctx).Debug("This query is DQL") if isStatementContext(ctx) { return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil } return driver.ResultNoRows, nil } logger.WithContext(ctx).Debug("This query is DDL") if isStatementContext(ctx) { return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil } return driver.ResultNoRows, nil } func (sc *snowflakeConn) QueryContext( ctx context.Context, query string, args []driver.NamedValue) ( driver.Rows, error) { qid, err := getResumeQueryID(ctx) if err != nil { return nil, err } if qid == "" { return sc.queryContextInternal(ctx, query, args) } // check the query status to find out if there is a result to fetch _, err = sc.checkQueryStatus(ctx, qid) snowflakeErr, isSnowflakeError := err.(*SnowflakeError) if err == nil || (isSnowflakeError && snowflakeErr.Number == ErrQueryIsRunning) { // the query is running. Rows object will be returned from here. return sc.buildRowsForRunningQuery(ctx, qid) } return nil, err } func (sc *snowflakeConn) queryContextInternal( ctx context.Context, query string, args []driver.NamedValue) ( driver.Rows, error) { if sc.rest == nil { return nil, driver.ErrBadConn } _, _, sessionID := safeGetTokens(sc.rest) ctx = context.WithValue(setResultType(ctx, queryResultType), SFSessionIDKey, sessionID) logger.WithContext(ctx).Debug("QueryContextInternal") noResult := isAsyncMode(ctx) isDesc := isDescribeOnly(ctx) isInternal := isInternal(ctx) data, err := sc.exec(ctx, query, noResult, isInternal, isDesc, args) if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) if data != nil { code, e := strconv.Atoi(data.Code) if e != nil { return nil, e } return nil, exceptionTelemetry(&SnowflakeError{ Number: code, SQLState: data.Data.SQLState, Message: err.Error(), QueryID: data.Data.QueryID, }, sc) } return nil, err } // if async query, return row object right away if noResult { return data.Data.AsyncRows, nil } rows := new(snowflakeRows) rows.sc = sc rows.queryID = data.Data.QueryID rows.ctx = ctx if isMultiStmt(&data.Data) { // handleMultiQuery is responsible to fill rows with childResults if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil { return nil, err } } else { rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data)) } err = rows.ChunkDownloader.start() return rows, err } func (sc *snowflakeConn) Prepare(query string) (driver.Stmt, error) { return sc.PrepareContext(context.Background(), query) } func (sc *snowflakeConn) Exec( query string, args []driver.Value) ( driver.Result, error) { return sc.ExecContext(context.Background(), query, toNamedValues(args)) } func (sc *snowflakeConn) Query( query string, args []driver.Value) ( driver.Rows, error) { return sc.QueryContext(context.Background(), query, toNamedValues(args)) } func (sc *snowflakeConn) Ping(ctx context.Context) error { logger.WithContext(ctx).Debug("Ping") if sc.rest == nil { return driver.ErrBadConn } noResult := isAsyncMode(ctx) isDesc := isDescribeOnly(ctx) isInternal := isInternal(ctx) ctx = setResultType(ctx, execResultType) _, err := sc.exec(ctx, "SELECT 1", noResult, isInternal, isDesc, []driver.NamedValue{}) return err } // CheckNamedValue determines which types are handled by this driver aside from // the instances captured by driver.Value func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error { if supportedNullBind(nv) || supportedDecfloatBind(nv) || supportedArrayBind(nv) || supportedStructuredObjectWriterBind(nv) || supportedStructuredArrayBind(nv) || supportedStructuredMapBind(nv) { return nil } return driver.ErrSkip } func (sc *snowflakeConn) GetQueryStatus( ctx context.Context, queryID string) ( *SnowflakeQueryStatus, error) { queryRet, err := sc.checkQueryStatus(ctx, queryID) if err != nil { return nil, err } return &SnowflakeQueryStatus{ queryRet.SQLText, queryRet.StartTime, queryRet.EndTime, queryRet.ErrorCode, queryRet.ErrorMessage, queryRet.Stats.ScanBytes, queryRet.Stats.ProducedRows, }, nil } func (sc *snowflakeConn) AddTelemetryData(_ context.Context, eventDate time.Time, data map[string]string) error { td := &telemetryData{ Timestamp: eventDate.UnixMilli(), Message: data, } return sc.telemetry.addLog(td) } // QueryArrowStream executes a query and returns an ArrowStreamLoader for // streaming raw Arrow IPC record batches from the result. func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bindings ...driver.NamedValue) (ArrowStreamLoader, error) { ctx = ia.EnableArrowBatches(context.WithValue(ctx, asyncMode, false)) ctx = setResultType(ctx, queryResultType) isDesc := isDescribeOnly(ctx) isInternal := isInternal(ctx) data, err := sc.exec(ctx, query, false, isInternal, isDesc, bindings) if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) if data != nil { code, e := strconv.Atoi(data.Code) if e != nil { return nil, e } return nil, exceptionTelemetry(&SnowflakeError{ Number: code, SQLState: data.Data.SQLState, Message: err.Error(), QueryID: data.Data.QueryID, }, sc) } return nil, err } var resultIDs []string if len(data.Data.ResultIDs) > 0 { resultIDs = strings.Split(data.Data.ResultIDs, ",") } scd := &snowflakeArrowStreamChunkDownloader{ sc: sc, ChunkMetas: data.Data.Chunks, Total: data.Data.Total, Qrmk: data.Data.Qrmk, ChunkHeader: data.Data.ChunkHeaders, FuncGet: getChunk, RowSet: rowSetType{ RowType: data.Data.RowType, JSON: data.Data.RowSet, RowSetBase64: data.Data.RowSetBase64, }, resultIDs: resultIDs, } if scd.hasNextResultSet() { if err = scd.NextResultSet(ctx); err != nil { return nil, err } } return scd, nil } // buildSnowflakeConn creates a new snowflakeConn. // The provided context is used only for establishing the initial connection. func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) { sc := &snowflakeConn{ sequenceCounter: 0, ctx: ctx, cfg: &config, currentTimeProvider: defaultTimeProvider, } initPlatformDetection() err := initEasyLogging(config.ClientConfigFile) if err != nil { return nil, err } logger.Debugf("Building snowflakeConn: %v", fmt.Sprintf("host: %v, account: %v, user: %v, password existed: %v, role: %v, database: %v, schema: %v, warehouse: %v, %v", config.Host, config.Account, config.User, config.Password != "", config.Role, config.Database, config.Schema, config.Warehouse, sfconfig.DescribeProxy(&config))) telemetry := &snowflakeTelemetry{} transportFactory := newTransportFactory(&config, telemetry) st, err := transportFactory.createTransport(defaultTransportConfigs.forTransportType(transportTypeSnowflake)) if err != nil { return nil, err } var tokenAccessor TokenAccessor if sc.cfg.TokenAccessor != nil { tokenAccessor = sc.cfg.TokenAccessor } else { tokenAccessor = getSimpleTokenAccessor() } // authenticate sc.rest = &snowflakeRestful{ Host: sc.cfg.Host, Port: sc.cfg.Port, Protocol: sc.cfg.Protocol, Client: &http.Client{ // request timeout including reading response body Timeout: sc.cfg.ClientTimeout, Transport: st, }, JWTClient: &http.Client{ Timeout: sc.cfg.JWTClientTimeout, Transport: st, }, TokenAccessor: tokenAccessor, LoginTimeout: sc.cfg.LoginTimeout, RequestTimeout: sc.cfg.RequestTimeout, MaxRetryCount: sc.cfg.MaxRetryCount, FuncPost: postRestful, FuncGet: getRestful, FuncAuthPost: postAuthRestful, FuncPostQuery: postRestfulQuery, FuncPostQueryHelper: postRestfulQueryHelper, FuncRenewSession: renewRestfulSession, FuncPostAuth: postAuth, FuncCloseSession: closeSession, FuncCancelQuery: cancelQuery, FuncPostAuthSAML: postAuthSAML, FuncPostAuthOKTA: postAuthOKTA, FuncGetSSO: getSSO, } telemetry.sr = sc.rest sc.telemetry = telemetry sc.syncParams = newSyncParams(sc.cfg.Params) return sc, nil } ================================================ FILE: connection_configuration_test.go ================================================ package gosnowflake import ( "database/sql" toml "github.com/BurntSushi/toml" "os" "strconv" "testing" ) // TODO move this test to config package when we have wiremock support in an internal package func TestTomlConnection(t *testing.T) { os.Setenv("SNOWFLAKE_HOME", "./test_data/") // TODO replace with snowflakeHome const os.Setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "toml-connection") // TODO replace with snowflakeConnectionName const defer os.Unsetenv("SNOWFLAKE_HOME") // TODO replace with snowflakeHome const defer os.Unsetenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME") // TODO replace with snowflakeHome const wiremock.registerMappings(t, wiremockMapping{filePath: "auth/password/successful_flow.json"}, wiremockMapping{filePath: "select1.json", params: map[string]string{ "%AUTHORIZATION_HEADER%": "session token", }}, ) type Connection struct { Account string `toml:"account"` User string `toml:"user"` Password string `toml:"password"` Host string `toml:"host"` Port string `toml:"port"` Protocol string `toml:"protocol"` } type TomlStruct struct { Connection Connection `toml:"toml-connection"` } cfg := wiremock.connectionConfig() connection := &TomlStruct{ Connection: Connection{ Account: cfg.Account, User: cfg.User, Password: cfg.Password, Host: cfg.Host, Port: strconv.Itoa(cfg.Port), Protocol: cfg.Protocol, }, } f, err := os.OpenFile("./test_data/connections.toml", os.O_APPEND|os.O_WRONLY, 0600) assertNilF(t, err, "Failed to create connections.toml file") defer f.Close() encoder := toml.NewEncoder(f) err = encoder.Encode(connection) assertNilF(t, err, "Failed to parse the config to toml structure") if !isWindows { err = os.Chmod("./test_data/connections.toml", 0600) assertNilF(t, err, "The error occurred because you cannot change the file permission") } db, err := sql.Open("snowflake", "autoConfig") assertNilF(t, err, "The error occurred because the db cannot be established") runSmokeQuery(t, db) } ================================================ FILE: connection_test.go ================================================ package gosnowflake import ( "context" "database/sql" "database/sql/driver" "encoding/json" "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "math/big" "strconv" "net/http" "net/url" "strings" "sync" "sync/atomic" "testing" "time" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/sdk/trace" ) const ( serviceNameStub = "SV" serviceNameAppend = "a" ) func TestInvalidConnection(t *testing.T) { db := openDB(t) if err := db.Close(); err != nil { t.Error("should not cause error in Close") } if err := db.Close(); err != nil { t.Error("should not cause error in the second call of Close") } if _, err := db.ExecContext(context.Background(), "CREATE TABLE OR REPLACE test0(c1 int)"); err == nil { t.Error("should fail to run Exec") } if _, err := db.QueryContext(context.Background(), "SELECT CURRENT_TIMESTAMP()"); err == nil { t.Error("should fail to run Query") } if _, err := db.BeginTx(context.Background(), nil); err == nil { t.Error("should fail to run Begin") } } // postQueryMock generates a response based on the X-Snowflake-Service header, // to generate a response with the SERVICE_NAME field appending a character at // the end of the header. This way it could test both the send and receive logic func postQueryMock(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { var serviceName string if serviceHeader, ok := headers[httpHeaderServiceName]; ok { serviceName = serviceHeader + serviceNameAppend } else { serviceName = serviceNameStub } dd := &execResponseData{ Parameters: []nameValueParameter{{"SERVICE_NAME", serviceName}}, } return &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, }, nil } func TestExecWithEmptyRequestID(t *testing.T) { ctx := WithRequestID(context.Background(), nilUUID) postQueryMock := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, requestID UUID, _ *Config) (*execResponse, error) { // ensure the same requestID from context is used if len(requestID) == 0 { t.Fatal("requestID is empty") } dd := &execResponseData{} return &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, }, nil } sr := &snowflakeRestful{ FuncPostQuery: postQueryMock, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, } if _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ false /* describeOnly */, nil); err != nil { t.Fatalf("err: %v", err) } } func TestGetQueryResultUsesTokenFromTokenAccessor(t *testing.T) { ta := getSimpleTokenAccessor() token := "snowflake-test-token" ta.SetTokens(token, "", 1) funcGetMock := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ time.Duration) (*http.Response, error) { if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, token) { t.Fatalf("header authorization key is not correct: %v", headers[headerAuthorizationKey]) } dd := &execResponseData{} er := &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, } ba, err := json.Marshal(er) if err != nil { t.Fatalf("err: %v", err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } sr := &snowflakeRestful{ FuncGet: funcGetMock, TokenAccessor: ta, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, currentTimeProvider: defaultTimeProvider, } if _, err := sc.getQueryResultResp(context.Background(), ""); err != nil { t.Fatalf("err: %v", err) } } func TestGetQueryResultTokenExpiry(t *testing.T) { ta := getSimpleTokenAccessor() token := "snowflake-test-token" ta.SetTokens(token, "", 1) funcGetMock := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ time.Duration) (*http.Response, error) { respData := execResponseData{} er := &execResponse{ Data: respData, Message: "", Code: sessionExpiredCode, Success: true, } ba, err := json.Marshal(er) if err != nil { t.Fatalf("err: %v", err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } expectedToken := "new token" expectedMaster := "new master" expectedSession := int64(321) renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error { ta.SetTokens(expectedToken, expectedMaster, expectedSession) return nil } sr := &snowflakeRestful{ FuncGet: funcGetMock, FuncRenewSession: renewSessionDummy, TokenAccessor: ta, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, currentTimeProvider: defaultTimeProvider, } _, err := sc.getQueryResultResp(context.Background(), "") assertNilF(t, err, fmt.Sprintf("err: %v", err)) updatedToken, updatedMaster, updatedSession := ta.GetTokens() assertEqualF(t, updatedToken, expectedToken) assertEqualF(t, updatedMaster, expectedMaster) assertEqualF(t, updatedSession, expectedSession) } func TestGetQueryResultTokenNotSet(t *testing.T) { ta := getSimpleTokenAccessor() funcGetMock := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ time.Duration) (*http.Response, error) { respData := execResponseData{} er := &execResponse{ Data: respData, Message: "", Code: sessionExpiredCode, Success: true, } ba, err := json.Marshal(er) if err != nil { t.Fatalf("err: %v", err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } expectedToken := "new token" expectedMaster := "new master" expectedSession := int64(321) renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error { ta.SetTokens(expectedToken, expectedMaster, expectedSession) return nil } sr := &snowflakeRestful{ FuncGet: funcGetMock, FuncRenewSession: renewSessionDummy, TokenAccessor: ta, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, currentTimeProvider: defaultTimeProvider, } _, err := sc.getQueryResultResp(context.Background(), "") assertNilF(t, err, fmt.Sprintf("err: %v", err)) updatedToken, updatedMaster, updatedSession := ta.GetTokens() assertEqualF(t, updatedToken, expectedToken) assertEqualF(t, updatedMaster, expectedMaster) assertEqualF(t, updatedSession, expectedSession) } func TestCheckNamedValue(t *testing.T) { sc := &snowflakeConn{} t.Run("dont panic on nil UUID", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("expected not to panic, but did panic") } }() var nilUUID *UUID nv := driver.NamedValue{Value: nilUUID} err := sc.CheckNamedValue(&nv) // should not panic and return false assertErrIsE(t, err, driver.ErrSkip, "expected not to support binding nil *UUID") }) t.Run("dont panic on nil pointer array", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("expected not to panic, but did panic") } }() var nilArray *[]string nv := driver.NamedValue{Value: nilArray} err := sc.CheckNamedValue(&nv) // should not panic and return false assertErrIsE(t, err, driver.ErrSkip, "expected not to support binding nil []string") }) t.Run("dont panic on nil pointer", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("expected not to panic, but did panic") } }() var nilTime *time.Time nv := driver.NamedValue{Value: nilTime} err := sc.CheckNamedValue(&nv) // should not panic and return false assertErrIsE(t, err, driver.ErrSkip, "expected not to support binding nil *time.Time") }) t.Run("dont panic on nil *big.Float", func(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("expected not to panic, but did panic") } }() var nilBigFloat *big.Float nv := driver.NamedValue{Value: nilBigFloat} err := sc.CheckNamedValue(&nv) // should not panic and return false assertErrIsE(t, err, driver.ErrSkip, "expected not to support binding nil *big.Float") }) t.Run("Is Valid for big.Float", func(t *testing.T) { val := big.NewFloat(123.456) nv := driver.NamedValue{Value: val} err := sc.CheckNamedValue(&nv) assertNilE(t, err, "expected to support binding big.Float") }) t.Run("Is Not Valid for other types", func(t *testing.T) { val := 123.456 // float64 nv := driver.NamedValue{Value: val} err := sc.CheckNamedValue(&nv) assertErrIsE(t, err, driver.ErrSkip, "expected not to support binding float64") }) } func TestExecWithSpecificRequestID(t *testing.T) { origRequestID := NewUUID() ctx := WithRequestID(context.Background(), origRequestID) postQueryMock := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, requestID UUID, _ *Config) (*execResponse, error) { // ensure the same requestID from context is used if requestID != origRequestID { t.Fatal("requestID doesn't match") } dd := &execResponseData{} return &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, }, nil } sr := &snowflakeRestful{ FuncPostQuery: postQueryMock, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, } if _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ false /* describeOnly */, nil); err != nil { t.Fatalf("err: %v", err) } } func TestExecContextPropagationIntegrationTest(t *testing.T) { originalTracerProvider := otel.GetTracerProvider() tp := trace.NewTracerProvider() otel.SetTracerProvider(tp) t.Cleanup(func() { otel.SetTracerProvider(originalTracerProvider) }) tracer := otel.Tracer("TestExecContextPropagationTracer") ctx, span := tracer.Start(context.Background(), "test-span") defer span.End() traceID := span.SpanContext().TraceID().String() spanID := span.SpanContext().SpanID().String() // expected header values expectedTraceparent := fmt.Sprintf("00-%s-%s-01", traceID, spanID) postQueryMock := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { // ensure the traceID and spanID from the ctx passed in has been injected into the headers // in W3 Trace Context format assertEqualE(t, headers["traceparent"], expectedTraceparent) dd := &execResponseData{} return &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, }, nil } sr := &snowflakeRestful{ FuncPostQuery: postQueryMock, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, } _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ false /* describeOnly */, nil) assertNilF(t, err) } // TestServiceName tests two things: // 1. request header contains X-Snowflake-Service if the cfg parameters // contains SERVICE_NAME // 2. SERVICE_NAME is updated by response payload // Uses interactive postQueryMock that generates a response based on header func TestServiceName(t *testing.T) { sr := &snowflakeRestful{ FuncPostQuery: postQueryMock, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, } expectServiceName := serviceNameStub for range 5 { _, err := sc.exec(context.Background(), "", false, /* noResult */ false /* isInternal */, false /* describeOnly */, nil) assertNilF(t, err) if actualServiceName, ok := sc.syncParams.get(serviceName); ok { if *actualServiceName != expectServiceName { t.Errorf("service name mis-match. expected %v, actual %v", expectServiceName, actualServiceName) } } else { t.Error("No service name in the response") } expectServiceName += serviceNameAppend } } var closedSessionCount = 0 var testTelemetry = &snowflakeTelemetry{ mutex: &sync.Mutex{}, } func closeSessionMock(_ context.Context, _ *snowflakeRestful, _ time.Duration) error { closedSessionCount++ return &SnowflakeError{ Number: ErrSessionGone, } } func TestCloseIgnoreSessionGone(t *testing.T) { sr := &snowflakeRestful{ FuncCloseSession: closeSessionMock, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, telemetry: testTelemetry, } if sc.Close() != nil { t.Error("Close should let go session gone error") } } func TestClientSessionPersist(t *testing.T) { sr := &snowflakeRestful{ FuncCloseSession: closeSessionMock, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, telemetry: testTelemetry, } sc.cfg.ServerSessionKeepAlive = true count := closedSessionCount if sc.Close() != nil { t.Error("Connection close should not return error") } if count != closedSessionCount { t.Fatal("close session was called") } } func TestFetchResultByQueryID(t *testing.T) { wiremock.registerMappings(t, wiremockMapping{filePath: "auth/password/successful_flow.json"}, wiremockMapping{filePath: "query/query_execution.json"}, wiremockMapping{filePath: "query/query_monitoring.json"}, ) cfg := wiremock.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() conn, err := db.Conn(context.Background()) assertNilF(t, err) defer conn.Close() var qid string err = conn.Raw(func(x any) error { rows1, err := x.(driver.QueryerContext).QueryContext(context.Background(), "SELECT 1", nil) if err != nil { return err } defer rows1.Close() qid = rows1.(SnowflakeRows).GetQueryID() return nil }) assertNilF(t, err) ctx := WithFetchResultByID(context.Background(), qid) rows2, err := db.QueryContext(ctx, "") assertNilF(t, err) closeCh := make(chan bool, 1) rows2ext := &RowsExtended{rows: rows2, closeChan: &closeCh, t: t} defer rows2ext.Close() var ms, sum int rows2ext.mustNext() rows2ext.mustScan(&ms, &sum) assertEqualE(t, ms, 1) assertEqualE(t, sum, 5050) } func TestFetchRunningQueryByID(t *testing.T) { wiremock.registerMappings(t, wiremockMapping{filePath: "auth/password/successful_flow.json"}, wiremockMapping{filePath: "query/query_execution.json"}, wiremockMapping{filePath: "query/query_monitoring_running.json"}, ) cfg := wiremock.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() conn, err := db.Conn(context.Background()) assertNilF(t, err) defer conn.Close() var qid string err = conn.Raw(func(x any) error { rows1, err := x.(driver.QueryerContext).QueryContext(context.Background(), "SELECT 1", nil) if err != nil { return err } defer rows1.Close() qid = rows1.(SnowflakeRows).GetQueryID() return nil }) assertNilF(t, err) ctx := WithFetchResultByID(context.Background(), qid) rows2, err := db.QueryContext(ctx, "") assertNilF(t, err) closeCh := make(chan bool, 1) rows2ext := &RowsExtended{rows: rows2, closeChan: &closeCh, t: t} defer rows2ext.Close() var ms, sum int rows2ext.mustNext() rows2ext.mustScan(&ms, &sum) assertEqualE(t, ms, 1) assertEqualE(t, sum, 5050) } func TestFetchErrorQueryByID(t *testing.T) { wiremock.registerMappings(t, wiremockMapping{filePath: "auth/password/successful_flow.json"}, wiremockMapping{filePath: "query/query_execution.json"}, wiremockMapping{filePath: "query/query_monitoring_error.json"}, ) cfg := wiremock.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() conn, err := db.Conn(context.Background()) assertNilF(t, err) defer conn.Close() var qid string err = conn.Raw(func(x any) error { rows1, err := x.(driver.QueryerContext).QueryContext(context.Background(), "SELECT 1", nil) if err != nil { return err } defer rows1.Close() qid = rows1.(SnowflakeRows).GetQueryID() return nil }) assertNilF(t, err) ctx := WithFetchResultByID(context.Background(), qid) _, err = db.QueryContext(ctx, "") assertNotNilF(t, err, "Expected error when fetching failed query") var se *SnowflakeError assertErrorsAsF(t, err, &se) assertEqualE(t, se.Number, ErrQueryReportedError) } func TestFetchMalformedJsonQueryByID(t *testing.T) { wiremock.registerMappings(t, wiremockMapping{filePath: "auth/password/successful_flow.json"}, wiremockMapping{filePath: "query/query_execution.json"}, wiremockMapping{filePath: "query/query_monitoring_malformed.json"}, ) cfg := wiremock.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() // Execute a query to get a query ID using raw connection conn, err := db.Conn(context.Background()) assertNilF(t, err) defer conn.Close() var qid string err = conn.Raw(func(x any) error { rows1, err := x.(driver.QueryerContext).QueryContext(context.Background(), "SELECT 1", nil) if err != nil { return err } defer rows1.Close() qid = rows1.(SnowflakeRows).GetQueryID() return nil }) assertNilF(t, err) ctx := WithFetchResultByID(context.Background(), qid) _, err = db.QueryContext(ctx, "") assertNotNilF(t, err, "Expected error when fetching malformed JSON") assertStringContainsF(t, err.Error(), "invalid character") } func TestIsPrivateLink(t *testing.T) { for _, tc := range []struct { host string isPrivatelink bool }{ {"testaccount.us-east-1.snowflakecomputing.com", false}, {"testaccount-no-privatelink.snowflakecomputing.com", false}, {"testaccount.us-east-1.privatelink.snowflakecomputing.com", true}, {"testaccount.cn-region.snowflakecomputing.cn", false}, {"testaccount.cn-region.privaTELINk.snowflakecomputing.cn", true}, {"testaccount.some-region.privatelink.snowflakecomputing.mil", true}, {"testaccount.us-east-1.privatelink.snowflakecOMPUTING.com", true}, {"snowhouse.snowflakecomputing.xyz", false}, {"snowhouse.privatelink.snowflakecomputing.xyz", true}, {"snowhouse.PRIVATELINK.snowflakecomputing.xyz", true}, } { t.Run(tc.host, func(t *testing.T) { assertEqualE(t, checkIsPrivateLink(tc.host), tc.isPrivatelink) }) } } func TestBuildPrivatelinkConn(t *testing.T) { ov := newOcspValidator(&Config{ Host: "testaccount.us-east-1.privatelink.snowflakecomputing.com", Account: "testaccount", User: "testuser", Password: "testpassword", }) assertEqualE(t, ov.cacheServerURL, "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json") assertEqualE(t, ov.retryURL, "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/retry/%v/%v") } func TestOcspAddressesSetup(t *testing.T) { for _, tc := range []struct { host string cacheURL string privateLinkRetryURL string }{ { host: "testaccount.us-east-1.snowflakecomputing.com", cacheURL: fmt.Sprintf("%v/%v", defaultCacheServerHost, cacheFileBaseName), privateLinkRetryURL: "", }, { host: "testaccount-no-privatelink.snowflakecomputing.com", cacheURL: fmt.Sprintf("%v/%v", defaultCacheServerHost, cacheFileBaseName), privateLinkRetryURL: "", }, { host: "testaccount.us-east-1.privatelink.snowflakecomputing.com", cacheURL: "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json", privateLinkRetryURL: "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/retry/%v/%v", }, { host: "testaccount.cn-region.snowflakecomputing.cn", cacheURL: "http://ocsp.testaccount.cn-region.snowflakecomputing.cn/ocsp_response_cache.json", privateLinkRetryURL: "", // not a privatelink env, no need to setup retry URL }, { host: "testaccount.cn-region.privaTELINk.snowflakecomputing.cn", cacheURL: "http://ocsp.testaccount.cn-region.privatelink.snowflakecomputing.cn/ocsp_response_cache.json", privateLinkRetryURL: "http://ocsp.testaccount.cn-region.privatelink.snowflakecomputing.cn/retry/%v/%v", }, { host: "testaccount.some-region.privatelink.snowflakecomputing.mil", cacheURL: "http://ocsp.testaccount.some-region.privatelink.snowflakecomputing.mil/ocsp_response_cache.json", privateLinkRetryURL: "http://ocsp.testaccount.some-region.privatelink.snowflakecomputing.mil/retry/%v/%v", }, } { t.Run(tc.host, func(t *testing.T) { ov := newOcspValidator(&Config{ Host: tc.host, }) assertEqualE(t, ov.cacheServerURL, tc.cacheURL) assertEqualE(t, ov.retryURL, tc.privateLinkRetryURL) }) } } func TestGetQueryStatus(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.mustExec(`create or replace table ut_conn(c1 number, c2 string) as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 from table(generator(rowcount => 100)))`, nil) rows := sct.mustQueryContext(sct.sc.ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil) qid := rows.(SnowflakeResult).GetQueryID() // use conn as type holder for SnowflakeConnection placeholder var conn any = sct.sc qStatus, err := conn.(SnowflakeConnection).GetQueryStatus(sct.sc.ctx, qid) if err != nil { t.Errorf("failed to get query status err = %s", err.Error()) return } if qStatus == nil { t.Error("there was no query status returned") return } if qStatus.ErrorCode != "" || qStatus.ScanBytes <= 0 || qStatus.ProducedRows != 10 { t.Errorf("expected no error. got: %v, scan bytes: %v, produced rows: %v", qStatus.ErrorCode, qStatus.ScanBytes, qStatus.ProducedRows) return } }) } func TestAddTelemetryDataViaSnowflakeConnection(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("auth/password/successful_flow.json"), newWiremockMapping("telemetry/custom_telemetry.json")) cfg := wiremock.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() conn, err := db.Conn(context.Background()) assertNilF(t, err) err = conn.Raw(func(x any) error { m := map[string]string{} m["test_key"] = "test_value" return x.(SnowflakeConnection).AddTelemetryData(context.Background(), time.Now(), m) }) assertNilF(t, err) } func TestConfigureTelemetry(t *testing.T) { for _, enabled := range []bool{true, false} { t.Run(strconv.FormatBool(enabled), func(t *testing.T) { wiremock.registerMappings(t, wiremockMapping{ filePath: "auth/password/successful_flow_with_telemetry.json", params: map[string]string{"%CLIENT_TELEMETRY_ENABLED%": strconv.FormatBool(enabled)}, }, ) cfg := wiremock.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) defer db.Close() conn, err := db.Conn(context.Background()) assertNilF(t, err) err = conn.Raw(func(x any) error { sc := x.(*snowflakeConn) assertEqualE(t, sc.telemetry.enabled, enabled) return nil }) assertNilF(t, err) }) } } func TestGetInvalidQueryStatus(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.sc.rest.RequestTimeout = 1 * time.Second qStatus, err := sct.sc.checkQueryStatus(sct.sc.ctx, "1234") if err == nil || qStatus != nil { t.Error("expected an error") } }) } func TestExecWithServerSideError(t *testing.T) { postQueryMock := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, requestID UUID, _ *Config) (*execResponse, error) { dd := &execResponseData{} return &execResponse{ Data: *dd, Message: "", Code: "", Success: false, }, nil } sr := &snowflakeRestful{ FuncPostQuery: postQueryMock, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, telemetry: testTelemetry, } _, err := sc.exec(context.Background(), "", false, /* noResult */ false /* isInternal */, false /* describeOnly */, nil) if err == nil { t.Error("expected a server side error") } sfe := err.(*SnowflakeError) errUnknownError := errors2.ErrUnknownError() if sfe.Number != -1 || sfe.SQLState != "-1" || sfe.QueryID != "-1" { t.Errorf("incorrect snowflake error. expected: %v, got: %v", errUnknownError, *sfe) } if !strings.Contains(sfe.Message, "an unknown server side error occurred") { t.Errorf("incorrect message. expected: %v, got: %v", errUnknownError.Message, sfe.Message) } } func TestConcurrentReadOnParams(t *testing.T) { config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") } connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) defer db.Close() var successCount, failureCount int32 wg := sync.WaitGroup{} for range 10 { wg.Add(1) go func() { for range 10 { func() { stmt, err := db.PrepareContext(context.Background(), "SELECT table_schema FROM information_schema.columns WHERE table_schema = ? LIMIT 1") if err != nil || stmt == nil { atomic.AddInt32(&failureCount, 1) return // Skip this iteration if PrepareContext fails } defer stmt.Close() rows, err := stmt.Query("INFORMATION_SCHEMA") if err != nil { atomic.AddInt32(&failureCount, 1) return } defer rows.Close() rows.Next() var tableName string err = rows.Scan(&tableName) if err != nil { atomic.AddInt32(&failureCount, 1) } else { atomic.AddInt32(&successCount, 1) } }() } wg.Done() }() } wg.Wait() totalOperations := int32(100) // 10 goroutines × 10 operations each if successCount != totalOperations { t.Errorf("Expected all %d concurrent operations to succeed, got %d successes, %d failures", totalOperations, successCount, failureCount) } else { t.Logf("All %d concurrent operations completed successfully", successCount) } } func postQueryTest(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { return nil, errors.New("failed to get query response") } func postQueryFail(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { dd := &execResponseData{ QueryID: "1eFhmhe23242kmfd540GgGre", SQLState: "22008", } return &execResponse{ Data: *dd, Message: "failed to get query response", Code: "12345", Success: false, }, errors.New("failed to get query response") } func TestErrorReportingOnConcurrentFails(t *testing.T) { db := openDB(t) defer db.Close() var wg sync.WaitGroup n := 5 wg.Add(3 * n) for range n { go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_ABC", "TABLE_ABC", t, &wg) go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_DEF", "TABLE_DEF", t, &wg) go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_GHI", "TABLE_GHI", t, &wg) } wg.Wait() } func executeQueryAndConfirmMessage(db *sql.DB, query string, expectedErrorTable string, t *testing.T, wg *sync.WaitGroup) { defer wg.Done() _, err := db.Exec(query) message := err.(*SnowflakeError).Message if !strings.Contains(message, expectedErrorTable) { t.Errorf("QueryID: %s, Message %s ###### Expected error message table name: %s", err.(*SnowflakeError).QueryID, err.(*SnowflakeError).Message, expectedErrorTable) } } func TestQueryArrowStreamError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { numrows := 50000 query := fmt.Sprintf(selectRandomGenerator, numrows) sct.sc.rest = &snowflakeRestful{ FuncPostQuery: postQueryTest, FuncCloseSession: closeSessionMock, TokenAccessor: getSimpleTokenAccessor(), RequestTimeout: 10, } _, err := sct.sc.QueryArrowStream(sct.sc.ctx, query) if err == nil { t.Error("should have raised an error") } sct.sc.rest.FuncPostQuery = postQueryFail _, err = sct.sc.QueryArrowStream(sct.sc.ctx, query) if err == nil { t.Error("should have raised an error") } _, ok := err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } }) } func TestExecContextError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.sc.rest = &snowflakeRestful{ FuncPostQuery: postQueryTest, FuncCloseSession: closeSessionMock, TokenAccessor: getSimpleTokenAccessor(), RequestTimeout: 10, } _, err := sct.sc.ExecContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) if err == nil { t.Fatalf("should have raised an error") } sct.sc.rest.FuncPostQuery = postQueryFail _, err = sct.sc.ExecContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) if err == nil { t.Fatalf("should have raised an error") } }) } func TestQueryContextError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.sc.rest = &snowflakeRestful{ FuncPostQuery: postQueryTest, FuncCloseSession: closeSessionMock, TokenAccessor: getSimpleTokenAccessor(), RequestTimeout: 10, } _, err := sct.sc.QueryContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) if err == nil { t.Fatalf("should have raised an error") } sct.sc.rest.FuncPostQuery = postQueryFail _, err = sct.sc.QueryContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) if err == nil { t.Fatalf("should have raised an error") } _, ok := err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } }) } func TestPrepareQuery(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { _, err := sct.sc.Prepare("SELECT 1") if err != nil { t.Fatalf("failed to prepare query. err: %v", err) } }) } func TestBeginCreatesTransaction(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { tx, _ := sct.sc.Begin() if tx == nil { t.Fatal("should have created a transaction with connection") } }) } type EmptyTransporter struct{} func (t EmptyTransporter) RoundTrip(*http.Request) (*http.Response, error) { return nil, nil } // castToTransport safely casts http.RoundTripper to *http.Transport // Returns nil if the cast fails func castToTransport(rt http.RoundTripper) *http.Transport { if transport, ok := rt.(*http.Transport); ok { return transport } return nil } func TestGetTransport(t *testing.T) { testcases := []struct { name string cfg *Config transportCheck func(t *testing.T, transport *http.Transport) roundTripperCheck func(t *testing.T, roundTripper http.RoundTripper) }{ { name: "DisableOCSPChecks", cfg: &Config{Account: "one", DisableOCSPChecks: false}, transportCheck: func(t *testing.T, transport *http.Transport) { // We should have a verifier function assertNotNilF(t, transport) assertNotNilF(t, transport.TLSClientConfig) assertNotNilF(t, transport.TLSClientConfig.VerifyPeerCertificate) }, }, { name: "DisableOCSPChecks missing from Config", cfg: &Config{Account: "four"}, transportCheck: func(t *testing.T, transport *http.Transport) { // We should have a verifier function assertNotNilF(t, transport) assertNotNilF(t, transport.TLSClientConfig) assertNotNilF(t, transport.TLSClientConfig.VerifyPeerCertificate) }, }, { name: "whole Config is missing", cfg: nil, transportCheck: func(t *testing.T, transport *http.Transport) { // We should not have a TLSClientConfig assertNotNilF(t, transport) assertNilF(t, transport.TLSClientConfig) }, }, { name: "Using custom Transporter", cfg: &Config{Account: "five", DisableOCSPChecks: true, Transporter: EmptyTransporter{}}, roundTripperCheck: func(t *testing.T, roundTripper http.RoundTripper) { // We should have a custom Transporter assertNotNilF(t, roundTripper) assertTrueE(t, roundTripper == EmptyTransporter{}) }, }, } for _, test := range testcases { t.Run(test.name, func(t *testing.T) { result, err := newTransportFactory(test.cfg, nil).createTransport(transportConfigFor(transportTypeSnowflake)) assertNilE(t, err) if test.transportCheck != nil { test.transportCheck(t, castToTransport(result)) } if test.roundTripperCheck != nil { test.roundTripperCheck(t, result) } }) } } func TestGetCRLTransport(t *testing.T) { t.Run("Using CRLs", func(t *testing.T) { crlCfg := &Config{ CertRevocationCheckMode: CertRevocationCheckEnabled, DisableOCSPChecks: true, } transportFactory := newTransportFactory(crlCfg, nil) crlRoundTripper, err := transportFactory.createTransport(transportConfigFor(transportTypeCRL)) assertNilF(t, err) transport := castToTransport(crlRoundTripper) assertNotNilF(t, transport, "Expected http.Transport") assertEqualE(t, transport.MaxIdleConns, defaultTransportConfigs.forTransportType(transportTypeCRL).MaxIdleConns) }) } ================================================ FILE: connection_util.go ================================================ package gosnowflake import ( "bytes" "context" "errors" "fmt" "io" "maps" "runtime" "strconv" "strings" "sync" "time" ) func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { v, ok := sc.syncParams.get(sessionClientSessionKeepAlive) if !ok { return false } return strings.Compare(*v, "true") == 0 } func (sc *snowflakeConn) getClientSessionKeepAliveHeartbeatFrequency() (time.Duration, bool) { v, ok := sc.syncParams.get(sessionClientSessionKeepAliveHeartbeatFrequency) if !ok { return 0, false } num, err := strconv.Atoi(*v) if err != nil { logger.Warnf("Failed to parse client session keepalive heartbeat frequency: %v. Falling back to default.", err) return 0, false } return time.Duration(num) * time.Second, true } func (sc *snowflakeConn) startHeartBeat() { if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return } if sc.rest != nil { if heartbeatFrequency, ok := sc.getClientSessionKeepAliveHeartbeatFrequency(); ok { sc.rest.HeartBeat = newHeartBeat(sc.rest, heartbeatFrequency) } else { sc.rest.HeartBeat = newDefaultHeartBeat(sc.rest) } logger.WithContext(sc.ctx).Debug("Start heart beat") sc.rest.HeartBeat.start() } } func (sc *snowflakeConn) stopHeartBeat() { if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return } if sc.rest != nil && sc.rest.HeartBeat != nil { logger.WithContext(sc.ctx).Debug("Stop heart beat") sc.rest.HeartBeat.stop() } } func (sc *snowflakeConn) getArrayBindStageThreshold() int { v, ok := sc.syncParams.get(sessionArrayBindStageThreshold) if !ok { return 0 } num, err := strconv.Atoi(*v) if err != nil { return 0 } return num } func (sc *snowflakeConn) connectionTelemetry(cfg *Config) { data := &telemetryData{ Message: map[string]string{ typeKey: connectionParameters, sourceKey: telemetrySource, driverTypeKey: "Go", driverVersionKey: SnowflakeGoDriverVersion, golangVersionKey: runtime.Version(), }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), } maps.Insert(data.Message, sc.syncParams.All()) if err := sc.telemetry.addLog(data); err != nil { logger.WithContext(sc.ctx).Warnf("cannot add telemetry log: %v", err) } if err := sc.telemetry.sendBatch(); err != nil { logger.WithContext(sc.ctx).Warnf("cannot send telemetry batch: %v", err) } } // processFileTransfer creates a snowflakeFileTransferAgent object to process // any PUT/GET commands with their specified options func (sc *snowflakeConn) processFileTransfer( ctx context.Context, data *execResponse, query string, isInternal bool) ( *execResponse, error) { options := &SnowflakeFileTransferOptions{} sfa := snowflakeFileTransferAgent{ ctx: ctx, sc: sc, data: &data.Data, command: query, options: options, streamBuffer: new(bytes.Buffer), } fs, err := getFileStream(ctx) if err != nil { return nil, err } if fs != nil { sfa.sourceStream = fs if isInternal { sfa.data.AutoCompress = false } } if op := getFileTransferOptions(ctx); op != nil { sfa.options = op } if sfa.options.MultiPartThreshold == 0 { sfa.options.MultiPartThreshold = multiPartThreshold // for streaming download, use a smaller default part size if sfa.commandType == downloadCommand && isFileGetStream(ctx) { sfa.options.MultiPartThreshold = streamingMultiPartThreshold } } if err := sfa.execute(); err != nil { return nil, err } data, err = sfa.result() if err != nil { return nil, err } if sfa.options != nil && isFileGetStream(ctx) { if err := writeFileStream(ctx, sfa.streamBuffer); err != nil { return nil, err } } return data, nil } func getFileStream(ctx context.Context) (io.Reader, error) { s := ctx.Value(filePutStream) if s == nil { return nil, nil } r, ok := s.(io.Reader) if !ok { return nil, errors.New("incorrect io.Reader") } return r, nil } func isFileGetStream(ctx context.Context) bool { v := ctx.Value(fileGetStream) return v != nil } func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions { v := ctx.Value(fileTransferOptions) if v == nil { return nil } o, ok := v.(*SnowflakeFileTransferOptions) if !ok { return nil } return o } func writeFileStream(ctx context.Context, streamBuf *bytes.Buffer) error { s := ctx.Value(fileGetStream) w, ok := s.(io.Writer) if !ok { return errors.New("expected an io.Writer") } _, err := streamBuf.WriteTo(w) if err != nil { return err } return nil } func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParameter) { // other session parameters (not all) logger.WithContext(sc.ctx).Tracef("params: %#v", parameters) for _, param := range parameters { v := "" switch param.Value.(type) { case int64: if vv, ok := param.Value.(int64); ok { v = strconv.FormatInt(vv, 10) } case float64: if vv, ok := param.Value.(float64); ok { v = strconv.FormatFloat(vv, 'g', -1, 64) } case bool: if vv, ok := param.Value.(bool); ok { v = strconv.FormatBool(vv) } default: if vv, ok := param.Value.(string); ok { v = vv } } logger.WithContext(sc.ctx).Tracef("parameter. name: %v, value: %v", param.Name, v) sc.syncParams.set(strings.ToLower(param.Name), &v) } } func (sc *snowflakeConn) configureTelemetry() { telemetryEnabled, ok := sc.syncParams.get("client_telemetry_enabled") // In-band telemetry is enabled by default on the backend side. if ok && telemetryEnabled != nil && *telemetryEnabled == "true" { sc.telemetry.flushSize = defaultFlushSize sc.telemetry.sr = sc.rest sc.telemetry.mutex = &sync.Mutex{} sc.telemetry.enabled = true } } func isAsyncMode(ctx context.Context) bool { return isBooleanContextEnabled(ctx, asyncMode) } func isDescribeOnly(ctx context.Context) bool { return isBooleanContextEnabled(ctx, describeOnly) } func isInternal(ctx context.Context) bool { return isBooleanContextEnabled(ctx, internalQuery) } func isLogQueryTextEnabled(ctx context.Context) bool { return isBooleanContextEnabled(ctx, logQueryText) } func isLogQueryParametersEnabled(ctx context.Context) bool { return isBooleanContextEnabled(ctx, logQueryParameters) } func isBooleanContextEnabled(ctx context.Context, key ContextKey) bool { v := ctx.Value(key) if v == nil { return false } d, ok := v.(bool) return ok && d } func setResultType(ctx context.Context, resType resultType) context.Context { return context.WithValue(ctx, snowflakeResultType, resType) } func getResultType(ctx context.Context) resultType { return ctx.Value(snowflakeResultType).(resultType) } // isDml returns true if the statement type code is in the range of DML. func isDml(v int64) bool { return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert } func isDql(data *execResponseData) bool { return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data) } func updateRows(data execResponseData) (int64, error) { if data.RowSet == nil { return 0, nil } var count int64 for i, n := 0, len(data.RowType); i < n; i++ { v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64) if err != nil { return -1, err } count += v } return count, nil } // isMultiStmt returns true if the statement code is of type multistatement // Note that the statement type code is also equivalent to type INSERT, so an // additional check of the name is required func isMultiStmt(data *execResponseData) bool { var isMultistatementByReturningSelect = data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name == "multiple statement execution" return isMultistatementByReturningSelect || data.StatementTypeID == statementTypeIDMultistatement } func getResumeQueryID(ctx context.Context) (string, error) { val := ctx.Value(fetchResultByID) if val == nil { return "", nil } strVal, ok := val.(string) if !ok { return "", fmt.Errorf("failed to cast val %+v to string", val) } // so there is a queryID in context for which we want to fetch the result if !queryIDRegexp.MatchString(strVal) { return strVal, &SnowflakeError{ Number: ErrQueryIDFormat, Message: "Invalid QID", QueryID: strVal, } } return strVal, nil } // returns snowflake chunk downloader by default or stream based chunk // downloader if option provided through context func populateChunkDownloader( ctx context.Context, sc *snowflakeConn, data execResponseData) chunkDownloader { return &snowflakeChunkDownloader{ sc: sc, ctx: ctx, pool: getAllocator(ctx), CurrentChunk: make([]chunkRowType, len(data.RowSet)), ChunkMetas: data.Chunks, Total: data.Total, TotalRowIndex: int64(-1), CellCount: len(data.RowType), Qrmk: data.Qrmk, QueryResultFormat: data.QueryResultFormat, ChunkHeader: data.ChunkHeaders, FuncDownload: downloadChunk, FuncDownloadHelper: downloadChunkHelper, FuncGet: getChunk, RowSet: rowSetType{ RowType: data.RowType, JSON: data.RowSet, RowSetBase64: data.RowSetBase64, }, } } /** * We can only tell if private link is enabled for certain hosts when the hostname contains the subdomain * 'privatelink.snowflakecomputing.' but we don't have a good way of telling if a private link connection is * expected for internal stages for example. */ func checkIsPrivateLink(host string) bool { return strings.Contains(strings.ToLower(host), ".privatelink.snowflakecomputing.") } func isStatementContext(ctx context.Context) bool { v := ctx.Value(executionType) return v == executionTypeStatement } ================================================ FILE: connectivity_diagnosis.go ================================================ package gosnowflake import ( "context" "crypto/x509" "encoding/json" "encoding/pem" "errors" "fmt" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "io" "net" "net/http" "net/url" "os" "slices" "strconv" "strings" "time" ) type connectivityDiagnoser struct { diagnosticClient *http.Client } func newConnectivityDiagnoser(cfg *Config) *connectivityDiagnoser { return &connectivityDiagnoser{ diagnosticClient: createDiagnosticClient(cfg), } } type allowlistEntry struct { Host string `json:"host"` Port int `json:"port"` Type string `json:"type"` } type allowlist struct { Entries []allowlistEntry } // acceptable HTTP status codes for connectivity diagnosis // for the sake of connectivity, e.g. HTTP403 from AWS S3 is perfectly fine // GCS bucket and Azure blob responds HTTP400 upon connecting with plain GET, its okay from connection standpoint var connDiagAcceptableStatusCodes = []int{http.StatusOK, http.StatusForbidden, http.StatusBadRequest} // map of already-fetched CRLs to not test them more than once as they can be quite large var connDiagTestedCrls = make(map[string]string) // create a diagnostic client with the appropriate transport for the given config func createDiagnosticClient(cfg *Config) *http.Client { transport := createDiagnosticTransport(cfg) clientTimeout := cfg.ClientTimeout if clientTimeout == 0 { clientTimeout = time.Duration(sfconfig.DefaultClientTimeout) } return &http.Client{ Timeout: clientTimeout, Transport: transport, } } // necessary to be able to log the IP address of the remote host to which we actually connected // might be even different from the result of DNS resolution func createDiagnosticDialContext() func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } return func(ctx context.Context, network, addr string) (net.Conn, error) { conn, err := dialer.DialContext(ctx, network, addr) if err != nil { return nil, err } if remoteAddr := conn.RemoteAddr(); remoteAddr != nil { remoteIPStr := remoteAddr.String() // parse out just the IP (maybe port is present) if host, _, err := net.SplitHostPort(remoteIPStr); err == nil { remoteIPStr = host } // get hostname hostname, _, _ := net.SplitHostPort(addr) if hostname == "" { hostname = addr } logger.Infof("[createDiagnosticDialContext] Connected to %s (remote IP: %s)", hostname, remoteIPStr) } return conn, nil } } // enhance the transport with IP logging func createDiagnosticTransport(cfg *Config) *http.Transport { baseTransport, err := newTransportFactory(cfg, &snowflakeTelemetry{enabled: false}).createTransport(transportConfigFor(transportTypeSnowflake)) if err != nil { logger.Fatalf("[createDiagnosticTransport] failed to get the transport from the config: %v", err) } if baseTransport == nil { logger.Fatal("[createDiagnosticTransport] transport from config is nil") } var httpTransport = baseTransport.(*http.Transport) // return a new transport enhanced with remote IP logging // for SnowflakeNoOcspTransport, TLSClientConfig is nil return &http.Transport{ TLSClientConfig: httpTransport.TLSClientConfig, MaxIdleConns: httpTransport.MaxIdleConns, IdleConnTimeout: httpTransport.IdleConnTimeout, Proxy: httpTransport.Proxy, DialContext: createDiagnosticDialContext(), } } func (cd *connectivityDiagnoser) openAndReadAllowlistJSON(filePath string) (allowlist allowlist, err error) { if filePath == "" { logger.Info("[openAndReadAllowlistJSON] allowlist.json location not specified, trying to load from current directory.") filePath = "allowlist.json" } logger.Infof("[openAndReadAllowlistJSON] reading allowlist from %s.", filePath) fileContent, err := os.ReadFile(filePath) if err != nil { return allowlist, err } logger.Debug("[openAndReadAllowlistJSON] parsing allowlist.json") err = json.Unmarshal(fileContent, &allowlist.Entries) return allowlist, err } // look up the host, using the local resolver func (cd *connectivityDiagnoser) resolveHostname(hostname string) { ips, err := net.LookupIP(hostname) if err != nil { logger.Errorf("[resolveHostname] error resolving hostname %s: %s", hostname, err) return } for _, ip := range ips { logger.Infof("[resolveHostname] resolved hostname %s to %s", hostname, ip.String()) if checkIsPrivateLink(hostname) && !ip.IsPrivate() { logger.Errorf("[resolveHostname] this hostname %s should resolve to a private IP, but %s is public IP. Please, check your DNS configuration.", hostname, ip.String()) } } } func (cd *connectivityDiagnoser) isAcceptableStatusCode(statusCode int, acceptableCodes []int) bool { return slices.Contains(acceptableCodes, statusCode) } func (cd *connectivityDiagnoser) fetchCRL(uri string) error { if _, ok := connDiagTestedCrls[uri]; ok { logger.Infof("[fetchCRL] CRL for %s already fetched and parsed.", uri) return nil } logger.Infof("[fetchCRL] fetching %s", uri) req, err := cd.createRequest(uri) if err != nil { logger.Errorf("[fetchCRL] error creating request: %v", err) return err } resp, err := cd.diagnosticClient.Do(req) if err != nil { return fmt.Errorf("[fetchCRL] HTTP GET to %s endpoint failed: %w", uri, err) } // if closing response body is unsuccessful for some reason defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { logger.Errorf("[fetchCRL] Failed to close response body: %v", err) return } }(resp.Body) if resp.StatusCode != http.StatusOK { return fmt.Errorf("[fetchCRL] HTTP response status from endpoint: %s", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("[fetchCRL] failed to read response body: %w", err) } logger.Infof("[fetchCRL] %s retrieved successfully (%d bytes)", uri, len(body)) logger.Infof("[fetchCRL] Parsing CRL fetched from %s", uri) crl, err := x509.ParseRevocationList(body) if err != nil { return fmt.Errorf("[fetchCRL] Failed to parse CRL: %w", err) } logger.Infof(" CRL Issuer: %s", crl.Issuer) logger.Infof(" This Update: %s", crl.ThisUpdate) logger.Infof(" Next Update: %s", crl.NextUpdate) logger.Infof(" Revoked Certificates#: %s", strconv.Itoa(len(crl.RevokedCertificateEntries))) connDiagTestedCrls[uri] = "" return nil } func (cd *connectivityDiagnoser) doHTTP(request *http.Request) error { if strings.HasPrefix(request.URL.Host, "ocsp.snowflakecomputing.") { fullOCSPCacheURI := request.URL.String() + "/ocsp_response_cache.json" newURL, err := url.Parse(fullOCSPCacheURI) if err != nil { return fmt.Errorf("failed to parse the full OCSP cache URL: %w", err) } request.URL = newURL } logger.Infof("[doHTTP] testing HTTP connection to %s", request.URL.String()) resp, err := cd.diagnosticClient.Do(request) if err != nil { return fmt.Errorf("HTTP GET to %s endpoint failed: %w", request.URL.String(), err) } defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { logger.Errorf("[doHTTP] Failed to close response body: %v", err) return } }(resp.Body) if !cd.isAcceptableStatusCode(resp.StatusCode, connDiagAcceptableStatusCodes) { return fmt.Errorf("HTTP response status from %s endpoint: %s", request.URL.String(), resp.Status) } logger.Infof("[doHTTP] Successfully connected to %s, HTTP response status: %s", request.URL.String(), resp.Status) return nil } func (cd *connectivityDiagnoser) doHTTPSGetCerts(request *http.Request, downloadCRLs bool) error { logger.Infof("[doHTTPSGetCerts] connecting to %s", request.URL.String()) resp, err := cd.diagnosticClient.Do(request) if err != nil { return fmt.Errorf("failed to connect: %w", err) } defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { logger.Errorf("[doHTTPSGetCerts] Failed to close response body: %v", err) return } }(resp.Body) if !cd.isAcceptableStatusCode(resp.StatusCode, connDiagAcceptableStatusCodes) { return fmt.Errorf("HTTP response status from %s endpoint: %s", request.URL.String(), resp.Status) } logger.Infof("[doHTTPSGetCerts] Successfully connected to %s, HTTP response status: %s", request.URL.String(), resp.Status) logger.Debug("[doHTTPSGetCerts] getting TLS connection state") tlsState := resp.TLS if tlsState == nil { return errors.New("no TLS connection state available") } logger.Debug("[doHTTPSGetCerts] getting certificate chain") certs := tlsState.PeerCertificates logger.Infof("[doHTTPSGetCerts] Retrieved %d certificate(s).", len(certs)) // log individual cert details for i, cert := range certs { logger.Infof("[doHTTPSGetCerts] Certificate %d, serial number: %x", i+1, cert.SerialNumber) logger.Infof("[doHTTPSGetCerts] Subject: %s", cert.Subject) logger.Infof("[doHTTPSGetCerts] Issuer: %s", cert.Issuer) logger.Infof("[doHTTPSGetCerts] Valid: %s to %s", cert.NotBefore, cert.NotAfter) logger.Infof("[doHTTPSGetCerts] For further details, check https://crt.sh/?serial=%x (non-Snowflake site)", cert.SerialNumber) // if cert has CRL endpoint, log them too if len(cert.CRLDistributionPoints) > 0 { logger.Infof("[doHTTPSGetCerts] CRL Distribution Points:") for _, dp := range cert.CRLDistributionPoints { logger.Infof("[doHTTPSGetCerts] - %s", dp) // only try to download the actual CRL if configured to do so if downloadCRLs { if err := cd.fetchCRL(dp); err != nil { logger.Errorf("[doHTTPSGetCerts] Failed to fetch or parse CRL: %v", err) } } } } else { logger.Infof("[doHTTPSGetCerts] CRL Distribution Points not included in the certificate.") } // dump the full PEM data too on DEBUG loglevel pemData := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, }) logger.Debug("[doHTTPSGetCerts] certificate PEM:") logger.Debug(string(pemData)) } return nil } func (cd *connectivityDiagnoser) createRequest(uri string) (*http.Request, error) { logger.Infof("[createRequest] creating GET request to %s", uri) req, err := http.NewRequest("GET", uri, nil) if err != nil { return nil, err } return req, nil } func (cd *connectivityDiagnoser) checkProxy(req *http.Request) { diagnosticTransport := cd.diagnosticClient.Transport.(*http.Transport) if diagnosticTransport == nil { logger.Errorf("[checkProxy] diagnosticTransport is nil") return } if diagnosticTransport.Proxy == nil { // no proxy configured, nothing to log return } p, err := diagnosticTransport.Proxy(req) if err != nil { logger.Errorf("[checkProxy] problem determining PROXY: %v", err) } if p != nil { logger.Infof("[checkProxy] PROXY detected in the connection: %v", p) } } func (cd *connectivityDiagnoser) performConnectivityCheck(entryType, host string, port int, downloadCRLs bool) (err error) { var protocol string var req *http.Request switch port { case 80: protocol = "http" case 443: protocol = "https" default: return fmt.Errorf("[performConnectivityCheck] unsupported port: %d", port) } logger.Infof("[performConnectivityCheck] %s check for %s %s", strings.ToUpper(protocol), entryType, host) req, err = cd.createRequest(fmt.Sprintf("%s://%s", protocol, host)) if err != nil { logger.Errorf("[performConnectivityCheck] error creating request: %v", err) return err } cd.checkProxy(req) switch protocol { case "http": err = cd.doHTTP(req) case "https": err = cd.doHTTPSGetCerts(req, downloadCRLs) } if err != nil { logger.Errorf("[performConnectivityCheck] error performing %s check: %v", strings.ToUpper(protocol), err) return err } return nil } func performDiagnosis(cfg *Config, downloadCRLs bool) { allowlistFile := cfg.ConnectionDiagnosticsAllowlistFile logger.Info("[performDiagnosis] starting connectivity diagnosis based on allowlist file.") if downloadCRLs { logger.Info("[performDiagnosis] CRLs will be attempted to be downloaded and parsed during https tests.") } diag := newConnectivityDiagnoser(cfg) allowlist, err := diag.openAndReadAllowlistJSON(allowlistFile) if err != nil { logger.Errorf("[performDiagnosis] error opening and parsing allowlist file: %v", err) return } for _, entry := range allowlist.Entries { host := entry.Host port := entry.Port entryType := entry.Type logger.Infof("[performDiagnosis] DNS check - resolving %s hostname %s", entryType, host) diag.resolveHostname(host) if port == 80 || port == 443 { err := diag.performConnectivityCheck(entryType, host, port, downloadCRLs) if err != nil { continue } } } } ================================================ FILE: connectivity_diagnosis_test.go ================================================ package gosnowflake import ( "bytes" "context" "crypto/tls" "encoding/pem" "fmt" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "time" ) /* * for the tests, we need to capture log output and assert on their content * this is done by creating a fresh logger to log into a buffer and look at that buffer * but we also need to preserve the original global logger to not modify that with tests * and restore original logger after the tests */ func setupTestLogger() (buffer *bytes.Buffer, cleanup func()) { originalLogger := logger testLogger := CreateDefaultLogger() // from log.go buffer = &bytes.Buffer{} testLogger.SetOutput(buffer) _ = testLogger.SetLogLevel("INFO") logger = testLogger cleanup = func() { logger = originalLogger } return buffer, cleanup } func TestSetupTestLogger(t *testing.T) { // save original global logger originalLogger := logger // and restore it after test defer func() { logger = originalLogger }() buffer, cleanup := setupTestLogger() assertNotNilE(t, buffer, "buffer should not be nil") assertNotNilE(t, cleanup, "cleanup function should not be nil") // the test message should be in the buffer testMessage := "test log message for setupTestLogger" logger.Info(testMessage) logOutput := buffer.String() assertStringContainsE(t, logOutput, testMessage, "buffer should capture log output") // now cleanup cleanup() assertEqualE(t, logger, originalLogger, "cleanup should restore original logger") // clear the buffer, log a new message into it // logs should not go to the test logger anymore buffer.Reset() logger.Info("this should not appear in test buffer") assertEqualE(t, buffer.String(), "", "buffer should be empty after cleanup") } // test case types type tcDiagnosticClient struct { name string config *Config expectedTimeout time.Duration } type tcOpenAllowlistJSON struct { name string setup func() (string, func()) shouldError bool expectedLength int } type tcAcceptableStatusCode struct { statusCode int isAcceptable bool } type tcFetchCRL struct { name string setupServer func() *httptest.Server shouldError bool errorContains string } type tcCreateRequest struct { name string uri string shouldError bool } type tcDoHTTP struct { name string setupServer func() *httptest.Server setupRequest func(serverURL string) *http.Request shouldError bool errorContains string } type tcDoHTTPSGetCerts struct { name string setupServer func() *httptest.Server downloadCRLs bool shouldError bool errorContains string } type tcResolveHostname struct { name string hostname string } type tcPerformConnectivityCheck struct { name string entryType string host string port int downloadCRLs bool expectedLog string } func TestCreateDiagnosticClient(t *testing.T) { testcases := []tcDiagnosticClient{ { name: "Diagnostic Client with default timeout", config: &Config{ ClientTimeout: 0, }, expectedTimeout: sfconfig.DefaultClientTimeout, }, { name: "Diagnostic Client with custom timeout", config: &Config{ ClientTimeout: 60 * time.Second, }, expectedTimeout: 60 * time.Second, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { client := createDiagnosticClient(tc.config) assertNotNilE(t, client, "client should not be nil") assertEqualE(t, client.Timeout, tc.expectedTimeout, "timeout did not match") assertNotNilE(t, client.Transport, "transport should not be nil") }) } } func TestCreateDiagnosticDialContext(t *testing.T) { dialContext := createDiagnosticDialContext() assertNotNilE(t, dialContext, "dialContext should not be nil") // new simple server to test basic connectivity server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() u, _ := url.Parse(server.URL) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err := dialContext(ctx, "tcp", u.Host) assertNilE(t, err, "error should be nil") } func TestOpenAndReadAllowlistJSON(t *testing.T) { var diagTest connectivityDiagnoser testcases := []tcOpenAllowlistJSON{ { name: "Open and Read Allowlist - valid file path, 2 entries", // create a temp allowlist file and then delete it setup: func() (filePath string, cleanup func()) { content := `[{"host":"myaccount.snowflakecomputing.com","port":443,"type":"SNOWFLAKE_DEPLOYMENT"},{"host":"ocsp.snowflakecomputing.com","port":80,"type":"OCSP_CACHE"}]` tmpFile, err := os.CreateTemp("", "allowlist_*.json") assertNilF(t, err, "Error during creating temp allowlist file.") _, err = tmpFile.WriteString(content) assertNilF(t, err, "Error during writing temp allowlist file.") tmpFile.Close() return tmpFile.Name(), func() { os.Remove(tmpFile.Name()) } }, shouldError: false, expectedLength: 2, }, { name: "Open and Read Allowlist - empty file path", setup: func() (filePath string, cleanup func()) { content := `[{"host":"myaccount.snowflakecomputing.com","port":443,"type":"SNOWFLAKE_DEPLOYMENT"}]` _ = os.WriteFile("allowlist.json", []byte(content), 0644) return "", func() { os.Remove("allowlist.json") } }, shouldError: false, expectedLength: 1, }, { name: "Open and Read Allowlist - non existent file", setup: func() (filePath string, cleanup func()) { return "/non/existent/file.json", func() {} }, shouldError: true, expectedLength: 0, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { filePath, cleanup := tc.setup() defer cleanup() allowlist, err := diagTest.openAndReadAllowlistJSON(filePath) if tc.shouldError { assertNotNilE(t, err, "error should not be nil") } else { assertNilE(t, err, "error should be nil") assertNotNilE(t, allowlist, "file content should not be nil") assertEqualE(t, len(allowlist.Entries), tc.expectedLength, "allowlist length did not match") } }) } } func TestIsAcceptableStatusCode(t *testing.T) { var diagTest connectivityDiagnoser acceptableCodes := []int{http.StatusOK, http.StatusForbidden, http.StatusBadRequest} testcases := []tcAcceptableStatusCode{ {http.StatusOK, true}, {http.StatusForbidden, true}, {http.StatusInternalServerError, false}, {http.StatusUnauthorized, false}, {http.StatusBadRequest, true}, } for _, tc := range testcases { t.Run(fmt.Sprintf("Is Acceptable Status Code - status %d", tc.statusCode), func(t *testing.T) { result := diagTest.isAcceptableStatusCode(tc.statusCode, acceptableCodes) assertEqualE(t, result, tc.isAcceptable, "http status code acceptance is wrong") }) } } func TestFetchCRL(t *testing.T) { config := &Config{ ClientTimeout: 30 * time.Second, } diagTest := newConnectivityDiagnoser(config) crlPEM := `-----BEGIN X509 CRL----- MIIBuDCBoQIBATANBgkqhkiG9w0BAQsFADBeMQswCQYDVQQGEwJVUzELMAkGA1UE CAwCQ0ExDTALBgNVBAcMBFRlc3QxEDAOBgNVBAoMB0V4YW1wbGUxDzANBgNVBAsM BlRlc3RDQTEQMA4GA1UEAwwHVGVzdCBDQRcNMjUwNzI1MTYyMTQzWhcNMzMxMDEx MTYyMTQzWqAPMA0wCwYDVR0UBAQCAhAAMA0GCSqGSIb3DQEBCwUAA4IBAQCakfe4 yaYe6jhSVZc177/y7a+qV6Vs8Ly+CwQiYCM/LieEI7coUpcMtF43ShfzG5FawrMI xa3L2ew5EHDPelrMAdc56GzGCZFlOp16++3HS8qUpodctMdWWcR9Jn0OAfR1I3cY KtLfQbYqwr+Ti6LT0SDp8kXhltH8ZfUcDWH779WF1IQatu5J+GoyHnfFCsP9gI3H Aacyfk7Pp7MyAUChvbM6miyUbWm5NLW9nZgmMxqi9VpMnGZSCwqpS9J01k8YAbwS S3HAS4o7ePBmhiERTPjqmwqEUdrKzEYMtdCFHHfnnDSZxdAmb+Ep6WjRgU1AHxak 6aJpJF0+Ic2kaXXI -----END X509 CRL-----` block, _ := pem.Decode([]byte(crlPEM)) testcases := []tcFetchCRL{ { name: "Fetch CRL - successful fetch", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write(block.Bytes) })) }, shouldError: false, }, { name: "Fetch CRL - server returns 404", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })) }, shouldError: true, errorContains: "HTTP response status", }, { name: "Fetch CRL - server returns 500", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) }, shouldError: true, errorContains: "HTTP response status", }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { server := tc.setupServer() defer server.Close() err := diagTest.fetchCRL(server.URL) if tc.shouldError { assertNotNilE(t, err, "error should not be nil") if tc.errorContains != "" { assertStringContainsE(t, err.Error(), tc.errorContains, "error message should contain the expected string") } } else { assertNilE(t, err, "error should be nil") } }) } } func TestCreateRequest(t *testing.T) { var diagTest connectivityDiagnoser testcases := []tcCreateRequest{ { name: "Create Request - valid http url", uri: "http://ocsp.snowflakecomputing.com", shouldError: false, }, { name: "Create Request - valid https url", uri: "https://myaccount.snowflakecomputing.com", shouldError: false, }, { name: "Create Request - invalid url", uri: ":/invalid-url", shouldError: true, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { req, err := diagTest.createRequest(tc.uri) if tc.shouldError { assertNotNilE(t, err, "error should not be nil") } else { assertNilE(t, err, "error should be nil") assertNotNilE(t, req, "request should not be nil") assertEqualE(t, req.Method, "GET", "method should be GET") assertEqualE(t, req.URL.String(), tc.uri, "url should match") } }) } } func TestDoHTTP(t *testing.T) { var diagTest connectivityDiagnoser testcases := []tcDoHTTP{ // simple disposable server to test basic connectivity { name: "Do HTTP - successful http request", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) }, setupRequest: func(serverURL string) *http.Request { req, _ := http.NewRequest("GET", serverURL, nil) return req }, shouldError: false, }, { name: "Do HTTP - ocsp.snowflakecomputing.com url modification", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // doHTTP should automatically add ocsp_response_cache.json to the full url assertStringContainsE(t, r.URL.Path, "ocsp_response_cache.json", "url path should contain ocsp_response_cache.json added") w.WriteHeader(http.StatusOK) })) }, setupRequest: func(serverURL string) *http.Request { req, _ := http.NewRequest("GET", serverURL, nil) req.URL.Host = "ocsp.snowflakecomputing.com" return req }, shouldError: false, }, { name: "Do HTTP - (CHINA) ocsp.snowflakecomputing.cn url modification", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assertStringContainsE(t, r.URL.Path, "ocsp_response_cache.json", "url path should contain ocsp_response_cache.json added") w.WriteHeader(http.StatusOK) })) }, setupRequest: func(serverURL string) *http.Request { req, _ := http.NewRequest("GET", serverURL, nil) req.URL.Host = "ocsp.snowflakecomputing.cn" return req }, // http://ocsp.snowflakecomputing.cn/ocsp_response_cache.json throws HTTP404 shouldError: true, }, { name: "Do HTTP - server returns forbidden, acceptable", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusForbidden) })) }, setupRequest: func(serverURL string) *http.Request { req, _ := http.NewRequest("GET", serverURL, nil) return req }, shouldError: false, }, { name: "Do HTTP - server returns internal server error, not acceptable", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) }, setupRequest: func(serverURL string) *http.Request { req, _ := http.NewRequest("GET", serverURL, nil) return req }, shouldError: true, errorContains: "HTTP response status", }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { server := tc.setupServer() defer server.Close() // modify the diagnostic client to use a shorter timeout diagTest.diagnosticClient = &http.Client{Timeout: 10 * time.Second} req := tc.setupRequest(server.URL) err := diagTest.doHTTP(req) if tc.shouldError { assertNotNilE(t, err, "error should not be nil") if tc.errorContains != "" { assertStringContainsE(t, err.Error(), tc.errorContains, "error message should contain the expected string") } } else { assertNilE(t, err, "error should be nil") } }) } } func TestDoHTTPSGetCerts(t *testing.T) { var diagTest connectivityDiagnoser testcases := []tcDoHTTPSGetCerts{ // simple disposable server with TLS to test basic connectivity { name: "Do HTTPS - successful https request", setupServer: func() *httptest.Server { return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) }, downloadCRLs: false, shouldError: false, }, { name: "Do HTTPS - server returns forbidden, acceptable", setupServer: func() *httptest.Server { return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusForbidden) })) }, downloadCRLs: false, shouldError: false, }, { name: "Do HTTPS - server returns internal server error, not acceptable", setupServer: func() *httptest.Server { return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) }, downloadCRLs: false, shouldError: true, errorContains: "HTTP response status", }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { server := tc.setupServer() defer server.Close() // modify the diagnostic client to use a shorter timeout // and to ignore the server's certificate diagTest.diagnosticClient = &http.Client{ Timeout: 10 * time.Second, Transport: &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, } req, _ := http.NewRequest("GET", server.URL, nil) err := diagTest.doHTTPSGetCerts(req, tc.downloadCRLs) if tc.shouldError { assertNotNilE(t, err, "error should not be nil") if tc.errorContains != "" { assertStringContainsE(t, err.Error(), tc.errorContains, "error message should contain the expected string") } } else { assertNilE(t, err, "error should be nil") } }) } } func TestCheckProxy(t *testing.T) { config := &Config{ ClientTimeout: 30 * time.Second, } diagTest := newConnectivityDiagnoser(config) t.Run("Check Proxy - with proxy configured", func(t *testing.T) { // setup test logger then restore original after test buffer, cleanup := setupTestLogger() defer cleanup() // set up transport with proxy proxyURL, _ := url.Parse("http://my.pro.xy:8080") diagTest.diagnosticClient.Transport = &http.Transport{ Proxy: func(req *http.Request) (*url.URL, error) { return proxyURL, nil }, } // this should generate a log output which indicates we use a proxy req, _ := http.NewRequest("GET", "https://myaccount.snowflakecomputing.com", nil) diagTest.checkProxy(req) logOutput := buffer.String() assertStringContainsE(t, logOutput, "[checkProxy] PROXY detected in the connection:", "log should contain proxy detection message") assertStringContainsE(t, logOutput, "http://my.pro.xy:8080", "log should contain the proxy URL") }) t.Run("Check Proxy - no proxy configured", func(t *testing.T) { // setup test logger then restore original after test buffer, cleanup := setupTestLogger() defer cleanup() // set up transport without proxy diagTest.diagnosticClient.Transport = &http.Transport{ Proxy: nil, } req, _ := http.NewRequest("GET", "https://myaccount.snowflakecomputing.com", nil) diagTest.checkProxy(req) // verify log output does NOT contain proxy detection logOutput := buffer.String() if strings.Contains(logOutput, "[checkProxy] PROXY detected") { t.Errorf("log should not contain proxy detection message when no proxy is configured, but got: %s", logOutput) } }) t.Run("Check Proxy - proxy function returns error", func(t *testing.T) { // setup test logger then restore original after test buffer, cleanup := setupTestLogger() defer cleanup() // deliberately return an error from the proxy function diagTest.diagnosticClient.Transport = &http.Transport{ Proxy: func(req *http.Request) (*url.URL, error) { return nil, fmt.Errorf("proxy configuration error") }, } req, _ := http.NewRequest("GET", "https://myaccount.snowflakecomputing.com", nil) diagTest.checkProxy(req) // verify log output contains error message logOutput := buffer.String() assertStringContainsE(t, logOutput, "[checkProxy] problem determining PROXY:", "log should contain proxy error message") assertStringContainsE(t, logOutput, "proxy configuration error", "log should contain the specific error") }) } func TestResolveHostname(t *testing.T) { var diagTest connectivityDiagnoser testcases := []tcResolveHostname{ { name: "Resolve Hostname - valid hostname myaccount.snowflakecomputing.com", hostname: "myaccount.snowflakecomputing.com", }, { name: "Resolve Hostname - invalid hostname", hostname: "this.is.invalid", }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { // setup test logger then restore original after test buffer, cleanup := setupTestLogger() defer cleanup() diagTest.resolveHostname(tc.hostname) logOutput := buffer.String() // check for expected log patterns based on hostname if tc.hostname == "this.is.invalid" { assertStringContainsE(t, logOutput, "[resolveHostname] error resolving hostname", "should contain error message for invalid hostname") assertStringContainsE(t, logOutput, tc.hostname, "should contain the hostname in error message") } else { // expect success message assertStringContainsE(t, logOutput, "[resolveHostname] resolved hostname", "should contain success message for valid hostname") assertStringContainsE(t, logOutput, tc.hostname, "should contain the hostname in success message") } }) } } func TestPerformConnectivityCheck(t *testing.T) { config := &Config{ ClientTimeout: 30 * time.Second, } diagTest := newConnectivityDiagnoser(config) testcases := []tcPerformConnectivityCheck{ { name: "HTTP check for port 80", entryType: "OCSP_CACHE", host: "ocsp.snowflakecomputing.com", port: 80, downloadCRLs: false, expectedLog: "[performConnectivityCheck] HTTP check", }, { name: "HTTPS check for port 443", entryType: "DUMMY_SNOWFLAKE_DEPLOYMENT", host: "www.snowflake.com", port: 443, downloadCRLs: false, expectedLog: "[performConnectivityCheck] HTTPS check", }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { // setup test logger then restore original after test buffer, cleanup := setupTestLogger() defer cleanup() err := diagTest.performConnectivityCheck(tc.entryType, tc.host, tc.port, tc.downloadCRLs) logOutput := buffer.String() // verify expected log message appears assertStringContainsE(t, logOutput, tc.expectedLog, fmt.Sprintf("should contain '%s' log message", tc.expectedLog)) assertStringContainsE(t, logOutput, tc.entryType, "should contain entry type in log") assertStringContainsE(t, logOutput, tc.host, "should contain host in log") // if error occurred, verify error log format if err != nil { assertStringContainsE(t, logOutput, "[performConnectivityCheck] error performing", "should contain error log message") } }) } } func TestPerformDiagnosis(t *testing.T) { t.Run("Perform Diagnosis - CRL download disabled", func(t *testing.T) { // setup test logger then restore original after test buffer, cleanup := setupTestLogger() defer cleanup() allowlistContent := `[ {"host":"ocsp.snowflakecomputing.com","port":80,"type":"OCSP_CACHE"}, {"host":"www.snowflake.com","port":443,"type":"DUMMY_SNOWFLAKE_DEPLOYMENT"} ]` tmpFile, err := os.CreateTemp("", "test_allowlist_*.json") assertNilE(t, err, "failed to create temp allowlist file") defer os.Remove(tmpFile.Name()) _, _ = tmpFile.WriteString(allowlistContent) tmpFile.Close() config := &Config{ ConnectionDiagnosticsAllowlistFile: tmpFile.Name(), ClientTimeout: 30 * time.Second, } // perform the diagnosis without downloading CRL performDiagnosis(config, false) // verify expected log messages from performDiagnosis and underlying functions logOutput := buffer.String() assertStringContainsE(t, logOutput, "[performDiagnosis] starting connectivity diagnosis", "should contain diagnosis start message") // DNS resolution assertStringContainsE(t, logOutput, "[performDiagnosis] DNS check - resolving OCSP_CACHE hostname ocsp.snowflakecomputing.com", "should contain DNS check for OCSP cache") assertStringContainsE(t, logOutput, "[performDiagnosis] DNS check - resolving DUMMY_SNOWFLAKE_DEPLOYMENT hostname www.snowflake.com", "should contain DNS check for Snowflake host") assertStringContainsE(t, logOutput, "[resolveHostname] resolved hostname", "should contain hostname resolution results") // HTTP check assertStringContainsE(t, logOutput, "[performConnectivityCheck] HTTP check for OCSP_CACHE ocsp.snowflakecomputing.com", "should contain HTTP check message") assertStringContainsE(t, logOutput, "[createRequest] creating GET request to http://ocsp.snowflakecomputing.com", "should contain request creation log") assertStringContainsE(t, logOutput, "[doHTTP] testing HTTP connection to", "should contain HTTP connection test log") // HTTPS check assertStringContainsE(t, logOutput, "[performConnectivityCheck] HTTPS check for DUMMY_SNOWFLAKE_DEPLOYMENT www.snowflake.com", "should contain HTTPS check message") assertStringContainsE(t, logOutput, "[createRequest] creating GET request to https://www.snowflake.com", "should contain HTTPS request creation log") assertStringContainsE(t, logOutput, "[doHTTPSGetCerts] connecting to https://www.snowflake.com", "should contain HTTPS connection log") // diagnostic dial context assertStringContainsE(t, logOutput, "[createDiagnosticDialContext] Connected to", "should contain dial context connection logs") assertStringContainsE(t, logOutput, "remote IP:", "should contain remote IP information") // should NOT contain CRL download messages when disabled if strings.Contains(logOutput, "[performDiagnosis] CRLs will be attempted to be downloaded") { t.Errorf("should not contain CRL download message when disabled, but got: %s", logOutput) } }) t.Run("Perform Diagnosis - CRL download enabled", func(t *testing.T) { // setup test logger then restore original after test buffer, cleanup := setupTestLogger() defer cleanup() // Create a temporary allowlist file with HTTPS entries to trigger CRL download attempts allowlistContent := `[ {"host":"ocsp.snowflakecomputing.com","port":80,"type":"OCSP_CACHE"}, {"host":"www.snowflake.com","port":443,"type":"DUMMY_SNOWFLAKE_DEPLOYMENT"} ]` tmpFile, err := os.CreateTemp("", "test_allowlist_*.json") assertNilE(t, err, "failed to create temp allowlist file") defer os.Remove(tmpFile.Name()) _, err = tmpFile.WriteString(allowlistContent) assertNilF(t, err, "Failed to write temp allowlist.json.") tmpFile.Close() config := &Config{ ConnectionDiagnosticsAllowlistFile: tmpFile.Name(), CertRevocationCheckMode: CertRevocationCheckAdvisory, ClientTimeout: 30 * time.Second, DisableOCSPChecks: true, } downloadCRLs := config.CertRevocationCheckMode.String() == "ADVISORY" // driver should download CRLs due to ADVISORY CRL mode // Note that there's a log.Fatalf in performDiagnosis that may cause the test to fail. performDiagnosis(config, downloadCRLs) // verify expected log messages including CRL download logOutput := buffer.String() assertStringContainsE(t, logOutput, "[performDiagnosis] starting connectivity diagnosis", "should contain diagnosis start message") assertStringContainsE(t, logOutput, "[performDiagnosis] CRLs will be attempted to be downloaded and parsed during https tests", "should contain CRL download enabled message") // DNS resolution assertStringContainsE(t, logOutput, "[performDiagnosis] DNS check - resolving OCSP_CACHE hostname ocsp.snowflakecomputing.com", "should contain DNS check for OCSP cache") assertStringContainsE(t, logOutput, "[performDiagnosis] DNS check - resolving DUMMY_SNOWFLAKE_DEPLOYMENT hostname www.snowflake.com", "should contain DNS check for Snowflake host") assertStringContainsE(t, logOutput, "[resolveHostname] resolved hostname", "should contain hostname resolution results") // HTTPS check assertStringContainsE(t, logOutput, "[performConnectivityCheck] HTTPS check for DUMMY_SNOWFLAKE_DEPLOYMENT www.snowflake.com", "should contain HTTPS check message") assertStringContainsE(t, logOutput, "[doHTTPSGetCerts] connecting to https://www.snowflake.com", "should contain HTTPS connection log") assertStringContainsE(t, logOutput, "[doHTTPSGetCerts] Retrieved", "should contain certificate retrieval log") assertStringContainsE(t, logOutput, "certificate(s)", "should contain certificate count information") // diagnostic dial context assertStringContainsE(t, logOutput, "[createDiagnosticDialContext] Connected to", "should contain dial context connection logs") assertStringContainsE(t, logOutput, "remote IP:", "should contain remote IP information") // CRL download // if certificate has CRLDistributionPoints this message is logged if strings.Contains(logOutput, "CRL Distribution Points:") { // and we should see CRL fetch attempts logged. we don't care if it's successful or not, we just want to see the log assertStringContainsE(t, logOutput, "[fetchCRL] fetching", "should contain CRL fetch attempt log") } }) } ================================================ FILE: connector.go ================================================ package gosnowflake import ( "context" "database/sql/driver" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" ) // InternalSnowflakeDriver is the interface for an internal Snowflake driver // Deprecated: will be removed in a future release. type InternalSnowflakeDriver interface { Open(dsn string) (driver.Conn, error) OpenWithConfig(ctx context.Context, config Config) (driver.Conn, error) } // Connector creates Driver with the specified Config type Connector struct { driver InternalSnowflakeDriver cfg Config } // NewConnector creates a new connector with the given SnowflakeDriver and Config. func NewConnector(driver InternalSnowflakeDriver, config Config) driver.Connector { return Connector{driver, config} } // Connect creates a new connection. func (t Connector) Connect(ctx context.Context) (driver.Conn, error) { cfg := t.cfg err := sfconfig.FillMissingConfigParameters(&cfg) if err != nil { return nil, err } return t.driver.OpenWithConfig(ctx, cfg) } // Driver creates a new driver. func (t Connector) Driver() driver.Driver { return t.driver } ================================================ FILE: connector_test.go ================================================ package gosnowflake import ( "bytes" "context" "database/sql/driver" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "reflect" "strings" "testing" "time" ) type noopTestDriver struct { config Config conn *snowflakeConn } func (d *noopTestDriver) Open(_ string) (driver.Conn, error) { return nil, nil } func (d *noopTestDriver) OpenWithConfig(_ context.Context, config Config) (driver.Conn, error) { d.config = config return d.conn, nil } func TestConnector(t *testing.T) { conn := snowflakeConn{} mock := noopTestDriver{conn: &conn} // Use fake DSN for unit test - should not make real connections fakeDSN := "testuser:testpass@testaccount.snowflakecomputing.com:443/testdb/testschema?warehouse=testwh&role=testrole" config, err := ParseDSN(fakeDSN) if err != nil { t.Fatal("Failed to parse dsn") } config.Authenticator = AuthTypeSnowflake config.PrivateKey = nil connector := NewConnector(&mock, *config) connection, err := connector.Connect(context.Background()) if err != nil { t.Fatalf("Connect error %s", err) } if connection != &conn { t.Fatalf("Connection mismatch %s", connection) } assertNilF(t, sfconfig.FillMissingConfigParameters(config)) if reflect.DeepEqual(config, mock.config) { t.Fatalf("Config should be equal, expected %v, actual %v", config, mock.config) } if connector.Driver() == nil { t.Fatalf("Missing driver") } } func TestConnectorWithMissingConfig(t *testing.T) { conn := snowflakeConn{} mock := noopTestDriver{conn: &conn} config := Config{ User: "u", Password: "p", Account: "", } expectedErr := errors.ErrEmptyAccount() connector := NewConnector(&mock, config) _, err := connector.Connect(context.Background()) assertNotNilF(t, err, "the connection should have failed due to empty account.") driverErr, ok := err.(*SnowflakeError) assertTrueF(t, ok, "should be a SnowflakeError") assertEqualE(t, driverErr.Number, expectedErr.Number) assertEqualE(t, driverErr.Message, expectedErr.Message) } func TestConnectorCancelContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) origLogger := GetLogger() // Create a test logger with buffer for capturing log output testLogger := CreateDefaultLogger() // Create a buffer for capturing log output var buf bytes.Buffer testLogger.SetOutput(&buf) SetLogger(testLogger) // Restore default logger after the test completes defer func() { // Recreate the default logger instead of trying to restore a proxy SetLogger(origLogger) }() // pass in our context which should only be used for establishing the initial connection; not persisted. sfConn, err := buildSnowflakeConn(ctx, Config{ Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key }) assertNilF(t, err) // patch close handler sfConn.rest = &snowflakeRestful{ FuncCloseSession: func(ctx context.Context, sr *snowflakeRestful, d time.Duration) error { return ctx.Err() }, } // cancel context BEFORE closing the connection. // this may occur if the *snowflakeConn was spawned by a QueryContext(), and the query has completed. cancel() assertNilF(t, sfConn.Close()) // if the following log is emitted, the connection is holding onto context that it shouldn't be. assertFalseF(t, strings.Contains(buf.String(), "context canceled")) } ================================================ FILE: converter.go ================================================ package gosnowflake import ( "bytes" "context" "database/sql" "database/sql/driver" "encoding/hex" "encoding/json" "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" "math" "math/big" "reflect" "regexp" "strconv" "strings" "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/decimal128" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" ) const format = "2006-01-02 15:04:05.999999999" const numberDefaultPrecision = 38 const jsonFormatStr = "json" const numberMaxPrecisionInBits = 127 // 38 (max precision) + 1 (for possible '-') + 1 (for possible '.') const decfloatPrintingPrec = 40 type timezoneType int var errUnsupportedTimeArrayBind = errors.New("unsupported time array bind. Set the type to TimestampNTZType, TimestampLTZType, TimestampTZType, DateType or TimeType") var errNativeArrowWithoutProperContext = errors.New("structured types must be enabled to use with native arrow") const ( // TimestampNTZType denotes a NTZ timezoneType for array binds TimestampNTZType timezoneType = iota // TimestampLTZType denotes a LTZ timezoneType for array binds TimestampLTZType // TimestampTZType denotes a TZ timezoneType for array binds TimestampTZType // DateType denotes a date type for array binds DateType // TimeType denotes a time type for array binds TimeType ) type interfaceArrayBinding struct { hasTimezone bool tzType timezoneType timezoneTypeArray any } func isInterfaceArrayBinding(t any) bool { switch t.(type) { case interfaceArrayBinding: return true case *interfaceArrayBinding: return true default: return false } } func isJSONFormatType(tsmode types.SnowflakeType) bool { return tsmode == types.ObjectType || tsmode == types.ArrayType || tsmode == types.SliceType } // goTypeToSnowflake translates Go data type to Snowflake data type. func goTypeToSnowflake(v driver.Value, tsmode types.SnowflakeType) types.SnowflakeType { if isJSONFormatType(tsmode) { return tsmode } if v == nil { return types.NullType } switch t := v.(type) { case int64, sql.NullInt64: return types.FixedType case float64, sql.NullFloat64: return types.RealType case bool, sql.NullBool: return types.BooleanType case string, sql.NullString: return types.TextType case []byte: if tsmode == types.BinaryType { return types.BinaryType // may be redundant but ensures BINARY type } if t == nil { return types.NullType // invalid byte array. won't take as BINARY } if len(t) != 1 { return types.ArrayType } if _, err := dataTypeMode(t); err != nil { return types.UnSupportedType } return types.ChangeType case time.Time, sql.NullTime: return tsmode } if supportedArrayBind(&driver.NamedValue{Value: v}) { return types.SliceType } // structured objects if _, ok := v.(StructuredObjectWriter); ok { return types.ObjectType } else if _, ok := v.(reflect.Type); ok && tsmode == types.NilObjectType { return types.NilObjectType } // structured arrays if reflect.TypeOf(v).Kind() == reflect.Slice || (reflect.TypeOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).Elem().Kind() == reflect.Slice) { return types.ArrayType } else if tsmode == types.NilArrayType { return types.NilArrayType } else if tsmode == types.NilMapType { return types.NilMapType } else if reflect.TypeOf(v).Kind() == reflect.Map || (reflect.TypeOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).Elem().Kind() == reflect.Map) { return types.MapType } return types.UnSupportedType } // snowflakeTypeToGo translates Snowflake data type to Go data type. func snowflakeTypeToGo(ctx context.Context, dbtype types.SnowflakeType, precision int64, scale int64, fields []query.FieldMetadata) reflect.Type { structuredTypesEnabled := structuredTypesEnabled(ctx) switch dbtype { case types.FixedType: if higherPrecisionEnabled(ctx) { if scale == 0 { if precision >= 19 { return reflect.TypeFor[*big.Int]() } return reflect.TypeFor[int64]() } return reflect.TypeFor[*big.Float]() } if scale == 0 { if precision >= 19 { return reflect.TypeFor[string]() } return reflect.TypeFor[int64]() } return reflect.TypeFor[float64]() case types.RealType: return reflect.TypeFor[float64]() case types.DecfloatType: if !decfloatMappingEnabled(ctx) { return reflect.TypeFor[string]() } if higherPrecisionEnabled(ctx) { return reflect.TypeFor[*big.Float]() } return reflect.TypeFor[float64]() case types.TextType, types.VariantType: return reflect.TypeFor[string]() case types.DateType, types.TimeType, types.TimestampLtzType, types.TimestampNtzType, types.TimestampTzType: return reflect.TypeOf(time.Now()) case types.BinaryType: return reflect.TypeFor[[]byte]() case types.BooleanType: return reflect.TypeFor[bool]() case types.ObjectType: if len(fields) > 0 && structuredTypesEnabled { return reflect.TypeFor[ObjectType]() } return reflect.TypeFor[string]() case types.ArrayType: if len(fields) == 0 || !structuredTypesEnabled { return reflect.TypeFor[string]() } if len(fields) != 1 { logger.WithContext(ctx).Warn("Unexpected fields number: " + strconv.Itoa(len(fields))) return reflect.TypeFor[string]() } switch types.GetSnowflakeType(fields[0].Type) { case types.FixedType: if fields[0].Scale == 0 && higherPrecisionEnabled(ctx) { return reflect.TypeFor[[]*big.Int]() } else if fields[0].Scale == 0 && !higherPrecisionEnabled(ctx) { return reflect.TypeFor[[]int64]() } else if fields[0].Scale != 0 && higherPrecisionEnabled(ctx) { return reflect.TypeFor[[]*big.Float]() } return reflect.TypeFor[[]float64]() case types.RealType: return reflect.TypeFor[[]float64]() case types.TextType: return reflect.TypeFor[[]string]() case types.DateType, types.TimeType, types.TimestampLtzType, types.TimestampNtzType, types.TimestampTzType: return reflect.TypeFor[[]time.Time]() case types.BooleanType: return reflect.TypeFor[[]bool]() case types.BinaryType: return reflect.TypeFor[[][]byte]() case types.ObjectType: return reflect.TypeFor[[]ObjectType]() } return nil case types.MapType: if !structuredTypesEnabled { return reflect.TypeFor[string]() } switch types.GetSnowflakeType(fields[0].Type) { case types.TextType: return snowflakeTypeToGoForMaps[string](ctx, fields[1]) case types.FixedType: return snowflakeTypeToGoForMaps[int64](ctx, fields[1]) } return reflect.TypeFor[map[any]any]() } logger.WithContext(ctx).Errorf("unsupported dbtype is specified. %v", dbtype) return reflect.TypeFor[string]() } func snowflakeTypeToGoForMaps[K comparable](ctx context.Context, valueMetadata query.FieldMetadata) reflect.Type { switch types.GetSnowflakeType(valueMetadata.Type) { case types.TextType: return reflect.TypeFor[map[K]string]() case types.FixedType: if higherPrecisionEnabled(ctx) && valueMetadata.Scale == 0 { return reflect.TypeFor[map[K]*big.Int]() } else if higherPrecisionEnabled(ctx) && valueMetadata.Scale != 0 { return reflect.TypeFor[map[K]*big.Float]() } else if !higherPrecisionEnabled(ctx) && valueMetadata.Scale == 0 { return reflect.TypeFor[map[K]int64]() } else { return reflect.TypeFor[map[K]float64]() } case types.RealType: return reflect.TypeFor[map[K]float64]() case types.BooleanType: return reflect.TypeFor[map[K]bool]() case types.BinaryType: return reflect.TypeFor[map[K][]byte]() case types.TimeType, types.DateType, types.TimestampTzType, types.TimestampNtzType, types.TimestampLtzType: return reflect.TypeFor[map[K]time.Time]() } logger.WithContext(ctx).Errorf("unsupported dbtype is specified for map value") return reflect.TypeFor[string]() } // valueToString converts arbitrary golang type to a string. This is mainly used in binding data with placeholders // in queries. func valueToString(v driver.Value, tsmode types.SnowflakeType, params *syncParams) (bindingValue, error) { isJSONFormat := isJSONFormatType(tsmode) if v == nil { if isJSONFormat { return bindingValue{nil, jsonFormatStr, nil}, nil } return bindingValue{nil, "", nil}, nil } v1 := reflect.Indirect(reflect.ValueOf(v)) if valuer, ok := v.(driver.Valuer); ok { // check for driver.Valuer satisfaction and honor that first if value, err := valuer.Value(); err == nil && value != nil { // if the output value is a valid string, return that if strVal, ok := value.(string); ok { if isJSONFormat { return bindingValue{&strVal, jsonFormatStr, nil}, nil } return bindingValue{&strVal, "", nil}, nil } } } if tsmode == types.DecfloatType && v1.Type() == reflect.TypeFor[big.Float]() { s := v.(*big.Float).Text('g', decfloatPrintingPrec) return bindingValue{&s, "", nil}, nil } switch v1.Kind() { case reflect.Bool: s := strconv.FormatBool(v1.Bool()) return bindingValue{&s, "", nil}, nil case reflect.Int64: s := strconv.FormatInt(v1.Int(), 10) return bindingValue{&s, "", nil}, nil case reflect.Float64: s := strconv.FormatFloat(v1.Float(), 'g', -1, 32) return bindingValue{&s, "", nil}, nil case reflect.String: s := v1.String() if isJSONFormat { return bindingValue{&s, jsonFormatStr, nil}, nil } return bindingValue{&s, "", nil}, nil case reflect.Slice, reflect.Array: return arrayToString(v, tsmode, params) case reflect.Map: return mapToString(v, tsmode, params) case reflect.Struct: return structValueToString(v, tsmode, params) } return bindingValue{}, fmt.Errorf("unsupported type: %v", v1.Kind()) } // isUUIDImplementer checks if a value is a UUID that satisfies RFC 4122 func isUUIDImplementer(v reflect.Value) bool { rt := v.Type() // Check if the type is an array of 16 bytes if v.Kind() == reflect.Array && rt.Elem().Kind() == reflect.Uint8 && rt.Len() == 16 { // Check if the type implements fmt.Stringer vInt := v.Interface() if stringer, ok := vInt.(fmt.Stringer); ok { uuidStr := stringer.String() rfc4122Regex := `^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$` matched, err := regexp.MatchString(rfc4122Regex, uuidStr) if err != nil { return false } if matched { // parse the UUID and ensure it is the same as the original string u := ParseUUID(uuidStr) return u.String() == uuidStr } } } return false } func arrayToString(v driver.Value, tsmode types.SnowflakeType, params *syncParams) (bindingValue, error) { v1 := reflect.Indirect(reflect.ValueOf(v)) if v1.Kind() == reflect.Slice && v1.IsNil() { return bindingValue{nil, jsonFormatStr, nil}, nil } if bd, ok := v.([][]byte); ok && tsmode == types.BinaryType { schema := bindingSchema{ Typ: "array", Nullable: true, Fields: []query.FieldMetadata{ { Type: "binary", Nullable: true, }, }, } if len(bd) == 0 { res := "[]" return bindingValue{value: &res, format: jsonFormatStr, schema: &schema}, nil } s := "" for _, b := range bd { s += "\"" + hex.EncodeToString(b) + "\"," } s = "[" + s[:len(s)-1] + "]" return bindingValue{&s, jsonFormatStr, &schema}, nil } else if times, ok := v.([]time.Time); ok { typ := types.DriverTypeToSnowflake[tsmode] sfFormat, err := dateTimeInputFormatByType(typ, params) if err != nil { return bindingValue{nil, "", nil}, err } goFormat, err := snowflakeFormatToGoFormat(sfFormat) if err != nil { return bindingValue{nil, "", nil}, err } arr := make([]string, len(times)) for idx, t := range times { arr[idx] = t.Format(goFormat) } res, err := json.Marshal(arr) if err != nil { return bindingValue{nil, jsonFormatStr, &bindingSchema{ Typ: "array", Nullable: true, Fields: []query.FieldMetadata{ { Type: typ, Nullable: true, }, }, }}, err } resString := string(res) return bindingValue{&resString, jsonFormatStr, nil}, nil } else if isArrayOfStructs(v) { stringEntries := make([]string, v1.Len()) sowcForSingleElement, err := buildSowcFromType(params, reflect.TypeOf(v).Elem()) if err != nil { return bindingValue{}, err } for i := 0; i < v1.Len(); i++ { potentialSow := v1.Index(i) if sow, ok := potentialSow.Interface().(StructuredObjectWriter); ok { bv, err := structValueToString(sow, tsmode, params) if err != nil { return bindingValue{nil, jsonFormatStr, nil}, err } stringEntries[i] = *bv.value } } value := "[" + strings.Join(stringEntries, ",") + "]" arraySchema := &bindingSchema{ Typ: "array", Nullable: true, Fields: []query.FieldMetadata{ { Type: "OBJECT", Nullable: true, Fields: sowcForSingleElement.toFields(), }, }, } return bindingValue{&value, jsonFormatStr, arraySchema}, nil } else if reflect.ValueOf(v).Len() == 0 { value := "[]" return bindingValue{&value, jsonFormatStr, nil}, nil } else if barr, ok := v.([]byte); ok { if tsmode == types.BinaryType { res := hex.EncodeToString(barr) return bindingValue{&res, jsonFormatStr, nil}, nil } schemaForBytes := bindingSchema{ Typ: "array", Nullable: true, Fields: []query.FieldMetadata{ { Type: "FIXED", Nullable: true, }, }, } if len(barr) == 0 { res := "[]" return bindingValue{&res, jsonFormatStr, &schemaForBytes}, nil } res := "[" for _, b := range barr { res += fmt.Sprint(b) + "," } res = res[0:len(res)-1] + "]" return bindingValue{&res, jsonFormatStr, &schemaForBytes}, nil } else if isUUIDImplementer(v1) { // special case for UUIDs (snowflake type and other implementers) stringer := v.(fmt.Stringer) // we don't need to validate if it's a fmt.Stringer because we already checked if it's a UUID type with a stringer value := stringer.String() return bindingValue{&value, "", nil}, nil } else if isSliceOfSlices(v) { return bindingValue{}, errors.New("array of arrays is not supported") } res, err := json.Marshal(v) if err != nil { return bindingValue{nil, jsonFormatStr, nil}, err } resString := string(res) return bindingValue{&resString, jsonFormatStr, nil}, nil } func mapToString(v driver.Value, tsmode types.SnowflakeType, params *syncParams) (bindingValue, error) { var err error valOf := reflect.Indirect(reflect.ValueOf(v)) if valOf.IsNil() { return bindingValue{nil, "", nil}, nil } typOf := reflect.TypeOf(v) var jsonBytes []byte if tsmode == types.BinaryType { m := make(map[string]*string, valOf.Len()) iter := valOf.MapRange() for iter.Next() { val := iter.Value().Interface().([]byte) if val != nil { s := hex.EncodeToString(val) m[stringOrIntToString(iter.Key())] = &s } else { m[stringOrIntToString(iter.Key())] = nil } } jsonBytes, err = json.Marshal(m) if err != nil { return bindingValue{}, err } } else if typOf.Elem().AssignableTo(reflect.TypeFor[time.Time]()) || typOf.Elem().AssignableTo(reflect.TypeFor[sql.NullTime]()) { m := make(map[string]*string, valOf.Len()) iter := valOf.MapRange() for iter.Next() { val, valid, err := toNullableTime(iter.Value().Interface()) if err != nil { return bindingValue{}, err } if !valid { m[stringOrIntToString(iter.Key())] = nil } else { typ := types.DriverTypeToSnowflake[tsmode] s, err := timeToString(val, typ, params) if err != nil { return bindingValue{}, err } m[stringOrIntToString(iter.Key())] = &s } } jsonBytes, err = json.Marshal(m) if err != nil { return bindingValue{}, err } } else if typOf.Elem().AssignableTo(reflect.TypeFor[sql.NullString]()) { m := make(map[string]*string, valOf.Len()) iter := valOf.MapRange() for iter.Next() { val := iter.Value().Interface().(sql.NullString) if val.Valid { m[stringOrIntToString(iter.Key())] = &val.String } else { m[stringOrIntToString(iter.Key())] = nil } } jsonBytes, err = json.Marshal(m) if err != nil { return bindingValue{}, err } } else if typOf.Elem().AssignableTo(reflect.TypeFor[sql.NullByte]()) || typOf.Elem().AssignableTo(reflect.TypeFor[sql.NullInt16]()) || typOf.Elem().AssignableTo(reflect.TypeFor[sql.NullInt32]()) || typOf.Elem().AssignableTo(reflect.TypeFor[sql.NullInt64]()) { m := make(map[string]*int64, valOf.Len()) iter := valOf.MapRange() for iter.Next() { val, valid := toNullableInt64(iter.Value().Interface()) if valid { m[stringOrIntToString(iter.Key())] = &val } else { m[stringOrIntToString(iter.Key())] = nil } } jsonBytes, err = json.Marshal(m) if err != nil { return bindingValue{}, err } } else if typOf.Elem().AssignableTo(reflect.TypeFor[sql.NullFloat64]()) { m := make(map[string]*float64, valOf.Len()) iter := valOf.MapRange() for iter.Next() { val := iter.Value().Interface().(sql.NullFloat64) if val.Valid { m[stringOrIntToString(iter.Key())] = &val.Float64 } else { m[stringOrIntToString(iter.Key())] = nil } } jsonBytes, err = json.Marshal(m) if err != nil { return bindingValue{}, err } } else if typOf.Elem().AssignableTo(reflect.TypeFor[sql.NullBool]()) { m := make(map[string]*bool, valOf.Len()) iter := valOf.MapRange() for iter.Next() { val := iter.Value().Interface().(sql.NullBool) if val.Valid { m[stringOrIntToString(iter.Key())] = &val.Bool } else { m[stringOrIntToString(iter.Key())] = nil } } jsonBytes, err = json.Marshal(m) if err != nil { return bindingValue{}, err } } else if typOf.Elem().AssignableTo(structuredObjectWriterType) { m := make(map[string]map[string]any, valOf.Len()) iter := valOf.MapRange() var valueMetadata *query.FieldMetadata for iter.Next() { sowc := structuredObjectWriterContext{} sowc.init(params) if iter.Value().IsNil() { m[stringOrIntToString(iter.Key())] = nil continue } err = iter.Value().Interface().(StructuredObjectWriter).Write(&sowc) if err != nil { return bindingValue{}, err } m[stringOrIntToString(iter.Key())] = sowc.values if valueMetadata == nil { valueMetadata = &query.FieldMetadata{ Type: "OBJECT", Nullable: true, Fields: sowc.toFields(), } } } if valueMetadata == nil { sowcFromValueType, err := buildSowcFromType(params, typOf.Elem()) if err != nil { return bindingValue{}, err } valueMetadata = &query.FieldMetadata{ Type: "OBJECT", Nullable: true, Fields: sowcFromValueType.toFields(), } } jsonBytes, err = json.Marshal(m) if err != nil { return bindingValue{}, err } jsonString := string(jsonBytes) keyMetadata, err := goTypeToFieldMetadata(typOf.Key(), types.TextType, params) if err != nil { return bindingValue{}, err } schema := bindingSchema{ Typ: "MAP", Fields: []query.FieldMetadata{keyMetadata, *valueMetadata}, } return bindingValue{&jsonString, jsonFormatStr, &schema}, nil } else { jsonBytes, err = json.Marshal(v) if err != nil { return bindingValue{}, err } } jsonString := string(jsonBytes) keyMetadata, err := goTypeToFieldMetadata(typOf.Key(), types.TextType, params) if err != nil { return bindingValue{}, err } valueMetadata, err := goTypeToFieldMetadata(typOf.Elem(), tsmode, params) if err != nil { return bindingValue{}, err } schema := bindingSchema{ Typ: "MAP", Fields: []query.FieldMetadata{keyMetadata, valueMetadata}, } return bindingValue{&jsonString, jsonFormatStr, &schema}, nil } func toNullableInt64(val any) (int64, bool) { switch v := val.(type) { case sql.NullByte: return int64(v.Byte), v.Valid case sql.NullInt16: return int64(v.Int16), v.Valid case sql.NullInt32: return int64(v.Int32), v.Valid case sql.NullInt64: return v.Int64, v.Valid } // should never happen, the list above is exhaustive panic("Only byte, int16, int32 or int64 are supported") } func toNullableTime(val any) (time.Time, bool, error) { switch v := val.(type) { case time.Time: return v, true, nil case sql.NullTime: return v.Time, v.Valid, nil } return time.Now(), false, fmt.Errorf("cannot use %T as time", val) } func stringOrIntToString(v reflect.Value) string { if v.CanInt() { return strconv.Itoa(int(v.Int())) } return v.String() } func goTypeToFieldMetadata(typ reflect.Type, tsmode types.SnowflakeType, params *syncParams) (query.FieldMetadata, error) { if tsmode == types.BinaryType { return query.FieldMetadata{ Type: "BINARY", Nullable: true, }, nil } if typ.Kind() == reflect.Pointer { typ = typ.Elem() } switch typ.Kind() { case reflect.String: return query.FieldMetadata{ Type: "TEXT", Nullable: true, }, nil case reflect.Bool: return query.FieldMetadata{ Type: "BOOLEAN", Nullable: true, }, nil case reflect.Int, reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Int32, reflect.Int64: return query.FieldMetadata{ Type: "FIXED", Precision: numberDefaultPrecision, Nullable: true, }, nil case reflect.Float32, reflect.Float64: return query.FieldMetadata{ Type: "REAL", Nullable: true, }, nil case reflect.Struct: if typ.AssignableTo(reflect.TypeFor[sql.NullString]()) { return query.FieldMetadata{ Type: "TEXT", Nullable: true, }, nil } else if typ.AssignableTo(reflect.TypeFor[sql.NullBool]()) { return query.FieldMetadata{ Type: "BOOLEAN", Nullable: true, }, nil } else if typ.AssignableTo(reflect.TypeFor[sql.NullByte]()) || typ.AssignableTo(reflect.TypeFor[sql.NullInt16]()) || typ.AssignableTo(reflect.TypeFor[sql.NullInt32]()) || typ.AssignableTo(reflect.TypeFor[sql.NullInt64]()) { return query.FieldMetadata{ Type: "FIXED", Precision: numberDefaultPrecision, Nullable: true, }, nil } else if typ.AssignableTo(reflect.TypeFor[sql.NullFloat64]()) { return query.FieldMetadata{ Type: "REAL", Nullable: true, }, nil } else if tsmode == types.DateType { return query.FieldMetadata{ Type: "DATE", Nullable: true, }, nil } else if tsmode == types.TimeType { return query.FieldMetadata{ Type: "TIME", Nullable: true, }, nil } else if tsmode == types.TimestampTzType { return query.FieldMetadata{ Type: "TIMESTAMP_TZ", Nullable: true, }, nil } else if tsmode == types.TimestampNtzType { return query.FieldMetadata{ Type: "TIMESTAMP_NTZ", Nullable: true, }, nil } else if tsmode == types.TimestampLtzType { return query.FieldMetadata{ Type: "TIMESTAMP_LTZ", Nullable: true, }, nil } else if typ.AssignableTo(structuredObjectWriterType) || tsmode == types.NilObjectType { sowc, err := buildSowcFromType(params, typ) if err != nil { return query.FieldMetadata{}, err } return query.FieldMetadata{ Type: "OBJECT", Nullable: true, Fields: sowc.toFields(), }, nil } else if tsmode == types.NilArrayType || tsmode == types.NilMapType { sowc, err := buildSowcFromType(params, typ) if err != nil { return query.FieldMetadata{}, err } return query.FieldMetadata{ Type: "OBJECT", Nullable: true, Fields: sowc.toFields(), }, nil } case reflect.Slice: metadata, err := goTypeToFieldMetadata(typ.Elem(), tsmode, params) if err != nil { return query.FieldMetadata{}, err } return query.FieldMetadata{ Type: "ARRAY", Nullable: true, Fields: []query.FieldMetadata{metadata}, }, nil case reflect.Map: keyMetadata, err := goTypeToFieldMetadata(typ.Key(), tsmode, params) if err != nil { return query.FieldMetadata{}, err } valueMetadata, err := goTypeToFieldMetadata(typ.Elem(), tsmode, params) if err != nil { return query.FieldMetadata{}, err } return query.FieldMetadata{ Type: "MAP", Nullable: true, Fields: []query.FieldMetadata{keyMetadata, valueMetadata}, }, nil } return query.FieldMetadata{}, fmt.Errorf("cannot build field metadata for %v (mode %v)", typ.Kind().String(), tsmode.String()) } func isSliceOfSlices(v any) bool { typ := reflect.TypeOf(v) return typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Slice } func isArrayOfStructs(v any) bool { return reflect.TypeOf(v).Elem().Kind() == reflect.Struct || (reflect.TypeOf(v).Elem().Kind() == reflect.Pointer && reflect.TypeOf(v).Elem().Elem().Kind() == reflect.Struct) } func structValueToString(v driver.Value, tsmode types.SnowflakeType, params *syncParams) (bindingValue, error) { switch typedVal := v.(type) { case time.Time: return timeTypeValueToString(typedVal, tsmode) case sql.NullTime: if !typedVal.Valid { return bindingValue{nil, "", nil}, nil } return timeTypeValueToString(typedVal.Time, tsmode) case sql.NullBool: if !typedVal.Valid { return bindingValue{nil, "", nil}, nil } s := strconv.FormatBool(typedVal.Bool) return bindingValue{&s, "", nil}, nil case sql.NullInt64: if !typedVal.Valid { return bindingValue{nil, "", nil}, nil } s := strconv.FormatInt(typedVal.Int64, 10) return bindingValue{&s, "", nil}, nil case sql.NullFloat64: if !typedVal.Valid { return bindingValue{nil, "", nil}, nil } s := strconv.FormatFloat(typedVal.Float64, 'g', -1, 32) return bindingValue{&s, "", nil}, nil case sql.NullString: fmt := "" if isJSONFormatType(tsmode) { fmt = jsonFormatStr } if !typedVal.Valid { return bindingValue{nil, fmt, nil}, nil } return bindingValue{&typedVal.String, fmt, nil}, nil } if sow, ok := v.(StructuredObjectWriter); ok { sowc := &structuredObjectWriterContext{} sowc.init(params) err := sow.Write(sowc) if err != nil { return bindingValue{nil, "", nil}, err } jsonBytes, err := json.Marshal(sowc.values) if err != nil { return bindingValue{nil, "", nil}, err } jsonString := string(jsonBytes) schema := bindingSchema{ Typ: "object", Nullable: true, Fields: sowc.toFields(), } return bindingValue{&jsonString, jsonFormatStr, &schema}, nil } else if typ, ok := v.(reflect.Type); ok && tsmode == types.NilArrayType { metadata, err := goTypeToFieldMetadata(typ, tsmode, params) if err != nil { return bindingValue{}, err } schema := bindingSchema{ Typ: "ARRAY", Nullable: true, Fields: []query.FieldMetadata{ metadata, }, } return bindingValue{nil, jsonFormatStr, &schema}, nil } else if t, ok := v.(NilMapTypes); ok && tsmode == types.NilMapType { keyMetadata, err := goTypeToFieldMetadata(t.Key, tsmode, params) if err != nil { return bindingValue{}, err } valueMetadata, err := goTypeToFieldMetadata(t.Value, tsmode, params) if err != nil { return bindingValue{}, err } schema := bindingSchema{ Typ: "map", Nullable: true, Fields: []query.FieldMetadata{keyMetadata, valueMetadata}, } return bindingValue{nil, jsonFormatStr, &schema}, nil } else if typ, ok := v.(reflect.Type); ok && tsmode == types.NilObjectType { metadata, err := goTypeToFieldMetadata(typ, tsmode, params) if err != nil { return bindingValue{}, err } schema := bindingSchema{ Typ: "object", Nullable: true, Fields: metadata.Fields, } return bindingValue{nil, jsonFormatStr, &schema}, nil } return bindingValue{}, fmt.Errorf("unknown binding for type %T and mode %v", v, tsmode) } func timeTypeValueToString(tm time.Time, tsmode types.SnowflakeType) (bindingValue, error) { switch tsmode { case types.DateType: _, offset := tm.Zone() tm = tm.Add(time.Second * time.Duration(offset)) s := strconv.FormatInt(tm.Unix()*1000, 10) return bindingValue{&s, "", nil}, nil case types.TimeType: s := fmt.Sprintf("%d", (tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond()) return bindingValue{&s, "", nil}, nil case types.TimestampNtzType, types.TimestampLtzType, types.TimestampTzType: s, err := convertTimeToTimeStamp(tm, tsmode) if err != nil { return bindingValue{nil, "", nil}, err } return bindingValue{&s, "", nil}, nil } return bindingValue{nil, "", nil}, fmt.Errorf("unsupported time type: %v", tsmode) } // extractTimestamp extracts the internal timestamp data to epoch time in seconds and milliseconds func extractTimestamp(srcValue *string) (sec int64, nsec int64, err error) { logger.Debugf("SRC: %v", srcValue) var i int for i = 0; i < len(*srcValue); i++ { if (*srcValue)[i] == '.' { sec, err = strconv.ParseInt((*srcValue)[0:i], 10, 64) if err != nil { return 0, 0, err } break } } if i == len(*srcValue) { // no fraction sec, err = strconv.ParseInt(*srcValue, 10, 64) if err != nil { return 0, 0, err } nsec = 0 } else { s := (*srcValue)[i+1:] nsec, err = strconv.ParseInt(s+strings.Repeat("0", 9-len(s)), 10, 64) if err != nil { return 0, 0, err } } logger.Infof("sec: %v, nsec: %v", sec, nsec) return sec, nsec, nil } // stringToValue converts a pointer of string data to an arbitrary golang variable // This is mainly used in fetching data. func stringToValue(ctx context.Context, dest *driver.Value, srcColumnMeta query.ExecResponseRowType, srcValue *string, loc *time.Location, params *syncParams) error { if srcValue == nil { logger.Debugf("snowflake data type: %v, raw value: nil", srcColumnMeta.Type) *dest = nil return nil } structuredTypesEnabled := structuredTypesEnabled(ctx) // Truncate large strings before logging to avoid secret masking performance issues valueForLogging := *srcValue if len(valueForLogging) > 1024 { valueForLogging = valueForLogging[:1024] + fmt.Sprintf("... (%d bytes total)", len(*srcValue)) } logger.Debugf("snowflake data type: %v, raw value: %v", srcColumnMeta.Type, valueForLogging) switch srcColumnMeta.Type { case "object": if len(srcColumnMeta.Fields) == 0 || !structuredTypesEnabled { // semistructured type without schema *dest = *srcValue return nil } m := make(map[string]any) decoder := decoderWithNumbersAsStrings(srcValue) if err := decoder.Decode(&m); err != nil { return err } v, err := buildStructuredTypeRecursive(ctx, m, srcColumnMeta.Fields, params) if err != nil { return err } *dest = v return nil case "text", "real", "variant": *dest = *srcValue return nil case "fixed": if higherPrecisionEnabled(ctx) { if srcColumnMeta.Scale == 0 { if srcColumnMeta.Precision >= 19 { bigInt := big.NewInt(0) bigInt.SetString(*srcValue, 10) *dest = *bigInt return nil } *dest = *srcValue return nil } bigFloat, _, err := big.ParseFloat(*srcValue, 10, numberMaxPrecisionInBits, big.AwayFromZero) if err != nil { return err } *dest = *bigFloat return nil } *dest = *srcValue return nil case "decfloat": if !decfloatMappingEnabled(ctx) { *dest = *srcValue return nil } bf := new(big.Float).SetPrec(127) if _, ok := bf.SetString(*srcValue); !ok { return fmt.Errorf("cannot convert %v to %T", *srcValue, bf) } if higherPrecisionEnabled(ctx) { *dest = *bf } else { *dest, _ = bf.Float64() } return nil case "date": v, err := strconv.ParseInt(*srcValue, 10, 64) if err != nil { return err } *dest = time.Unix(v*86400, 0).UTC() return nil case "time": sec, nsec, err := extractTimestamp(srcValue) if err != nil { return err } t0 := time.Time{} *dest = t0.Add(time.Duration(sec*1e9 + nsec)) return nil case "timestamp_ntz": sec, nsec, err := extractTimestamp(srcValue) if err != nil { return err } *dest = time.Unix(sec, nsec).UTC() return nil case "timestamp_ltz": sec, nsec, err := extractTimestamp(srcValue) if err != nil { return err } if loc == nil { loc = time.Now().Location() } *dest = time.Unix(sec, nsec).In(loc) return nil case "timestamp_tz": logger.Debugf("tz: %v", *srcValue) tm := strings.Split(*srcValue, " ") if len(tm) != 2 { return &SnowflakeError{ Number: ErrInvalidTimestampTz, SQLState: SQLStateInvalidDataTimeFormat, Message: fmt.Sprintf("invalid TIMESTAMP_TZ data. The value doesn't consist of two numeric values separated by a space: %v", *srcValue), } } sec, nsec, err := extractTimestamp(&tm[0]) if err != nil { return err } offset, err := strconv.ParseInt(tm[1], 10, 64) if err != nil { return &SnowflakeError{ Number: ErrInvalidTimestampTz, SQLState: SQLStateInvalidDataTimeFormat, Message: fmt.Sprintf("invalid TIMESTAMP_TZ data. The offset value is not integer: %v", tm[1]), } } loc := Location(int(offset) - 1440) tt := time.Unix(sec, nsec) *dest = tt.In(loc) return nil case "binary": b, err := hex.DecodeString(*srcValue) if err != nil { return &SnowflakeError{ Number: ErrInvalidBinaryHexForm, SQLState: SQLStateNumericValueOutOfRange, Message: err.Error(), } } *dest = b return nil case "array": if len(srcColumnMeta.Fields) == 0 || !structuredTypesEnabled { *dest = *srcValue return nil } if len(srcColumnMeta.Fields) > 1 { return errors.New("got more than one field for array") } var arr []any decoder := decoderWithNumbersAsStrings(srcValue) if err := decoder.Decode(&arr); err != nil { return err } v, err := buildStructuredArray(ctx, srcColumnMeta.Fields[0], arr, params) if err != nil { return err } *dest = v return nil case "map": var err error *dest, err = jsonToMap(ctx, srcColumnMeta.Fields[0], srcColumnMeta.Fields[1], *srcValue, params) return err } *dest = *srcValue return nil } func jsonToMap(ctx context.Context, keyMetadata, valueMetadata query.FieldMetadata, srcValue string, params *syncParams) (snowflakeValue, error) { structuredTypesEnabled := structuredTypesEnabled(ctx) if !structuredTypesEnabled { return srcValue, nil } switch keyMetadata.Type { case "text": var m map[string]any decoder := decoderWithNumbersAsStrings(&srcValue) err := decoder.Decode(&m) if err != nil { return nil, err } // returning snowflakeValue of complex types does not work with generics if valueMetadata.Type == "object" { res := make(map[string]*structuredType) for k, v := range m { if v == nil || reflect.ValueOf(v).IsNil() { res[k] = nil } else { res[k] = buildStructuredTypeFromMap(v.(map[string]any), valueMetadata.Fields, params) } } return res, nil } return jsonToMapWithKeyType[string](ctx, valueMetadata, m, params) case "fixed": var m map[int64]any decoder := decoderWithNumbersAsStrings(&srcValue) err := decoder.Decode(&m) if err != nil { return nil, err } if valueMetadata.Type == "object" { res := make(map[int64]*structuredType) for k, v := range m { res[k] = buildStructuredTypeFromMap(v.(map[string]any), valueMetadata.Fields, params) } return res, nil } return jsonToMapWithKeyType[int64](ctx, valueMetadata, m, params) default: return nil, fmt.Errorf("unsupported map key type: %v", keyMetadata.Type) } } func jsonToMapWithKeyType[K comparable](ctx context.Context, valueMetadata query.FieldMetadata, m map[K]any, params *syncParams) (snowflakeValue, error) { mapValuesNullableEnabled := embeddedValuesNullableEnabled(ctx) switch valueMetadata.Type { case "text": return buildMapValues[K, sql.NullString, string](mapValuesNullableEnabled, m, func(v any) (string, error) { return v.(string), nil }, func(v any) (sql.NullString, error) { return sql.NullString{Valid: v != nil, String: ifNotNullOrDefault(v, "")}, nil }, false) case "boolean": return buildMapValues[K, sql.NullBool, bool](mapValuesNullableEnabled, m, func(v any) (bool, error) { return v.(bool), nil }, func(v any) (sql.NullBool, error) { return sql.NullBool{Valid: v != nil, Bool: ifNotNullOrDefault(v, false)}, nil }, false) case "fixed": if valueMetadata.Scale == 0 { return buildMapValues[K, sql.NullInt64, int64](mapValuesNullableEnabled, m, func(v any) (int64, error) { return strconv.ParseInt(string(v.(json.Number)), 10, 64) }, func(v any) (sql.NullInt64, error) { if v != nil { i64, err := strconv.ParseInt(string(v.(json.Number)), 10, 64) if err != nil { return sql.NullInt64{}, err } return sql.NullInt64{Valid: true, Int64: i64}, nil } return sql.NullInt64{Valid: false}, nil }, false) } return buildMapValues[K, sql.NullFloat64, float64](mapValuesNullableEnabled, m, func(v any) (float64, error) { return strconv.ParseFloat(string(v.(json.Number)), 64) }, func(v any) (sql.NullFloat64, error) { if v != nil { f64, err := strconv.ParseFloat(string(v.(json.Number)), 64) if err != nil { return sql.NullFloat64{}, err } return sql.NullFloat64{Valid: true, Float64: f64}, nil } return sql.NullFloat64{Valid: false}, nil }, false) case "real": return buildMapValues[K, sql.NullFloat64, float64](mapValuesNullableEnabled, m, func(v any) (float64, error) { return strconv.ParseFloat(string(v.(json.Number)), 64) }, func(v any) (sql.NullFloat64, error) { if v != nil { f64, err := strconv.ParseFloat(string(v.(json.Number)), 64) if err != nil { return sql.NullFloat64{}, err } return sql.NullFloat64{Valid: true, Float64: f64}, nil } return sql.NullFloat64{Valid: false}, nil }, false) case "binary": return buildMapValues[K, []byte, []byte](mapValuesNullableEnabled, m, func(v any) ([]byte, error) { if v == nil { return nil, nil } return hex.DecodeString(v.(string)) }, func(v any) ([]byte, error) { if v == nil { return nil, nil } return hex.DecodeString(v.(string)) }, true) case "date", "time", "timestamp_tz", "timestamp_ltz", "timestamp_ntz": return buildMapValues[K, sql.NullTime, time.Time](mapValuesNullableEnabled, m, func(v any) (time.Time, error) { sfFormat, err := dateTimeOutputFormatByType(valueMetadata.Type, params) if err != nil { return time.Time{}, err } goFormat, err := snowflakeFormatToGoFormat(sfFormat) if err != nil { return time.Time{}, err } return time.Parse(goFormat, v.(string)) }, func(v any) (sql.NullTime, error) { if v == nil { return sql.NullTime{Valid: false}, nil } sfFormat, err := dateTimeOutputFormatByType(valueMetadata.Type, params) if err != nil { return sql.NullTime{}, err } goFormat, err := snowflakeFormatToGoFormat(sfFormat) if err != nil { return sql.NullTime{}, err } time, err := time.Parse(goFormat, v.(string)) if err != nil { return sql.NullTime{}, err } return sql.NullTime{Valid: true, Time: time}, nil }, false) case "array": arrayMetadata := valueMetadata.Fields[0] switch arrayMetadata.Type { case "text": return buildArrayFromMap[K, string](ctx, arrayMetadata, m, params) case "fixed": if arrayMetadata.Scale == 0 { return buildArrayFromMap[K, int64](ctx, arrayMetadata, m, params) } return buildArrayFromMap[K, float64](ctx, arrayMetadata, m, params) case "real": return buildArrayFromMap[K, float64](ctx, arrayMetadata, m, params) case "binary": return buildArrayFromMap[K, []byte](ctx, arrayMetadata, m, params) case "boolean": return buildArrayFromMap[K, bool](ctx, arrayMetadata, m, params) case "date", "time", "timestamp_ltz", "timestamp_tz", "timestamp_ntz": return buildArrayFromMap[K, time.Time](ctx, arrayMetadata, m, params) } } return nil, fmt.Errorf("unsupported map value type: %v", valueMetadata.Type) } func buildArrayFromMap[K comparable, V any](ctx context.Context, valueMetadata query.FieldMetadata, m map[K]any, params *syncParams) (snowflakeValue, error) { res := make(map[K][]V) for k, v := range m { if v == nil { res[k] = nil } else { structuredArray, err := buildStructuredArray(ctx, valueMetadata, v.([]any), params) if err != nil { return nil, err } res[k] = structuredArray.([]V) } } return res, nil } func buildStructuredTypeFromMap(values map[string]any, fieldMetadata []query.FieldMetadata, params *syncParams) *structuredType { return &structuredType{ values: values, params: params, fieldMetadata: fieldMetadata, } } func ifNotNullOrDefault[T any](t any, def T) T { if t == nil { return def } return t.(T) } func buildMapValues[K comparable, Vnullable any, VnotNullable any](mapValuesNullableEnabled bool, m map[K]any, buildNotNullable func(v any) (VnotNullable, error), buildNullable func(v any) (Vnullable, error), nullableByDefault bool) (snowflakeValue, error) { var err error if mapValuesNullableEnabled { result := make(map[K]Vnullable, len(m)) for k, v := range m { if result[k], err = buildNullable(v); err != nil { return nil, err } } return result, nil } result := make(map[K]VnotNullable, len(m)) for k, v := range m { if v == nil && !nullableByDefault { return nil, errors2.ErrNullValueInMapError() } if result[k], err = buildNotNullable(v); err != nil { return nil, err } } return result, nil } func buildStructuredArray(ctx context.Context, fieldMetadata query.FieldMetadata, srcValue []any, params *syncParams) (any, error) { switch fieldMetadata.Type { case "text": return copyArrayAndConvert[string](srcValue, func(input any) (string, error) { return input.(string), nil }) case "fixed": if fieldMetadata.Scale == 0 { return copyArrayAndConvert[int64](srcValue, func(input any) (int64, error) { return strconv.ParseInt(string(input.(json.Number)), 10, 64) }) } return copyArrayAndConvert[float64](srcValue, func(input any) (float64, error) { return strconv.ParseFloat(string(input.(json.Number)), 64) }) case "real": return copyArrayAndConvert[float64](srcValue, func(input any) (float64, error) { return strconv.ParseFloat(string(input.(json.Number)), 64) }) case "time", "date", "timestamp_ltz", "timestamp_ntz", "timestamp_tz": return copyArrayAndConvert[time.Time](srcValue, func(input any) (time.Time, error) { sfFormat, err := dateTimeOutputFormatByType(fieldMetadata.Type, params) if err != nil { return time.Time{}, err } goFormat, err := snowflakeFormatToGoFormat(sfFormat) if err != nil { return time.Time{}, err } return time.Parse(goFormat, input.(string)) }) case "boolean": return copyArrayAndConvert[bool](srcValue, func(input any) (bool, error) { return input.(bool), nil }) case "binary": return copyArrayAndConvert[[]byte](srcValue, func(input any) ([]byte, error) { return hex.DecodeString(input.(string)) }) case "object": return copyArrayAndConvert[*structuredType](srcValue, func(input any) (*structuredType, error) { return buildStructuredTypeRecursive(ctx, input.(map[string]any), fieldMetadata.Fields, params) }) case "array": switch fieldMetadata.Fields[0].Type { case "text": return buildStructuredArrayRecursive[string](ctx, fieldMetadata.Fields[0], srcValue, params) case "fixed": if fieldMetadata.Fields[0].Scale == 0 { return buildStructuredArrayRecursive[int64](ctx, fieldMetadata.Fields[0], srcValue, params) } return buildStructuredArrayRecursive[float64](ctx, fieldMetadata.Fields[0], srcValue, params) case "real": return buildStructuredArrayRecursive[float64](ctx, fieldMetadata.Fields[0], srcValue, params) case "boolean": return buildStructuredArrayRecursive[bool](ctx, fieldMetadata.Fields[0], srcValue, params) case "binary": return buildStructuredArrayRecursive[[]byte](ctx, fieldMetadata.Fields[0], srcValue, params) case "date", "time", "timestamp_ltz", "timestamp_ntz", "timestamp_tz": return buildStructuredArrayRecursive[time.Time](ctx, fieldMetadata.Fields[0], srcValue, params) } } return srcValue, nil } func buildStructuredArrayRecursive[T any](ctx context.Context, fieldMetadata query.FieldMetadata, srcValue []any, params *syncParams) ([][]T, error) { arr := make([][]T, len(srcValue)) for i, v := range srcValue { structuredArray, err := buildStructuredArray(ctx, fieldMetadata, v.([]any), params) if err != nil { return nil, err } arr[i] = structuredArray.([]T) } return arr, nil } func copyArrayAndConvert[T any](input []any, convertFunc func(input any) (T, error)) ([]T, error) { var err error output := make([]T, len(input)) for i, s := range input { if output[i], err = convertFunc(s); err != nil { return nil, err } } return output, nil } func buildStructuredTypeRecursive(ctx context.Context, m map[string]any, fields []query.FieldMetadata, params *syncParams) (*structuredType, error) { var err error for _, fm := range fields { if fm.Type == "array" && m[fm.Name] != nil { if m[fm.Name], err = buildStructuredArray(ctx, fm.Fields[0], m[fm.Name].([]any), params); err != nil { return nil, err } } else if fm.Type == "map" && m[fm.Name] != nil { if m[fm.Name], err = jsonToMapWithKeyType(ctx, fm.Fields[1], m[fm.Name].(map[string]any), params); err != nil { return nil, err } } else if fm.Type == "object" && m[fm.Name] != nil { if m[fm.Name], err = buildStructuredTypeRecursive(ctx, m[fm.Name].(map[string]any), fm.Fields, params); err != nil { return nil, err } } } return &structuredType{ values: m, fieldMetadata: fields, params: params, }, nil } var decimalShift = new(big.Int).Exp(big.NewInt(2), big.NewInt(64), nil) func intToBigFloat(val int64, scale int64) *big.Float { f := new(big.Float).SetInt64(val) s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) return new(big.Float).Quo(f, s) } func decimalToBigInt(num decimal128.Num) *big.Int { high := new(big.Int).SetInt64(num.HighBits()) low := new(big.Int).SetUint64(num.LowBits()) return new(big.Int).Add(new(big.Int).Mul(high, decimalShift), low) } func decimalToBigFloat(num decimal128.Num, scale int64) *big.Float { f := new(big.Float).SetInt(decimalToBigInt(num)) s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) return new(big.Float).Quo(f, s) } func arrowSnowflakeTimestampToTime( column arrow.Array, sfType types.SnowflakeType, scale int, recIdx int, loc *time.Location) *time.Time { if column.IsNull(recIdx) { return nil } var ret time.Time switch sfType { case types.TimestampNtzType: if column.DataType().ID() == arrow.STRUCT { structData := column.(*array.Struct) epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).UTC() } else { intData := column.(*array.Int64) value := intData.Value(recIdx) epoch := extractEpoch(value, scale) fraction := extractFraction(value, scale) ret = time.Unix(epoch, fraction).UTC() } case types.TimestampLtzType: if column.DataType().ID() == arrow.STRUCT { structData := column.(*array.Struct) epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(loc) } else { intData := column.(*array.Int64) value := intData.Value(recIdx) epoch := extractEpoch(value, scale) fraction := extractFraction(value, scale) ret = time.Unix(epoch, fraction).In(loc) } case types.TimestampTzType: structData := column.(*array.Struct) if structData.NumField() == 2 { value := structData.Field(0).(*array.Int64).Int64Values() timezone := structData.Field(1).(*array.Int32).Int32Values() epoch := extractEpoch(value[recIdx], scale) fraction := extractFraction(value[recIdx], scale) locTz := Location(int(timezone[recIdx]) - 1440) ret = time.Unix(epoch, fraction).In(locTz) } else { epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() timezone := structData.Field(2).(*array.Int32).Int32Values() locTz := Location(int(timezone[recIdx]) - 1440) ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(locTz) } } return &ret } func extractEpoch(value int64, scale int) int64 { return value / int64(math.Pow10(scale)) } func extractFraction(value int64, scale int) int64 { return (value % int64(math.Pow10(scale))) * int64(math.Pow10(9-scale)) } // Arrow Interface (Column) converter. This is called when Arrow chunks are // downloaded to convert to the corresponding row type. func arrowToValues( ctx context.Context, destcol []snowflakeValue, srcColumnMeta query.ExecResponseRowType, srcValue arrow.Array, loc *time.Location, higherPrecision bool, params *syncParams) error { if len(destcol) != srcValue.Len() { return fmt.Errorf("array interface length mismatch") } logger.Debugf("snowflake data type: %v, arrow data type: %v", srcColumnMeta.Type, srcValue.DataType()) var err error snowflakeType := types.GetSnowflakeType(srcColumnMeta.Type) for i := range destcol { if destcol[i], err = arrowToValue(ctx, i, srcColumnMeta.ToFieldMetadata(), srcValue, loc, higherPrecision, params, snowflakeType); err != nil { return err } } return nil } func arrowToValue(ctx context.Context, rowIdx int, srcColumnMeta query.FieldMetadata, srcValue arrow.Array, loc *time.Location, higherPrecision bool, params *syncParams, snowflakeType types.SnowflakeType) (snowflakeValue, error) { structuredTypesEnabled := structuredTypesEnabled(ctx) switch snowflakeType { case types.FixedType: // Snowflake data types that are fixed-point numbers will fall into this category // e.g. NUMBER, DECIMAL/NUMERIC, INT/INTEGER switch numericValue := srcValue.(type) { case *array.Decimal128: return arrowDecimal128ToValue(numericValue, rowIdx, higherPrecision, srcColumnMeta), nil case *array.Int64: return arrowInt64ToValue(numericValue, rowIdx, higherPrecision, srcColumnMeta), nil case *array.Int32: return arrowInt32ToValue(numericValue, rowIdx, higherPrecision, srcColumnMeta), nil case *array.Int16: return arrowInt16ToValue(numericValue, rowIdx, higherPrecision, srcColumnMeta), nil case *array.Int8: return arrowInt8ToValue(numericValue, rowIdx, higherPrecision, srcColumnMeta), nil } return nil, fmt.Errorf("unsupported data type") case types.RealType: // Snowflake data types that are floating-point numbers will fall in this category // e.g. FLOAT/REAL/DOUBLE return arrowRealToValue(srcValue.(*array.Float64), rowIdx), nil case types.DecfloatType: return arrowDecFloatToValue(ctx, srcValue.(*array.Struct), rowIdx) case types.BooleanType: return arrowBoolToValue(srcValue.(*array.Boolean), rowIdx), nil case types.TextType, types.VariantType: strings := srcValue.(*array.String) if !srcValue.IsNull(rowIdx) { return strings.Value(rowIdx), nil } return nil, nil case types.ArrayType: if len(srcColumnMeta.Fields) == 0 || !structuredTypesEnabled { // semistructured type without schema strings := srcValue.(*array.String) if !srcValue.IsNull(rowIdx) { return strings.Value(rowIdx), nil } return nil, nil } strings, ok := srcValue.(*array.String) if ok { // structured array as json if !srcValue.IsNull(rowIdx) { val := strings.Value(rowIdx) var arr []any decoder := decoderWithNumbersAsStrings(&val) if err := decoder.Decode(&arr); err != nil { return nil, err } return buildStructuredArray(ctx, srcColumnMeta.Fields[0], arr, params) } return nil, nil } if !structuredTypesEnabled { return nil, errNativeArrowWithoutProperContext } return buildListFromNativeArrow(ctx, rowIdx, srcColumnMeta.Fields[0], srcValue, loc, higherPrecision, params) case types.ObjectType: if len(srcColumnMeta.Fields) == 0 || !structuredTypesEnabled { // semistructured type without schema strings := srcValue.(*array.String) if !srcValue.IsNull(rowIdx) { return strings.Value(rowIdx), nil } return nil, nil } strings, ok := srcValue.(*array.String) if ok { // structured objects as json if !srcValue.IsNull(rowIdx) { m := make(map[string]any) value := strings.Value(rowIdx) decoder := decoderWithNumbersAsStrings(&value) if err := decoder.Decode(&m); err != nil { return nil, err } return buildStructuredTypeRecursive(ctx, m, srcColumnMeta.Fields, params) } return nil, nil } // structured objects as native arrow if !structuredTypesEnabled { return nil, errNativeArrowWithoutProperContext } if srcValue.IsNull(rowIdx) { return nil, nil } structs := srcValue.(*array.Struct) return arrowToStructuredType(ctx, structs, srcColumnMeta.Fields, loc, rowIdx, higherPrecision, params) case types.MapType: if srcValue.IsNull(rowIdx) { return nil, nil } strings, ok := srcValue.(*array.String) if ok { // structured map as json if !srcValue.IsNull(rowIdx) { return jsonToMap(ctx, srcColumnMeta.Fields[0], srcColumnMeta.Fields[1], strings.Value(rowIdx), params) } } else { // structured map as native arrow if !structuredTypesEnabled { return nil, errNativeArrowWithoutProperContext } return buildMapFromNativeArrow(ctx, rowIdx, srcColumnMeta.Fields[0], srcColumnMeta.Fields[1], srcValue, loc, higherPrecision, params) } case types.BinaryType: return arrowBinaryToValue(srcValue.(*array.Binary), rowIdx), nil case types.DateType: return arrowDateToValue(srcValue.(*array.Date32), rowIdx), nil case types.TimeType: return arrowTimeToValue(srcValue, rowIdx, int(srcColumnMeta.Scale)), nil case types.TimestampNtzType, types.TimestampLtzType, types.TimestampTzType: v := arrowSnowflakeTimestampToTime(srcValue, snowflakeType, int(srcColumnMeta.Scale), rowIdx, loc) if v != nil { return *v, nil } return nil, nil } return nil, fmt.Errorf("unsupported data type") } func buildMapFromNativeArrow(ctx context.Context, rowIdx int, keyMetadata, valueMetadata query.FieldMetadata, srcValue arrow.Array, loc *time.Location, higherPrecision bool, params *syncParams) (snowflakeValue, error) { arrowMap := srcValue.(*array.Map) if arrowMap.IsNull(rowIdx) { return nil, nil } keys := arrowMap.Keys() items := arrowMap.Items() offsets := arrowMap.Offsets() switch keyMetadata.Type { case "text": keyFunc := func(j int) (string, error) { return keys.(*array.String).Value(j), nil } return buildStructuredMapFromArrow(ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) case "fixed": keyFunc := func(j int) (int64, error) { k, err := extractInt64(keys, int(j)) if err != nil { return 0, err } return k, nil } return buildStructuredMapFromArrow(ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) } return nil, nil } func buildListFromNativeArrow(ctx context.Context, rowIdx int, fieldMetadata query.FieldMetadata, srcValue arrow.Array, loc *time.Location, higherPrecision bool, params *syncParams) (snowflakeValue, error) { list := srcValue.(*array.List) if list.IsNull(rowIdx) { return nil, nil } values := list.ListValues() offsets := list.Offsets() snowflakeType := types.GetSnowflakeType(fieldMetadata.Type) switch snowflakeType { case types.FixedType: switch typedValues := values.(type) { case *array.Decimal128: if higherPrecision && fieldMetadata.Scale == 0 { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (*big.Int, error) { bigInt := arrowDecimal128ToValue(typedValues, j, higherPrecision, fieldMetadata) if bigInt == nil { return nil, nil } return bigInt.(*big.Int), nil }) } else if higherPrecision && fieldMetadata.Scale != 0 { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (*big.Float, error) { bigFloat := arrowDecimal128ToValue(typedValues, j, higherPrecision, fieldMetadata) if bigFloat == nil { return nil, nil } return bigFloat.(*big.Float), nil }) } else if !higherPrecision && fieldMetadata.Scale == 0 { if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullInt64, error) { v := arrowDecimal128ToValue(typedValues, j, higherPrecision, fieldMetadata) if v == nil { return sql.NullInt64{Valid: false}, nil } val, err := strconv.ParseInt(v.(string), 10, 64) if err != nil { return sql.NullInt64{Valid: false}, err } return sql.NullInt64{Valid: true, Int64: val}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (int64, error) { v := arrowDecimal128ToValue(typedValues, j, higherPrecision, fieldMetadata) if v == nil { return 0, errors2.ErrNullValueInArrayError() } return strconv.ParseInt(v.(string), 10, 64) }) } else { if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullFloat64, error) { v := arrowDecimal128ToValue(typedValues, j, higherPrecision, fieldMetadata) if v == nil { return sql.NullFloat64{Valid: false}, nil } val, err := strconv.ParseFloat(v.(string), 64) if err != nil { return sql.NullFloat64{Valid: false}, err } return sql.NullFloat64{Valid: true, Float64: val}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (float64, error) { v := arrowDecimal128ToValue(typedValues, j, higherPrecision, fieldMetadata) if v == nil { return 0, errors2.ErrNullValueInArrayError() } return strconv.ParseFloat(v.(string), 64) }) } case *array.Int64: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullInt64, error) { resInt := arrowInt64ToValue(typedValues, j, higherPrecision, fieldMetadata) if resInt == nil { return sql.NullInt64{Valid: false}, nil } return sql.NullInt64{Valid: true, Int64: resInt.(int64)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (int64, error) { resInt := arrowInt64ToValue(typedValues, j, higherPrecision, fieldMetadata) if resInt == nil { return 0, errors2.ErrNullValueInArrayError() } return resInt.(int64), nil }) case *array.Int32: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullInt32, error) { resInt := arrowInt32ToValue(typedValues, j, higherPrecision, fieldMetadata) if resInt == nil { return sql.NullInt32{Valid: false}, nil } return sql.NullInt32{Valid: true, Int32: resInt.(int32)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (int32, error) { resInt := arrowInt32ToValue(typedValues, j, higherPrecision, fieldMetadata) if resInt == nil { return 0, errors2.ErrNullValueInArrayError() } return resInt.(int32), nil }) case *array.Int16: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullInt16, error) { resInt := arrowInt16ToValue(typedValues, j, higherPrecision, fieldMetadata) if resInt == nil { return sql.NullInt16{Valid: false}, nil } return sql.NullInt16{Valid: true, Int16: resInt.(int16)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (int16, error) { resInt := arrowInt16ToValue(typedValues, j, higherPrecision, fieldMetadata) if resInt == nil { return 0, errors2.ErrNullValueInArrayError() } return resInt.(int16), nil }) case *array.Int8: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullByte, error) { resInt := arrowInt8ToValue(typedValues, j, higherPrecision, fieldMetadata) if resInt == nil { return sql.NullByte{Valid: false}, nil } return sql.NullByte{Valid: true, Byte: resInt.(byte)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (int8, error) { resInt := arrowInt8ToValue(typedValues, j, higherPrecision, fieldMetadata) if resInt == nil { return 0, errors2.ErrNullValueInArrayError() } return resInt.(int8), nil }) } case types.RealType: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullFloat64, error) { resFloat := arrowRealToValue(values.(*array.Float64), j) if resFloat == nil { return sql.NullFloat64{Valid: false}, nil } return sql.NullFloat64{Valid: true, Float64: resFloat.(float64)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (float64, error) { resFloat := arrowRealToValue(values.(*array.Float64), j) if resFloat == nil { return 0, errors2.ErrNullValueInArrayError() } return resFloat.(float64), nil }) case types.TextType: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullString, error) { resString := arrowStringToValue(values.(*array.String), j) if resString == nil { return sql.NullString{Valid: false}, nil } return sql.NullString{Valid: true, String: resString.(string)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (string, error) { resString := arrowStringToValue(values.(*array.String), j) if resString == nil { return "", errors2.ErrNullValueInArrayError() } return resString.(string), nil }) case types.BooleanType: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullBool, error) { resBool := arrowBoolToValue(values.(*array.Boolean), j) if resBool == nil { return sql.NullBool{Valid: false}, nil } return sql.NullBool{Valid: true, Bool: resBool.(bool)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (bool, error) { resBool := arrowBoolToValue(values.(*array.Boolean), j) if resBool == nil { return false, errors2.ErrNullValueInArrayError() } return resBool.(bool), nil }) case types.BinaryType: return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) ([]byte, error) { res := arrowBinaryToValue(values.(*array.Binary), j) if res == nil { return nil, nil } return res.([]byte), nil }) case types.DateType: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullTime, error) { v := arrowDateToValue(values.(*array.Date32), j) if v == nil { return sql.NullTime{Valid: false}, nil } return sql.NullTime{Valid: true, Time: v.(time.Time)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (time.Time, error) { v := arrowDateToValue(values.(*array.Date32), j) if v == nil { return time.Time{}, errors2.ErrNullValueInArrayError() } return v.(time.Time), nil }) case types.TimeType: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullTime, error) { v := arrowTimeToValue(values, j, fieldMetadata.Scale) if v == nil { return sql.NullTime{Valid: false}, nil } return sql.NullTime{Valid: true, Time: v.(time.Time)}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (time.Time, error) { v := arrowTimeToValue(values, j, fieldMetadata.Scale) if v == nil { return time.Time{}, errors2.ErrNullValueInArrayError() } return v.(time.Time), nil }) case types.TimestampNtzType, types.TimestampLtzType, types.TimestampTzType: if embeddedValuesNullableEnabled(ctx) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (sql.NullTime, error) { ptr := arrowSnowflakeTimestampToTime(values, snowflakeType, fieldMetadata.Scale, j, loc) if ptr != nil { return sql.NullTime{Valid: true, Time: *ptr}, nil } return sql.NullTime{Valid: false}, nil }) } return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (time.Time, error) { ptr := arrowSnowflakeTimestampToTime(values, snowflakeType, fieldMetadata.Scale, j, loc) if ptr != nil { return *ptr, nil } return time.Time{}, errors2.ErrNullValueInArrayError() }) case types.ObjectType: return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) (*structuredType, error) { if values.IsNull(j) { return nil, nil } m := make(map[string]any, len(fieldMetadata.Fields)) for fieldIdx, field := range fieldMetadata.Fields { m[field.Name] = values.(*array.Struct).Field(fieldIdx).ValueStr(j) } return buildStructuredTypeRecursive(ctx, m, fieldMetadata.Fields, params) }) case types.ArrayType: switch fieldMetadata.Fields[0].Type { case "text": if embeddedValuesNullableEnabled(ctx) { return buildArrowListRecursive[sql.NullString](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) } return buildArrowListRecursive[string](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) case "fixed": if fieldMetadata.Fields[0].Scale == 0 { if embeddedValuesNullableEnabled(ctx) { return buildArrowListRecursive[sql.NullInt64](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) } return buildArrowListRecursive[int64](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) } if embeddedValuesNullableEnabled(ctx) { return buildArrowListRecursive[sql.NullFloat64](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) } return buildArrowListRecursive[float64](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) case "real": if embeddedValuesNullableEnabled(ctx) { return buildArrowListRecursive[sql.NullFloat64](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) } return buildArrowListRecursive[float64](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) case "boolean": if embeddedValuesNullableEnabled(ctx) { return buildArrowListRecursive[sql.NullBool](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) } return buildArrowListRecursive[bool](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) case "binary": return buildArrowListRecursive[[]byte](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) case "date", "time", "timestamp_ltz", "timestamp_ntz", "timestamp_tz": if embeddedValuesNullableEnabled(ctx) { return buildArrowListRecursive[sql.NullTime](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) } return buildArrowListRecursive[time.Time](ctx, rowIdx, fieldMetadata, offsets, values, loc, higherPrecision, params) } } return nil, nil } func buildArrowListRecursive[T any](ctx context.Context, rowIdx int, fieldMetadata query.FieldMetadata, offsets []int32, values arrow.Array, loc *time.Location, higherPrecision bool, params *syncParams) (snowflakeValue, error) { return mapStructuredArrayNativeArrowRows(offsets, rowIdx, func(j int) ([]T, error) { arrowList, err := buildListFromNativeArrow(ctx, j, fieldMetadata.Fields[0], values, loc, higherPrecision, params) if err != nil { return nil, err } if arrowList == nil { return nil, nil } return arrowList.([]T), nil }) } func mapStructuredArrayNativeArrowRows[T any](offsets []int32, rowIdx int, createValueFunc func(j int) (T, error)) (snowflakeValue, error) { arr := make([]T, offsets[rowIdx+1]-offsets[rowIdx]) for j := offsets[rowIdx]; j < offsets[rowIdx+1]; j++ { v, err := createValueFunc(int(j)) if err != nil { return nil, err } arr[j-offsets[rowIdx]] = v } return arr, nil } func extractInt64(values arrow.Array, j int) (int64, error) { switch typedValues := values.(type) { case *array.Decimal128: return int64(typedValues.Value(j).LowBits()), nil case *array.Int64: return typedValues.Value(j), nil case *array.Int32: return int64(typedValues.Value(j)), nil case *array.Int16: return int64(typedValues.Value(j)), nil case *array.Int8: return int64(typedValues.Value(j)), nil } return 0, fmt.Errorf("unsupported map type: %T", values.DataType().Name()) } func buildStructuredMapFromArrow[K comparable](ctx context.Context, rowIdx int, valueMetadata query.FieldMetadata, offsets []int32, keyFunc func(j int) (K, error), items arrow.Array, higherPrecision bool, loc *time.Location, params *syncParams) (snowflakeValue, error) { mapNullValuesEnabled := embeddedValuesNullableEnabled(ctx) switch valueMetadata.Type { case "text": if mapNullValuesEnabled { return mapStructuredMapNativeArrowRows(make(map[K]sql.NullString), offsets, rowIdx, keyFunc, func(j int) (sql.NullString, error) { if items.IsNull(j) { return sql.NullString{Valid: false}, nil } return sql.NullString{Valid: true, String: items.(*array.String).Value(j)}, nil }) } return mapStructuredMapNativeArrowRows(make(map[K]string), offsets, rowIdx, keyFunc, func(j int) (string, error) { if items.IsNull(j) { return "", errors2.ErrNullValueInMapError() } return items.(*array.String).Value(j), nil }) case "boolean": if mapNullValuesEnabled { return mapStructuredMapNativeArrowRows(make(map[K]sql.NullBool), offsets, rowIdx, keyFunc, func(j int) (sql.NullBool, error) { if items.IsNull(j) { return sql.NullBool{Valid: false}, nil } return sql.NullBool{Valid: true, Bool: items.(*array.Boolean).Value(j)}, nil }) } return mapStructuredMapNativeArrowRows(make(map[K]bool), offsets, rowIdx, keyFunc, func(j int) (bool, error) { if items.IsNull(j) { return false, errors2.ErrNullValueInMapError() } return items.(*array.Boolean).Value(j), nil }) case "fixed": if higherPrecision && valueMetadata.Scale == 0 { return mapStructuredMapNativeArrowRows(make(map[K]*big.Int), offsets, rowIdx, keyFunc, func(j int) (*big.Int, error) { if items.IsNull(j) { return nil, nil } return mapStructuredMapNativeArrowFixedValue[*big.Int](valueMetadata, j, items, higherPrecision, nil) }) } else if higherPrecision && valueMetadata.Scale != 0 { return mapStructuredMapNativeArrowRows(make(map[K]*big.Float), offsets, rowIdx, keyFunc, func(j int) (*big.Float, error) { if items.IsNull(j) { return nil, nil } return mapStructuredMapNativeArrowFixedValue[*big.Float](valueMetadata, j, items, higherPrecision, nil) }) } else if !higherPrecision && valueMetadata.Scale == 0 { if mapNullValuesEnabled { return mapStructuredMapNativeArrowRows(make(map[K]sql.NullInt64), offsets, rowIdx, keyFunc, func(j int) (sql.NullInt64, error) { if items.IsNull(j) { return sql.NullInt64{Valid: false}, nil } s, err := mapStructuredMapNativeArrowFixedValue[string](valueMetadata, j, items, higherPrecision, "") if err != nil { return sql.NullInt64{}, err } i64, err := strconv.ParseInt(s, 10, 64) return sql.NullInt64{Valid: true, Int64: i64}, err }) } return mapStructuredMapNativeArrowRows(make(map[K]int64), offsets, rowIdx, keyFunc, func(j int) (int64, error) { if items.IsNull(j) { return 0, errors2.ErrNullValueInMapError() } s, err := mapStructuredMapNativeArrowFixedValue[string](valueMetadata, j, items, higherPrecision, "") if err != nil { return 0, err } return strconv.ParseInt(s, 10, 64) }) } else { if mapNullValuesEnabled { return mapStructuredMapNativeArrowRows(make(map[K]sql.NullFloat64), offsets, rowIdx, keyFunc, func(j int) (sql.NullFloat64, error) { if items.IsNull(j) { return sql.NullFloat64{Valid: false}, nil } s, err := mapStructuredMapNativeArrowFixedValue[string](valueMetadata, j, items, higherPrecision, "") if err != nil { return sql.NullFloat64{}, err } f64, err := strconv.ParseFloat(s, 64) return sql.NullFloat64{Valid: true, Float64: f64}, err }) } return mapStructuredMapNativeArrowRows(make(map[K]float64), offsets, rowIdx, keyFunc, func(j int) (float64, error) { if items.IsNull(j) { return 0, errors2.ErrNullValueInMapError() } s, err := mapStructuredMapNativeArrowFixedValue[string](valueMetadata, j, items, higherPrecision, "") if err != nil { return 0, err } return strconv.ParseFloat(s, 64) }) } case "real": if mapNullValuesEnabled { return mapStructuredMapNativeArrowRows(make(map[K]sql.NullFloat64), offsets, rowIdx, keyFunc, func(j int) (sql.NullFloat64, error) { if items.IsNull(j) { return sql.NullFloat64{Valid: false}, nil } f64 := items.(*array.Float64).Value(j) return sql.NullFloat64{Valid: true, Float64: f64}, nil }) } return mapStructuredMapNativeArrowRows(make(map[K]float64), offsets, rowIdx, keyFunc, func(j int) (float64, error) { if items.IsNull(j) { return 0, errors2.ErrNullValueInMapError() } return arrowRealToValue(items.(*array.Float64), j).(float64), nil }) case "binary": return mapStructuredMapNativeArrowRows(make(map[K][]byte), offsets, rowIdx, keyFunc, func(j int) ([]byte, error) { if items.IsNull(j) { return nil, nil } return arrowBinaryToValue(items.(*array.Binary), j).([]byte), nil }) case "date": return buildTimeFromNativeArrowArray(mapNullValuesEnabled, offsets, rowIdx, keyFunc, items, func(j int) time.Time { return arrowDateToValue(items.(*array.Date32), j).(time.Time) }) case "time": return buildTimeFromNativeArrowArray(mapNullValuesEnabled, offsets, rowIdx, keyFunc, items, func(j int) time.Time { return arrowTimeToValue(items, j, valueMetadata.Scale).(time.Time) }) case "timestamp_ltz", "timestamp_ntz", "timestamp_tz": return buildTimeFromNativeArrowArray(mapNullValuesEnabled, offsets, rowIdx, keyFunc, items, func(j int) time.Time { return *arrowSnowflakeTimestampToTime(items, types.GetSnowflakeType(valueMetadata.Type), valueMetadata.Scale, j, loc) }) case "object": return mapStructuredMapNativeArrowRows(make(map[K]*structuredType), offsets, rowIdx, keyFunc, func(j int) (*structuredType, error) { if items.IsNull(j) { return nil, nil } var err error m := make(map[string]any) for fieldIdx, field := range valueMetadata.Fields { snowflakeType := types.GetSnowflakeType(field.Type) m[field.Name], err = arrowToValue(ctx, j, field, items.(*array.Struct).Field(fieldIdx), loc, higherPrecision, params, snowflakeType) if err != nil { return nil, err } } return &structuredType{ values: m, fieldMetadata: valueMetadata.Fields, params: params, }, nil }) case "array": switch valueMetadata.Fields[0].Type { case "text": return buildListFromNativeArrowMap[K, string](ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) case "fixed": if valueMetadata.Fields[0].Scale == 0 { return buildListFromNativeArrowMap[K, int64](ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) } return buildListFromNativeArrowMap[K, float64](ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) case "real": return buildListFromNativeArrowMap[K, float64](ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) case "binary": return buildListFromNativeArrowMap[K, []byte](ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) case "boolean": return buildListFromNativeArrowMap[K, bool](ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) case "date", "time", "timestamp_ltz", "timestamp_ntz", "timestamp_tz": return buildListFromNativeArrowMap[K, time.Time](ctx, rowIdx, valueMetadata, offsets, keyFunc, items, higherPrecision, loc, params) } } return nil, errors.New("Unsupported map value: " + valueMetadata.Type) } func buildListFromNativeArrowMap[K comparable, V any](ctx context.Context, rowIdx int, valueMetadata query.FieldMetadata, offsets []int32, keyFunc func(j int) (K, error), items arrow.Array, higherPrecision bool, loc *time.Location, params *syncParams) (snowflakeValue, error) { return mapStructuredMapNativeArrowRows(make(map[K][]V), offsets, rowIdx, keyFunc, func(j int) ([]V, error) { if items.IsNull(j) { return nil, nil } list, err := buildListFromNativeArrow(ctx, j, valueMetadata.Fields[0], items, loc, higherPrecision, params) return list.([]V), err }) } func buildTimeFromNativeArrowArray[K comparable](mapNullValuesEnabled bool, offsets []int32, rowIdx int, keyFunc func(j int) (K, error), items arrow.Array, buildTime func(j int) time.Time) (snowflakeValue, error) { if mapNullValuesEnabled { return mapStructuredMapNativeArrowRows(make(map[K]sql.NullTime), offsets, rowIdx, keyFunc, func(j int) (sql.NullTime, error) { if items.IsNull(j) { return sql.NullTime{Valid: false}, nil } return sql.NullTime{Valid: true, Time: buildTime(j)}, nil }) } return mapStructuredMapNativeArrowRows(make(map[K]time.Time), offsets, rowIdx, keyFunc, func(j int) (time.Time, error) { if items.IsNull(j) { return time.Time{}, errors2.ErrNullValueInMapError() } return buildTime(j), nil }) } func mapStructuredMapNativeArrowFixedValue[V any](valueMetadata query.FieldMetadata, j int, items arrow.Array, higherPrecision bool, defaultValue V) (V, error) { v, err := extractNumberFromArrow(&items, j, higherPrecision, valueMetadata) if err != nil { return defaultValue, err } return v.(V), nil } func extractNumberFromArrow(values *arrow.Array, j int, higherPrecision bool, srcColumnMeta query.FieldMetadata) (snowflakeValue, error) { switch typedValues := (*values).(type) { case *array.Decimal128: return arrowDecimal128ToValue(typedValues, j, higherPrecision, srcColumnMeta), nil case *array.Int64: return arrowInt64ToValue(typedValues, j, higherPrecision, srcColumnMeta), nil case *array.Int32: return arrowInt32ToValue(typedValues, j, higherPrecision, srcColumnMeta), nil case *array.Int16: return arrowInt16ToValue(typedValues, j, higherPrecision, srcColumnMeta), nil case *array.Int8: return arrowInt8ToValue(typedValues, j, higherPrecision, srcColumnMeta), nil } return 0, fmt.Errorf("unknown number type: %T", values) } func mapStructuredMapNativeArrowRows[K comparable, V any](m map[K]V, offsets []int32, rowIdx int, keyFunc func(j int) (K, error), itemFunc func(j int) (V, error)) (map[K]V, error) { for j := offsets[rowIdx]; j < offsets[rowIdx+1]; j++ { k, err := keyFunc(int(j)) if err != nil { return nil, err } if m[k], err = itemFunc(int(j)); err != nil { return nil, err } } return m, nil } func arrowToStructuredType(ctx context.Context, structs *array.Struct, fieldMetadata []query.FieldMetadata, loc *time.Location, rowIdx int, higherPrecision bool, params *syncParams) (*structuredType, error) { var err error m := make(map[string]any) for colIdx := 0; colIdx < structs.NumField(); colIdx++ { var v any switch types.GetSnowflakeType(fieldMetadata[colIdx].Type) { case types.FixedType: v = structs.Field(colIdx).ValueStr(rowIdx) switch typedValues := structs.Field(colIdx).(type) { case *array.Decimal128: v = arrowDecimal128ToValue(typedValues, rowIdx, higherPrecision, fieldMetadata[colIdx]) case *array.Int64: v = arrowInt64ToValue(typedValues, rowIdx, higherPrecision, fieldMetadata[colIdx]) case *array.Int32: v = arrowInt32ToValue(typedValues, rowIdx, higherPrecision, fieldMetadata[colIdx]) case *array.Int16: v = arrowInt16ToValue(typedValues, rowIdx, higherPrecision, fieldMetadata[colIdx]) case *array.Int8: v = arrowInt8ToValue(typedValues, rowIdx, higherPrecision, fieldMetadata[colIdx]) } case types.BooleanType: v = arrowBoolToValue(structs.Field(colIdx).(*array.Boolean), rowIdx) case types.RealType: v = arrowRealToValue(structs.Field(colIdx).(*array.Float64), rowIdx) case types.BinaryType: v = arrowBinaryToValue(structs.Field(colIdx).(*array.Binary), rowIdx) case types.DateType: v = arrowDateToValue(structs.Field(colIdx).(*array.Date32), rowIdx) case types.TimeType: v = arrowTimeToValue(structs.Field(colIdx), rowIdx, fieldMetadata[colIdx].Scale) case types.TextType: v = arrowStringToValue(structs.Field(colIdx).(*array.String), rowIdx) case types.TimestampLtzType, types.TimestampTzType, types.TimestampNtzType: ptr := arrowSnowflakeTimestampToTime(structs.Field(colIdx), types.GetSnowflakeType(fieldMetadata[colIdx].Type), fieldMetadata[colIdx].Scale, rowIdx, loc) if ptr != nil { v = *ptr } case types.ObjectType: if !structs.Field(colIdx).IsNull(rowIdx) { if v, err = arrowToStructuredType(ctx, structs.Field(colIdx).(*array.Struct), fieldMetadata[colIdx].Fields, loc, rowIdx, higherPrecision, params); err != nil { return nil, err } } case types.ArrayType: if !structs.Field(colIdx).IsNull(rowIdx) { var err error if v, err = buildListFromNativeArrow(ctx, rowIdx, fieldMetadata[colIdx].Fields[0], structs.Field(colIdx), loc, higherPrecision, params); err != nil { return nil, err } } case types.MapType: if !structs.Field(colIdx).IsNull(rowIdx) { var err error if v, err = buildMapFromNativeArrow(ctx, rowIdx, fieldMetadata[colIdx].Fields[0], fieldMetadata[colIdx].Fields[1], structs.Field(colIdx), loc, higherPrecision, params); err != nil { return nil, err } } } m[fieldMetadata[colIdx].Name] = v } return &structuredType{ values: m, fieldMetadata: fieldMetadata, params: params, }, nil } func arrowStringToValue(srcValue *array.String, rowIdx int) snowflakeValue { if srcValue.IsNull(rowIdx) { return nil } return srcValue.Value(rowIdx) } func arrowDecimal128ToValue(srcValue *array.Decimal128, rowIdx int, higherPrecision bool, srcColumnMeta query.FieldMetadata) snowflakeValue { if !srcValue.IsNull(rowIdx) { num := srcValue.Value(rowIdx) if srcColumnMeta.Scale == 0 { if higherPrecision { return num.BigInt() } return num.ToString(0) } f := decimalToBigFloat(num, int64(srcColumnMeta.Scale)) if higherPrecision { return f } return fmt.Sprintf("%.*f", srcColumnMeta.Scale, f) } return nil } func arrowInt64ToValue(srcValue *array.Int64, rowIdx int, higherPrecision bool, srcColumnMeta query.FieldMetadata) snowflakeValue { if !srcValue.IsNull(rowIdx) { val := srcValue.Value(rowIdx) return arrowIntToValue(srcColumnMeta, higherPrecision, val) } return nil } func arrowInt32ToValue(srcValue *array.Int32, rowIdx int, higherPrecision bool, srcColumnMeta query.FieldMetadata) snowflakeValue { if !srcValue.IsNull(rowIdx) { val := srcValue.Value(rowIdx) return arrowIntToValue(srcColumnMeta, higherPrecision, int64(val)) } return nil } func arrowInt16ToValue(srcValue *array.Int16, rowIdx int, higherPrecision bool, srcColumnMeta query.FieldMetadata) snowflakeValue { if !srcValue.IsNull(rowIdx) { val := srcValue.Value(rowIdx) return arrowIntToValue(srcColumnMeta, higherPrecision, int64(val)) } return nil } func arrowInt8ToValue(srcValue *array.Int8, rowIdx int, higherPrecision bool, srcColumnMeta query.FieldMetadata) snowflakeValue { if !srcValue.IsNull(rowIdx) { val := srcValue.Value(rowIdx) return arrowIntToValue(srcColumnMeta, higherPrecision, int64(val)) } return nil } func arrowIntToValue(srcColumnMeta query.FieldMetadata, higherPrecision bool, val int64) snowflakeValue { if srcColumnMeta.Scale == 0 { if higherPrecision { if srcColumnMeta.Precision >= 19 { return big.NewInt(val) } return val } return fmt.Sprintf("%d", val) } if higherPrecision { f := intToBigFloat(val, int64(srcColumnMeta.Scale)) return f } return fmt.Sprintf("%.*f", srcColumnMeta.Scale, float64(val)/math.Pow10(srcColumnMeta.Scale)) } func arrowRealToValue(srcValue *array.Float64, rowIdx int) snowflakeValue { if !srcValue.IsNull(rowIdx) { return srcValue.Value(rowIdx) } return nil } func arrowDecFloatToValue(ctx context.Context, srcValue *array.Struct, rowIdx int) (snowflakeValue, error) { if !srcValue.IsNull(rowIdx) { exponent := srcValue.Field(0).(*array.Int16).Value(rowIdx) mantissaBytes := srcValue.Field(1).(*array.Binary).Value(rowIdx) mantissaInt, err := parseTwosComplementBigEndian(mantissaBytes) if err != nil { return nil, fmt.Errorf("failed to parse mantissa bytes: %s, error: %v", hex.EncodeToString(mantissaBytes), err) } if decfloatMappingEnabled(ctx) { mantissa := new(big.Float).SetPrec(127).SetInt(mantissaInt) if result, ok := new(big.Float).SetPrec(127).SetString(fmt.Sprintf("%ve%v", mantissa.Text('G', 38), exponent)); ok { return result, nil } return nil, fmt.Errorf("failed to create decfloat from mantissa %s and exponent %d", mantissa.Text('G', 38), exponent) } mantissaStr := mantissaInt.String() if mantissaStr == "0" { return "0", nil } negative := mantissaStr[0] == '-' mantissaUnsigned := strings.TrimLeft(mantissaStr, "-") mantissaLen := len(mantissaUnsigned) if mantissaLen > 1 { mantissaUnsigned = mantissaUnsigned[0:1] + "." + mantissaUnsigned[1:] } if negative { mantissaStr = "-" + mantissaUnsigned } else { mantissaStr = mantissaUnsigned } exponent = exponent + int16(mantissaLen) - 1 result := mantissaStr if exponent != 0 { result = mantissaStr + "e" + strconv.Itoa(int(exponent)) } return result, nil } return nil, nil } func parseTwosComplementBigEndian(b []byte) (*big.Int, error) { if len(b) > 16 { return nil, fmt.Errorf("input byte slice is too long (max 16 bytes)") } val := new(big.Int) val.SetBytes(b) // big.Int.SetBytes treats the bytes as an unsigned magnitude // If the sign bit is 1, the number is negative. if b[0]&0x80 != 0 { // Calculate 2^(bit length) for subtraction bitLength := uint(len(b) * 8) powerOfTwo := new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(bitLength)), nil) // Subtract 2^(bit length) from the unsigned value to get the signed value. val.Sub(val, powerOfTwo) } return val, nil } func arrowBoolToValue(srcValue *array.Boolean, rowIdx int) snowflakeValue { if !srcValue.IsNull(rowIdx) { return srcValue.Value(rowIdx) } return nil } func arrowBinaryToValue(srcValue *array.Binary, rowIdx int) snowflakeValue { if !srcValue.IsNull(rowIdx) { return srcValue.Value(rowIdx) } return nil } func arrowDateToValue(srcValue *array.Date32, rowID int) snowflakeValue { if !srcValue.IsNull(rowID) { return time.Unix(int64(srcValue.Value(rowID))*86400, 0).UTC() } return nil } func arrowTimeToValue(srcValue arrow.Array, rowIdx int, scale int) snowflakeValue { t0 := time.Time{} if srcValue.DataType().ID() == arrow.INT64 { if !srcValue.IsNull(rowIdx) { return t0.Add(time.Duration(srcValue.(*array.Int64).Value(rowIdx) * int64(math.Pow10(9-scale)))) } } else { if !srcValue.IsNull(rowIdx) { return t0.Add(time.Duration(int64(srcValue.(*array.Int32).Value(rowIdx)) * int64(math.Pow10(9-scale)))) } } return nil } type ( intArray []int int32Array []int32 int64Array []int64 float64Array []float64 float32Array []float32 decfloatArray []*big.Float boolArray []bool stringArray []string byteArray [][]byte timestampNtzArray []time.Time timestampLtzArray []time.Time timestampTzArray []time.Time dateArray []time.Time timeArray []time.Time ) // Array takes in a column of a row to be inserted via array binding, bulk or // otherwise, and converts it into a native snowflake type for binding func Array(a any, typ ...any) (any, error) { switch t := a.(type) { case []int: return (*intArray)(&t), nil case []int32: return (*int32Array)(&t), nil case []int64: return (*int64Array)(&t), nil case []float64: return (*float64Array)(&t), nil case []float32: return (*float32Array)(&t), nil case []*big.Float: if len(typ) == 1 { if b, ok := typ[0].([]byte); ok && bytes.Equal(b, DataTypeDecfloat) { return (*decfloatArray)(&t), nil } } return nil, errors.New("unsupported *big.Float array bind. Set the type to DataTypeDecfloat to use decfloatArray") case []bool: return (*boolArray)(&t), nil case []string: return (*stringArray)(&t), nil case [][]byte: return (*byteArray)(&t), nil case []time.Time: if len(typ) < 1 { return nil, errUnsupportedTimeArrayBind } switch typ[0] { case TimestampNTZType: return (*timestampNtzArray)(&t), nil case TimestampLTZType: return (*timestampLtzArray)(&t), nil case TimestampTZType: return (*timestampTzArray)(&t), nil case DateType: return (*dateArray)(&t), nil case TimeType: return (*timeArray)(&t), nil default: return nil, errUnsupportedTimeArrayBind } case *[]int: return (*intArray)(t), nil case *[]int32: return (*int32Array)(t), nil case *[]int64: return (*int64Array)(t), nil case *[]float64: return (*float64Array)(t), nil case *[]float32: return (*float32Array)(t), nil case *[]*big.Float: if len(typ) == 1 { if b, ok := typ[0].([]byte); ok && bytes.Equal(b, DataTypeDecfloat) { return (*decfloatArray)(t), nil } } return nil, errors.New("unsupported *big.Float array bind. Set the type to DataTypeDecfloat to use decfloatArray") case *[]bool: return (*boolArray)(t), nil case *[]string: return (*stringArray)(t), nil case *[][]byte: return (*byteArray)(t), nil case *[]time.Time: if len(typ) < 1 { return nil, errUnsupportedTimeArrayBind } switch typ[0] { case TimestampNTZType: return (*timestampNtzArray)(t), nil case TimestampLTZType: return (*timestampLtzArray)(t), nil case TimestampTZType: return (*timestampTzArray)(t), nil case DateType: return (*dateArray)(t), nil case TimeType: return (*timeArray)(t), nil default: return nil, errUnsupportedTimeArrayBind } case []any, *[]any: // Support for bulk array binding insertion using []any / *[]any if len(typ) < 1 { return interfaceArrayBinding{ hasTimezone: false, timezoneTypeArray: a, }, nil } return interfaceArrayBinding{ hasTimezone: true, tzType: typ[0].(timezoneType), timezoneTypeArray: a, }, nil default: return nil, fmt.Errorf("unknown array type for binding: %T", a) } } // snowflakeArrayToString converts the array binding to snowflake's native // string type. The string value differs whether it's directly bound or // uploaded via stream. func snowflakeArrayToString(nv *driver.NamedValue, stream bool) (types.SnowflakeType, []*string, error) { var t types.SnowflakeType var arr []*string switch reflect.TypeOf(nv.Value) { case reflect.TypeFor[*intArray](): t = types.FixedType a := nv.Value.(*intArray) for _, x := range *a { v := strconv.Itoa(x) arr = append(arr, &v) } case reflect.TypeFor[*int64Array](): t = types.FixedType a := nv.Value.(*int64Array) for _, x := range *a { v := strconv.FormatInt(x, 10) arr = append(arr, &v) } case reflect.TypeFor[*int32Array](): t = types.FixedType a := nv.Value.(*int32Array) for _, x := range *a { v := strconv.Itoa(int(x)) arr = append(arr, &v) } case reflect.TypeFor[*float64Array](): t = types.RealType a := nv.Value.(*float64Array) for _, x := range *a { v := fmt.Sprintf("%g", x) arr = append(arr, &v) } case reflect.TypeFor[*float32Array](): t = types.RealType a := nv.Value.(*float32Array) for _, x := range *a { v := fmt.Sprintf("%g", x) arr = append(arr, &v) } case reflect.TypeFor[*decfloatArray](): t = types.TextType a := nv.Value.(*decfloatArray) for _, x := range *a { v := x.Text('g', decfloatPrintingPrec) arr = append(arr, &v) } case reflect.TypeFor[*boolArray](): t = types.BooleanType a := nv.Value.(*boolArray) for _, x := range *a { v := strconv.FormatBool(x) arr = append(arr, &v) } case reflect.TypeFor[*stringArray](): t = types.TextType a := nv.Value.(*stringArray) for _, x := range *a { v := x // necessary for address to be not overwritten arr = append(arr, &v) } case reflect.TypeFor[*byteArray](): t = types.BinaryType a := nv.Value.(*byteArray) for _, x := range *a { v := hex.EncodeToString(x) arr = append(arr, &v) } case reflect.TypeFor[*timestampNtzArray](): t = types.TimestampNtzType a := nv.Value.(*timestampNtzArray) for _, x := range *a { v, err := getTimestampBindValue(x, stream, t) if err != nil { return types.UnSupportedType, nil, err } arr = append(arr, &v) } case reflect.TypeFor[*timestampLtzArray](): t = types.TimestampLtzType a := nv.Value.(*timestampLtzArray) for _, x := range *a { v, err := getTimestampBindValue(x, stream, t) if err != nil { return types.UnSupportedType, nil, err } arr = append(arr, &v) } case reflect.TypeFor[*timestampTzArray](): t = types.TimestampTzType a := nv.Value.(*timestampTzArray) for _, x := range *a { v, err := getTimestampBindValue(x, stream, t) if err != nil { return types.UnSupportedType, nil, err } arr = append(arr, &v) } case reflect.TypeFor[*dateArray](): t = types.DateType a := nv.Value.(*dateArray) for _, x := range *a { var v string if stream { v = x.Format("2006-01-02") } else { _, offset := x.Zone() x = x.Add(time.Second * time.Duration(offset)) v = fmt.Sprintf("%d", x.Unix()*1000) } arr = append(arr, &v) } case reflect.TypeFor[*timeArray](): t = types.TimeType a := nv.Value.(*timeArray) for _, x := range *a { var v string if stream { v = fmt.Sprintf("%02d:%02d:%02d.%09d", x.Hour(), x.Minute(), x.Second(), x.Nanosecond()) } else { h, m, s := x.Clock() tm := int64(h)*int64(time.Hour) + int64(m)*int64(time.Minute) + int64(s)*int64(time.Second) + int64(x.Nanosecond()) v = strconv.FormatInt(tm, 10) } arr = append(arr, &v) } default: // Support for bulk array binding insertion using []any / *[]any nvValue := reflect.ValueOf(nv) if nvValue.Kind() == reflect.Pointer { value := reflect.Indirect(reflect.ValueOf(nv.Value)) if isInterfaceArrayBinding(value.Interface()) { timeStruct, ok := value.Interface().(interfaceArrayBinding) if ok { timeInterfaceSlice := reflect.Indirect(reflect.ValueOf(timeStruct.timezoneTypeArray)) if timeStruct.hasTimezone { return interfaceSliceToString(timeInterfaceSlice, stream, timeStruct.tzType) } return interfaceSliceToString(timeInterfaceSlice, stream) } } } return types.UnSupportedType, nil, nil } return t, arr, nil } func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType ...timezoneType) (types.SnowflakeType, []*string, error) { var t types.SnowflakeType var arr []*string for i := 0; i < interfaceSlice.Len(); i++ { val := interfaceSlice.Index(i) if val.CanInterface() { v := val.Interface() switch x := v.(type) { case int: t = types.FixedType v := strconv.Itoa(x) arr = append(arr, &v) case int32: t = types.FixedType v := strconv.Itoa(int(x)) arr = append(arr, &v) case int64: t = types.FixedType v := strconv.FormatInt(x, 10) arr = append(arr, &v) case float32: t = types.RealType v := fmt.Sprintf("%g", x) arr = append(arr, &v) case float64: t = types.RealType v := fmt.Sprintf("%g", x) arr = append(arr, &v) case bool: t = types.BooleanType v := strconv.FormatBool(x) arr = append(arr, &v) case string: t = types.TextType arr = append(arr, &x) case []byte: t = types.BinaryType v := hex.EncodeToString(x) arr = append(arr, &v) case time.Time: if len(tzType) < 1 { return types.UnSupportedType, nil, nil } switch tzType[0] { case TimestampNTZType: t = types.TimestampNtzType v, err := getTimestampBindValue(x, stream, t) if err != nil { return types.UnSupportedType, nil, err } arr = append(arr, &v) case TimestampLTZType: t = types.TimestampLtzType v, err := getTimestampBindValue(x, stream, t) if err != nil { return types.UnSupportedType, nil, err } arr = append(arr, &v) case TimestampTZType: t = types.TimestampTzType v, err := getTimestampBindValue(x, stream, t) if err != nil { return types.UnSupportedType, nil, err } arr = append(arr, &v) case DateType: t = types.DateType _, offset := x.Zone() x = x.Add(time.Second * time.Duration(offset)) v := fmt.Sprintf("%d", x.Unix()*1000) arr = append(arr, &v) case TimeType: t = types.TimeType var v string if stream { v = x.Format(format[11:19]) } else { h, m, s := x.Clock() tm := int64(h)*int64(time.Hour) + int64(m)*int64(time.Minute) + int64(s)*int64(time.Second) + int64(x.Nanosecond()) v = strconv.FormatInt(tm, 10) } arr = append(arr, &v) default: return types.UnSupportedType, nil, nil } case driver.Valuer: // honor each driver's Valuer interface if value, err := x.Value(); err == nil && value != nil { // if the output value is a valid string, return that if strVal, ok := value.(string); ok { t = types.TextType arr = append(arr, &strVal) } } else if v != nil { return types.UnSupportedType, nil, nil } else { arr = append(arr, nil) } default: if val.Interface() != nil { if isUUIDImplementer(val) { t = types.TextType x := v.(fmt.Stringer).String() arr = append(arr, &x) continue } return types.UnSupportedType, nil, nil } arr = append(arr, nil) } } } return t, arr, nil } func higherPrecisionEnabled(ctx context.Context) bool { return ia.HigherPrecisionEnabled(ctx) } func decfloatMappingEnabled(ctx context.Context) bool { v := ctx.Value(enableDecfloat) if v == nil { return false } d, ok := v.(bool) return ok && d } // TypedNullTime is required to properly bind the null value with the snowflakeType as the Snowflake functions // require the type of the field to be provided explicitly for the null values type TypedNullTime struct { Time sql.NullTime TzType timezoneType } func convertTzTypeToSnowflakeType(tzType timezoneType) types.SnowflakeType { switch tzType { case TimestampNTZType: return types.TimestampNtzType case TimestampLTZType: return types.TimestampLtzType case TimestampTZType: return types.TimestampTzType case DateType: return types.DateType case TimeType: return types.TimeType } return types.UnSupportedType } func getTimestampBindValue(x time.Time, stream bool, t types.SnowflakeType) (string, error) { if stream { return x.Format(format), nil } return convertTimeToTimeStamp(x, t) } func convertTimeToTimeStamp(x time.Time, t types.SnowflakeType) (string, error) { unixTime, _ := new(big.Int).SetString(fmt.Sprintf("%d", x.Unix()), 10) m, ok := new(big.Int).SetString(strconv.FormatInt(1e9, 10), 10) if !ok { return "", errors.New("failed to parse big int from string: invalid format or unsupported characters") } unixTime.Mul(unixTime, m) tmNanos, _ := new(big.Int).SetString(fmt.Sprintf("%d", x.Nanosecond()), 10) if t == types.TimestampTzType { _, offset := x.Zone() return fmt.Sprintf("%v %v", unixTime.Add(unixTime, tmNanos), offset/60+1440), nil } return unixTime.Add(unixTime, tmNanos).String(), nil } func decoderWithNumbersAsStrings(srcValue *string) *json.Decoder { decoder := json.NewDecoder(bytes.NewBufferString(*srcValue)) decoder.UseNumber() return decoder } ================================================ FILE: converter_test.go ================================================ package gosnowflake import ( "context" "database/sql" "database/sql/driver" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" "io" "math" "math/big" "math/cmplx" "reflect" "strings" "testing" "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/memory" ) func stringIntToDecimal(src string) (decimal128.Num, bool) { b, ok := new(big.Int).SetString(src, 10) if !ok { return decimal128.Num{}, ok } var high, low big.Int high.QuoRem(b, decimalShift, &low) return decimal128.New(high.Int64(), low.Uint64()), ok } func stringFloatToDecimal(src string, scale int64) (decimal128.Num, bool) { b, ok := new(big.Float).SetString(src) if !ok { return decimal128.Num{}, ok } s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) n := new(big.Float).Mul(b, s) if !n.IsInt() { return decimal128.Num{}, false } var high, low, z big.Int n.Int(&z) high.QuoRem(&z, decimalShift, &low) return decimal128.New(high.Int64(), low.Uint64()), ok } func stringFloatToInt(src string, scale int64) (int64, bool) { b, ok := new(big.Float).SetString(src) if !ok { return 0, ok } s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) n := new(big.Float).Mul(b, s) var z big.Int n.Int(&z) if !z.IsInt64() { return 0, false } return z.Int64(), true } type testValueToStringStructuredObject struct { s string i int32 date time.Time } func (o *testValueToStringStructuredObject) Write(sowc StructuredObjectWriterContext) error { if err := sowc.WriteString("s", o.s); err != nil { return err } if err := sowc.WriteInt32("i", o.i); err != nil { return err } if err := sowc.WriteTime("date", o.date, DataTypeDate); err != nil { return err } return nil } func TestValueToString(t *testing.T) { v := cmplx.Sqrt(-5 + 12i) // should never happen as Go sql package must have already validated. _, err := valueToString(v, types.NullType, nil) if err == nil { t.Errorf("should raise error: %v", v) } params := newSyncParams(make(map[string]*string)) dateFormat := "YYYY-MM-DD" params.set("date_output_format", &dateFormat) // both localTime and utcTime should yield the same unix timestamp localTime := time.Date(2019, 2, 6, 14, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600)) utcTime := time.Date(2019, 2, 6, 22, 17, 31, 123456789, time.UTC) expectedUnixTime := "1549491451123456789" // time.Unix(1549491451, 123456789).Format(time.RFC3339) == "2019-02-06T14:17:31-08:00" expectedBool := "true" expectedInt64 := "1" expectedFloat64 := "1.1" expectedString := "teststring" bv, err := valueToString(localTime, types.TimestampLtzType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertNilE(t, bv.schema) assertEqualE(t, *bv.value, expectedUnixTime) bv, err = valueToString(utcTime, types.TimestampLtzType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertNilE(t, bv.schema) assertEqualE(t, *bv.value, expectedUnixTime) bv, err = valueToString(sql.NullBool{Bool: true, Valid: true}, types.TimestampLtzType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertNilE(t, bv.schema) assertEqualE(t, *bv.value, expectedBool) bv, err = valueToString(sql.NullInt64{Int64: 1, Valid: true}, types.TimestampLtzType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertNilE(t, bv.schema) assertEqualE(t, *bv.value, expectedInt64) bv, err = valueToString(sql.NullFloat64{Float64: 1.1, Valid: true}, types.TimestampLtzType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertNilE(t, bv.schema) assertEqualE(t, *bv.value, expectedFloat64) bv, err = valueToString(sql.NullString{String: "teststring", Valid: true}, types.TimestampLtzType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertNilE(t, bv.schema) assertEqualE(t, *bv.value, expectedString) t.Run("SQL Time", func(t *testing.T) { bv, err := valueToString(sql.NullTime{Time: localTime, Valid: true}, types.TimestampLtzType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertNilE(t, bv.schema) assertEqualE(t, *bv.value, expectedUnixTime) }) t.Run("arrays", func(t *testing.T) { bv, err := valueToString([2]int{1, 2}, types.ObjectType, nil) assertNilF(t, err) assertEqualE(t, bv.format, jsonFormatStr) assertEqualE(t, *bv.value, "[1,2]") }) t.Run("slices", func(t *testing.T) { bv, err := valueToString([]int{1, 2}, types.ObjectType, nil) assertNilF(t, err) assertEqualE(t, bv.format, jsonFormatStr) assertEqualE(t, *bv.value, "[1,2]") }) t.Run("UUID - should return string", func(t *testing.T) { u := NewUUID() bv, err := valueToString(u, types.TextType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertEqualE(t, *bv.value, u.String()) }) t.Run("database/sql/driver - Valuer interface", func(t *testing.T) { u := newTestUUID() bv, err := valueToString(u, types.TextType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertEqualE(t, *bv.value, u.String()) }) t.Run("testUUID", func(t *testing.T) { u := newTestUUID() assertEqualE(t, u.String(), parseTestUUID(u.String()).String()) bv, err := valueToString(u, types.TextType, nil) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertEqualE(t, *bv.value, u.String()) }) bv, err = valueToString(&testValueToStringStructuredObject{s: "some string", i: 123, date: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)}, types.TimestampLtzType, ¶ms) assertNilF(t, err) assertEqualE(t, bv.format, jsonFormatStr) assertDeepEqualE(t, *bv.schema, bindingSchema{ Typ: "object", Nullable: true, Fields: []query.FieldMetadata{ { Name: "s", Type: "text", Nullable: true, Length: 134217728, }, { Name: "i", Type: "fixed", Nullable: true, Precision: 38, Scale: 0, }, { Name: "date", Type: "date", Nullable: true, Scale: 9, }, }, }) assertEqualIgnoringWhitespaceE(t, *bv.value, `{"date": "2024-05-24", "i": 123, "s": "some string"}`) } func TestExtractTimestamp(t *testing.T) { s := "1234abcdef" // pragma: allowlist secret _, _, err := extractTimestamp(&s) if err == nil { t.Errorf("should raise error: %v", s) } s = "1234abc.def" _, _, err = extractTimestamp(&s) if err == nil { t.Errorf("should raise error: %v", s) } s = "1234.def" _, _, err = extractTimestamp(&s) if err == nil { t.Errorf("should raise error: %v", s) } } func TestStringToValue(t *testing.T) { var source string var dest driver.Value var err error var rowType *query.ExecResponseRowType source = "abcdefg" types := []string{ "date", "time", "timestamp_ntz", "timestamp_ltz", "timestamp_tz", "binary", } for _, tt := range types { t.Run(tt, func(t *testing.T) { rowType = &query.ExecResponseRowType{ Type: tt, } if err = stringToValue(context.Background(), &dest, *rowType, &source, nil, nil); err == nil { t.Errorf("should raise error. type: %v, value:%v", tt, source) } }) } sources := []string{ "12345K78 2020", "12345678 20T0", } types = []string{ "timestamp_tz", } for _, ss := range sources { for _, tt := range types { t.Run(ss+tt, func(t *testing.T) { rowType = &query.ExecResponseRowType{ Type: tt, } if err = stringToValue(context.Background(), &dest, *rowType, &ss, nil, nil); err == nil { t.Errorf("should raise error. type: %v, value:%v", tt, source) } }) } } src := "1549491451.123456789" if err = stringToValue(context.Background(), &dest, query.ExecResponseRowType{Type: "timestamp_ltz"}, &src, nil, nil); err != nil { t.Errorf("unexpected error: %v", err) } else if ts, ok := dest.(time.Time); !ok { t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) } else if ts.UnixNano() != 1549491451123456789 { t.Errorf("expected unix timestamp: 1549491451123456789, got %v", ts.UnixNano()) } } type tcArrayToString struct { in driver.NamedValue typ types.SnowflakeType out []string } func TestArrayToString(t *testing.T) { testcases := []tcArrayToString{ {in: driver.NamedValue{Value: &intArray{1, 2}}, typ: types.FixedType, out: []string{"1", "2"}}, {in: driver.NamedValue{Value: &int32Array{1, 2}}, typ: types.FixedType, out: []string{"1", "2"}}, {in: driver.NamedValue{Value: &int64Array{3, 4, 5}}, typ: types.FixedType, out: []string{"3", "4", "5"}}, {in: driver.NamedValue{Value: &float64Array{6.7}}, typ: types.RealType, out: []string{"6.7"}}, {in: driver.NamedValue{Value: &float32Array{1.5}}, typ: types.RealType, out: []string{"1.5"}}, {in: driver.NamedValue{Value: &boolArray{true, false}}, typ: types.BooleanType, out: []string{"true", "false"}}, {in: driver.NamedValue{Value: &stringArray{"foo", "bar", "baz"}}, typ: types.TextType, out: []string{"foo", "bar", "baz"}}, } for _, test := range testcases { t.Run(strings.Join(test.out, "_"), func(t *testing.T) { s, a, err := snowflakeArrayToString(&test.in, false) assertNilF(t, err) if s != test.typ { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.typ, s) } for i, v := range a { if *v != test.out[i] { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out[i], a) } } }) } } func TestArrowToValues(t *testing.T) { dest := make([]snowflakeValue, 2) pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer pool.AssertSize(t, 0) var valids []bool // AppendValues() with an empty valid array adds every value by default localTime := time.Date(2019, 2, 6, 14, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600)) field1 := arrow.Field{Name: "epoch", Type: &arrow.Int64Type{}} field2 := arrow.Field{Name: "timezone", Type: &arrow.Int32Type{}} tzStruct := arrow.StructOf(field1, field2) type testObj struct { field1 int field2 string } for _, tc := range []struct { logical string physical string rowType query.ExecResponseRowType values any builder array.Builder append func(b array.Builder, vs any) compare func(src any, dst []snowflakeValue) int higherPrecision bool }{ { logical: "fixed", physical: "number", // default: number(38, 0) values: []int64{1, 2}, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, higherPrecision: true, }, { logical: "fixed", physical: "number(38,5)", rowType: query.ExecResponseRowType{Scale: 5}, values: []string{"1.05430", "2.08983"}, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringFloatToInt(s, 5) if !ok { t.Fatalf("failed to convert to int") } b.(*array.Int64Builder).Append(num) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]string) for i := range srcvs { num, ok := stringFloatToInt(srcvs[i], 5) if !ok { return i } srcDec := intToBigFloat(num, 5) dstDec := dst[i].(*big.Float) if srcDec.Cmp(dstDec) != 0 { return i } } return -1 }, higherPrecision: true, }, { logical: "fixed", physical: "number(38,5)", rowType: query.ExecResponseRowType{Scale: 5}, values: []string{"1.05430", "2.08983"}, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringFloatToInt(s, 5) if !ok { t.Fatalf("failed to convert to int") } b.(*array.Int64Builder).Append(num) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]string) for i := range srcvs { num, ok := stringFloatToInt(srcvs[i], 5) if !ok { return i } srcDec := fmt.Sprintf("%.*f", 5, float64(num)/math.Pow10(int(5))) dstDec := dst[i] if srcDec != dstDec { return i } } return -1 }, higherPrecision: false, }, { logical: "fixed", physical: "number(38,0)", values: []string{"10000000000000000000000000000000000000", "-12345678901234567890123456789012345678"}, builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 30, Scale: 2}), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringIntToDecimal(s) if !ok { t.Fatalf("failed to convert to big.Int") } b.(*array.Decimal128Builder).Append(num) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]string) for i := range srcvs { num, ok := stringIntToDecimal(srcvs[i]) if !ok { return i } srcDec := decimalToBigInt(num) dstDec := dst[i].(*big.Int) if srcDec.Cmp(dstDec) != 0 { return i } } return -1 }, higherPrecision: true, }, { logical: "fixed", physical: "number(38,37)", rowType: query.ExecResponseRowType{Scale: 37}, values: []string{"1.2345678901234567890123456789012345678", "-9.9999999999999999999999999999999999999"}, builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 37}), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringFloatToDecimal(s, 37) if !ok { t.Fatalf("failed to convert to big.Rat") } b.(*array.Decimal128Builder).Append(num) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]string) for i := range srcvs { num, ok := stringFloatToDecimal(srcvs[i], 37) if !ok { return i } srcDec := decimalToBigFloat(num, 37) dstDec := dst[i].(*big.Float) if srcDec.Cmp(dstDec) != 0 { return i } } return -1 }, higherPrecision: true, }, { logical: "fixed", physical: "int8", values: []int8{1, 2}, builder: array.NewInt8Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int8Builder).AppendValues(vs.([]int8), valids) }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]int8) for i := range srcvs { if int64(srcvs[i]) != dst[i].(int64) { return i } } return -1 }, higherPrecision: true, }, { logical: "fixed", physical: "int16", values: []int16{1, 2}, builder: array.NewInt16Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int16Builder).AppendValues(vs.([]int16), valids) }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]int16) for i := range srcvs { if int64(srcvs[i]) != dst[i].(int64) { return i } } return -1 }, higherPrecision: true, }, { logical: "fixed", physical: "int16", values: []string{"1.2345", "2.3456"}, rowType: query.ExecResponseRowType{Scale: 4}, builder: array.NewInt16Builder(pool), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringFloatToInt(s, 4) if !ok { t.Fatalf("failed to convert to int") } b.(*array.Int16Builder).Append(int16(num)) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]string) for i := range srcvs { num, ok := stringFloatToInt(srcvs[i], 4) if !ok { return i } srcDec := intToBigFloat(num, 4) dstDec := dst[i].(*big.Float) if srcDec.Cmp(dstDec) != 0 { return i } } return -1 }, higherPrecision: true, }, { logical: "fixed", physical: "int16", values: []string{"1.2345", "2.3456"}, rowType: query.ExecResponseRowType{Scale: 4}, builder: array.NewInt16Builder(pool), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringFloatToInt(s, 4) if !ok { t.Fatalf("failed to convert to int") } b.(*array.Int16Builder).Append(int16(num)) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]string) for i := range srcvs { num, ok := stringFloatToInt(srcvs[i], 4) if !ok { return i } srcDec := fmt.Sprintf("%.*f", 4, float64(num)/math.Pow10(int(4))) dstDec := dst[i] if srcDec != dstDec { return i } } return -1 }, higherPrecision: false, }, { logical: "fixed", physical: "int32", values: []int32{1, 2}, builder: array.NewInt32Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int32Builder).AppendValues(vs.([]int32), valids) }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]int32) for i := range srcvs { if int64(srcvs[i]) != dst[i] { return i } } return -1 }, higherPrecision: true, }, { logical: "fixed", physical: "int32", values: []string{"1.23456", "2.34567"}, rowType: query.ExecResponseRowType{Scale: 5}, builder: array.NewInt32Builder(pool), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringFloatToInt(s, 5) if !ok { t.Fatalf("failed to convert to int") } b.(*array.Int32Builder).Append(int32(num)) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]string) for i := range srcvs { num, ok := stringFloatToInt(srcvs[i], 5) if !ok { return i } srcDec := intToBigFloat(num, 5) dstDec := dst[i].(*big.Float) if srcDec.Cmp(dstDec) != 0 { return i } } return -1 }, higherPrecision: true, }, { logical: "fixed", physical: "int32", values: []string{"1.23456", "2.34567"}, rowType: query.ExecResponseRowType{Scale: 5}, builder: array.NewInt32Builder(pool), append: func(b array.Builder, vs any) { for _, s := range vs.([]string) { num, ok := stringFloatToInt(s, 5) if !ok { t.Fatalf("failed to convert to int") } b.(*array.Int32Builder).Append(int32(num)) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]string) for i := range srcvs { num, ok := stringFloatToInt(srcvs[i], 5) if !ok { return i } srcDec := fmt.Sprintf("%.*f", 5, float64(num)/math.Pow10(int(5))) dstDec := dst[i] if srcDec != dstDec { return i } } return -1 }, higherPrecision: false, }, { logical: "fixed", physical: "int64", values: []int64{1, 2}, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, higherPrecision: true, }, { logical: "boolean", values: []bool{true, false}, builder: array.NewBooleanBuilder(pool), append: func(b array.Builder, vs any) { b.(*array.BooleanBuilder).AppendValues(vs.([]bool), valids) }, }, { logical: "real", physical: "float", values: []float64{1, 2}, builder: array.NewFloat64Builder(pool), append: func(b array.Builder, vs any) { b.(*array.Float64Builder).AppendValues(vs.([]float64), valids) }, }, { logical: "text", physical: "string", values: []string{"foo", "bar"}, builder: array.NewStringBuilder(pool), append: func(b array.Builder, vs any) { b.(*array.StringBuilder).AppendValues(vs.([]string), valids) }, }, { logical: "binary", values: [][]byte{[]byte("foo"), []byte("bar")}, builder: array.NewBinaryBuilder(pool, arrow.BinaryTypes.Binary), append: func(b array.Builder, vs any) { b.(*array.BinaryBuilder).AppendValues(vs.([][]byte), valids) }, }, { logical: "date", values: []time.Time{time.Now(), localTime}, builder: array.NewDate32Builder(pool), append: func(b array.Builder, vs any) { for _, d := range vs.([]time.Time) { b.(*array.Date32Builder).Append(arrow.Date32(d.Unix())) } }, }, { logical: "time", values: []time.Time{time.Now(), time.Now()}, rowType: query.ExecResponseRowType{Scale: 9}, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixNano()) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]time.Time) for i := range srcvs { if srcvs[i].Nanosecond() != dst[i].(time.Time).Nanosecond() { return i } } return -1 }, higherPrecision: true, }, { logical: "timestamp_ntz", values: []time.Time{time.Now(), localTime}, rowType: query.ExecResponseRowType{Scale: 9}, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixNano()) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]time.Time) for i := range srcvs { if srcvs[i].UnixNano() != dst[i].(time.Time).UnixNano() { return i } } return -1 }, }, { logical: "timestamp_ltz", values: []time.Time{time.Now(), localTime}, rowType: query.ExecResponseRowType{Scale: 9}, builder: array.NewInt64Builder(pool), append: func(b array.Builder, vs any) { for _, t := range vs.([]time.Time) { b.(*array.Int64Builder).Append(t.UnixNano()) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]time.Time) for i := range srcvs { if srcvs[i].UnixNano() != dst[i].(time.Time).UnixNano() { return i } } return -1 }, }, { logical: "timestamp_tz", values: []time.Time{time.Now(), localTime}, builder: array.NewStructBuilder(pool, tzStruct), append: func(b array.Builder, vs any) { sb := b.(*array.StructBuilder) valids = []bool{true, true} sb.AppendValues(valids) for _, t := range vs.([]time.Time) { sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix()) sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.UnixNano())) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]time.Time) for i := range srcvs { if srcvs[i].Unix() != dst[i].(time.Time).Unix() { return i } } return -1 }, }, { logical: "array", values: [][]string{{"foo", "bar"}, {"baz", "quz", "quux"}}, builder: array.NewStringBuilder(pool), append: func(b array.Builder, vs any) { for _, a := range vs.([][]string) { b.(*array.StringBuilder).Append(fmt.Sprint(a)) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([][]string) for i, o := range srcvs { if fmt.Sprint(o) != dst[i].(string) { return i } } return -1 }, }, { logical: "object", values: []testObj{{0, "foo"}, {1, "bar"}}, builder: array.NewStringBuilder(pool), append: func(b array.Builder, vs any) { for _, o := range vs.([]testObj) { b.(*array.StringBuilder).Append(fmt.Sprint(o)) } }, compare: func(src any, dst []snowflakeValue) int { srcvs := src.([]testObj) for i, o := range srcvs { if fmt.Sprint(o) != dst[i].(string) { return i } } return -1 }, }, } { testName := tc.logical if tc.physical != "" { testName += " " + tc.physical } t.Run(testName, func(t *testing.T) { b := tc.builder tc.append(b, tc.values) arr := b.NewArray() defer arr.Release() meta := tc.rowType meta.Type = tc.logical withHigherPrecision := tc.higherPrecision if err := arrowToValues(context.Background(), dest, meta, arr, localTime.Location(), withHigherPrecision, nil); err != nil { // TODO t.Fatalf("error: %s", err) } elemType := reflect.TypeOf(tc.values).Elem() if tc.compare != nil { idx := tc.compare(tc.values, dest) if idx != -1 { t.Fatalf("error: column array value mistmatch at index %v", idx) } } else { for _, d := range dest { if reflect.TypeOf(d) != elemType { t.Fatalf("error: expected type %s, got type %s", reflect.TypeOf(d), elemType) } } } }) } } // TestArrowToRecord has been moved to arrowbatches/converter_test.go // (all test case data removed from this file) func TestTimestampLTZLocation(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { src := "1549491451.123456789" var dest driver.Value loc, _ := time.LoadLocation(PSTLocation) if err := stringToValue(context.Background(), &dest, query.ExecResponseRowType{Type: "timestamp_ltz"}, &src, loc, nil); err != nil { t.Errorf("unexpected error: %v", err) } ts, ok := dest.(time.Time) if !ok { t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) } if ts.Location() != loc { t.Errorf("expected location to be %v, got '%v'", loc, ts.Location()) } if err := stringToValue(context.Background(), &dest, query.ExecResponseRowType{Type: "timestamp_ltz"}, &src, nil, nil); err != nil { t.Errorf("unexpected error: %v", err) } ts, ok = dest.(time.Time) if !ok { t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) } if ts.Location() != time.Local { t.Errorf("expected location to be local, got '%v'", ts.Location()) } }) } func TestSmallTimestampBinding(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { ctx := context.Background() timeValue, err := time.Parse("2006-01-02 15:04:05", "1600-10-10 10:10:10") if err != nil { t.Fatalf("failed to parse time: %v", err) } parameters := []driver.NamedValue{ {Ordinal: 1, Value: DataTypeTimestampNtz}, {Ordinal: 2, Value: timeValue}, } rows := sct.mustQueryContext(ctx, "SELECT ?", parameters) defer func() { assertNilF(t, rows.Close()) }() scanValues := make([]driver.Value, 1) for { if err := rows.Next(scanValues); err == io.EOF { break } else if err != nil { t.Fatalf("failed to run query: %v", err) } if scanValues[0] != timeValue { t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) } } }) } // TestTimestampConversionWithoutArrowBatches tests all 10 timestamp scales // (0-9) because each scale exercises a mathematically distinct code path in // the timestamp conversion logic. See TestTimestampConversionDistantDates in // arrowbatches/batches_test.go for rationale on why the full scale range is // required. func TestTimestampConversionWithoutArrowBatches(t *testing.T) { timestamps := [3]string{ "2000-10-10 10:10:10.123456789", // neutral "9999-12-12 23:59:59.999999999", // max "0001-01-01 00:00:00.000000000"} // min types := [3]string{"TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ"} runDBTest(t, func(sct *DBTest) { ctx := context.Background() for _, tsStr := range timestamps { ts, err := time.Parse("2006-01-02 15:04:05", tsStr) if err != nil { t.Fatalf("failed to parse time: %v", err) } for _, tp := range types { t.Run(tp+"_"+tsStr, func(t *testing.T) { // Batch all 10 scales into a single multi-column query to reduce round trips. var cols []string for scale := 0; scale <= 9; scale++ { cols = append(cols, fmt.Sprintf("'%s'::%s(%v)", tsStr, tp, scale)) } query := "SELECT " + strings.Join(cols, ", ") rows := sct.mustQueryContext(ctx, query, nil) defer func() { assertNilF(t, rows.Close()) }() if !rows.Next() { t.Fatalf("failed to run query: %v", query) } scanVals := make([]time.Time, 10) scanPtrs := make([]any, 10) for i := range scanVals { scanPtrs[i] = &scanVals[i] } assertNilF(t, rows.Scan(scanPtrs...)) for scale := 0; scale <= 9; scale++ { exp := ts.Truncate(time.Duration(math.Pow10(9 - scale))) act := scanVals[scale] if !exp.Equal(act) { t.Fatalf("scale %d: unexpected result. expected: %v, got: %v", scale, exp, act) } } }) } } }) } func TestTimeTypeValueToString(t *testing.T) { timeValue, err := time.Parse("2006-01-02 15:04:05", "2020-01-02 10:11:12") if err != nil { t.Fatal(err) } offsetTimeValue, err := time.ParseInLocation("2006-01-02 15:04:05", "2020-01-02 10:11:12", Location(6*60)) if err != nil { t.Fatal(err) } testcases := []struct { in time.Time tsmode types.SnowflakeType out string }{ {timeValue, types.DateType, "1577959872000"}, {timeValue, types.TimeType, "36672000000000"}, {timeValue, types.TimestampNtzType, "1577959872000000000"}, {timeValue, types.TimestampLtzType, "1577959872000000000"}, {timeValue, types.TimestampTzType, "1577959872000000000 1440"}, {offsetTimeValue, types.TimestampTzType, "1577938272000000000 1800"}, } for _, tc := range testcases { t.Run(tc.out, func(t *testing.T) { bv, err := timeTypeValueToString(tc.in, tc.tsmode) assertNilF(t, err) assertEmptyStringE(t, bv.format) assertNilE(t, bv.schema) assertEqualE(t, tc.out, *bv.value) }) } } func TestIsArrayOfStructs(t *testing.T) { testcases := []struct { value any expected bool }{ {[]simpleObject{}, true}, {[]*simpleObject{}, true}, {[]int{1}, false}, {[]string{"abc"}, false}, {&[]bool{true}, false}, } for _, tc := range testcases { t.Run(fmt.Sprintf("%v", tc.value), func(t *testing.T) { res := isArrayOfStructs(tc.value) if res != tc.expected { t.Errorf("expected %v to result in %v", tc.value, tc.expected) } }) } } func TestSqlNull(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery("SELECT 1, NULL UNION SELECT 2, 'test' ORDER BY 1") defer rows.Close() var rowID int var nullStr sql.Null[string] assertTrueF(t, rows.Next()) assertNilF(t, rows.Scan(&rowID, &nullStr)) assertEqualE(t, nullStr, sql.Null[string]{Valid: false}) assertTrueF(t, rows.Next()) assertNilF(t, rows.Scan(&rowID, &nullStr)) assertEqualE(t, nullStr, sql.Null[string]{Valid: true, V: "test"}) }) } func TestNumbersScanType(t *testing.T) { for _, forceFormat := range []string{forceJSON, forceARROW} { t.Run(forceFormat, func(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExecT(t, forceFormat) t.Run("scale == 0", func(t *testing.T) { t.Run("without higher precision", func(t *testing.T) { rows := dbt.mustQueryContext(context.Background(), "SELECT 1, 300::NUMBER(15, 0), 600::NUMBER(18, 0), 700::NUMBER(19, 0), 900::NUMBER(38, 0), 123456789012345678901234567890") defer rows.Close() rows.mustNext() var i1, i2, i3 int64 var i4, i5, i6 string rows.mustScan(&i1, &i2, &i3, &i4, &i5, &i6) assertEqualE(t, i1, int64(1)) assertEqualE(t, i2, int64(300)) assertEqualE(t, i3, int64(600)) assertEqualE(t, i4, "700") assertEqualE(t, i5, "900") assertEqualE(t, i6, "123456789012345678901234567890") // pragma: allowlist secret types, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, types[0].ScanType(), reflect.TypeFor[int64]()) assertEqualE(t, types[1].ScanType(), reflect.TypeFor[int64]()) assertEqualE(t, types[2].ScanType(), reflect.TypeFor[int64]()) assertEqualE(t, types[3].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, types[4].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, types[5].ScanType(), reflect.TypeFor[string]()) }) t.Run("without higher precision - regardless of scan type, int parsing should still work", func(t *testing.T) { rows := dbt.mustQueryContext(context.Background(), "SELECT 1, 300::NUMBER(15, 0), 600::NUMBER(18, 0), 700::NUMBER(19, 0), 900::NUMBER(38, 0), 123456789012345678901234567890") defer rows.Close() rows.mustNext() var i1, i2, i3, i4, i5 int64 var i6 string rows.mustScan(&i1, &i2, &i3, &i4, &i5, &i6) assertEqualE(t, i1, int64(1)) assertEqualE(t, i2, int64(300)) assertEqualE(t, i3, int64(600)) assertEqualE(t, i4, int64(700)) assertEqualE(t, i5, int64(900)) assertEqualE(t, i6, "123456789012345678901234567890") // pragma: allowlist secret types, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, types[0].ScanType(), reflect.TypeFor[int64]()) assertEqualE(t, types[1].ScanType(), reflect.TypeFor[int64]()) assertEqualE(t, types[2].ScanType(), reflect.TypeFor[int64]()) assertEqualE(t, types[3].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, types[4].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, types[5].ScanType(), reflect.TypeFor[string]()) }) t.Run("with higher precision", func(t *testing.T) { rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), "SELECT 1::NUMBER(1, 0), 300::NUMBER(15, 0), 600::NUMBER(19, 0), 700::NUMBER(20, 0), 900::NUMBER(38, 0), 123456789012345678901234567890") defer rows.Close() rows.mustNext() var i1, i2 int64 var i3, i4, i5, i6 *big.Int rows.mustScan(&i1, &i2, &i3, &i4, &i5, &i6) assertEqualE(t, i1, int64(1)) assertEqualE(t, i2, int64(300)) assertEqualE(t, i3.Cmp(big.NewInt(600)), 0) assertEqualE(t, i4.Cmp(big.NewInt(700)), 0) assertEqualE(t, i5.Cmp(big.NewInt(900)), 0) bigInt123456789012345678901234567890 := &big.Int{} bigInt123456789012345678901234567890.SetString("123456789012345678901234567890", 10) // pragma: allowlist secret assertEqualE(t, i6.Cmp(bigInt123456789012345678901234567890), 0) types, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, types[0].ScanType(), reflect.TypeFor[int64]()) assertEqualE(t, types[1].ScanType(), reflect.TypeFor[int64]()) assertEqualE(t, types[2].ScanType(), reflect.TypeFor[*big.Int]()) assertEqualE(t, types[3].ScanType(), reflect.TypeFor[*big.Int]()) assertEqualE(t, types[4].ScanType(), reflect.TypeFor[*big.Int]()) assertEqualE(t, types[5].ScanType(), reflect.TypeFor[*big.Int]()) }) }) t.Run("scale != 0", func(t *testing.T) { t.Run("without higher precision", func(t *testing.T) { rows := dbt.mustQueryContext(context.Background(), "SELECT 1.5, 300.5::NUMBER(15, 1), 600.5::NUMBER(18, 1), 700.5::NUMBER(19, 1), 900.5::NUMBER(38, 1), 123456789012345678901234567890.5") defer rows.Close() rows.mustNext() var i1, i2, i3, i4, i5, i6 float64 rows.mustScan(&i1, &i2, &i3, &i4, &i5, &i6) assertEqualE(t, i1, 1.5) assertEqualE(t, i2, 300.5) assertEqualE(t, i3, 600.5) assertEqualE(t, i4, 700.5) assertEqualE(t, i5, 900.5) assertEqualE(t, i6, 123456789012345678901234567890.5) types, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, types[0].ScanType(), reflect.TypeFor[float64]()) assertEqualE(t, types[1].ScanType(), reflect.TypeFor[float64]()) assertEqualE(t, types[2].ScanType(), reflect.TypeFor[float64]()) assertEqualE(t, types[3].ScanType(), reflect.TypeFor[float64]()) assertEqualE(t, types[4].ScanType(), reflect.TypeFor[float64]()) assertEqualE(t, types[5].ScanType(), reflect.TypeFor[float64]()) }) t.Run("with higher precision", func(t *testing.T) { rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), "SELECT 1.5, 300.5::NUMBER(15, 1), 600.5::NUMBER(18, 1), 700.5::NUMBER(19, 1), 900.5::NUMBER(38, 1), 123456789012345678901234567890.5") defer rows.Close() rows.mustNext() var i1, i2, i3, i4, i5, i6 *big.Float rows.mustScan(&i1, &i2, &i3, &i4, &i5, &i6) assertEqualE(t, i1.Cmp(big.NewFloat(1.5)), 0) assertEqualE(t, i2.Cmp(big.NewFloat(300.5)), 0) assertEqualE(t, i3.Cmp(big.NewFloat(600.5)), 0) assertEqualE(t, i4.Cmp(big.NewFloat(700.5)), 0) assertEqualE(t, i5.Cmp(big.NewFloat(900.5)), 0) bigInt123456789012345678901234567890, _, err := big.ParseFloat("123456789012345678901234567890.5", 10, numberMaxPrecisionInBits, big.AwayFromZero) assertNilF(t, err) assertEqualE(t, i6.Cmp(bigInt123456789012345678901234567890), 0) types, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, types[0].ScanType(), reflect.TypeFor[*big.Float]()) assertEqualE(t, types[1].ScanType(), reflect.TypeFor[*big.Float]()) assertEqualE(t, types[2].ScanType(), reflect.TypeFor[*big.Float]()) assertEqualE(t, types[3].ScanType(), reflect.TypeFor[*big.Float]()) assertEqualE(t, types[4].ScanType(), reflect.TypeFor[*big.Float]()) assertEqualE(t, types[5].ScanType(), reflect.TypeFor[*big.Float]()) }) }) }) }) } } func mustArray(v any, typ ...any) driver.Value { array, err := Array(v, typ...) if err != nil { panic(fmt.Sprintf("failed to convert to array: %v", err)) } return array } ================================================ FILE: crl.go ================================================ package gosnowflake import ( "crypto/x509" "encoding/asn1" "errors" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "runtime" "slices" "strings" "sync" "time" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" ) const snowflakeCrlCacheValidityTimeEnv = "SNOWFLAKE_CRL_CACHE_VALIDITY_TIME" var idpOID = asn1.ObjectIdentifier{2, 5, 29, 28} type distributionPointName struct { FullName []asn1.RawValue `asn1:"optional,tag:0"` } type issuingDistributionPoint struct { DistributionPoint distributionPointName `asn1:"optional,tag:0"` } type crlValidator struct { certRevocationCheckMode CertRevocationCheckMode allowCertificatesWithoutCrlURL bool inMemoryCacheDisabled bool onDiskCacheDisabled bool crlDownloadMaxSize int httpClient *http.Client telemetry *snowflakeTelemetry } type crlCacheCleanerType struct { mu sync.Mutex cacheValidityTime time.Duration onDiskCacheRemovalDelay time.Duration onDiskCacheDir string cleanupStopChan chan struct{} cleanupDoneChan chan struct{} } type crlInMemoryCacheValueType struct { crl *x509.RevocationList downloadTime *time.Time } var ( crlCacheCleanerTickRate = time.Hour crlInMemoryCache = make(map[string]*crlInMemoryCacheValueType) crlInMemoryCacheMutex = &sync.Mutex{} crlURLMutexes = make(map[string]*sync.Mutex) crlCacheCleanerMu = &sync.Mutex{} crlCacheCleaner *crlCacheCleanerType ) func newCrlValidator(certRevocationCheckMode CertRevocationCheckMode, allowCertificatesWithoutCrlURL bool, inMemoryCacheDisabled, onDiskCacheDisabled bool, crlDownloadMaxSize int, httpClient *http.Client, telemetry *snowflakeTelemetry) (*crlValidator, error) { initCrlCacheCleaner() cv := &crlValidator{ certRevocationCheckMode: certRevocationCheckMode, allowCertificatesWithoutCrlURL: allowCertificatesWithoutCrlURL, inMemoryCacheDisabled: inMemoryCacheDisabled, onDiskCacheDisabled: onDiskCacheDisabled, crlDownloadMaxSize: crlDownloadMaxSize, httpClient: httpClient, telemetry: telemetry, } return cv, nil } func initCrlCacheCleaner() { crlCacheCleanerMu.Lock() defer crlCacheCleanerMu.Unlock() if crlCacheCleaner != nil { return } var err error validityTime := defaultCrlCacheValidityTime if validityTimeStr := os.Getenv(snowflakeCrlCacheValidityTimeEnv); validityTimeStr != "" { if validityTime, err = time.ParseDuration(os.Getenv(snowflakeCrlCacheValidityTimeEnv)); err != nil { logger.Infof("failed to parse %v: %v, using default value %v", snowflakeCrlCacheValidityTimeEnv, err, defaultCrlCacheValidityTime) validityTime = defaultCrlCacheValidityTime } } onDiskCacheRemovalDelay := defaultCrlOnDiskCacheRemovalDelay if onDiskCacheRemovalDelayStr := os.Getenv("SNOWFLAKE_CRL_ON_DISK_CACHE_REMOVAL_DELAY"); onDiskCacheRemovalDelayStr != "" { if onDiskCacheRemovalDelay, err = time.ParseDuration(onDiskCacheRemovalDelayStr); err != nil { logger.Infof("failed to parse SNOWFLAKE_CRL_ON_DISK_CACHE_REMOVAL_DELAY: %v, using default value %v", err, defaultCrlOnDiskCacheRemovalDelay) onDiskCacheRemovalDelay = defaultCrlOnDiskCacheRemovalDelay } } onDiskCacheDir := os.Getenv("SNOWFLAKE_CRL_ON_DISK_CACHE_DIR") if onDiskCacheDir == "" { if onDiskCacheDir, err = defaultCrlOnDiskCacheDir(); err != nil { logger.Infof("failed to get default CRL on-disk cache directory: %v", err) onDiskCacheDir = "" // it will work only if on-disk cache is disabled } } if onDiskCacheDir != "" { if err = os.MkdirAll(onDiskCacheDir, 0755); err != nil { logger.Errorf("error while preparing cache dir for CRLs: %v", err) } } crlCacheCleaner = &crlCacheCleanerType{ cacheValidityTime: validityTime, onDiskCacheRemovalDelay: onDiskCacheRemovalDelay, onDiskCacheDir: onDiskCacheDir, cleanupStopChan: nil, cleanupDoneChan: nil, } } // CertRevocationCheckMode defines the modes for certificate revocation checks. type CertRevocationCheckMode = sfconfig.CertRevocationCheckMode const ( // CertRevocationCheckDisabled means that certificate revocation checks are disabled. CertRevocationCheckDisabled = sfconfig.CertRevocationCheckDisabled // CertRevocationCheckAdvisory means that certificate revocation checks are advisory, and the driver will not fail if the checks end with error (cannot verify revocation status). // Driver will fail only if a certicate is revoked. CertRevocationCheckAdvisory = sfconfig.CertRevocationCheckAdvisory // CertRevocationCheckEnabled means that every certificate revocation check must pass, otherwise the driver will fail. CertRevocationCheckEnabled = sfconfig.CertRevocationCheckEnabled ) type crlValidationResult int const ( crlRevoked crlValidationResult = iota crlUnrevoked crlError ) type certValidationResult int const ( certRevoked certValidationResult = iota certUnrevoked certError ) const ( defaultCrlHTTPClientTimeout = 10 * time.Second defaultCrlCacheValidityTime = 24 * time.Hour defaultCrlOnDiskCacheRemovalDelay = 7 * time.Hour defaultCrlDownloadMaxSize = 20 * 1024 * 1024 // 20 MB ) func (cv *crlValidator) verifyPeerCertificates(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { if cv.certRevocationCheckMode == CertRevocationCheckDisabled { logger.Debug("certificate revocation check is disabled, skipping CRL validation") return nil } crlValidationResults := cv.validateChains(verifiedChains) allRevoked := true for _, result := range crlValidationResults { if result == crlUnrevoked { logger.Debug("found certificate chain with no revoked certificates") return nil } if result != crlRevoked { allRevoked = false } } if allRevoked { return fmt.Errorf("every verified certificate chain contained revoked certificates") } logger.Warn("some certificate chains didn't pass or driver wasn't able to peform the checks") if cv.certRevocationCheckMode == CertRevocationCheckAdvisory { logger.Warn("certificate revocation check is set to CERT_REVOCATION_CHECK_ADVISORY, so assuming that certificates are not revoked") return nil } return fmt.Errorf("certificate revocation check failed") } func (cv *crlValidator) validateChains(chains [][]*x509.Certificate) []crlValidationResult { crlValidationResults := make([]crlValidationResult, len(chains)) for i, chain := range chains { crlValidationResults[i] = crlUnrevoked var chainStr strings.Builder for _, cert := range chain { fmt.Fprintf(&chainStr, "%v -> ", cert.Subject) } logger.Debugf("validating certificate chain %d: %s", i, chainStr.String()) for j, cert := range chain { if j == len(chain)-1 { logger.Debugf("skipping root certificate %v for CRL validation", cert.Subject) continue } if isShortLivedCertificate(cert) { logger.Debugf("certificate %v is short-lived, skipping CRL validation", cert.Subject) continue } if len(cert.CRLDistributionPoints) == 0 { if cv.allowCertificatesWithoutCrlURL { logger.Debugf("certificate %v has no CRL distribution points, skipping CRL validation", cert.Subject) continue } logger.Warnf("certificate %v has no CRL distribution points, skipping CRL validation, but marking as error", cert.Subject) crlValidationResults[i] = crlError continue } certStatus := cv.validateCertificate(cert, chain[j+1]) if certStatus == certRevoked { crlValidationResults[i] = crlRevoked break } if certStatus == certError { crlValidationResults[i] = crlError continue } } if crlValidationResults[i] == crlUnrevoked { logger.Debugf("certificate chain %d is unrevoked, skipping remaining chains", i) break } } return crlValidationResults } func (cv *crlValidator) validateCertificate(cert *x509.Certificate, parent *x509.Certificate) certValidationResult { var results []certValidationResult for _, crlURL := range cert.CRLDistributionPoints { result := cv.validateCrlAgainstCrlURL(cert, crlURL, parent) if result == certRevoked { return result } results = append(results, result) } if slices.Contains(results, certError) { return certError } return certUnrevoked } func (cv *crlValidator) validateCrlAgainstCrlURL(cert *x509.Certificate, crlURL string, parent *x509.Certificate) certValidationResult { now := time.Now() mu := cv.getOrCreateMutex(crlURL) mu.Lock() defer mu.Unlock() crl, downloadTime := cv.getFromCache(crlURL) needsFreshCrl := crl == nil || crl.NextUpdate.Before(now) || downloadTime.Add(crlCacheCleaner.cacheValidityTime).Before(now) shouldUpdateCrl := false if needsFreshCrl { newCrl, newDownloadTime, err := cv.downloadCrl(crlURL) if err != nil { logger.Warnf("failed to download CRL from %v: %v", crlURL, err) } if newCrl != nil && newCrl.NextUpdate.Before(now) { logger.Warnf("downloaded CRL from %v is already expired (next update at %v)", crlURL, newCrl.NextUpdate) newCrl = nil if crl == nil { return certError } } shouldUpdateCrl = newCrl != nil && (crl == nil || newCrl.ThisUpdate.After(crl.ThisUpdate)) if shouldUpdateCrl { logger.Debugf("Found updated CRL for %v", crlURL) crl = newCrl downloadTime = newDownloadTime } else { if crl != nil && crl.NextUpdate.After(now) { logger.Debugf("CRL for %v is up-to-date, using cached version", crlURL) } else { logger.Warnf("CRL for %v is not available or outdated", crlURL) return certError } } } logger.Debugf("CRL has %v entries, next update at %v", len(crl.RevokedCertificateEntries), crl.NextUpdate) if err := cv.validateCrl(crl, parent, crlURL); err != nil { return certError } if shouldUpdateCrl { logger.Debugf("CRL for %v is valid, updating cache", crlURL) cv.updateCache(crlURL, crl, downloadTime) } for _, rce := range crl.RevokedCertificateEntries { if cert.SerialNumber.Cmp(rce.SerialNumber) == 0 { logger.Warnf("certificate for %v (serial number %v) has been revoked at %v, reason: %v", cert.Subject, rce.SerialNumber, rce.RevocationTime, rce.ReasonCode) return certRevoked } } return certUnrevoked } func (cv *crlValidator) validateCrl(crl *x509.RevocationList, parent *x509.Certificate, crlURL string) error { if crl.Issuer.String() != parent.Subject.String() { err := fmt.Errorf("CRL issuer %v does not match parent certificate subject %v for %v", crl.Issuer, parent.Subject, crlURL) logger.Warn(err.Error()) return err } if err := crl.CheckSignatureFrom(parent); err != nil { logger.Warnf("CRL signature verification failed for %v: %v", crlURL, err) return err } if err := cv.verifyAgainstIdpExtension(crl, crlURL); err != nil { logger.Warnf("CRL IDP extension verification failed for %v: %v", crlURL, err) return err } return nil } func (cv *crlValidator) getFromCache(crlURL string) (*x509.RevocationList, *time.Time) { if cv.inMemoryCacheDisabled { logger.Debugf("in-memory cache is disabled") } else { crlInMemoryCacheMutex.Lock() cacheValue, exists := crlInMemoryCache[crlURL] crlInMemoryCacheMutex.Unlock() if exists { logger.Debugf("found CRL in cache for %v", crlURL) return cacheValue.crl, cacheValue.downloadTime } } if cv.onDiskCacheDisabled { logger.Debugf("CRL cache is disabled, not checking disk for %v", crlURL) return nil, nil } crlFilePath := cv.crlURLToPath(crlURL) fileHandle, err := os.Open(crlFilePath) if err != nil { logger.Debugf("cannot open CRL from disk for %v (%v): %v", crlURL, crlFilePath, err) return nil, nil } defer func() { if err := fileHandle.Close(); err != nil { logger.Warnf("failed to close CRL file handle for %v (%v): %v", crlURL, crlFilePath, err) } }() stat, err := fileHandle.Stat() if err != nil { logger.Debugf("cannot stat CRL file for %v (%v): %v", crlURL, crlFilePath, err) return nil, nil } crlBytes, err := io.ReadAll(fileHandle) if err != nil { logger.Debugf("cannot read CRL from disk for %v (%v): %v", crlURL, crlFilePath, err) return nil, nil } crl, err := x509.ParseRevocationList(crlBytes) if err != nil { logger.Warnf("cannot parse CRL from disk for %v (%v): %v", crlURL, crlFilePath, err) return nil, nil } modTime := stat.ModTime() if !cv.inMemoryCacheDisabled { // promote CRL to in-memory cache crlInMemoryCacheMutex.Lock() crlInMemoryCache[crlURL] = &crlInMemoryCacheValueType{ crl: crl, // modTime is not the exact time the CRL was downloaded, but rather the last modification time of the file // still, it is good enough for our purposes downloadTime: &modTime, } crlInMemoryCacheMutex.Unlock() } return crl, &modTime } func (cv *crlValidator) updateCache(crlURL string, crl *x509.RevocationList, downloadTime *time.Time) { if cv.inMemoryCacheDisabled { logger.Debugf("in-memory cache is disabled, not updating") } else { crlInMemoryCacheMutex.Lock() crlInMemoryCache[crlURL] = &crlInMemoryCacheValueType{ crl: crl, downloadTime: downloadTime, } crlInMemoryCacheMutex.Unlock() } if cv.onDiskCacheDisabled { logger.Debugf("CRL cache is disabled, not writing to disk for %v", crlURL) return } crlFilePath := cv.crlURLToPath(crlURL) crlDirPath := filepath.Dir(crlFilePath) crlDirParentPath := filepath.Dir(crlDirPath) if err := os.MkdirAll(crlDirParentPath, 0755); err != nil { logger.Warnf("failed to create directory for CRL file %v: %v", crlFilePath, err) return } if err := os.Mkdir(crlDirPath, 0700); err != nil { if !errors.Is(err, os.ErrExist) { logger.Warnf("failed to create directory for CRL file %v: %v", crlFilePath, err) return } if err = os.Chmod(crlDirPath, 0700); err != nil { logger.Warnf("failed to chmod existing directory for CRL file %v: %v", crlFilePath, err) return } } if err := os.WriteFile(crlFilePath, crl.Raw, 0600); err != nil { logger.Warnf("failed to write CRL to disk for %v (%v): %v", crlURL, crlFilePath, err) } } func (cv *crlValidator) downloadCrl(crlURL string) (*x509.RevocationList, *time.Time, error) { telemetryEvent := &telemetryData{ Timestamp: time.Now().UnixNano() / int64(time.Millisecond), Message: map[string]string{ "type": "client_crl_stats", "crl_url": crlURL, }, } defer func() { if err := cv.telemetry.addLog(telemetryEvent); err != nil { logger.Warnf("failed to add telemetry log for CRL download: %v", err) } }() logger.Debugf("downloading CRL from %v", crlURL) now := time.Now() resp, err := cv.httpClient.Get(crlURL) if err != nil { return nil, nil, err } defer func() { if err = resp.Body.Close(); err != nil { logger.Warnf("failed to close response body for CRL downloaded from %v: %v", crlURL, err) } }() if resp.StatusCode >= 400 { return nil, nil, fmt.Errorf("failed to download CRL from %v, status code: %v", crlURL, resp.StatusCode) } maxSize := resp.ContentLength if maxSize <= 0 || maxSize > int64(cv.crlDownloadMaxSize) { maxSize = int64(cv.crlDownloadMaxSize) } crlBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxSize)) if err != nil { return nil, nil, err } if cv.crlDownloadMaxSize > 0 && len(crlBytes) >= cv.crlDownloadMaxSize { return nil, nil, fmt.Errorf("CRL from %v exceeds maximum size of %d bytes", crlURL, cv.crlDownloadMaxSize) } telemetryEvent.Message["crl_bytes"] = fmt.Sprintf("%d", len(crlBytes)) downloadTime := time.Since(now) telemetryEvent.Message["crl_download_time_ms"] = fmt.Sprintf("%d", downloadTime.Milliseconds()) logger.Debugf("downloaded %v bytes for CRL %v", len(crlBytes), crlURL) timeBeforeParsing := time.Now() crl, err := x509.ParseRevocationList(crlBytes) logger.Debugf("parsed CRL from %v, error: %v", crlURL, err) if err != nil { return nil, nil, err } logger.Debugf("parsed CRL from %v, next update at %v", crlURL, crl.NextUpdate) telemetryEvent.Message["crl_parse_time_ms"] = fmt.Sprintf("%d", time.Since(timeBeforeParsing).Milliseconds()) telemetryEvent.Message["crl_revoked_certificates"] = fmt.Sprintf("%d", len(crl.RevokedCertificateEntries)) return crl, &now, err } func (cv *crlValidator) crlURLToPath(crlURL string) string { // Convert CRL URL to a file path, e.g., by replacing slashes with underscores return filepath.Join(crlCacheCleaner.onDiskCacheDir, url.QueryEscape(crlURL)) } func (cv *crlValidator) verifyAgainstIdpExtension(crl *x509.RevocationList, distributionPoint string) error { for _, ext := range append(crl.Extensions, crl.ExtraExtensions...) { if ext.Id.Equal(idpOID) { var idp issuingDistributionPoint _, err := asn1.Unmarshal(ext.Value, &idp) if err != nil { return fmt.Errorf("failed to unmarshal IDP extension: %w", err) } for _, dp := range idp.DistributionPoint.FullName { if string(dp.Bytes) == distributionPoint { logger.Debugf("distribution point %v matches CRL IDP extension", distributionPoint) return nil } } return fmt.Errorf("distribution point %v not found in CRL IDP extension", distributionPoint) } } return nil } func (cv *crlValidator) getOrCreateMutex(crlURL string) *sync.Mutex { crlInMemoryCacheMutex.Lock() mu, ok := crlURLMutexes[crlURL] if !ok { mu = &sync.Mutex{} crlURLMutexes[crlURL] = mu } crlInMemoryCacheMutex.Unlock() return mu } func isShortLivedCertificate(cert *x509.Certificate) bool { // https://cabforum.org/working-groups/server/baseline-requirements/requirements/ // See Short-lived Subscriber Certificate section if cert.NotBefore.Before(time.Date(2024, time.March, 15, 0, 0, 0, 0, time.UTC)) { // Certificates issued before March 15, 2024 are not considered short-lived return false } maximumValidityPeriod := 7 * 24 * time.Hour if cert.NotBefore.Before(time.Date(2026, time.March, 15, 0, 0, 0, 0, time.UTC)) { maximumValidityPeriod = 10 * 24 * time.Hour } maximumValidityPeriod += time.Minute // Fix inclusion start and end time certValidityPeriod := cert.NotAfter.Sub(cert.NotBefore) return maximumValidityPeriod > certValidityPeriod } func (ccc *crlCacheCleanerType) startPeriodicCacheCleanup() { ccc.mu.Lock() defer ccc.mu.Unlock() if ccc.cleanupStopChan != nil { logger.Debug("CRL cache cleaner is already running, not starting again") return } logger.Debugf("starting periodic CRL cache cleanup with tick rate %v", crlCacheCleanerTickRate) ccc.cleanupStopChan = make(chan struct{}) ccc.cleanupDoneChan = make(chan struct{}) go func() { ticker := time.NewTicker(crlCacheCleanerTickRate) defer ticker.Stop() for { select { case <-ticker.C: ccc.cleanupInMemoryCache() ccc.cleanupOnDiskCache() case <-ccc.cleanupStopChan: close(ccc.cleanupDoneChan) return } } }() } func (ccc *crlCacheCleanerType) stopPeriodicCacheCleanup() { ccc.mu.Lock() defer ccc.mu.Unlock() logger.Debug("stopping periodic CRL cache cleanup") if ccc.cleanupStopChan != nil { close(ccc.cleanupStopChan) <-ccc.cleanupDoneChan ccc.cleanupStopChan = nil ccc.cleanupDoneChan = nil } else { logger.Debugf("CRL cache cleaner was not running, nothing to stop") } } func (ccc *crlCacheCleanerType) cleanupInMemoryCache() { now := time.Now() logger.Debugf("cleaning up in-memory CRL cache at %v", now) crlInMemoryCacheMutex.Lock() for k, v := range crlInMemoryCache { expired := v.crl.NextUpdate.Before(now) evicted := v.downloadTime.Add(ccc.cacheValidityTime).Before(now) logger.Debugf("testing CRL for %v (nextUpdate=%v, downloadTime=%v) from in-memory cache (expired: %v, evicted: %v)", k, v.crl.NextUpdate, v.downloadTime, expired, evicted) if expired || evicted { delete(crlInMemoryCache, k) } } crlInMemoryCacheMutex.Unlock() } func (ccc *crlCacheCleanerType) cleanupOnDiskCache() { now := time.Now() logger.Debugf("cleaning up on-disk CRL cache at %v", now) entries, err := os.ReadDir(ccc.onDiskCacheDir) if err != nil { logger.Warnf("failed to read CRL cache dir: %v", err) return } for _, entry := range entries { if !entry.Type().IsRegular() { continue } path := filepath.Join(ccc.onDiskCacheDir, entry.Name()) crlBytes, err := os.ReadFile(path) if err != nil { logger.Warnf("failed to read CRL file %v: %v", path, err) continue } crl, err := x509.ParseRevocationList(crlBytes) if err != nil { logger.Warnf("failed to parse CRL file %v: %v", path, err) continue } if crl.NextUpdate.Add(ccc.onDiskCacheRemovalDelay).Before(now) { logger.Debugf("CRL file %v is expired, removing", path) if err := os.Remove(path); err != nil { logger.Warnf("failed to remove expired CRL file %v: %v", path, err) } } } } func defaultCrlOnDiskCacheDir() (string, error) { switch runtime.GOOS { case "windows": return filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local", "Snowflake", "Caches", "crls"), nil case "darwin": home := os.Getenv("HOME") if home == "" { return "", errors.New("HOME is blank") } return filepath.Join(home, "Library", "Caches", "Snowflake", "crls"), nil default: home := os.Getenv("HOME") if home == "" { return "", errors.New("HOME is blank") } return filepath.Join(home, ".cache", "snowflake", "crls"), nil } } ================================================ FILE: crl_test.go ================================================ package gosnowflake import ( "cmp" "context" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "crypto/x509/pkix" "database/sql" "encoding/asn1" "encoding/base64" "fmt" "math/big" "net" "net/http" "os" "path/filepath" "sync" "testing" "time" ) var serialNumber = int64(0) // to be incremented type allowCertificatesWithoutCrlURLType bool type inMemoryCacheDisabledType bool type onDiskCacheDisabledType bool type downloadMaxSizeType int type notAfterType time.Time type crlEndpointType string type revokedCert *x509.Certificate type thisUpdateType time.Time type nextUpdateType time.Time func newTestCrlValidator(t *testing.T, checkMode CertRevocationCheckMode, args ...any) *crlValidator { httpClient := &http.Client{} allowCertificatesWithoutCrlURL := false inMemoryCacheDisabled := false onDiskCacheDisabled := false downloadMaxSize := defaultCrlDownloadMaxSize telemetry := &snowflakeTelemetry{} for _, arg := range args { switch v := arg.(type) { case *http.Client: httpClient = v case allowCertificatesWithoutCrlURLType: allowCertificatesWithoutCrlURL = bool(v) case inMemoryCacheDisabledType: inMemoryCacheDisabled = bool(v) case onDiskCacheDisabledType: onDiskCacheDisabled = bool(v) case downloadMaxSizeType: downloadMaxSize = int(v) case *snowflakeTelemetry: telemetry = v default: t.Fatalf("unexpected argument type %T", v) } } cv, err := newCrlValidator(checkMode, allowCertificatesWithoutCrlURL, inMemoryCacheDisabled, onDiskCacheDisabled, downloadMaxSize, httpClient, telemetry) assertNilF(t, err) return cv } func TestCrlCheckModeDisabledNoHttpCall(t *testing.T) { caKey, caCert := createCa(t, nil, nil, "root CA", 0) _, leafCert := createLeafCert(t, caCert, caKey, 0, crlEndpointType("/rootCrl")) crt := &countingRoundTripper{} cv := newTestCrlValidator(t, CertRevocationCheckDisabled, &http.Client{Transport: crt}) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) assertEqualE(t, crt.totalRequests(), 0, "no HTTP request should be made when check mode is disabled") } func TestCrlModes(t *testing.T) { for _, checkMode := range []CertRevocationCheckMode{CertRevocationCheckEnabled, CertRevocationCheckAdvisory} { t.Run(fmt.Sprintf("checkMode=%v", checkMode), func(t *testing.T) { t.Run("ShortLivedCertDoesNotNeedCRL", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode, allowCertificatesWithoutCrlURLType(false)) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", 0, "") _, leafCert := createLeafCert(t, caCert, caPrivateKey, 0, "", notAfterType(time.Now().Add(4*24*time.Hour))) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) }) t.Run("LeafCertNotRevoked", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) }) t.Run("LeafCertRevoked", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey, revokedCert(leafCert)) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNotNilF(t, err) assertEqualE(t, err.Error(), "every verified certificate chain contained revoked certificates") }) t.Run("LeafOneCrlErrorAndOneNotRevoked", func(t *testing.T) { cleanupCrlCache(t) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/404"), crlEndpointType("rootCrl")) crl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) cv := newTestCrlValidator(t, checkMode) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) switch checkMode { case CertRevocationCheckEnabled: assertNotNilF(t, err) assertEqualE(t, err.Error(), "certificate revocation check failed") case CertRevocationCheckAdvisory: assertNilE(t, err) } }) t.Run("LeafOneCrlErrorAndOneRevoked", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/404"), crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey, revokedCert(leafCert)) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNotNilF(t, err) assertEqualE(t, err.Error(), "every verified certificate chain contained revoked certificates") }) t.Run("TestLeafNotRevokedAndRootDoesNotProvideCrl", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) rootCaPrivateKey, rootCaCert := createCa(t, nil, nil, "root CA", port) intermediateCaKey, intermediateCaCert := createCa(t, rootCaCert, rootCaPrivateKey, "intermediate CA", port) _, leafCert := createLeafCert(t, intermediateCaCert, intermediateCaKey, port, crlEndpointType("/intermediateCrl")) intermediateCrl := createCrl(t, intermediateCaCert, intermediateCaKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/intermediateCrl", intermediateCrl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, intermediateCaCert, rootCaCert}}) if checkMode == CertRevocationCheckEnabled { assertEqualE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("IntermediateRevokedAndLeafDoesNotProvideCrl", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) rootCaPrivateKey, rootCaCert := createCa(t, nil, nil, "root CA", port) intermediateCaKey, intermediateCaCert := createCa(t, rootCaCert, rootCaPrivateKey, "intermediate CA", port, crlEndpointType("/rootCrl")) _, leafCert := createLeafCert(t, intermediateCaCert, intermediateCaKey, port, crlEndpointType("/intermediateCrl")) rootCrl := createCrl(t, rootCaCert, rootCaPrivateKey, revokedCert(intermediateCaCert)) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", rootCrl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, intermediateCaCert, rootCaCert}}) assertEqualE(t, err.Error(), "every verified certificate chain contained revoked certificates") }) t.Run("IntermediateRevokedAndLeafDoesNotProvideCrl", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) rootCaPrivateKey, rootCaCert := createCa(t, nil, nil, "root CA", port) intermediateCaKey, intermediateCaCert := createCa(t, rootCaCert, rootCaPrivateKey, "intermediate CA", port, "/rootCrl") _, leafCert := createLeafCert(t, intermediateCaCert, intermediateCaKey, port) rootCrl := createCrl(t, rootCaCert, rootCaPrivateKey, revokedCert(intermediateCaCert)) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", rootCrl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, intermediateCaCert, rootCaCert}}) assertEqualE(t, err.Error(), "every verified certificate chain contained revoked certificates") }) t.Run("DownloadedCrlIsExpiredAndNoneValidExists", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey, thisUpdateType(time.Now().Add(-2*time.Hour)), nextUpdateType(time.Now().Add(-1*time.Hour))) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) if checkMode == CertRevocationCheckEnabled { assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("DownloadedCrlIsExpiredButTheValidExists", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) oldCrl := createCrl(t, caCert, caPrivateKey, thisUpdateType(time.Now().Add(-50*time.Hour)), nextUpdateType(time.Now().Add(48*time.Hour))) newCrl := createCrl(t, caCert, caPrivateKey, thisUpdateType(time.Now().Add(-2*time.Hour)), nextUpdateType(time.Now().Add(-1*time.Hour))) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", newCrl)) oldCrlDownloadTime := time.Now().Add(-48 * time.Hour) crlInMemoryCache[fullCrlURL(port, "/rootCrl")] = &crlInMemoryCacheValueType{ crl: oldCrl, downloadTime: &oldCrlDownloadTime, } err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) }) t.Run("CrlSignatureInvalid", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) otherCaPrivateKey, _ := createCa(t, nil, nil, "other CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, otherCaPrivateKey) // signed with wrong key registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) if checkMode == CertRevocationCheckEnabled { assertStringContainsE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("CrlIssuerMismatch", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) otherKey, otherCert := createCa(t, nil, nil, "other CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, otherCert, otherKey) // issued by other CA registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) if checkMode == CertRevocationCheckEnabled { assertStringContainsE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("CertWithNoCrlDistributionPoints", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) if checkMode == CertRevocationCheckEnabled { assertEqualE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("CertWithNoCrlDistributionPointsAllowed", func(t *testing.T) { cleanupCrlCache(t) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", 0) _, leafCert := createLeafCert(t, caCert, caPrivateKey, 0) cv := newTestCrlValidator(t, checkMode, allowCertificatesWithoutCrlURLType(true)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) }) t.Run("DownloadCrlFailsOnUnparsableCrl", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode, &http.Client{ Transport: &malformedCrlRoundTripper{}, }) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) if checkMode == CertRevocationCheckEnabled { assertEqualE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("DownloadCrlFailsOn404", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) if checkMode == CertRevocationCheckEnabled { assertEqualE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("CrlFitsLimit", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode, downloadMaxSizeType(1024*1024)) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) }) t.Run("CrlTooLargeToDownload", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode, downloadMaxSizeType(10)) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) if checkMode == CertRevocationCheckEnabled { assertEqualE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("VerifyAgainstIdpExtensionWithDistributionPointMatch", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) idpValue, err := asn1.Marshal(issuingDistributionPoint{ DistributionPoint: distributionPointName{ FullName: []asn1.RawValue{ {Bytes: fmt.Appendf(nil, "http://localhost:%v/rootCrl", port)}, }, }, }) assertNilF(t, err) idpExtension := &pkix.Extension{ Id: idpOID, Value: idpValue, } crl := createCrl(t, caCert, caPrivateKey, idpExtension) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err = cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) }) t.Run("TestVerifyAgainstIdpExtensionWithDistributionPointMismatch", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) idpValue, err := asn1.Marshal(issuingDistributionPoint{ DistributionPoint: distributionPointName{ FullName: []asn1.RawValue{ {Bytes: fmt.Appendf(nil, "http://localhost:%v/otherCrl", port)}, }, }, }) assertNilF(t, err) idpExtension := &pkix.Extension{ Id: idpOID, Value: idpValue, } crl := createCrl(t, caCert, caPrivateKey, idpExtension) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err = cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) if checkMode == CertRevocationCheckEnabled { assertNotNilF(t, err) assertEqualE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("AnyValidChainCausesSuccess", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caKey, caCert := createCa(t, nil, nil, "root CA", port) _, revokedLeaf := createLeafCert(t, caCert, caKey, port, crlEndpointType("/rootCrl")) _, validLeaf := createLeafCert(t, caCert, caKey, port, crlEndpointType("/rootCrl")) // CRL revokes only the first leaf crl := createCrl(t, caCert, caKey, revokedCert(revokedLeaf)) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) // First chain: revoked, second chain: valid err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{ {revokedLeaf, caCert}, {validLeaf, caCert}, }) assertNilE(t, err) }) t.Run("OneChainIsRevokedAndOtherIsError", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caKey, caCert := createCa(t, nil, nil, "root CA", port) _, revokedLeaf := createLeafCert(t, caCert, caKey, port, crlEndpointType("/rootCrl")) _, errorLeaf := createLeafCert(t, caCert, caKey, port, crlEndpointType("/missingCrl")) // CRL revokes only the first leaf crl := createCrl(t, caCert, caKey, revokedCert(revokedLeaf)) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) // First chain: revoked, second chain: valid err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{ {revokedLeaf, caCert}, {errorLeaf, caCert}, }) if checkMode == CertRevocationCheckEnabled { assertNotNilF(t, err) assertEqualE(t, err.Error(), "certificate revocation check failed") } else { assertNilE(t, err) } }) t.Run("CacheTests", func(t *testing.T) { t.Run("should use in-memory cache", func(t *testing.T) { cleanupCrlCache(t) crt := newCountingRoundTripper(createTestNoRevocationTransport()) cv := newTestCrlValidator(t, checkMode, &http.Client{ Transport: crt, }) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) downloadTime := time.Now().Add(-1 * time.Minute) crlInMemoryCache[fullCrlURL(port, "/rootCrl")] = &crlInMemoryCacheValueType{ crl: crl, downloadTime: &downloadTime, } err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) assertEqualE(t, crt.totalRequests(), 0) _, err = os.Open(cv.crlURLToPath("/rootCrl")) assertErrIsE(t, err, os.ErrNotExist, "CRL file should not be created in the cache directory") }) t.Run("should promote on-disk cache to memory and not modify on-disk entry", func(t *testing.T) { skipOnMissingHome(t) cleanupCrlCache(t) crt := newCountingRoundTripper(createTestNoRevocationTransport()) cv := newTestCrlValidator(t, checkMode, &http.Client{ Transport: crt, }) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) assertNilF(t, os.WriteFile(cv.crlURLToPath(fullCrlURL(port, "/rootCrl")), crl.Raw, 0600)) // simulate a cached CRL statBefore, err := os.Stat(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilF(t, err) err = cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) assertEqualE(t, crt.totalRequests(), 0) statAfter, err := os.Stat(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilF(t, err) assertEqualE(t, statBefore.ModTime().Equal(statAfter.ModTime()), true, "CRL file should not be modified in the cache directory") }) t.Run("should redownload when nextUpdate is reached", func(t *testing.T) { cleanupCrlCache(t) crt := newCountingRoundTripper(createTestNoRevocationTransport()) cv := newTestCrlValidator(t, checkMode, &http.Client{ Transport: crt, }) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) oldCrl := createCrl(t, caCert, caPrivateKey, thisUpdateType(time.Now().Add(-2*time.Minute)), nextUpdateType(time.Now().Add(-1*time.Minute))) newCrl := createCrl(t, caCert, caPrivateKey, thisUpdateType(time.Now()), nextUpdateType(time.Now().Add(time.Hour))) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", newCrl)) previousDownloadTime := time.Now().Add(-1 * time.Minute) crlInMemoryCache[fullCrlURL(port, "/rootCrl")] = &crlInMemoryCacheValueType{ crl: oldCrl, downloadTime: &previousDownloadTime, } err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) assertEqualE(t, crt.totalRequests(), 1) fd, err := os.Open(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilE(t, err, "CRL file should be created in the cache directory") defer fd.Close() assertTrueE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")].downloadTime.After(previousDownloadTime)) assertTrueE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")].crl.NextUpdate.Equal(newCrl.NextUpdate)) }) t.Run("should redownload when evicted in cache", func(t *testing.T) { cleanupCrlCache(t) crt := newCountingRoundTripper(createTestNoRevocationTransport()) cv := newTestCrlValidator(t, checkMode, &http.Client{ Transport: crt, }) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) oldCrl := createCrl(t, caCert, caPrivateKey, thisUpdateType(time.Now().Add(-2*time.Hour)), nextUpdateType(time.Now().Add(time.Hour))) newCrl := createCrl(t, caCert, caPrivateKey, thisUpdateType(time.Now()), nextUpdateType(time.Now().Add(4*time.Hour))) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", newCrl)) previousValidityTime := crlCacheCleaner.cacheValidityTime defer func() { crlCacheCleaner.cacheValidityTime = previousValidityTime }() crlCacheCleaner.cacheValidityTime = 10 * time.Minute previousDownloadTime := time.Now().Add(-1 * time.Hour) crlInMemoryCache[fullCrlURL(port, "/rootCrl")] = &crlInMemoryCacheValueType{ crl: oldCrl, downloadTime: &previousDownloadTime, } err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) assertEqualE(t, crt.totalRequests(), 1) fd, err := os.Open(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilE(t, err, "CRL file should be created in the cache directory") defer fd.Close() assertTrueE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")].downloadTime.After(previousDownloadTime)) assertTrueE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")].crl.NextUpdate.Equal(newCrl.NextUpdate)) if !isWindows { stat, err := os.Stat(filepath.Dir(cv.crlURLToPath(fullCrlURL(port, "/rootCrl")))) assertNilF(t, err) assertEqualE(t, stat.Mode().Perm(), os.FileMode(0700), "cache directory permissions should be 0700") } }) t.Run("should not save to on-disk cache when disabled", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode, onDiskCacheDisabledType(true)) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) _, err = os.Open(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertErrIsE(t, err, os.ErrNotExist, "CRL file should not be created in the cache directory when on-disk cache is disabled") assertNotNilE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")]) // in-memory cache should still be used }) t.Run("should not read from on-disk cache when disabled", func(t *testing.T) { cleanupCrlCache(t) crt := newCountingRoundTripper(createTestNoRevocationTransport()) cv := newTestCrlValidator(t, checkMode, onDiskCacheDisabledType(true), &http.Client{ Transport: crt, }) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) oldCrl := createCrl(t, caCert, caPrivateKey, nextUpdateType(time.Now())) newCrl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", newCrl)) assertNilF(t, os.WriteFile(cv.crlURLToPath(fullCrlURL(port, "/rootCrl")), oldCrl.Raw, 0600)) // simulate a cached CRL statBefore, err := os.Stat(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilF(t, err) err = cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) assertEqualE(t, crt.totalRequests(), 1, "CRL should be downloaded from the server") assertNotNilE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")]) // in-memory cache should still be used statAfter, err := os.Stat(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilF(t, err) assertTrueE(t, statBefore.ModTime().Equal(statAfter.ModTime()), "CRL file should be modified in the cache directory") }) t.Run("should not use in-memory cache when disabled", func(t *testing.T) { skipOnMissingHome(t) cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode, inMemoryCacheDisabledType(true)) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) assertEqualE(t, len(crlInMemoryCache), 0, "in-memory cache should not be used when disabled") fd, err := os.Open(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilE(t, err) // on-disk cache should still be used defer fd.Close() }) t.Run("should not use on disk cache when disabled", func(t *testing.T) { cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode, inMemoryCacheDisabledType(true), onDiskCacheDisabledType(true)) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) assertNilE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")], "in-memory cache should not be used when disabled") _, err = os.Open(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertErrIsE(t, err, os.ErrNotExist, "CRL file should not be created in the cache directory when on-disk cache is disabled") }) t.Run("should clean up cache", func(t *testing.T) { skipOnMissingHome(t) cleanupCrlCache(t) cv := newTestCrlValidator(t, checkMode) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey, nextUpdateType(time.Now().Add(3000*time.Millisecond))) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) previousValidityTime := crlCacheCleaner.cacheValidityTime previousOnDiskCacheRemovalDelay := crlCacheCleaner.onDiskCacheRemovalDelay defer func() { crlCacheCleaner.cacheValidityTime = previousValidityTime crlCacheCleaner.onDiskCacheRemovalDelay = previousOnDiskCacheRemovalDelay }() crlCacheCleaner.cacheValidityTime = 1000 * time.Millisecond crlCacheCleaner.onDiskCacheRemovalDelay = 2000 * time.Millisecond crlCacheCleaner.stopPeriodicCacheCleanup() previousCacheCleanerTickRate := crlCacheCleanerTickRate defer func() { crlCacheCleanerTickRate = previousCacheCleanerTickRate }() crlCacheCleanerTickRate = 500 * time.Millisecond crlCacheCleaner.startPeriodicCacheCleanup() defer crlCacheCleaner.stopPeriodicCacheCleanup() err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) crlInMemoryCacheMutex.Lock() assertNotNilE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")], "in-memory cache should be populated") crlInMemoryCacheMutex.Unlock() fd, err := os.Open(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilE(t, err, "CRL file should be created in the cache directory") fd.Close() time.Sleep(3000 * time.Millisecond) // wait for cleanup to happen crlInMemoryCacheMutex.Lock() assertNilE(t, crlInMemoryCache[fullCrlURL(port, "/rootCrl")], "in-memory cache should be cleaned up") crlInMemoryCacheMutex.Unlock() fd, err = os.Open(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertNilE(t, err, "CRL file should still be present in the cache directory") fd.Close() time.Sleep(4000 * time.Millisecond) // wait for removal delay to pass _, err = os.Open(cv.crlURLToPath(fullCrlURL(port, "/rootCrl"))) assertErrIsE(t, err, os.ErrNotExist, "CRL file should be removed from the cache directory after removal delay") }) }) }) } } func cleanupCrlCache(t *testing.T) { crlCacheCleanerMu.Lock() if crlCacheCleaner != nil { crlCacheCleaner.stopPeriodicCacheCleanup() err := os.RemoveAll(crlCacheCleaner.onDiskCacheDir) assertNilF(t, err) crlCacheCleaner = nil } crlCacheCleanerMu.Unlock() crlInMemoryCache = make(map[string]*crlInMemoryCacheValueType) } func TestRealCrlWithIdpExtension(t *testing.T) { crlBytes, err := base64.StdEncoding.DecodeString(`MIIWCzCCFbECAQEwCgYIKoZIzj0EAwIwOzELMAkGA1UEBhMCVVMxHjAcBgNVBAoTFUdvb2dsZSBUcnVzdCBTZXJ2aWNlczEMMAoGA1UEAxMDV0UyFw0yNTA2MDMwNTE0MjZaFw0yNTA2MTMwNDE0MjVaMIIU1TAiAhEA+GNmsfmkiSYS3So6PtM4YRcNMjUwNTMwMDgzMDU0WjAiAhEAjnadf1gDhyYKPKaa/12+7xcNMjUwNTMwMDgzNDMyWjAhAhBE9QlX3xRpuxJ814WV+K/1Fw0yNTA1MzAxMTA0MzNaMCICEQCqN2nq4YSOEwkyJCn6HYQlFw0yNTA1MzAxMTM0MzNaMCECEDBfFh8CphcdEJF+zBTMw74XDTI1MDUzMDEyMDA1M1owIQIQalbjU7py90YQObvUekSOhBcNMjUwNTMwMTIwNDMzWjAiAhEAr2k4vZwyJnISwutcyf2nyRcNMjUwNTMwMTMwNDMzWjAhAhB35TMXvzwpYwooflxIqWDEFw0yNTA1MzAxMzMwNTNaMCECEAGHFbYpRjuyEmwHBjVy54gXDTI1MDUzMDEzMzQzMlowIgIRAId502qqmD3KEDgIHLdDwZYXDTI1MDUzMDE0MTg1MFowIgIRAJEe803uv+NQEJUBE5Q6P0kXDTI1MDUzMDE0MTg1MFowIgIRAOLFs7G+1xolCsv2TgVXc0AXDTI1MDUzMDE4MDQzMlowIQIQUsjln6aQLBgQRpsXpimESRcNMjUwNTMwMTgzNDMyWjAiAhEA62yPgGbg8uAKRBAp3N7zjRcNMjUwNTMwMjAwNDMyWjAiAhEAsjA4b2hRSeQJ3HSOmSCsfxcNMjUwNTMwMjAzNDMzWjAiAhEA5vGSk0V5AiQSSlJJgHBO/RcNMjUwNTMwMjEwMDUzWjAhAhBC5Bb9vfzyyQkPGoyM+1y3Fw0yNTA1MzEwMDA0MzNaMCICEQCk2xXPFJlcFAq8gAoYZcWKFw0yNTA1MzEwMDM0MzJaMCICEQDoXOJPuECUGwpzgim5mc9mFw0yNTA1MzEwMTAwNTNaMCECEHgn0iqA3FOqEGZkc3nMlQsXDTI1MDUzMTAyMzA1NFowIQIQdnsVe7yop/YSZC36hn8k0hcNMjUwNTMxMDUwMDUzWjAiAhEA988MkvjARu0K+NJ1aVwOIRcNMjUwNTMxMDcwMDUzWjAiAhEAwFdObfm70cMSBKAflw/KCxcNMjUwNTMxMDczMDUzWjAiAhEAqX2jbkbYhlwKl2fgguEfdRcNMjUwNTMxMDgzMDUzWjAhAhAcfL0AhaLI2xAfTjDas2e4Fw0yNTA1MzEwODM0MzNaMCECEHcuTXPmmCULECe4qj6t/woXDTI1MDUzMTA5MDQzMlowIgIRAL0tNF+V7aarEjS5X52ozVwXDTI1MDUzMTA5MzA1M1owIQIQEWjKzEnAuZAQOdBZQMCcLRcNMjUwNTMxMTAzMDUzWjAhAhA2l4kUNXKzpwoDbrMlYN65Fw0yNTA1MzExMTA0MzJaMCICEQDQMi07YAslxglpYDrFllr0Fw0yNTA1MzExMTMwNTNaMCECEEfIJzk/qTOVEDehcdaIr3YXDTI1MDUzMTEyMzQzMlowIgIRAPs9bOlpEQZzEL71JmOr4gMXDTI1MDUzMTE0MzA1NFowIgIRAKA4/laWgpf+CX5Xqdui57sXDTI1MDUzMTE0MzQzMlowIQIQIJL+kywlXcIQoNk1IR4hABcNMjUwNTMxMTcwMDUzWjAiAhEA10YhoTDr3JIJdDwoUvU7PBcNMjUwNTMxMTgzNDMyWjAhAhBjqqc9j1zo+grP13nPYjlrFw0yNTA1MzExOTMwNTNaMCECEFvJXOjJWg4XCg9lgBLgFCUXDTI1MDUzMTIxMDA1NFowIQIQHjWkZX62R5gKS9bus/vO3hcNMjUwNTMxMjIwNDMzWjAiAhEArzROq2M27voKXANmOzjg4BcNMjUwNTMxMjIzMDU0WjAhAhBGoxuPheM5twmSM9LO0NZuFw0yNTA1MzEyMzAwNTNaMCICEQClgDoqCxhihxDvXApTEN/QFw0yNTA1MzEyMzM0MzJaMCICEQCjffeJqicvMxCaQlnCRp1kFw0yNTA2MDEwMDM0MzNaMCECEB3bMsobz0qRCdm+plUwrNUXDTI1MDYwMTAxMDA1NFowIgIRANusCipK0XOVEC0+C1Ce+bsXDTI1MDYwMTAyMDA1NFowIgIRANsRDccCPVBrEGplnFXS3y0XDTI1MDYwMTAyMzQzMlowIQIQZBPFmHRcxzESJeZSri7+fBcNMjUwNjAxMDMwMDUyWjAhAhBUeunArcVjrApcJ9uR1v0cFw0yNTA2MDEwMzA0MzJaMCECEH7M2GgoJPa3Ccjz9nx1FmwXDTI1MDYwMTAzMzA1M1owIgIRAKwbWa1xrjjgCvB5I6ICstAXDTI1MDYwMTA0MzA1M1owIgIRAKRJvSq/BfQqEPgYyqN/lkwXDTI1MDYwMTA1MDA1NFowIQIQPJxOkr7drV4Qjxa9rYfUwhcNMjUwNjAxMDUwNDMzWjAiAhEA8lQTTLlsfBoJlrx6CydL7hcNMjUwNjAxMDUzNDMzWjAiAhEAluoSt/87SbUKN6WD8WO/uBcNMjUwNjAxMDYzMDUzWjAiAhEAi1z9zzq3ecYQYbpyjZcV0BcNMjUwNjAxMDcwMDU0WjAhAhBDYZctZbp9NQkS+H75yhEmFw0yNTA2MDExMDM0MzJaMCICEQDhKSZ6X/VHjQpM79Em7auJFw0yNTA2MDExMTAwNTNaMCICEQCzngaFAi5rTBJBHMJnGgjCFw0yNTA2MDExMTMwNTRaMCECEAi0b7W58XDnEHtR8u+d+TwXDTI1MDMwNDEyMTIyNVowIgIRANw8VR+umOAsEpehwNHqCWkXDTI1MDYwMTEyMzQzMlowIQIQVDJ7+F+QyfQSUexffugxPBcNMjUwNjAxMTMwMDUzWjAiAhEA3kZX5ACREf4Ql7R88uTRiBcNMjUwMzA0MTQ1MTU2WjAhAhBAmF4m8TDJfxCB93DGRJ5SFw0yNTA2MDExNTMwNTRaMCECED2nNXiAdcbkCorz/3SaOXkXDTI1MDMwNDE2MDY0MFowIgIRAJPjTBx12IeKCsZC+WsYtqwXDTI1MDYwMTE4MzA1M1owIQIQH89eMYtFX+ESUBJx9drNdxcNMjUwNjAxMTkwMDUzWjAiAhEA9h1UKrkPonEJ3oHf6DAdeRcNMjUwNjAxMTkzMDUzWjAiAhEAx7HcWI25jVsJzEFAa8H6hhcNMjUwNjAxMTkzNDMyWjAiAhEA2xt7Vz1eC9US2Lx9U7IdQxcNMjUwNjAyMDEzMDUzWjAhAhBLBChzFL7nMBKrkgfIqmL4Fw0yNTA2MDIwMjA0MzJaMCICEQCoWrPIkhkCEwoZoBW8Wi7iFw0yNTA2MDIwMjMwNTNaMCICEQCW9nREFwgFExAhQPkEcX1GFw0yNTA2MDIwNzA0MzJaMCECEFJpjh2fOfnwEPYEmgM4vAsXDTI1MDYwMjEwMzQzMlowIgIRAMARWx58ovYeCYlv9x/+dXUXDTI1MDYwMjExMzQzMlowIgIRANGVJSxAtM0+CmvyDk5yemEXDTI1MDYwMjEyMzQzMlowIQIQLuR16MKk7VIJsPZDdxmxjBcNMjUwMzA1MTMxNTQ2WjAhAhBgWj2KpFDd1hLS8czTxP9WFw0yNTA2MDIxMzMwNTNaMCICEQDpBAXC4tks2RA3PmivojEYFw0yNTA2MDIxMzM0MzJaMCICEQDPAqlDrpaIZRLOv4dkWD9YFw0yNTA2MDIxNDM0MzJaMCECECHJcaelQHswEjWQOK4shmQXDTI1MDYwMjE1MDA1M1owIgIRAKSC4iHRwdOXEI4MVwjYASMXDTI1MDYwMjE4MDQzMlowIgIRAPhnb/McQolNCT5KPL9WBy0XDTI1MDYwMjE5MDA1M1owIQIQA/fNWPLbkQ8SJc6T1ykDtxcNMjUwNjAyMTkwNDMyWjAhAhBp1e5W8/pEFgoVhg1GywuhFw0yNTA2MDIxOTM0MzJaMCECEDa16LoaHM7jEBLVfZOw+2EXDTI1MDYwMjIwMDA1NFowIgIRANhoeJQh/bgAChCj0tjaOhoXDTI1MDYwMjIwMzQzMlowIgIRAPGCJfkpjnA0Ep42ikTZTDQXDTI1MDYwMjIxMDQzMlowIgIRANBfcQ5tm+jQEIrc4G9uz30XDTI1MDYwMjIyMDQzM1owIQIQDvMAXxXjJV0Q07lbQyqRlRcNMjUwNjAzMDIzMDUzWjAhAhABUapKRf9bwxJ9pM421HlyFw0yNTA2MDMwMjM0MzNaMCICEQDE7QlV4jWoawmVVFlPlN5ZFw0yNTA2MDMwMzAwNTNaMCECEDrfc2dpmptdEOBKNuW5dN0XDTI1MDMwNzE2MDY0NVowIgIRAO08CoY80ZYZCnASAJsibosXDTI1MDMwOTE2MDYzOVowIgIRAO3z/WMJKFPwEqGv+wIQqVUXDTI1MDMxMTE2MDYzOVowIgIRAOGk/CY9/86iEkStcRIR74oXDTI1MDMxMTE3MzMzNVowIgIRALmyt1+31WZtCrklPUahHsoXDTI1MDMxMTIxMjcyM1owIgIRAN0K49cWZ5XVCRUwnqkyzAcXDTI1MDMxNTE3MzM0MFowIgIRAKHAD2cxPWesCiXtOaFLRMwXDTI1MDMxNTE4MzcxM1owIQIQerJr0+WomOYQqOCLMwwQQhcNMjUwMzE5MDYxNDA5WjAhAhAX1xTDBKnX9RBHto7Yo8lVFw0yNTAzMjAwOTM4NTdaMCICEQDrpjOSW5W9fgqtI2heAOexFw0yNTAzMjMwNjE0MDlaMCICEQCdIwrsmoZRIhIDnY2gQhZZFw0yNTAzMjYwODI1MDRaMCICEQCc3wlTpAB6ZxJB5SLJ1cGFFw0yNTAzMzExMDQ0NDVaMCECEDvSrWlzrD2bEHLHvZ+Ak9sXDTI1MDQwMjA3NTk0NFowIgIRAMJ2ztUSpiKpCqYpTx6GEWwXDTI1MDQwNjA4MDA0OFowIQIQed72ikZNyBISyOL/lLPDIxcNMjUwNDA4MjA1NjM5WjAiAhEA+fjeN7n4PugS5Mh4kSSUhhcNMjUwNDA5MDQ1ODM4WjAhAhA2Gg3BxIzzaAqR0K/EYS9uFw0yNTA0MDkwNTU4MzBaMCECEA6iX6ZA2cvtCvqLywYZkGEXDTI1MDQwOTA4MDIwMFowIQIQaajjpNdTR+MSotZQd0le4BcNMjUwNDExMDk0ODQxWjAiAhEA+Z7TKxQHRP8KXarTEkKl/xcNMjUwNDExMTY0MTUxWjAhAhAYS5W1oCus3gqsNhnA9lgNFw0yNTA0MTExODUyMDBaMCECEB9WtUrjbzKNCcLJuZELbPIXDTI1MDQxMTE4NTIwMFowIQIQJuqczPhm8x4JCjjS5UEV4hcNMjUwNDEzMjExMzQ4WjAiAhEA8pC1AgBcHQMK98lYehVRqBcNMjUwNDEzMjIxNzE3WjAhAhBEe078o0AX4hCPOfwW08DgFw0yNTA0MTUyMTEzNDhaMCECEFtBlrwO2/yCEI5FaTjhEMUXDTI1MDQxNTIzNDgxMVowIgIRAOAhdu/DwnQZEGh9ABuntsEXDTI1MDQxNzE5MjA0NlowIgIRAKblmThTrKCLCaAfU80cgHUXDTI1MDQxNzIwMjU0N1owIQIQJ+PW+89xTOgJv3sKUFzpFRcNMjUwNDE4MjEzOTI0WjAhAhBreCVIZnxIxQkm0n/lw8XuFw0yNTA0MjAxODAxMzJaMCICEQDLHBY49bRaWxAUwMRRaYGkFw0yNTA0MjExNjU3NDlaMCECEBKDWcexQm8uCQPht1B2WCMXDTI1MDQyMjE2NTc0OVowIgIRANEuLddZ+6e/Cinj83AK2TIXDTI1MDQyMzE2NTc0OFowIQIQQRs5pdt3rw0Kj3yAi9nB8BcNMjUwNDI0MTgwMTMxWjAiAhEA2++UC5BwrkkSDLuijbOlhxcNMjUwNDI0MTgxNTI0WjAiAhEAso/DvQaXc8cQJQzH3vT39xcNMjUwNDI1MTgwOTAxWjAhAhA6Wxu2SrTNQAqGYEIlmug6Fw0yNTA0MzAxNTE2NTZaMCECECrQTDxnQf4UCjmTomNx6uoXDTI1MDQzMDE2MTU1MVowIgIRAOzK9hrrhUpREDNdMK+UhKYXDTI1MDUwMjIyNDgzNFowIQIQJywBgwts3CYJlswBuEfC4BcNMjUwNTAyMjM1MTQ1WjAhAhAv+aqUHySyHQnqo/kXTj07Fw0yNTA1MDMyMjQ4MzVaMCICEQDlfhMr/mGCeQrUug4RBfCwFw0yNTA1MDQyMzUxNDVaMCICEQC12JkjoHkyGQrXnfDh1Ak3Fw0yNTA1MDgxNjM3MjJaMCICEQD+ChJzg9zffhJvICXO5egWFw0yNTA1MDkxNDQ0MjNaMCICEQDINngvxFORLgmtenUC0eReFw0yNTA1MTAxNTQ4MTFaMCICEQDCRQG/17P8RgkCuuqVCqOEFw0yNTA1MTIxNDQ0MjRaMCECEG/pHThaOXIYEA6gUwBN2AAXDTI1MDUxMjE1NDgxMFowIQIQKRDCPxMlRDkQtVuZlc1y/BcNMjUwNTE1MTYzMjQ4WjAhAhAQJithNwlgHhBJtOo4cr7PFw0yNTA1MTgxNjMyNDhaMCICEQDuzLB0Dym1dAopKKRwqg+FFw0yNTA1MTkxNjMyNDhaMCICEQCic+mqwTKh2wlW/M9hFsKUFw0yNTA1MjExNDE5NTJaMCICEQCkcISpajRR8gloWttjVtWYFw0yNTA1MjIxNDE5NTFaMCECEC+QfsXidSEECVCY2XJcobsXDTI1MDUyNTEzMTU1NFowIQIQGaVPji8ez7sQc2BEKZ6zQRcNMjUwNTI2MTQxOTUxWjAhAhAWzpGux+VcMBLCf/uAu+UHFw0yNTA1MjgxNjAzMDdaMCICEQCVcpW8k5oxiwkAGBCtQXleFw0yNTA1MjkxNTMwNTRaMCICEQDCbweEznzXHxLmEoYkMXAXFw0yNTA1MjkxNjQyMjZaMCICEQC0/LnZiZ/wlhAYZ7QNFoMOFw0yNTA1MzAxNDE4NTBaMCICEQDNuyNRBRFsWhC2IgBtBr4jFw0yNTA1MzAxNDE4NTBaMCICEQCqTNQ5/wthcQoKTERGUrPiFw0yNTA2MDExNTMwNTNaoGwwajAfBgNVHSMEGDAWgBR1vsR3ron2RDd9z7FoHx0a69w0WTALBgNVHRQEBAICCwswOgYDVR0cAQH/BDAwLqApoCeGJWh0dHA6Ly9jLnBraS5nb29nL3dlMi95SzVuUGh0SEtRcy5jcmyBAf8wCgYIKoZIzj0EAwIDSAAwRQIhANnRHxa67XPmeX/SrH7l5sMJxA+OLg6eAjiUCBHW7NeKAiBZTWzYLK9IDgfUffYcRLtITegsRIjm02lrBd1I1I+QbQ==`) assertNilF(t, err) crl, err := x509.ParseRevocationList(crlBytes) assertNilF(t, err) cv := newTestCrlValidator(t, CertRevocationCheckEnabled) err = cv.verifyAgainstIdpExtension(crl, "http://c.pki.goog/we2/yK5nPhtHKQs.crl") assertNilE(t, err) err = cv.verifyAgainstIdpExtension(crl, "http://c.pki.goog/we2/other.crl") assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "distribution point http://c.pki.goog/we2/other.crl not found in CRL IDP extension") } func TestParallelRequestToTheSameCrl(t *testing.T) { cleanupCrlCache(t) server, port := createCrlServer(t) defer closeServer(t, server) caPrivateKey, caCert := createCa(t, nil, nil, "root CA", port) _, leafCert := createLeafCert(t, caCert, caPrivateKey, port, crlEndpointType("/rootCrl")) crl := createCrl(t, caCert, caPrivateKey) registerCrlEndpoints(t, server, newCrlEndpointDef("/rootCrl", crl)) brt := newBlockingRoundTripper(createTestNoRevocationTransport(), 100*time.Millisecond) crt := newCountingRoundTripper(brt) cv := newTestCrlValidator(t, CertRevocationCheckEnabled, &http.Client{ Transport: crt, }) var wg sync.WaitGroup for range 10 { wg.Add(1) go func() { defer wg.Done() err := cv.verifyPeerCertificates(nil, [][]*x509.Certificate{{leafCert, caCert}}) assertNilE(t, err) }() } wg.Wait() assertEqualE(t, crt.totalRequests(), 1) } func TestIsShortLivedCertificate(t *testing.T) { tests := []struct { name string cert *x509.Certificate expected bool }{ { name: "Issued before March 15, 2024 (not short-lived)", cert: &x509.Certificate{ NotBefore: time.Date(2024, time.March, 1, 0, 0, 0, 0, time.UTC), NotAfter: time.Date(2024, time.March, 10, 0, 0, 0, 0, time.UTC), }, expected: false, }, { name: "Issued after March 15, 2024, validity less than 10, but more than 7 days (short-lived)", cert: &x509.Certificate{ NotBefore: time.Date(2024, time.March, 16, 0, 0, 0, 0, time.UTC), NotAfter: time.Date(2024, time.March, 24, 0, 0, 0, 0, time.UTC), }, expected: true, }, { name: "Issued after March 15, 2024, validity less than 7 days (short-lived)", cert: &x509.Certificate{ NotBefore: time.Date(2024, time.March, 16, 0, 0, 0, 0, time.UTC), NotAfter: time.Date(2024, time.March, 22, 0, 0, 0, 0, time.UTC), }, expected: true, }, { name: "Issued after March 15, 2024, validity exactly 10 days (short-lived)", cert: &x509.Certificate{ NotBefore: time.Date(2024, time.March, 16, 0, 0, 0, 0, time.UTC), NotAfter: time.Date(2024, time.March, 26, 0, 0, 0, 0, time.UTC), }, expected: true, }, { name: "Issued after March 15, 2024, validity more than 10 days (not short-lived)", cert: &x509.Certificate{ NotBefore: time.Date(2024, time.March, 16, 0, 0, 0, 0, time.UTC), NotAfter: time.Date(2024, time.March, 27, 0, 0, 0, 0, time.UTC), }, expected: false, }, { name: "Issued after March 15, 2026, validity less than 7 days (short-lived)", cert: &x509.Certificate{ NotBefore: time.Date(2026, time.March, 16, 0, 0, 0, 0, time.UTC), NotAfter: time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC), }, expected: true, }, { name: "Issued after March 15, 2026, validity exactly 7 days (short-lived)", cert: &x509.Certificate{ NotBefore: time.Date(2026, time.March, 16, 0, 0, 0, 0, time.UTC), NotAfter: time.Date(2026, time.March, 23, 0, 0, 0, 0, time.UTC), }, expected: true, }, { name: "Issued after March 15, 2026, validity more than 7 days (not short-lived)", cert: &x509.Certificate{ NotBefore: time.Date(2026, time.March, 16, 0, 0, 0, 0, time.UTC), NotAfter: time.Date(2026, time.March, 24, 0, 0, 0, 0, time.UTC), }, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assertEqualE(t, isShortLivedCertificate(tt.cert), tt.expected) }) } } type malformedCrlRoundTripper struct { } func (m *malformedCrlRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { response := http.Response{ StatusCode: http.StatusOK, } response.Body = http.NoBody return &response, nil } func createCa(t *testing.T, issuerCert *x509.Certificate, issuerPrivateKey *rsa.PrivateKey, cn string, port int, crlEndpoints ...crlEndpointType) (*rsa.PrivateKey, *x509.Certificate) { caTemplate := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"Snowflake"}, OrganizationalUnit: []string{"Drivers"}, Locality: []string{"Warsaw"}, CommonName: cn, }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(10, 0, 0), IsCA: true, KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, BasicConstraintsValid: true, SignatureAlgorithm: x509.SHA256WithRSA, } return createCert(t, caTemplate, issuerCert, issuerPrivateKey, port, crlEndpoints) } func createLeafCert(t *testing.T, issuerCert *x509.Certificate, issuerPrivateKey *rsa.PrivateKey, port int, params ...any) (*rsa.PrivateKey, *x509.Certificate) { notAfter := time.Now().AddDate(1, 0, 0) var crlEndpoints []crlEndpointType for _, param := range params { switch v := param.(type) { case notAfterType: notAfter = time.Time(v) case crlEndpointType: crlEndpoints = append(crlEndpoints, v) } } serialNumber++ certTemplate := &x509.Certificate{ SerialNumber: big.NewInt(serialNumber), Subject: pkix.Name{ Organization: []string{"Snowflake"}, OrganizationalUnit: []string{"Drivers"}, Locality: []string{"Warsaw"}, CommonName: "localhost", }, NotBefore: time.Now(), NotAfter: notAfter, IsCA: false, SignatureAlgorithm: x509.SHA256WithRSA, } return createCert(t, certTemplate, issuerCert, issuerPrivateKey, port, crlEndpoints) } func createCert(t *testing.T, template, issuerCert *x509.Certificate, issuerPrivateKey *rsa.PrivateKey, port int, crlEndpoints []crlEndpointType) (*rsa.PrivateKey, *x509.Certificate) { var distributionPoints []string for _, crlEndpoint := range crlEndpoints { distributionPoints = append(distributionPoints, fmt.Sprintf("http://localhost:%v%v", port, crlEndpoint)) template.CRLDistributionPoints = distributionPoints } privateKey, err := rsa.GenerateKey(rand.Reader, 2048) assertNilF(t, err) template.SubjectKeyId = calculateKeyID(t, &privateKey.PublicKey) signerPrivateKey := cmp.Or(issuerPrivateKey, privateKey) issuerCertOrSelfSigned := cmp.Or(issuerCert, template) certBytes, err := x509.CreateCertificate(rand.Reader, template, issuerCertOrSelfSigned, &privateKey.PublicKey, signerPrivateKey) assertNilF(t, err) cert, err := x509.ParseCertificate(certBytes) assertNilF(t, err) return privateKey, cert } func calculateKeyID(t *testing.T, pubKey any) []byte { pubBytes, err := x509.MarshalPKIXPublicKey(pubKey) assertNilF(t, err) hash := sha256.Sum256(pubBytes) return hash[:] } func createCrl(t *testing.T, issuerCert *x509.Certificate, issuerPrivateKey *rsa.PrivateKey, args ...any) *x509.RevocationList { var revokedCertEntries []x509.RevocationListEntry var extensions []pkix.Extension thisUpdate := time.Now().Add(-time.Hour) nextUpdate := time.Now().Add(time.Hour) for _, arg := range args { switch v := arg.(type) { case revokedCert: revokedCertEntries = append(revokedCertEntries, x509.RevocationListEntry{ SerialNumber: v.SerialNumber, RevocationTime: time.Now().Add(-time.Hour * 24), }) case *pkix.Extension: extensions = append(extensions, *v) case thisUpdateType: thisUpdate = time.Time(v) case nextUpdateType: nextUpdate = time.Time(v) default: t.Fatalf("unexpected argument type: %T", arg) } } crlTemplate := &x509.RevocationList{ Number: big.NewInt(1), RevokedCertificateEntries: revokedCertEntries, ExtraExtensions: extensions, ThisUpdate: thisUpdate, NextUpdate: nextUpdate, } crlBytes, err := x509.CreateRevocationList(rand.Reader, crlTemplate, issuerCert, issuerPrivateKey) assertNilF(t, err) crl, err := x509.ParseRevocationList(crlBytes) assertNilF(t, err) return crl } type crlEndpointDef struct { endpoint string crl *x509.RevocationList } func newCrlEndpointDef(endpoint string, crl *x509.RevocationList) *crlEndpointDef { return &crlEndpointDef{ endpoint: endpoint, crl: crl, } } func createCrlServer(t *testing.T) (*http.Server, int) { listener, err := net.Listen("tcp", ":0") assertNilF(t, err) port := listener.Addr().(*net.TCPAddr).Port server := &http.Server{ Addr: fmt.Sprintf(":%v", port), Handler: http.NewServeMux(), } go func() { err := server.Serve(listener) assertErrIsF(t, err, http.ErrServerClosed) }() return server, port } func registerCrlEndpoints(t *testing.T, server *http.Server, endpointDefs ...*crlEndpointDef) { for _, endpointDef := range endpointDefs { server.Handler.(*http.ServeMux).HandleFunc(endpointDef.endpoint, func(responseWriter http.ResponseWriter, request *http.Request) { responseWriter.WriteHeader(http.StatusOK) _, err := responseWriter.Write(endpointDef.crl.Raw) assertNilF(t, err) }) } } func fullCrlURL(port int, endpoint string) string { return fmt.Sprintf("http://localhost:%v%v", port, endpoint) } func closeServer(t *testing.T, server *http.Server) { err := server.Shutdown(context.Background()) assertNilF(t, err) } func TestCrlE2E(t *testing.T) { t.Run("Successful flow", func(t *testing.T) { skipOnJenkins(t, "Jenkins tests use HTTP connection to SF, so CRL is not used") cleanupCrlCache(t) defer cleanupCrlCache(t) // to reset cache cleaner after test crlCacheCleanerTickRate = 1 * time.Second cacheValidityTimeOverride := overrideEnv(snowflakeCrlCacheValidityTimeEnv, "15s") defer cacheValidityTimeOverride.rollback() cfg, err := ParseDSN(dsn) assertNilF(t, err, "Failed to parse DSN") // Add CRL-specific test parameters cfg.CertRevocationCheckMode = CertRevocationCheckEnabled cfg.CrlAllowCertificatesWithoutCrlURL = ConfigBoolTrue cfg.DisableOCSPChecks = true cfg.CrlOnDiskCacheDisabled = true db := sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfg)) defer db.Close() rows, err := db.Query("SELECT 1") assertNilF(t, err, "CRL E2E test failed") defer rows.Close() crlInMemoryCacheMutex.Lock() memoryEntriesAfterSnowflakeConnection := len(crlInMemoryCache) crlInMemoryCacheMutex.Unlock() logger.Debugf("memory entries after Snowflake connection: %v", memoryEntriesAfterSnowflakeConnection) assertTrueE(t, memoryEntriesAfterSnowflakeConnection > 0) // additional entries for connecting to cloud providers and checking their certs cwd, err := os.Getwd() assertNilF(t, err, "Failed to get current working directory") _, err = db.Exec(fmt.Sprintf("PUT file://%v @~/%v", filepath.Join(cwd, "test_data", "put_get_1.txt"), "put_get_1.txt")) assertNilF(t, err, "Failed to execute PUT file") crlInMemoryCacheMutex.Lock() memoryEntriesAfterCSPConnection := len(crlInMemoryCache) crlInMemoryCacheMutex.Unlock() logger.Debugf("memory entries after CSP connection: %v", memoryEntriesAfterCSPConnection) assertTrueE(t, memoryEntriesAfterCSPConnection > memoryEntriesAfterSnowflakeConnection) time.Sleep(17 * time.Second) // wait for the cache cleaner to run crlInMemoryCacheMutex.Lock() assertEqualE(t, len(crlInMemoryCache), 0) crlInMemoryCacheMutex.Unlock() }) t.Run("OCSP and CRL cannot be enabled at the same time", func(t *testing.T) { crlInMemoryCache = make(map[string]*crlInMemoryCacheValueType) // cleanup to ensure our test will fill it cfg := &Config{ User: username, Password: pass, Account: account, Database: dbname, Schema: schemaname, CertRevocationCheckMode: CertRevocationCheckEnabled, } _, err := buildSnowflakeConn(context.Background(), *cfg) assertStringContainsE(t, err.Error(), "both OCSP and CRL cannot be enabled at the same time") assertEqualE(t, len(crlInMemoryCache), 0) }) } ================================================ FILE: ctx_test.go ================================================ package gosnowflake import ( "bytes" "context" "fmt" "strings" "testing" ) func TestCtxVal(t *testing.T) { type favContextKey string f := func(ctx context.Context, k favContextKey) error { if v := ctx.Value(k); v != nil { return nil } return fmt.Errorf("key not found: %v", k) } k := favContextKey("language") ctx := context.WithValue(context.Background(), k, "Go") k2 := favContextKey("data") ctx2 := context.WithValue(ctx, k2, "Snowflake") if err := f(ctx, k); err != nil { t.Error(err) } if err := f(ctx, "color"); err == nil { t.Error("should not have been found in context") } if err := f(ctx2, k); err != nil { t.Error(err) } if err := f(ctx2, k2); err != nil { t.Error(err) } } func TestLogCtx(t *testing.T) { log := CreateDefaultLogger() sessCtx := context.WithValue(context.Background(), SFSessionIDKey, "sessID1") ctx := context.WithValue(sessCtx, SFSessionUserKey, "admin") var b bytes.Buffer log.SetOutput(&b) assertNilF(t, log.SetLogLevel("trace"), "could not set log level") l := log.WithContext(ctx) l.Info("Hello 1") l.Warn("Hello 2") s := b.String() if len(s) <= 0 { t.Error("nothing written") } if !strings.Contains(s, "LOG_SESSION_ID=sessID1") { t.Error("context ctx1 keys/values not logged") } if !strings.Contains(s, "LOG_USER=admin") { t.Error("context ctx2 keys/values not logged") } } ================================================ FILE: datatype.go ================================================ package gosnowflake import ( "bytes" "database/sql" "database/sql/driver" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "github.com/snowflakedb/gosnowflake/v2/internal/types" ) var ( // DataTypeFixed is a FIXED datatype. DataTypeFixed = []byte{types.FixedType.Byte()} // DataTypeReal is a REAL datatype. DataTypeReal = []byte{types.RealType.Byte()} // DataTypeDecfloat is a DECFLOAT datatype. DataTypeDecfloat = []byte{types.DecfloatType.Byte()} // DataTypeText is a TEXT datatype. DataTypeText = []byte{types.TextType.Byte()} // DataTypeDate is a Date datatype. DataTypeDate = []byte{types.DateType.Byte()} // DataTypeVariant is a TEXT datatype. DataTypeVariant = []byte{types.VariantType.Byte()} // DataTypeTimestampLtz is a TIMESTAMP_LTZ datatype. DataTypeTimestampLtz = []byte{types.TimestampLtzType.Byte()} // DataTypeTimestampNtz is a TIMESTAMP_NTZ datatype. DataTypeTimestampNtz = []byte{types.TimestampNtzType.Byte()} // DataTypeTimestampTz is a TIMESTAMP_TZ datatype. DataTypeTimestampTz = []byte{types.TimestampTzType.Byte()} // DataTypeObject is a OBJECT datatype. DataTypeObject = []byte{types.ObjectType.Byte()} // DataTypeArray is a ARRAY datatype. DataTypeArray = []byte{types.ArrayType.Byte()} // DataTypeBinary is a BINARY datatype. DataTypeBinary = []byte{types.BinaryType.Byte()} // DataTypeTime is a TIME datatype. DataTypeTime = []byte{types.TimeType.Byte()} // DataTypeBoolean is a BOOLEAN datatype. DataTypeBoolean = []byte{types.BooleanType.Byte()} // DataTypeNilObject represents a nil structured object. DataTypeNilObject = []byte{types.NilObjectType.Byte()} // DataTypeNilArray represents a nil structured array. DataTypeNilArray = []byte{types.NilArrayType.Byte()} // DataTypeNilMap represents a nil structured map. DataTypeNilMap = []byte{types.NilMapType.Byte()} ) // dataTypeMode returns the subsequent data type in a string representation. func dataTypeMode(v driver.Value) (tsmode types.SnowflakeType, err error) { if bd, ok := v.([]byte); ok { switch { case bytes.Equal(bd, DataTypeDecfloat): tsmode = types.DecfloatType case bytes.Equal(bd, DataTypeDate): tsmode = types.DateType case bytes.Equal(bd, DataTypeTime): tsmode = types.TimeType case bytes.Equal(bd, DataTypeTimestampLtz): tsmode = types.TimestampLtzType case bytes.Equal(bd, DataTypeTimestampNtz): tsmode = types.TimestampNtzType case bytes.Equal(bd, DataTypeTimestampTz): tsmode = types.TimestampTzType case bytes.Equal(bd, DataTypeBinary): tsmode = types.BinaryType case bytes.Equal(bd, DataTypeObject): tsmode = types.ObjectType case bytes.Equal(bd, DataTypeArray): tsmode = types.ArrayType case bytes.Equal(bd, DataTypeVariant): tsmode = types.VariantType case bytes.Equal(bd, DataTypeNilObject): tsmode = types.NilObjectType case bytes.Equal(bd, DataTypeNilArray): tsmode = types.NilArrayType case bytes.Equal(bd, DataTypeNilMap): tsmode = types.NilMapType default: return types.NullType, fmt.Errorf(errors.ErrMsgInvalidByteArray, v) } } else { return types.NullType, fmt.Errorf(errors.ErrMsgInvalidByteArray, v) } return tsmode, nil } // SnowflakeParameter includes the columns output from SHOW PARAMETER command. type SnowflakeParameter struct { Key string Value string Default string Level string Description string SetByUser string SetInJob string SetOn string SetByThreadID string SetByThreadName string SetByClass string ParameterComment string Type string IsExpired string ExpiresAt string SetByControllingParameter string ActivateVersion string PartialRollout string Unknown string // Reserve for added parameter } func populateSnowflakeParameter(colname string, p *SnowflakeParameter) any { switch colname { case "key": return &p.Key case "value": return &p.Value case "default": return &p.Default case "level": return &p.Level case "description": return &p.Description case "set_by_user": return &p.SetByUser case "set_in_job": return &p.SetInJob case "set_on": return &p.SetOn case "set_by_thread_id": return &p.SetByThreadID case "set_by_thread_name": return &p.SetByThreadName case "set_by_class": return &p.SetByClass case "parameter_comment": return &p.ParameterComment case "type": return &p.Type case "is_expired": return &p.IsExpired case "expires_at": return &p.ExpiresAt case "set_by_controlling_parameter": return &p.SetByControllingParameter case "activate_version": return &p.ActivateVersion case "partial_rollout": return &p.PartialRollout default: logger.Debugf("unknown type: %v", colname) return &p.Unknown } } // ScanSnowflakeParameter binds SnowflakeParameter variable with an array of column buffer. func ScanSnowflakeParameter(rows *sql.Rows) (*SnowflakeParameter, error) { var err error var columns []string columns, err = rows.Columns() if err != nil { return nil, err } colNum := len(columns) p := SnowflakeParameter{} cols := make([]any, colNum) for i := range colNum { cols[i] = populateSnowflakeParameter(columns[i], &p) } err = rows.Scan(cols...) return &p, err } ================================================ FILE: datatype_test.go ================================================ package gosnowflake import ( "database/sql/driver" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "github.com/snowflakedb/gosnowflake/v2/internal/types" "testing" ) func TestDataTypeMode(t *testing.T) { var testcases = []struct { tp driver.Value tmode types.SnowflakeType err error }{ {tp: DataTypeTimestampLtz, tmode: types.TimestampLtzType, err: nil}, {tp: DataTypeTimestampNtz, tmode: types.TimestampNtzType, err: nil}, {tp: DataTypeTimestampTz, tmode: types.TimestampTzType, err: nil}, {tp: DataTypeDate, tmode: types.DateType, err: nil}, {tp: DataTypeTime, tmode: types.TimeType, err: nil}, {tp: DataTypeBinary, tmode: types.BinaryType, err: nil}, {tp: DataTypeObject, tmode: types.ObjectType, err: nil}, {tp: DataTypeArray, tmode: types.ArrayType, err: nil}, {tp: DataTypeVariant, tmode: types.VariantType, err: nil}, {tp: DataTypeFixed, tmode: types.FixedType, err: fmt.Errorf(errors.ErrMsgInvalidByteArray, DataTypeFixed)}, {tp: DataTypeReal, tmode: types.RealType, err: fmt.Errorf(errors.ErrMsgInvalidByteArray, DataTypeFixed)}, {tp: 123, tmode: types.NullType, err: fmt.Errorf(errors.ErrMsgInvalidByteArray, 123)}, } for _, ts := range testcases { t.Run(fmt.Sprintf("%v_%v", ts.tp, ts.tmode), func(t *testing.T) { tmode, err := dataTypeMode(ts.tp) if ts.err == nil { if err != nil { t.Errorf("failed to get datatype mode: %v", err) } if tmode != ts.tmode { t.Errorf("wrong data type: %v", tmode) } } else { if err == nil { t.Errorf("should raise an error: %v", ts.err) } } }) } } func TestPopulateSnowflakeParameter(t *testing.T) { columns := []string{"key", "value", "default", "level", "description", "set_by_user", "set_in_job", "set_on", "set_by_thread_id", "set_by_thread_name", "set_by_class", "parameter_comment", "type", "is_expired", "expires_at", "set_by_controlling_parameter", "activate_version", "partial_rollout"} p := SnowflakeParameter{} cols := make([]any, len(columns)) for i := range columns { cols[i] = populateSnowflakeParameter(columns[i], &p) } for i := range cols { if cols[i] == nil { t.Fatal("failed to populate parameter") } } } ================================================ FILE: datetime.go ================================================ package gosnowflake import ( "errors" "regexp" "strconv" "strings" "time" ) var incorrectSecondsFractionRegex = regexp.MustCompile(`[^.,]FF`) var correctSecondsFractionRegex = regexp.MustCompile(`FF(?P\d?)`) type formatReplacement struct { input string output string } var formatReplacements = []formatReplacement{ {input: "YYYY", output: "2006"}, {input: "YY", output: "06"}, {input: "MMMM", output: "January"}, {input: "MM", output: "01"}, {input: "MON", output: "Jan"}, {input: "DD", output: "02"}, {input: "DY", output: "Mon"}, {input: "HH24", output: "15"}, {input: "HH12", output: "03"}, {input: "AM", output: "PM"}, {input: "MI", output: "04"}, {input: "SS", output: "05"}, {input: "TZH", output: "Z07"}, {input: "TZM", output: "00"}, } func timeToString(t time.Time, dateTimeType string, sp *syncParams) (string, error) { sfFormat, err := dateTimeInputFormatByType(dateTimeType, sp) if err != nil { return "", err } goFormat, err := snowflakeFormatToGoFormat(sfFormat) if err != nil { return "", err } return t.Format(goFormat), nil } func snowflakeFormatToGoFormat(sfFormat string) (string, error) { res := sfFormat for _, replacement := range formatReplacements { res = strings.ReplaceAll(res, replacement.input, replacement.output) } if incorrectSecondsFractionRegex.MatchString(res) { return "", errors.New("incorrect second fraction - golang requires fraction to be preceded by comma or decimal point") } for { submatch := correctSecondsFractionRegex.FindStringSubmatch(res) if submatch == nil { break } fractionNumbers := 9 if submatch[1] != "" { var err error fractionNumbers, err = strconv.Atoi(submatch[1]) if err != nil { return "", err } } res = strings.ReplaceAll(res, submatch[0], strings.Repeat("0", fractionNumbers)) } return res, nil } func dateTimeOutputFormatByType(dateTimeType string, sp *syncParams) (string, error) { var format *string switch strings.ToLower(dateTimeType) { case "date": format, _ = sp.get("date_output_format") case "time": format, _ = sp.get("time_output_format") case "timestamp_ltz": format, _ = sp.get("timestamp_ltz_output_format") if format == nil || *format == "" { format, _ = sp.get("timestamp_output_format") } case "timestamp_tz": format, _ = sp.get("timestamp_tz_output_format") if format == nil || *format == "" { format, _ = sp.get("timestamp_output_format") } case "timestamp_ntz": format, _ = sp.get("timestamp_ntz_output_format") if format == nil || *format == "" { format, _ = sp.get("timestamp_output_format") } } if format != nil { return *format, nil } return "", errors.New("not known output format parameter for " + dateTimeType) } func dateTimeInputFormatByType(dateTimeType string, sp *syncParams) (string, error) { var format *string var ok bool switch strings.ToLower(dateTimeType) { case "date": if format, ok = sp.get("date_input_format"); !ok || format == nil || *format == "" { format, _ = sp.get("date_output_format") } case "time": if format, ok = sp.get("time_input_format"); !ok || format == nil || *format == "" { format, _ = sp.get("time_output_format") } case "timestamp_ltz": if format, ok = sp.get("timestamp_ltz_input_format"); !ok || format == nil || *format == "" { if format, ok = sp.get("timestamp_input_format"); !ok || format == nil || *format == "" { if format, ok = sp.get("timestamp_ltz_output_format"); !ok || format == nil || *format == "" { format, _ = sp.get("timestamp_output_format") } } } case "timestamp_tz": if format, ok = sp.get("timestamp_tz_input_format"); !ok || format == nil || *format == "" { if format, ok = sp.get("timestamp_input_format"); !ok || format == nil || *format == "" { if format, ok = sp.get("timestamp_tz_output_format"); !ok || format == nil || *format == "" { format, _ = sp.get("timestamp_output_format") } } } case "timestamp_ntz": if format, ok = sp.get("timestamp_ntz_input_format"); !ok || format == nil || *format == "" { if format, ok = sp.get("timestamp_input_format"); !ok || format == nil || *format == "" { if format, ok = sp.get("timestamp_ntz_output_format"); !ok || format == nil || *format == "" { format, _ = sp.get("timestamp_output_format") } } } } if format != nil { return *format, nil } return "", errors.New("not known input format parameter for " + dateTimeType) } ================================================ FILE: datetime_test.go ================================================ package gosnowflake import ( "testing" "time" ) func TestSnowflakeFormatToGoFormatUnitTest(t *testing.T) { location, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) someTime1 := time.Date(2024, time.January, 19, 3, 42, 33, 123456789, location) someTime2 := time.Date(1973, time.December, 5, 13, 5, 3, 987000000, location) testcases := []struct { inputFormat string output string formatted1 string formatted2 string }{ { inputFormat: "YYYY-MM-DD HH24:MI:SS.FF TZH:TZM", output: "2006-01-02 15:04:05.000000000 Z07:00", formatted1: "2024-01-19 03:42:33.123456789 +01:00", formatted2: "1973-12-05 13:05:03.987000000 +01:00", }, { inputFormat: "YY-MM-DD HH12:MI:SS,FF5AM TZHTZM", output: "06-01-02 03:04:05,00000PM Z0700", formatted1: "24-01-19 03:42:33,12345AM +0100", formatted2: "73-12-05 01:05:03,98700PM +0100", }, { inputFormat: "MMMM DD, YYYY DY HH24:MI:SS.FF9 TZH:TZM", output: "January 02, 2006 Mon 15:04:05.000000000 Z07:00", formatted1: "January 19, 2024 Fri 03:42:33.123456789 +01:00", formatted2: "December 05, 1973 Wed 13:05:03.987000000 +01:00", }, { inputFormat: "MON DD, YYYY HH12:MI:SS,FF9PM TZH:TZM", output: "Jan 02, 2006 03:04:05,000000000PM Z07:00", formatted1: "Jan 19, 2024 03:42:33,123456789AM +01:00", formatted2: "Dec 05, 1973 01:05:03,987000000PM +01:00", }, { inputFormat: "HH24:MI:SS.FF3 HH12:MI:SS,FF9", output: "15:04:05.000 03:04:05,000000000", formatted1: "03:42:33.123 03:42:33,123456789", formatted2: "13:05:03.987 01:05:03,987000000", }, } for _, tc := range testcases { t.Run(tc.inputFormat, func(t *testing.T) { goFormat, err := snowflakeFormatToGoFormat(tc.inputFormat) assertNilF(t, err) assertEqualE(t, tc.output, goFormat) assertEqualE(t, tc.formatted1, someTime1.Format(goFormat)) assertEqualE(t, tc.formatted2, someTime2.Format(goFormat)) }) } } func TestIncorrectSecondsFraction(t *testing.T) { _, err := snowflakeFormatToGoFormat("HH24 MI SS FF") assertHasPrefixE(t, err.Error(), "incorrect second fraction") } func TestSnowflakeFormatToGoFormatIntegrationTest(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIME_OUTPUT_FORMAT = 'HH24:MI:SS.FF'") dbt.mustExec("ALTER SESSION SET TIMESTAMP_OUTPUT_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM'") dbt.mustExec("ALTER SESSION SET TIMESTAMP_NTZ_OUTPUT_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3'") for _, forceFormat := range []string{forceJSON, forceARROW} { dbt.mustExec(forceFormat) for _, tc := range []struct { sfType string formatParamName string sfFunction string }{ { sfType: "TIMESTAMPLTZ", formatParamName: "TIMESTAMP_OUTPUT_FORMAT", sfFunction: "CURRENT_TIMESTAMP", }, { sfType: "TIMESTAMPTZ", formatParamName: "TIMESTAMP_OUTPUT_FORMAT", sfFunction: "CURRENT_TIMESTAMP", }, { sfType: "TIMESTAMPNTZ", formatParamName: "TIMESTAMP_NTZ_OUTPUT_FORMAT", sfFunction: "CURRENT_TIMESTAMP", }, { sfType: "DATE", formatParamName: "DATE_OUTPUT_FORMAT", sfFunction: "CURRENT_DATE", }, { sfType: "TIME", formatParamName: "TIME_OUTPUT_FORMAT", sfFunction: "CURRENT_TIME", }, } { t.Run(tc.sfType+"___"+forceFormat, func(t *testing.T) { params := dbt.mustQuery("show parameters like '" + tc.formatParamName + "'") defer params.Close() params.Next() defaultTimestampOutputFormat, err := ScanSnowflakeParameter(params.rows) assertNilF(t, err) rows := dbt.mustQuery("SELECT " + tc.sfFunction + "()::" + tc.sfType + ", " + tc.sfFunction + "()::" + tc.sfType + "::varchar") defer rows.Close() var t1 time.Time var t2 string rows.Next() err = rows.Scan(&t1, &t2) assertNilF(t, err) goFormat, err := snowflakeFormatToGoFormat(defaultTimestampOutputFormat.Value) assertNilF(t, err) assertEqualE(t, t1.Format(goFormat), t2) parseResult, err := time.Parse(goFormat, t2) assertNilF(t, err) if tc.sfType != "TIME" { assertEqualE(t, t1.UTC(), parseResult.UTC()) } else { assertEqualE(t, t1.Hour(), parseResult.Hour()) assertEqualE(t, t1.Minute(), parseResult.Minute()) assertEqualE(t, t1.Second(), parseResult.Second()) } }) } } }) } ================================================ FILE: doc.go ================================================ /* Package gosnowflake is a pure Go Snowflake driver for the database/sql package. Clients can use the database/sql package directly. For example: import ( "database/sql" _ "github.com/snowflakedb/gosnowflake/v2" "log" ) func main() { db, err := sql.Open("snowflake", "user:password@my_organization-my_account/mydb") if err != nil { log.Fatal(err) } defer db.Close() ... } # Connection String Use the Open() function to create a database handle with connection parameters: db, err := sql.Open("snowflake", "") The Go Snowflake Driver supports the following connection syntaxes (or data source name (DSN) formats): - username[:password]@/dbname/schemaname[?param1=value&...¶mN=valueN] - username[:password]@/dbname[?param1=value&...¶mN=valueN] - username[:password]@hostname:port/dbname/schemaname?account=[¶m1=value&...¶mN=valueN] where all parameters must be escaped or use Config and DSN to construct a DSN string. For information about account identifiers, see the Snowflake documentation (https://docs.snowflake.com/en/user-guide/admin-account-identifier.html). The following example opens a database handle with the Snowflake account named "my_account" under the organization named "my_organization", where the username is "jsmith", password is "mypassword", database is "mydb", schema is "testschema", and warehouse is "mywh": db, err := sql.Open("snowflake", "jsmith:mypassword@my_organization-my_account/mydb/testschema?warehouse=mywh") # Connection Parameters The connection string (DSN) can contain both connection parameters (described below) and session parameters (https://docs.snowflake.com/en/sql-reference/parameters.html). The following connection parameters are supported: - account : Specifies your Snowflake account, where "" is the account identifier assigned to your account by Snowflake. For information about account identifiers, see the Snowflake documentation (https://docs.snowflake.com/en/user-guide/admin-account-identifier.html). If you are using a global URL, then append the connection group and ".global" (e.g. "-.global"). The account identifier and the connection group are separated by a dash ("-"), as shown above. This parameter is optional if your account identifier is specified after the "@" character in the connection string. - region : DEPRECATED. You may specify a region, such as "eu-central-1", with this parameter. However, since this parameter is deprecated, it is best to specify the region as part of the account parameter. For details, see the description of the account parameter. - --> Important note: for the database object and other objects (schema, role, etc), please always adhere to the rules for Snowflake Object Identifiers; especially https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers. As mentioned in the docs, if you have e.g. a database with mIxEDcAsE naming, as you needed to create it with enclosing it in double quotes, similarly you'll need to reference it also with double quotes when specifying it in the connection string / DSN. In practice, this means you'll need to escape the second pair of double quotes, which are part of the database name, and not the String notation. - database: Specifies the database to use by default in the client session (can be changed after login). - schema: Specifies the database schema to use by default in the client session (can be changed after login). - warehouse: Specifies the virtual warehouse to use by default for queries, loading, etc. in the client session (can be changed after login). - role: Specifies the role to use by default for accessing Snowflake objects in the client session (can be changed after login). - passcode: Specifies the passcode provided by Duo when using multi-factor authentication (MFA) for login. - passcodeInPassword: false by default. Set to true if the MFA passcode is embedded in the login password. Appends the MFA passcode to the end of the password. - loginTimeout: Specifies the timeout, in seconds, for login. The default is 60 seconds. The login request gives up after the timeout length if the HTTP response is success. - requestTimeout: Specifies the timeout, in seconds, for a query to complete. 0 (zero) specifies that the driver should wait indefinitely. The default is 0 seconds. The query request gives up after the timeout length if the HTTP response is success. - authenticator: Specifies the authenticator to use for authenticating user credentials. See "Authenticator Values" section below for supported values. - singleAuthenticationPrompt: specifies whether only one authentication should be performed at the same time for authentications that needs human interactions (like MFA or OAuth authorization code). By default it is true. - application: Identifies your application to Snowflake Support. - disableOCSPChecks: false by default. Set to true to bypass the Online Certificate Status Protocol (OCSP) certificate revocation check. OCSP module caches responses internally. If your application is long running, you can enable cache clearing by calling StartOCSPCacheClearer and disable by calling StopOCSPCacheClearer. IMPORTANT: Change the default value for testing or emergency situations only. - token: a token that can be used to authenticate. Should be used in conjunction with the "oauth" authenticator. - client_session_keep_alive: Set to true have a heartbeat in the background every hour by default or the value of client_session_keep_alive_heartbeat_frequency, if set, to keep the connection alive such that the connection session will never expire. Care should be taken in using this option as it opens up the access forever as long as the process is alive. - client_session_keep_alive_heartbeat_frequency: Number of seconds in-between client attempts to update the token for the session. > The default is 3600 seconds > Minimum value is 900 seconds. A smaller value will be reset to 900 seconds. > Maximum value is 3600 seconds. A larger value will be reset to 3600 seconds. > This parameter is only valid if client_session_keep_alive is set to true. - ocspFailOpen: true by default. Set to false to make OCSP check fail closed mode. - certRevocationCheckMode (enabled, advisory, disabled): Specifies the certificate revocation check mode. When enabled, the driver performs a certificate revocation check using CRL. When advisory, the driver performs a certificate revocation check using CRL, but fails the connection only if the certificate is revoked. If the status cannot be determined, the connection is established. When disabled, the driver does not perform a certificate revocation check. Keep in mind that the certificate revocation check with CRLs is a heavy task, both for memory and CPU. The default is disabled. - crlAllowCertificatesWithoutCrlURL: if a certificate does not have a CRL URL, the driver will allow the connection to be established. The default is false. - SNOWFLAKE_CRL_CACHE_VALIDITY_TIME (environment variable): specifies the validity time of the CRL cache in seconds. - crlInMemoryCacheDisabled: set to disable in-memory caching of CRLs. - crlOnDiskCacheDisabled: set to disable on-disk caching of CRLs (on-disk cache may help with cold starts). - crlDownloadMaxSize: maximum size (in bytes) of a CRL to download. Default is 20MB. - SNOWFLAKE_CRL_ON_DISK_CACHE_DIR (environment variable): set to customize the directory for on-disk caching of CRLs. - SNOWFLAKE_CRL_ON_DISK_CACHE_REMOVAL_DELAY (environment variable): set the delay (in seconds) for removing the on-disk cache (for debuggability). - crlHTTPClientTimeout: customize the HTTP client timeout for downloading CRLs. - validateDefaultParameters: true by default. Set to false to disable checks on existence and privileges check for Database, Schema, Warehouse and Role when setting up the connection --> Important note: with the default true value, the connection will fail as the validation fails, if you specify a non-existent database/schema/etc name. This is particularly important when you have a miXedCaSE-named object (e.g. database) and you forgot to properly double quote it. This behaviour is still preferable as it provides a very clear, fail-fast indication of the configuration error. If you would still like to forego this validation, which ensures that the driver always connects with proper database, schema etc. and creates a proper context for it, you can set this configuration to false to allow connection with invalid object identifiers. In this case (with this default validation deliberately turned off) the driver cannot guarantee that the actual behaviour inside the session will match with the one you'd expect, i.e. not actually using the database you expect, and so on. - tracing: Specifies the logging level to be used. Set to error by default. Valid values are off, fatal, error, warn, info, debug, trace. - logQueryText: when set to true, the full query text will be logged. Be aware that it may include sensitive information. Default value is false. - logQueryParameters: when set to true, the parameters will be logged. Requires logQueryText to be enabled first. Be aware that it may include sensitive information. Default value is false. - disableQueryContextCache: disables parsing of query context returned from server and resending it to server as well. Default value is false. - clientConfigFile: specifies the location of the client configuration json file. In this file you can configure Easy Logging feature. - disableSamlURLCheck: disables the SAML URL check. Default value is false. All other parameters are interpreted as session parameters (https://docs.snowflake.com/en/sql-reference/parameters.html). For example, the TIMESTAMP_OUTPUT_FORMAT session parameter can be set by adding: ...&TIMESTAMP_OUTPUT_FORMAT=MM-DD-YYYY... A complete connection string looks similar to the following: my_user_name:my_password@ac123456/my_database/my_schema?my_warehouse=inventory_warehouse&role=my_user_role&DATE_OUTPUT_FORMAT=YYYY-MM-DD ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ connection connection session parameter parameter parameter Session-level parameters can also be set by using the SQL command "ALTER SESSION" (https://docs.snowflake.com/en/sql-reference/sql/alter-session.html). Alternatively, use OpenWithConfig() function to create a database handle with the specified Config. # Authenticator values - To use the internal Snowflake authenticator, specify snowflake (Default). - To use programmatic access tokens, specify programmatic_access_token. - If you want to cache your MFA logins, specify username_password_mfa. You can pass TOTP in a separate passcode parameter or append it to the password setting in which case you need to set passcodeInPassword = true. - To authenticate through Okta, specify https://.okta.com (URL prefix for Okta). - To authenticate using your IDP via a browser, specify externalbrowser. - To authenticate via OAuth with token, specify oauth and provide an OAuth Access Token (see the token parameter below). - To authenticate via full OAuth flow, specify oauth_authorization_code or oauth_client_credentials and fill relevant parameters (oauthClientId, oauthClientSecret, oauthAuthorizationUrl, oauthTokenRequestUrl, oauthRedirectUri, oauthScope). Specify URLs if you want to use external OAuth2 IdP, otherwise Snowflake will be used as a default IdP. If oauthScope is not configured, the role is used (giving session:role: scope). For more information, please reach to official Snowflake documentation. - To authenticate via workload identity, specify workload_identity. This option requires workloadIdentityProvider option to be set (AWS, GCP, AZURE, OIDC). When workloadIdentityProvider=AZURE, workloadIdentityEntraResource can be optionally set to customize entra resource used to fetch JWT token. When workloadIdentityProvider=GCP or AWS, workloadIdentityImpersonationPath can be optionally set to customize impersonation path. This is a comma separated list. For GCP the last parameter is a target service account and the rest are chained delegation. For AWS this is the list of role ARNs to assume. For more details, refer to the usage guide: https://docs.snowflake.com/en/user-guide/workload-identity-federation # Connection Config You can also connect to your warehouse using the connection config. The database/sql package is appropriate when you want driver-specific connection features that aren’t available in a connection string. Each driver supports its own set of connection properties, often providing ways to customize the connection request specific to the DBMS. For example: c := &gosnowflake.Config{ ~your credentials go here~ } connector := gosnowflake.NewConnector(gosnowflake.SnowflakeDriver{}, *c) db := sql.OpenDB(connector) When Host is a full Snowflake hostname (the host string contains ".snowflakecomputing.", consistent with DSN-based URLs) and Account is left empty, the driver derives Account from the first DNS label of Host while completing configuration (for example, database/sql.Connector Connect invokes FillMissingConfigParameters). If Host does not contain that substring, you must set Account explicitly (for example private-link or custom endpoints). When Account is already non-empty, it is kept as provided. Truncating a dotted account value from DSN query parameters happens inside ParseDSN before FillMissingConfigParameters; that normalization does not apply to every programmatic Config. If you are using this method, you don't need to pass a driver name to specify the driver type in which you are looking to connect. Since the driver name is not needed, you can optionally bypass driver registration on startup. To do this, set `GOSNOWFLAKE_SKIP_REGISTRATION` in your environment. This is useful if you wish to register multiple versions of the driver. Note: `GOSNOWFLAKE_SKIP_REGISTRATION` should not be used if sql.Open() is used as the method to connect to the server, as sql.Open will require registration so it can map the driver name to the driver type, which in this case is "snowflake" and SnowflakeDriver{}. You can load the connection configuration with .toml file format. With two environment variables, `SNOWFLAKE_HOME` (`connections.toml` file directory) and `SNOWFLAKE_DEFAULT_CONNECTION_NAME` (DSN name), the driver will search the config file and load the connection. You can find how to use this connection way at ./cmd/tomlfileconnection or Snowflake doc: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials If the connection.toml file is readable by others, a warning will be logged. To disable it you need to set the environment variable `SF_SKIP_WARNING_FOR_READ_PERMISSIONS_ON_CONFIG_FILE` to true. If you wish to specify a custom transporter (e.g. to provide a custom TLS config to be used with your custom truststore) pass it through the `NewConnector`. Example: tlsConfig := &tls.Config{ // your custom fields here } config := Config{ Transporter: &http.Transport{ TLSClientConfig: tlsConfig, }, } connector := NewConnector(SnowflakeDriver{}, config) db := sql.OpenDB(connector) As an alternative, you can use the `RegisterTLSConfig` / `DeregisterTLSConfig` functions as seen in the unit tests: https://github.com/snowflakedb/gosnowflake/blob/v1.16.0/transport_test.go#L127 # Proxy The Go Snowflake Driver honors the environment variables HTTP_PROXY, HTTPS_PROXY and NO_PROXY for the forward proxy setting. NO_PROXY specifies which hostname endings should be allowed to bypass the proxy server, e.g. no_proxy=.amazonaws.com means that Amazon S3 access does not need to go through the proxy. NO_PROXY does not support wildcards. Each value specified should be one of the following: - The end of a hostname (or a complete hostname), for example: ".amazonaws.com" or "xy12345.snowflakecomputing.com". - An IP address, for example "192.196.1.15". If more than one value is specified, values should be separated by commas, for example: no_proxy=localhost,.my_company.com,xy12345.snowflakecomputing.com,192.168.1.15,192.168.1.16 In addition to environment variables, the Go Snowflake Driver also supports configuring the proxy via connection parameters. When these parameters are provided in the connection string or DSN, they take precedence and any environment proxy settings (HTTP_PROXY, HTTPS_PROXY, NO_PROXY) will be ignored. | Parameter | Description | Default | |-----------------|-----------------------------------------------------------------------------|---------| | `proxyHost` | Hostname or IP address of the proxy server. | | | `proxyPort` | Port number of the proxy server. | | | `proxyUser` | Username for proxy authentication. | | | `proxyPassword` | Password for proxy authentication. | | | `proxyProtocol` | Protocol to use for proxy connection. Valid values: `http`, `https`. | `http` | | `noProxy` | Comma-separated list of hosts that should bypass the proxy. | | For more details, please refer to the example in ./cmd/proxyconnection. # Logging By default, the driver uses a built-in slog-based logger at ERROR level. The driver automatically masks secrets in all log messages to prevent credential leakage. Users can customize logging in two ways: 1. Using a custom slog.Handler (if you want to use slog with custom formatting): import ( "log/slog" "os" sf "github.com/snowflakedb/gosnowflake/v2" ) // Create your custom handler customHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelDebug, }) // Get the default logger and set your handler logger := sf.GetLogger() if sl, ok := logger.(sf.SFSlogLogger); ok { sl.SetHandler(customHandler) } 2. Using a complete custom logger implementation (if you want full control): // Implement the sf.SFLogger interface type MyCustomLogger struct { // your implementation } // Set your custom logger customLogger := &MyCustomLogger{} sf.SetLogger(customLogger) Important notes: - Secret masking is automatically applied to all loggers (both custom and default) - To change log level: logger.SetLogLevel("debug") - To redirect output: logger.SetOutput(writer) - For examples, see log_client_test.go If you want to define S3 client logging, override S3LoggingMode variable using configuration: https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#ClientLogMode Example: import ( sf "github.com/snowflakedb/gosnowflake/v2" "github.com/aws/aws-sdk-go-v2/aws" ) ... sf.S3LoggingMode = aws.LogRequest | aws.LogResponseWithBody | aws.LogRetries # Query tag A custom query tag can be set in the context. Each query run with this context will include the custom query tag as metadata that will appear in the Query Tag column in the Query History log. For example: queryTag := "my custom query tag" ctxWithQueryTag := WithQueryTag(ctx, queryTag) rows, err := db.QueryContext(ctxWithQueryTag, query) # Query request ID A specific query request ID can be set in the context and will be passed through in place of the default randomized request ID. For example: requestID := ParseUUID("6ba7b812-9dad-11d1-80b4-00c04fd430c8") ctxWithID := WithRequestID(ctx, requestID) rows, err := db.QueryContext(ctxWithID, query) # Last query ID If you need query ID for your query you have to use raw connection. For queries: ``` err := conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) rows.(SnowflakeRows).GetQueryID() stmt.(SnowflakeStmt).GetQueryID() return nil } ``` For execs: ``` err := conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") result, err := stmt.(driver.StmtExecContext).ExecContext(ctx, nil) result.(SnowflakeResult).GetQueryID() stmt.(SnowflakeStmt).GetQueryID() return nil } ``` # Fetch Results by Query ID The result of your query can be retrieved by setting the query ID in the WithFetchResultByID context. ``` // Get the query ID using raw connection as mentioned above: err := conn.Raw(func(x any) error { rows1, err = x.(driver.QueryerContext).QueryContext(ctx, "SELECT 1", nil) queryID = rows1.(sf.SnowflakeRows).GetQueryID() return nil } // Update the Context object to specify the query ID fetchResultByIDCtx = sf.WithFetchResultByID(ctx, queryID) // Execute an empty string query rows2, err := db.QueryContext(fetchResultByIDCtx, "") // Retrieve the results as usual for rows2.Next() { err = rows2.Scan(...) ... } ``` # Canceling Query by CtrlC From 0.5.0, a signal handling responsibility has moved to the applications. If you want to cancel a query/command by Ctrl+C, add a os.Interrupt trap in context to execute methods that can take the context parameter (e.g. QueryContext, ExecContext). // handle interrupt signal ctx, cancel := context.WithCancel(context.Background()) c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) defer func() { signal.Stop(c) }() go func() { select { case <-c: cancel() case <-ctx.Done(): } }() ... (connection) // execute a query rows, err := db.QueryContext(ctx, query) ... (Ctrl+C to cancel the query) See cmd/selectmany.go for the full example. # OpenTelemetry headers A context containing OpenTelemetry headers for distributed tracing can be created. Each query, both synchronous and asynchronous, run with this context will include the Trace ID and Span ID as metadata. If you are instrumenting your program with OpenTelemetry and exporting telemetry data to Snowflake, then queries run with this context will be properly nested under the appropriate parent span. This can be viewed in the Traces and Logs tab in Snowsight. For example: ctx, parent_span := tracer.Start(context.Background(), "parent_span") defer parent_span.End() rows, err := db.QueryContext(ctx, query) # Supported Data Types The Go Snowflake Driver now supports the Arrow data format for data transfers between Snowflake and the Golang client. The Arrow data format avoids extra conversions between binary and textual representations of the data. The Arrow data format can improve performance and reduce memory consumption in clients. Snowflake continues to support the JSON data format. The data format is controlled by the session-level parameter GO_QUERY_RESULT_FORMAT. To use JSON format, execute: ALTER SESSION SET GO_QUERY_RESULT_FORMAT = 'JSON'; The valid values for the parameter are: - ARROW (default) - JSON If the user attempts to set the parameter to an invalid value, an error is returned. The parameter name and the parameter value are case-insensitive. This parameter can be set only at the session level. Usage notes: - The Arrow data format reduces rounding errors in floating point numbers. You might see slightly different values for floating point numbers when using Arrow format than when using JSON format. In order to take advantage of the increased precision, you must pass in the context.Context object provided by the WithHigherPrecision function when querying. - Traditionally, the rows.Scan() method returned a string when a variable of types interface was passed in. Turning on the flag ENABLE_HIGHER_PRECISION via WithHigherPrecision will return the natural, expected data type as well. - For some numeric data types, the driver can retrieve larger values when using the Arrow format than when using the JSON format. For example, using Arrow format allows the full range of SQL NUMERIC(38,0) values to be retrieved, while using JSON format allows only values in the range supported by the Golang int64 data type. Users should ensure that Golang variables are declared using the appropriate data type for the full range of values contained in the column. For an example, see below. When using the Arrow format, the driver supports more Golang data types and more ways to convert SQL values to those Golang data types. The table below lists the supported Snowflake SQL data types and the corresponding Golang data types. The columns are: 1. The SQL data type. 2. The default Golang data type that is returned when you use snowflakeRows.Scan() to read data from Arrow data format via an interface{}. 3. The possible Golang data types that can be returned when you use snowflakeRows.Scan() to read data from Arrow data format directly. 4. The default Golang data type that is returned when you use snowflakeRows.Scan() to read data from JSON data format via an interface{}. (All returned values are strings.) 5. The standard Golang data type that is returned when you use snowflakeRows.Scan() to read data from JSON data format directly. Go Data Types for Scan() =================================================================================================================== | ARROW | JSON =================================================================================================================== SQL Data Type | Default Go Data Type | Supported Go Data | Default Go Data Type | Supported Go Data | for Scan() interface{} | Types for Scan() | for Scan() interface{} | Types for Scan() =================================================================================================================== BOOLEAN | bool | string | bool ------------------------------------------------------------------------------------------------------------------- VARCHAR | string | string ------------------------------------------------------------------------------------------------------------------- DOUBLE | float32, float64 [1] , [2] | string | float32, float64 ------------------------------------------------------------------------------------------------------------------- INTEGER that | int, int8, int16, int32, int64 | string | int, int8, int16, fits in int64 | [1] , [2] | | int32, int64 ------------------------------------------------------------------------------------------------------------------- INTEGER that doesn't | int, int8, int16, int32, int64, *big.Int | string | error fit in int64 | [1] , [2] , [3] , [4] | ------------------------------------------------------------------------------------------------------------------- NUMBER(P, S) | float32, float64, *big.Float | string | float32, float64 where S > 0 | [1] , [2] , [3] , [5] | ------------------------------------------------------------------------------------------------------------------- DATE | time.Time | string | time.Time ------------------------------------------------------------------------------------------------------------------- TIME | time.Time | string | time.Time ------------------------------------------------------------------------------------------------------------------- TIMESTAMP_LTZ | time.Time | string | time.Time ------------------------------------------------------------------------------------------------------------------- TIMESTAMP_NTZ | time.Time | string | time.Time ------------------------------------------------------------------------------------------------------------------- TIMESTAMP_TZ | time.Time | string | time.Time ------------------------------------------------------------------------------------------------------------------- BINARY | []byte | string | []byte ------------------------------------------------------------------------------------------------------------------- ARRAY [6] | string / array | string / array ------------------------------------------------------------------------------------------------------------------- OBJECT [6] | string / struct | string / struct ------------------------------------------------------------------------------------------------------------------- VARIANT | string | string ------------------------------------------------------------------------------------------------------------------- MAP | map | map [1] Converting from a higher precision data type to a lower precision data type via the snowflakeRows.Scan() method can lose low bits (lose precision), lose high bits (completely change the value), or result in error. [2] Attempting to convert from a higher precision data type to a lower precision data type via interface{} causes an error. [3] Higher precision data types like *big.Int and *big.Float can be accessed by querying with a context returned by WithHigherPrecision(). [4] You cannot directly Scan() into the alternative data types via snowflakeRows.Scan(), but can convert to those data types by using .Int64()/.String()/.Uint64() methods. For an example, see below. [5] You cannot directly Scan() into the alternative data types via snowflakeRows.Scan(), but can convert to those data types by using .Float32()/.String()/.Float64() methods. For an example, see below. [6] Arrays and objects can be either semistructured or structured, see more info in section below. Note: SQL NULL values are converted to Golang nil values, and vice-versa. # Semistructured and structured types Snowflake supports two flavours of "structured data" - semistructured and structured. Semistructured types are variants, objects and arrays without schema. When data is fetched, it's represented as strings and the client is responsible for its interpretation. Example table definition: CREATE TABLE semistructured (v VARIANT, o OBJECT, a ARRAY) The data not have any corresponding schema, so values in table may be slightly different. Semistuctured variants, objects and arrays are always represented as strings for scanning: rows, err := db.Query("SELECT {'a': 'b'}::OBJECT") // handle error defer rows.Close() rows.Next() var v string err := rows.Scan(&v) When inserting, a marker indicating correct type must be used, for example: db.Exec("CREATE TABLE test_object_binding (obj OBJECT)") db.Exec("INSERT INTO test_object_binding SELECT (?)", DataTypeObject, "{'s': 'some string'}") Structured types differentiate from semistructured types by having specific schema. In all rows of the table, values must conform to this schema. Example table definition: CREATE TABLE structured (o OBJECT(s VARCHAR, i INTEGER), a ARRAY(INTEGER), m MAP(VARCHAR, BOOLEAN)) To retrieve structured objects, follow these steps: 1. Create a struct implementing sql.Scanner interface, example: a) type simpleObject struct { s string i int32 } func (so *simpleObject) Scan(val any) error { st := val.(StructuredObject) var err error if so.s, err = st.GetString("s"); err != nil { return err } if so.i, err = st.GetInt32("i"); err != nil { return err } return nil } b) type simpleObject struct { S string `sf:"otherName"` I int32 `sf:"i,ignore"` } func (so *simpleObject) Scan(val any) error { st := val.(StructuredObject) return st.ScanTo(so) } Automatic scan goes through all fields in a struct and read object fields. Struct fields have to be public. Embedded structs have to be pointers. Matching name is built using struct field name with first letter lowercase. Additionally, `sf` tag can be added: - first value is always a name of a field in an SQL object - additionally `ignore` parameter can be passed to omit this field 2. Use WithStructuredTypesEnabled context while querying data. 3. Use it in regular scan: var res simpleObject err := rows.Scan(&res) See StructuredObject for all available operations including null support, embedding nested structs, etc. Retrieving array of simple types works exactly the same like normal values - using Scan function. You can use WithEmbeddedValuesNullable context to handle null values in maps and arrays of simple types in the database. In that case, sql null types will be used: ctx := WithEmbeddedValuesNullable(WithStructuredTypesEnabled(context.Background())) ... var res []sql.NullBool err := rows.Scan(&res) If you want to scan array of structs, you have to use a helper function ScanArrayOfScanners: var res []*simpleObject err := rows.Scan(ScanArrayOfScanners(&res)) Retrieving structured maps is very similar to retrieving arrays: var res map[string]*simpleObject err := rows.Scan(ScanMapOfScanners(&res)) To bind structured objects use: 1. Create a type which implements a StructuredObjectWriter interface, example: a) type simpleObject struct { s string i int32 } func (so *simpleObject) Write(sowc StructuredObjectWriterContext) error { if err := sowc.WriteString("s", so.s); err != nil { return err } if err := sowc.WriteInt32("i", so.i); err != nil { return err } return nil } b) type simpleObject struct { S string `sf:"otherName"` I int32 `sf:"i,ignore"` } func (so *simpleObject) Write(sowc StructuredObjectWriterContext) error { return sowc.WriteAll(so) } 2. Use an instance as regular bind. 3. If you need to bind nil value, use special syntax: db.Exec('INSERT INTO some_table VALUES ?', sf.DataTypeNilObject, reflect.TypeOf(simpleObject{}) Binding structured arrays are like any other parameter. The only difference is - if you want to insert empty array (not nil but empty), you have to use: db.Exec('INSERT INTO some_table VALUES ?', sf.DataTypeEmptyArray, reflect.TypeOf(simpleObject{})) # Using higher precision numbers The following example shows how to retrieve very large values using the math/big package. This example retrieves a large INTEGER value to an interface and then extracts a big.Int value from that interface. If the value fits into an int64, then the code also copies the value to a variable of type int64. Note that a context that enables higher precision must be passed in with the query. import "context" import "math/big" ... var my_interface interface{} var my_big_int_pointer *big.Int var my_int64 int64 var rows snowflakeRows ... rows = db.QueryContext(WithHigherPrecision(context.Background), ) rows.Scan(&my_interface) my_big_int_pointer, ok = my_interface.(*big.Int) if my_big_int_pointer.IsInt64() { my_int64 = my_big_int_pointer.Int64() } If the variable named "rows" is known to contain a big.Int, then you can use the following instead of scanning into an interface and then converting to a big.Int: rows.Scan(&my_big_int_pointer) If the variable named "rows" contains a big.Int, then each of the following fails: rows.Scan(&my_int64) my_int64, _ = my_interface.(int64) Similar code and rules also apply to big.Float values. If you are not sure what data type will be returned, you can use code similar to the following to check the data type of the returned value: // Create variables into which you can scan the returned values. var i64 int64 var bigIntPtr *big.Int for rows.Next() { // Get the data type info. column_types, err := rows.ColumnTypes() if err != nil { log.Fatalf("ERROR: ColumnTypes() failed. err: %v", err) } // The data type of the zeroeth column in the row. column_type := column_types[0].ScanType() // Choose the appropriate variable based on the data type. switch column_type { case reflect.TypeOf(i64): err = rows.Scan(&i64) fmt.Println("INFO: retrieved int64 value:") fmt.Println(i64) case reflect.TypeOf(bigIntPtr): err = rows.Scan(&bigIntPtr) fmt.Println("INFO: retrieved bigIntPtr value:") fmt.Println(bigIntPtr) } } # Using decfloats By default, DECFLOAT values are returned as string values. If you want to retrieve them as numbers, you have to use the WithDecfloatMappingEnabled context. If higher precision is enabled, the driver will return them as *big.Float values. Otherwise, they will be returned as float64 values. Keep in mind that both float64 and *big.Float are not able to precisely represent some DECFLOAT values. If precision is important, you have to use string representation and use your own library to parse it. # Arrow batches You can retrieve data in a columnar format similar to the format a server returns, without transposing them to rows. Arrow Batches mode is available through the separate `arrowbatches` sub-package (`github.com/snowflakedb/gosnowflake/v2/arrowbatches`). This sub-package provides access to Arrow columnar data using ArrowBatch structs, which correspond to data chunks received from the backend. They allow for access to specific arrow.Record structs. The arrow-compute dependency (which significantly increases binary size) is only pulled in when you import the arrowbatches sub-package. If you don't need Arrow batch support, simply don't import it. An ArrowBatch can exist in a state where the underlying data has not yet been loaded. The data is downloaded and translated only on demand. Translation options are retrieved from a context.Context interface, which is either passed from query context or set by the user using WithContext(ctx) method. In order to access them you must use `arrowbatches.WithArrowBatches` context, similar to the following: var rows driver.Rows err = conn.Raw(func(x interface{}) error { rows, err = x.(driver.QueryerContext).QueryContext(ctx, query, nil) return err }) ... batches, err := arrowbatches.GetArrowBatches(rows.(sf.SnowflakeRows)) ... // use Arrow records This returns []*arrowbatches.ArrowBatch. ArrowBatch functions: GetRowCount(): Returns the number of rows in the ArrowBatch. Note that this returns 0 if the data has not yet been loaded, irrespective of it’s actual size. WithContext(ctx context.Context): Sets the context of the ArrowBatch to the one provided. Note that the context will not retroactively apply to data that has already been downloaded. For example: records1, _ := batch.Fetch() records2, _ := batch.WithContext(ctx).Fetch() will produce the same result in records1 and records2, irrespective of the newly provided ctx. Context worth noting are: -arrowbatches.WithTimestampOption -WithHigherPrecision -arrowbatches.WithUtf8Validation described in more detail later. Fetch(): Returns the underlying records as *[]arrow.Record. When this function is called, the ArrowBatch checks whether the underlying data has already been loaded, and downloads it if not. Limitations: 1. For some queries Snowflake may decide to return data in JSON format (examples: `SHOW PARAMETERS` or `ls @stage`). You cannot use JSON with Arrow batches context. See alternative below. 2. Snowflake handles timestamps in a range which is broader than available space in Arrow timestamp type. Because of that special treatment should be used (see below). 3. When using numbers, Snowflake chooses the smallest type that covers all values in a batch. So even when your column is NUMBER(38, 0), if all values are 8bits, array.Int8 is used. How to handle timestamps in Arrow batches: Snowflake returns timestamps natively (from backend to driver) in multiple formats. The Arrow timestamp is an 8-byte data type, which is insufficient to handle the larger date and time ranges used by Snowflake. Also, Snowflake supports 0-9 (nanosecond) digit precision for seconds, while Arrow supports only 3 (millisecond), 6 (microsecond), an 9 (nanosecond) precision. Consequently, Snowflake uses a custom timestamp format in Arrow, which differs on timestamp type and precision. If you want to use timestamps in Arrow batches, you have two options: 1. The Go driver can reduce timestamp struct into simple Arrow Timestamp, if you set `arrowbatches.WithTimestampOption` to nanosecond, microsecond, millisecond or second. For nanosecond, some timestamp values might not fit into Arrow timestamp. E.g after year 2262 or before 1677. 2. You can use native Snowflake values. In that case you will receive complex structs as described above. To transform Snowflake values into the Golang time.Time struct you can use `ArrowSnowflakeTimestampToTime`. To enable this feature, you must use `arrowbatches.WithTimestampOption` context with value set to`UseOriginalTimestamp`. How to handle invalid UTF-8 characters in Arrow batches: Snowflake previously allowed users to upload data with invalid UTF-8 characters. Consequently, Arrow records containing string columns in Snowflake could include these invalid UTF-8 characters. However, according to the Arrow specifications (https://arrow.apache.org/docs/cpp/api/datatype.html and https://github.com/apache/arrow/blob/a03d957b5b8d0425f9d5b6c98b6ee1efa56a1248/go/arrow/datatype.go#L73-L74), Arrow string columns should only contain UTF-8 characters. To address this issue and prevent potential downstream disruptions, the context arrowbatches.WithUtf8Validation is introduced. When enabled, this feature iterates through all values in string columns, identifying and replacing any invalid characters with `�`. This ensures that Arrow records conform to the UTF-8 standards, preventing validation failures in downstream services like the Rust Arrow library that impose strict validation checks. How to handle higher precision in Arrow batches: To preserve BigDecimal values within Arrow batches, use WithHigherPrecision. This offers two main benefits: it helps avoid precision loss and defers the conversion to upstream services. Alternatively, without this setting, all non-zero scale numbers will be converted to float64, potentially resulting in loss of precision. Zero-scale numbers (DECIMAL256, DECIMAL128) will be converted to int64, which could lead to overflow. WHen using NUMBERs with non zero scale, the value is returned as an integer type and a scale is provided in record metadata. Example. When we have a 123.45 value that comes from NUMBER(9, 4), it will be represented as 1234500 with scale equal to 4. It is a client responsibility to interpret it correctly. Also - see limitations section above. How to handle JSON responses in Arrow batches: Due to technical limitations Snowflake backend may return JSON even if client expects Arrow. In that case Arrow batches are not available and the error with code ErrNonArrowResponseInArrowBatches is returned. The response is parsed to regular rows. You can read rows in a way described in transform_batches_to_rows.go example. This has a very strong limitation though - this is a very low level API (Go driver API), so there are no conversions ready. All values are returned as strings. Alternative approach is to rerun a query, but without enabling Arrow batches and use a general Go SQL API instead of driver API. It can be optimized by using `WithRequestID`, so backend returns results from cache. # Binding Parameters Binding allows a SQL statement to use a value that is stored in a Golang variable. Without binding, a SQL statement specifies values by specifying literals inside the statement. For example, the following statement uses the literal value “42“ in an UPDATE statement: _, err = db.Exec("UPDATE table1 SET integer_column = 42 WHERE ID = 1000") With binding, you can execute a SQL statement that uses a value that is inside a variable. For example: var my_integer_variable int = 42 _, err = db.Exec("UPDATE table1 SET integer_column = ? WHERE ID = 1000", my_integer_variable) The “?“ inside the “VALUES“ clause specifies that the SQL statement uses the value from a variable. Binding data that involves time zones can require special handling. For details, see the section titled "Timestamps with Time Zones". Version 1.6.23 (and later) of the driver takes advantage of sql.Null types which enables the proper handling of null parameters inside function calls, i.e.: rows, err := db.Query("SELECT * FROM TABLE(SOMEFUNCTION(?))", sql.NullBool{}) The timestamp nullability had to be achieved by wrapping the sql.NullTime type as the Snowflake provides several date and time types which are mapped to single Go time.Time type: rows, err := db.Query("SELECT * FROM TABLE(SOMEFUNCTION(?))", sf.TypedNullTime{sql.NullTime{}, sf.TimestampLTZType}) # Binding Parameters to Array Variables Version 1.3.9 (and later) of the Go Snowflake Driver supports the ability to bind an array variable to a parameter in a SQL INSERT statement. You can use this technique to insert multiple rows in a single batch. As an example, the following code inserts rows into a table that contains integer, float, boolean, and string columns. The example binds arrays to the parameters in the INSERT statement. // Create a table containing an integer, float, boolean, and string column. _, err = db.Exec("create or replace table my_table(c1 int, c2 float, c3 boolean, c4 string)") ... // Define the arrays containing the data to insert. intArray := []int{1, 2, 3} fltArray := []float64{0.1, 2.34, 5.678} boolArray := []bool{true, false, true} strArray := []string{"test1", "test2", "test3"} ... // Insert the data from the arrays and wrap in an Array() function into the table. intArr, err := Array(&intArray) fltArr, err := Array(&fltArray) boolArr, err := Array(&boolArray) strArr, err := Array(&strArray) _, err = db.Exec("insert into my_table values (?, ?, ?, ?)", intArr, fltArr, boolArr, strArr) If the array contains SQL NULL values, use slice []interface{}, which allows Golang nil values. This feature is available in version 1.6.12 (and later) of the driver. For example, // Define the arrays containing the data to insert. strArray := make([]interface{}, 3) strArray[0] = "test1" strArray[1] = "test2" strArray[2] = nil // This line is optional as nil is the default value. ... // Create a table and insert the data from the array as shown above. strArr, err := Array(&strArray) _, err = db.Exec("create or replace table my_table(c1 string)") _, err = db.Exec("insert into my_table values (?)", strArr) ... // Use sql.NullString to fetch the string column that contains NULL values. var s sql.NullString rows, _ := db.Query("select * from my_table") for rows.Next() { err := rows.Scan(&s) if err != nil { log.Fatalf("Failed to scan. err: %v", err) } if s.Valid { fmt.Println("Retrieved value:", s.String) } else { fmt.Println("Retrieved value: NULL") } } For slices []interface{} containing time.Time values, a binding parameter flag is required for the preceding array variable in the Array() function. This feature is available in version 1.6.13 (and later) of the driver. For example, ntzArr, err := Array(&ntzArray, sf.TimestampNTZType) ltzArr, err := Array(<zArray, sf.TimestampLTZType) _, err = db.Exec("create or replace table my_table(c1 timestamp_ntz, c2 timestamp_ltz)") _, err = db.Exec("insert into my_table values (?,?)", ntzArr, ltzArr) Note: For alternative ways to load data into the Snowflake database (including bulk loading using the COPY command), see Loading Data into Snowflake (https://docs.snowflake.com/en/user-guide-data-load.html). # Batch Inserts and Binding Parameters When you use array binding to insert a large number of values, the driver can improve performance by streaming the data (without creating files on the local machine) to a temporary stage for ingestion. The driver automatically does this when the number of values exceeds a threshold (no changes are needed to user code). In order for the driver to send the data to a temporary stage, the user must have the following privilege on the schema: CREATE STAGE If the user does not have this privilege, the driver falls back to sending the data with the query to the Snowflake database. In addition, the current database and schema for the session must be set. If these are not set, the CREATE TEMPORARY STAGE command executed by the driver can fail with the following error: CREATE TEMPORARY STAGE SYSTEM$BIND file_format=(type=csv field_optionally_enclosed_by='"') Cannot perform CREATE STAGE. This session does not have a current schema. Call 'USE SCHEMA', or use a qualified name. For alternative ways to load data into the Snowflake database (including bulk loading using the COPY command), see Loading Data into Snowflake (https://docs.snowflake.com/en/user-guide-data-load.html). # Binding a Parameter to a Time Type Go's database/sql package supports the ability to bind a parameter in a SQL statement to a time.Time variable. However, when the client binds data to send to the server, the driver cannot determine the correct Snowflake date/timestamp data type to associate with the binding parameter. For example: dbt.mustExec("CREATE OR REPLACE TABLE tztest (id int, ntz, timestamp_ntz, ltz timestamp_ltz)") // ... stmt, err :=dbt.db.Prepare("INSERT INTO tztest(id,ntz,ltz) VALUES(1, ?, ?)") // ... tmValue time.Now() // ... Is tmValue a TIMESTAMP_NTZ or TIMESTAMP_LTZ? _, err = stmt.Exec(tmValue, tmValue) To resolve this issue, a binding parameter flag is introduced that associates any subsequent time.Time type to the DATE, TIME, TIMESTAMP_LTZ, TIMESTAMP_NTZ or BINARY data type. The above example could be rewritten as follows: import ( sf "github.com/snowflakedb/gosnowflake/v2" ) dbt.mustExec("CREATE OR REPLACE TABLE tztest (id int, ntz, timestamp_ntz, ltz timestamp_ltz)") // ... stmt, err :=dbt.db.Prepare("INSERT INTO tztest(id,ntz,ltz) VALUES(1, ?, ?)") // ... tmValue time.Now() // ... _, err = stmt.Exec(sf.DataTypeTimestampNtz, tmValue, sf.DataTypeTimestampLtz, tmValue) # Timestamps with Time Zones The driver fetches TIMESTAMP_TZ (timestamp with time zone) data using the offset-based Location types, which represent a collection of time offsets in use in a geographical area, such as CET (Central European Time) or UTC (Coordinated Universal Time). The offset-based Location data is generated and cached when a Go Snowflake Driver application starts, and if the given offset is not in the cache, it is generated dynamically. Currently, Snowflake does not support the name-based Location types (e.g. "America/Los_Angeles"). For more information about Location types, see the Go documentation for https://golang.org/pkg/time/#Location. # Binary Data Internally, this feature leverages the []byte data type. As a result, BINARY data cannot be bound without the binding parameter flag. In the following example, sf is an alias for the gosnowflake package: var b = []byte{0x01, 0x02, 0x03} _, err = stmt.Exec(sf.DataTypeBinary, b) # JWT authentication The Go Snowflake Driver supports JWT (JSON Web Token) authentication. To enable this feature, construct the DSN with fields "authenticator=SNOWFLAKE_JWT&privateKey=", or using a Config structure specifying: config := &Config{ ... Authenticator: AuthTypeJwt, PrivateKey: "", } The should be a base64 URL encoded PKCS8 rsa private key string. One way to encode a byte slice to URL base 64 URL format is through the base64.URLEncoding.EncodeToString() function. On the server side, you can alter the public key with the SQL command: ALTER USER SET RSA_PUBLIC_KEY=''; The should be a base64 Standard encoded PKI public key string. One way to encode a byte slice to base 64 Standard format is through the base64.StdEncoding.EncodeToString() function. To generate the valid key pair, you can execute the following commands in the shell: # generate 2048-bit pkcs8 encoded RSA private key openssl genpkey -algorithm RSA \ -pkeyopt rsa_keygen_bits:2048 \ -pkeyopt rsa_keygen_pubexp:65537 | \ openssl pkcs8 -topk8 -outform der > rsa-2048-private-key.p8 # extract 2048-bit PKI encoded RSA public key from the private key openssl pkey -pubout -inform der -outform der \ -in rsa-2048-private-key.p8 \ -out rsa-2048-public-key.spki Note: As of February 2020, Golang's official library does not support passcode-encrypted PKCS8 private key. For security purposes, Snowflake highly recommends that you store the passcode-encrypted private key on the disk and decrypt the key in your application using a library you trust. JWT tokens are recreated on each retry and they are valid (`exp` claim) for `jwtTimeout` seconds. Each retry timeout is configured by `jwtClientTimeout`. Retries are limited by total time of `loginTimeout`. # External browser authentication The driver allows to authenticate using the external browser. When a connection is created, the driver will open the browser window and ask the user to sign in. To enable this feature, construct the DSN with field "authenticator=EXTERNALBROWSER" or using a Config structure with following Authenticator specified: config := &Config{ ... Authenticator: AuthTypeExternalBrowser, } The external browser authentication implements timeout mechanism. This prevents the driver from hanging interminably when browser window was closed, or not responding. Timeout defaults to 120s and can be changed through setting DSN field "externalBrowserTimeout=240" (time in seconds) or using a Config structure with following ExternalBrowserTimeout specified: config := &Config{ ExternalBrowserTimeout: 240 * time.Second, // Requires time.Duration } # Executing Multiple Statements in One Call This feature is available in version 1.3.8 or later of the driver. By default, Snowflake returns an error for queries issued with multiple statements. This restriction helps protect against SQL Injection attacks (https://en.wikipedia.org/wiki/SQL_injection). The multi-statement feature allows users skip this restriction and execute multiple SQL statements through a single Golang function call. However, this opens up the possibility for SQL injection, so it should be used carefully. The risk can be reduced by specifying the exact number of statements to be executed, which makes it more difficult to inject a statement by appending it. More details are below. The Go Snowflake Driver provides two functions that can execute multiple SQL statements in a single call: - db.QueryContext(): This function is used to execute queries, such as SELECT statements, that return a result set. - db.ExecContext(): This function is used to execute statements that don't return a result set (i.e. most DML and DDL statements). To compose a multi-statement query, simply create a string that contains all the queries, separated by semicolons, in the order in which the statements should be executed. To protect against SQL Injection attacks while using the multi-statement feature, pass a Context that specifies the number of statements in the string. For example: import ( "context" "database/sql" ) var multiStatementQuery = "SELECT c1 FROM t1; SELECT c2 FROM t2" var number_of_statements = 2 ctx := WithMultiStatement(context.Background(), number_of_statements) rows, err := db.QueryContext(ctx, multiStatementQuery) When multiple queries are executed by a single call to QueryContext(), multiple result sets are returned. After you process the first result set, get the next result set (for the next SQL statement) by calling NextResultSet(). The following pseudo-code shows how to process multiple result sets: Execute the statement and get the result set(s): rows, err := db.QueryContext(ctx, multiStmtQuery) Retrieve the rows in the first query's result set: while rows.Next() { err = rows.Scan(&variable_1) if err != nil { t.Errorf("failed to scan: %#v", err) } ... } Retrieve the remaining result sets and the rows in them: while rows.NextResultSet() { while rows.Next() { ... } } The function db.ExecContext() returns a single result, which is the sum of the number of rows changed by each individual statement. For example, if your multi-statement query executed two UPDATE statements, each of which updated 10 rows, then the result returned would be 20. Individual row counts for individual statements are not available. The following code shows how to retrieve the result of a multi-statement query executed through db.ExecContext(): Execute the SQL statements: res, err := db.ExecContext(ctx, multiStmtQuery) Get the summed result and store it in the variable named count: count, err := res.RowsAffected() Note: Because a multi-statement ExecContext() returns a single value, you cannot detect offsetting errors. For example, suppose you expected the return value to be 20 because you expected each UPDATE statement to update 10 rows. If one UPDATE statement updated 15 rows and the other UPDATE statement updated only 5 rows, the total would still be 20. You would see no indication that the UPDATES had not functioned as expected. The ExecContext() function does not return an error if passed a query (e.g. a SELECT statement). However, it still returns only a single value, not a result set, so using it to execute queries (or a mix of queries and non-query statements) is impractical. The QueryContext() function does not return an error if passed non-query statements (e.g. DML). The function returns a result set for each statement, whether or not the statement is a query. For each non-query statement, the result set contains a single row that contains a single column; the value is the number of rows changed by the statement. If you want to execute a mix of query and non-query statements (e.g. a mix of SELECT and DML statements) in a multi-statement query, use QueryContext(). You can retrieve the result sets for the queries, and you can retrieve or ignore the row counts for the non-query statements. Note: PUT statements are not supported for multi-statement queries. If a SQL statement passed to ExecQuery() or QueryContext() fails to compile or execute, that statement is aborted, and subsequent statements are not executed. Any statements prior to the aborted statement are unaffected. For example, if the statements below are run as one multi-statement query, the multi-statement query fails on the third statement, and an exception is thrown. CREATE OR REPLACE TABLE test(n int); INSERT INTO TEST VALUES (1), (2); INSERT INTO TEST VALUES ('not_an_integer'); -- execution fails here INSERT INTO TEST VALUES (3); If you then query the contents of the table named "test", the values 1 and 2 would be present. When using the QueryContext() and ExecContext() functions, golang code can check for errors the usual way. For example: rows, err := db.QueryContext(ctx, multiStmtQuery) if err != nil { Fatalf("failed to query multiple statements: %v", err) } Preparing statements and using bind variables are also not supported for multi-statement queries. # Asynchronous Queries The Go Snowflake Driver supports asynchronous execution of SQL statements. Asynchronous execution allows you to start executing a statement and then retrieve the result later without being blocked while waiting. While waiting for the result of a SQL statement, you can perform other tasks, including executing other SQL statements. Most of the steps to execute an asynchronous query are the same as the steps to execute a synchronous query. However, there is an additional step, which is that you must call the WithAsyncMode() function to update your Context object to specify that asynchronous mode is enabled. In the code below, the call to "WithAsyncMode()" is specific to asynchronous mode. The rest of the code is compatible with both asynchronous mode and synchronous mode. ... // Update your Context object to specify asynchronous mode: ctx := WithAsyncMode(context.Background()) // Execute your query as usual by calling: rows, _ := db.QueryContext(ctx, query_string) // Retrieve the results as usual by calling: for rows.Next() { err := rows.Scan(...) ... } The function db.QueryContext() returns an object of type snowflakeRows regardless of whether the query is synchronous or asynchronous. However: - If the query is synchronous, then db.QueryContext() does not return until the query has finished and the result set has been loaded into the snowflakeRows object. - If the query is asynchronous, then db.QueryContext() returns a potentially incomplete snowflakeRows object that is filled in later in the background. The call to the Next() function of snowflakeRows is always synchronous (i.e. blocking). If the query has not yet completed and the snowflakeRows object (named "rows" in this example) has not been filled in yet, then rows.Next() waits until the result set has been filled in. More generally, calls to any Golang SQL API function implemented in snowflakeRows or snowflakeResult are blocking calls, and wait if results are not yet available. (Examples of other synchronous calls include: snowflakeRows.Err(), snowflakeRows.Columns(), snowflakeRows.columnTypes(), snowflakeRows.Scan(), and snowflakeResult.RowsAffected().) Because the example code above executes only one query and no other activity, there is no significant difference in behavior between asynchronous and synchronous behavior. The differences become significant if, for example, you want to perform some other activity after the query starts and before it completes. The example code below starts a query, which run in the background, and then retrieves the results later. This example uses small SELECT statements that do not retrieve enough data to require asynchronous handling. However, the technique works for larger data sets, and for situations where the programmer might want to do other work after starting the queries and before retrieving the results. For a more elaborative example please see cmd/async/async.go package gosnowflake import ( "context" "database/sql" "database/sql/driver" "fmt" "log" "os" sf "github.com/snowflakedb/gosnowflake/v2" ) ... func DemonstrateAsyncMode(db *sql.DB) { // Enable asynchronous mode ctx := sf.WithAsyncMode(context.Background()) // Run the query with asynchronous context rows, err := db.QueryContext(ctx, "select 1") if err != nil { // handle error } // do something as the workflow continues whereas the query is computing in the background ... // Get the data when you are ready to handle it var val int err = rows.Scan(&val) if err != nil { // handle error } ... } ==> Some considerations related to the ServerSessionKeepAlive configuration option in context of asynchronous query execution When SQL Go connection is being closed, it performs the following actions: * stops the scheduled heartbeats (CLIENT_SESSION_KEEP_ALIVE), if it was enabled * cleans up all the http connections which are already idle - doesn't touch the ones which are in active use currently * if Config.ServerSessionKeepAlive is false (default), then actively logs out the current Snowflake session. !! Caveat: If there are any queries which are currently executing in the same Snowflake session (e.g. async queries sent with WithAsyncMode()), then those queries are automatically cancelled from the client side a couple minutes later after the Close() call, as a Snowflake session which has been actively logged out from, cannot sustain any queries. You can govern this behaviour with setting Config.ServerSessionKeepAlive to true; when the corresponding Snowflake session will be kept alive for a long time (determined by the Snowflake engine) even after an explicit Connection.Close() call past the time when the last running query in the session finished executing. The behaviour is also dependent on ABORT_DETACHED_QUERY parameter, please see the detailed explanation in the parameter description at https://docs.snowflake.com/en/sql-reference/parameters#abort-detached-query. As a consequence, best practice would be to isolate all long-running async tasks (especially ones supposed to be continued after the connection is closed) into a separate connection. # Support For PUT and GET The Go Snowflake Driver supports the PUT and GET commands. The PUT command copies a file from a local computer (the computer where the Golang client is running) to a stage on the cloud platform. The GET command copies data files from a stage on the cloud platform to a local computer. See the following for information on the syntax and supported parameters: - PUT: https://docs.snowflake.com/en/sql-reference/sql/put.html - GET: https://docs.snowflake.com/en/sql-reference/sql/get.html Using PUT: The following example shows how to run a PUT command by passing a string to the db.Query() function: db.Query("PUT file:// ") "" should include the file path as well as the name. Snowflake recommends using an absolute path rather than a relative path. For example: db.Query("PUT file:///tmp/my_data_file @~ auto_compress=false overwrite=false") Different client platforms (e.g. linux, Windows) have different path name conventions. Ensure that you specify path names appropriately. This is particularly important on Windows, which uses the backslash character as both an escape character and as a separator in path names. To send information from a stream (rather than a file) use code similar to the code below. (The ReplaceAll() function is needed on Windows to handle backslashes in the path to the file.) fileStream, _ := os.Open(fname) defer func() { if fileStream != nil { fileStream.Close() } } () sql := "put 'file://%v' @%%%v auto_compress=true parallel=30" sqlText := fmt.Sprintf(sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) dbt.mustExecContext(WithFilePutStream(context.Background(), fileStream), sqlText) Note: PUT statements are not supported for multi-statement queries. Using GET: The following example shows how to run a GET command by passing a string to the db.Query() function: db.Query("GET file:// ") "" should include the file path as well as the name. Snowflake recommends using an absolute path rather than a relative path. For example: db.Query("GET @~ file:///tmp/my_data_file auto_compress=false overwrite=false") To download a file into an in-memory stream (rather than a file) use code similar to the code below. var streamBuf bytes.Buffer ctx := WithFileGetStream(context.Background(), &streamBuf) sql := "get @~/data1.txt.gz file:///tmp/testData" dbt.mustExecContext(ctx, sql) // streamBuf is now filled with the stream. Use bytes.NewReader(streamBuf.Bytes()) to read uncompressed stream or // use gzip.NewReader(&streamBuf) for to read compressed stream. Note: GET statements are not supported for multi-statement queries. Specifying temporary directory for encryption and compression: Putting and getting requires compression and/or encryption, which is done in the OS temporary directory. If you cannot use default temporary directory for your OS or you want to specify it yourself, you can use "tmpDirPath" DSN parameter. Remember, to encode slashes. Example: u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&tmpDirPath=%2Fother%2Ftmp Using custom configuration for PUT/GET: If you want to override some default configuration options, you can use `WithFileTransferOptions` context. There are multiple config parameters including progress bars or compression. # Minicore (Native Library) The Go Snowflake Driver includes an embedded native library called "minicore" that verifies loading of native Rust extensions on various platforms. By default, minicore is enabled and loaded dynamically at runtime. ## Disabling Minicore There are two ways to disable minicore: 1. **At runtime using environment variable:** Set the SF_DISABLE_MINICORE environment variable to "true" to disable minicore loading: export SF_DISABLE_MINICORE=true This is useful when you want to disable minicore for a specific run without recompiling. 2. **At compile time using build tags:** Build with the -tags minicore_disabled flag to completely exclude minicore from the binary: go build -tags minicore_disabled ./... This is required for static linking (e.g., CGO_ENABLED=0) because minicore relies on dynamic library loading (dlopen) which is incompatible with static binaries. Benefits of compile-time disable: - Smaller binary size (no embedded native libraries) - No CGO dependency for POSIX systems - Compatible with static linking Example for fully static build: CGO_ENABLED=0 go build -tags minicore_disabled ./... ## Static Linking On Linux, if the binary is fully statically linked (e.g., built with -linkmode external -extldflags '-static'), the driver automatically detects this and skips minicore loading. Calling dlopen from a statically linked glibc binary would crash with SIGFPE, so the driver inspects the ELF header for a dynamic linker (PT_INTERP) and gracefully skips minicore if none is found. When minicore is disabled (either at runtime, at compile time, or automatically due to static linking), the driver continues to work normally but without the additional functionality provided by the native library. # FIPS forcing If you force FIPS mode using fips140 GODEBUG option, driver will switch OCSP requests from SHA-1 to SHA-256. Be aware, that Snowflake cache server doesn't support OCSP requests signed with SHA-256, so driver may work slower, and, in case of OCSP cache server unavailability, OCSP requests will fail, and if OCSP is enabled, then connection attempts will fail as well. # Connectivity diagnostics ==> Relevant configuration - `ConnectionDiagnosticsEnabled` (default: false) - main flag to enable the diagnostics - `ConnectionDiagnosticsAllowlistFile` - specify `/path/to/allowlist.json` to use a specific allowlist file which the driver should parse. If not specified, the driver tries to open `allowlist.json` from the current directory. The `ConnectionDiagnosticsAllowlistFile` is only taken into consideration when `ConnectionDiagnosticsEnabled=true` ==> Flow of operation when `ConnectionDiagnosticsEnabled=true` 1. upon initial startup, driver opens and reads the `allowlist.json` to determine which hosts it needs to connect to, and then for each entry in the allowlist 2. perform a DNS resolution test to see if the hostname is resolvable 3. driver logs an Error, when encountering a _public_ IP address for a host which looks to be a _private_ link hostname 4. checks if proxy is used in the connection 5. sets up a connection; for which we use the same transport which is driven by the driver's config (custom transport, or when OCSP disabled then OCSP-less transport, or by default, the OCSP-enabled transport) 6. for HTTP endpoints, issues a HTTP GET request and see if it connects 7. for HTTPS endpoints, the same , plus - verifies if HTTPS connectivity is set up correctly - parses the certificate chain and logs information on each certificate (on DEBUG loglevel, dump the whole cert) - if (implicitly) configured from `CertRevocationCheckMode` being `advisory` or `enabled`, also tries to connect to the CRL endpoints 8. the driver exits after performing diagnostics. If you want to use the driver 'normally' after performing connection diagnostics, set `ConnectionDiagnosticsEnabled=false` or remove it from the config */ package gosnowflake ================================================ FILE: driver.go ================================================ package gosnowflake import ( "context" "database/sql" "database/sql/driver" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "os" "strconv" "strings" "time" ) // SnowflakeDriver is a context of Go Driver type SnowflakeDriver struct{} // Open creates a new connection. func (d SnowflakeDriver) Open(dsn string) (driver.Conn, error) { var cfg *Config var err error logger.Info("Open") ctx := context.Background() if dsn == "autoConfig" { cfg, err = sfconfig.LoadConnectionConfig() } else { cfg, err = ParseDSN(dsn) } if err != nil { return nil, err } return d.OpenWithConfig(ctx, *cfg) } // OpenConnector creates a new connector with parsed DSN. func (d SnowflakeDriver) OpenConnector(dsn string) (driver.Connector, error) { var cfg *Config var err error if dsn == "autoConfig" { cfg, err = sfconfig.LoadConnectionConfig() } else { cfg, err = ParseDSN(dsn) } if err != nil { return Connector{}, err } return NewConnector(d, *cfg), nil } // OpenWithConfig creates a new connection with the given Config. func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (driver.Conn, error) { timer := time.Now() if err := config.Validate(); err != nil { return nil, err } if config.Params == nil { config.Params = make(map[string]*string) } if config.Tracing != "" { if err := logger.SetLogLevel(config.Tracing); err != nil { return nil, err } } logger.WithContext(ctx).Info("OpenWithConfig") if config.ConnectionDiagnosticsEnabled { connDiagDownloadCrl := (config.CertRevocationCheckMode.String() == "ADVISORY") || (config.CertRevocationCheckMode.String() == "ENABLED") logger.WithContext(ctx).Infof("Connection diagnostics enabled. Allowlist file specified in config: %s, will download CRLs in certificates: %s", config.ConnectionDiagnosticsAllowlistFile, strconv.FormatBool(connDiagDownloadCrl)) performDiagnosis(&config, connDiagDownloadCrl) logger.WithContext(ctx).Info("Connection diagnostics finished.") logger.WithContext(ctx).Warn("A connection to Snowflake was not created because the driver is running in diagnostics mode. If this is unintended then disable diagnostics check by removing the ConnectionDiagnosticsEnabled connection parameter") os.Exit(0) } sc, err := buildSnowflakeConn(ctx, config) if err != nil { return nil, err } if strings.HasSuffix(strings.ToLower(config.Host), sfconfig.CnDomain) { logger.WithContext(ctx).Info("Connecting to CHINA Snowflake domain") } else { logger.WithContext(ctx).Info("Connecting to GLOBAL Snowflake domain") } if err = authenticateWithConfig(sc); err != nil { logger.WithContext(ctx).Errorf("Failed to authenticate. Connection failed after %v milliseconds", time.Since(timer).String()) return nil, err } sc.connectionTelemetry(&config) sc.startHeartBeat() sc.internal = &httpClient{sr: sc.rest} // Check context before returning since connectionTelemetry doesn't handle cancellation if ctx.Err() != nil { return nil, ctx.Err() } logger.WithContext(ctx).Infof("Connected successfully after %v milliseconds", time.Since(timer).String()) return sc, nil } func runningOnGithubAction() bool { return os.Getenv("GITHUB_ACTIONS") != "" } // GOSNOWFLAKE_SKIP_REGISTRATION is an environment variable which can be set client side to // bypass dbSql driver registration. This should not be used if sql.Open() is used as the method // to connect to the server, as sql.Open will require registration so it can map the driver name // to the driver type, which in this case is "snowflake" and SnowflakeDriver{}. If you wish to call // into multiple versions of the driver from one client, this is needed because calling register // twice with the same name on init will cause the driver to panic. func skipRegistration() bool { return os.Getenv("GOSNOWFLAKE_SKIP_REGISTRATION") != "" } func init() { if !skipRegistration() { sql.Register("snowflake", &SnowflakeDriver{}) } // Set initial log level _ = GetLogger().SetLogLevel("error") if runningOnGithubAction() { _ = GetLogger().SetLogLevel("fatal") } } ================================================ FILE: driver_ocsp_test.go ================================================ package gosnowflake import ( "context" "crypto/tls" "crypto/x509" "database/sql" "errors" "fmt" "net/http" "net/url" "os" "strings" "testing" "time" ) func setenv(k, v string) { err := os.Setenv(k, v) if err != nil { panic(err) } } func unsetenv(k string) { err := os.Unsetenv(k) if err != nil { panic(err) } } // deleteOCSPCacheFile deletes the OCSP response cache file func deleteOCSPCacheFile() { os.Remove(cacheFileName) } // deleteOCSPCacheAll deletes all entries in the OCSP response cache on memory func deleteOCSPCacheAll() { syncUpdateOcspResponseCache(func() { ocspResponseCache = make(map[certIDKey]*certCacheValue) }) } func cleanup() { deleteOCSPCacheFile() deleteOCSPCacheAll() unsetenv(cacheServerURLEnv) unsetenv(ocspTestResponderURLEnv) unsetenv(ocspTestNoOCSPURLEnv) unsetenv(cacheDirEnv) } func TestOCSPFailOpen(t *testing.T) { cleanup() defer cleanup() config := &Config{ Account: "fakeaccount1", User: "fakeuser", Password: "fakepassword", LoginTimeout: 10 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, Authenticator: AuthTypeSnowflake, PrivateKey: nil, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if isFailToConnectOrAuthErr(driverErr) { t.Fatalf("should failed to connect %v", err) } } func isFailToConnectOrAuthErr(driverErr *SnowflakeError) bool { return driverErr.Number != ErrCodeFailedToConnect && driverErr.Number != ErrFailedToAuth } func TestOCSPFailOpenWithoutFileCache(t *testing.T) { cleanup() defer cleanup() setenv(cacheDirEnv, "/NEVER_EXISTS") config := &Config{ Account: "fakeaccount1", User: "fakeuser", Password: "fakepassword", LoginTimeout: 10 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if isFailToConnectOrAuthErr(driverErr) { t.Fatalf("should failed to connect %v", err) } } func TestOCSPFailOpenRevokedStatus(t *testing.T) { t.Skip("revoked.badssl.com certificate expired") cleanup() defer cleanup() ocspCacheServerEnabled = false config := &Config{ Account: "fakeaccount6", User: "fakeuser", Password: "fakepassword", Host: "revoked.badssl.com", LoginTimeout: 10 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } urlErr, ok := err.(*url.Error) if !ok { t.Fatalf("failed to extract error URL Error: %v", err) } var driverErr *SnowflakeError driverErr, ok = urlErr.Err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if driverErr.Number != ErrOCSPStatusRevoked { t.Fatalf("should failed to connect %v", err) } } func TestOCSPFailClosedRevokedStatus(t *testing.T) { t.Skip("revoked.badssl.com certificate expired") cleanup() defer cleanup() ocspCacheServerEnabled = false config := &Config{ Account: "fakeaccount7", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", Host: "revoked.badssl.com", LoginTimeout: 20 * time.Second, OCSPFailOpen: OCSPFailOpenFalse, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } urlErr, ok := err.(*url.Error) if !ok { t.Fatalf("failed to extract error URL Error: %v", err) } var driverErr *SnowflakeError driverErr, ok = urlErr.Err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if driverErr.Number != ErrOCSPStatusRevoked { t.Fatalf("should failed to connect %v", err) } } func TestOCSPFailOpenCacheServerTimeout(t *testing.T) { cleanup() defer cleanup() setenv(cacheServerURLEnv, fmt.Sprintf("http://localhost:%v/hang", wiremock.port)) wiremock.registerMappings(t, newWiremockMapping("hang.json")) origCacheServerTimeout := OcspCacheServerTimeout OcspCacheServerTimeout = time.Second defer func() { OcspCacheServerTimeout = origCacheServerTimeout }() config := &Config{ Account: "fakeaccount8", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", LoginTimeout: 10 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if isFailToConnectOrAuthErr(driverErr) { t.Fatalf("should failed to connect %v", err) } } func TestOCSPFailClosedCacheServerTimeout(t *testing.T) { cleanup() defer cleanup() setenv(cacheServerURLEnv, fmt.Sprintf("http://localhost:%v/hang", wiremock.port)) wiremock.registerMappings(t, newWiremockMapping("hang.json")) origCacheServerTimeout := OcspCacheServerTimeout OcspCacheServerTimeout = time.Second defer func() { OcspCacheServerTimeout = origCacheServerTimeout }() config := &Config{ Account: "fakeaccount9", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", LoginTimeout: 20 * time.Second, OCSPFailOpen: OCSPFailOpenFalse, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if err == nil { t.Fatalf("should failed to connect. err: %v", err) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } switch errType := err.(type) { // Before Go 1.17 case *SnowflakeError: driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if isFailToConnectOrAuthErr(driverErr) { t.Fatalf("should have failed to connect. err: %v", err) } // Go 1.18 and after rejects SHA-1 certificates, therefore a different error is returned (https://github.com/golang/go/issues/41682) case *url.Error: expectedErrMsg := "bad OCSP signature" if !strings.Contains(err.Error(), expectedErrMsg) { t.Fatalf("should have failed with bad OCSP signature. err: %v", err) } default: t.Fatalf("should failed to connect. err type: %v, err: %v", errType, err) } } func TestOCSPFailOpenResponderTimeout(t *testing.T) { cleanup() defer cleanup() ocspCacheServerEnabled = false setenv(ocspTestResponderURLEnv, fmt.Sprintf("http://localhost:%v/ocsp/hang", wiremock.port)) wiremock.registerMappings(t, newWiremockMapping("hang.json")) origOCSPResponderTimeout := OcspResponderTimeout OcspResponderTimeout = 1000 defer func() { OcspResponderTimeout = origOCSPResponderTimeout }() config := &Config{ Account: "fakeaccount10", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", LoginTimeout: 10 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if isFailToConnectOrAuthErr(driverErr) { t.Fatalf("should failed to connect %v", err) } } func TestOCSPFailClosedResponderTimeout(t *testing.T) { cleanup() defer cleanup() ocspCacheServerEnabled = false setenv(ocspTestResponderURLEnv, fmt.Sprintf("http://localhost:%v/hang", wiremock.port)) wiremock.registerMappings(t, newWiremockMapping("hang.json")) origOCSPResponderTimeout := OcspResponderTimeout origOCSPMaxRetryCount := OcspMaxRetryCount OcspResponderTimeout = 100 * time.Millisecond OcspMaxRetryCount = 1 defer func() { OcspResponderTimeout = origOCSPResponderTimeout OcspMaxRetryCount = origOCSPMaxRetryCount }() config := &Config{ Account: "fakeaccount11", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", LoginTimeout: 3 * time.Second, OCSPFailOpen: OCSPFailOpenFalse, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } urlErr, ok := err.(*url.Error) if !ok { t.Fatalf("failed to extract error URL Error: %v", err) } urlErr0, ok := urlErr.Err.(*url.Error) if !ok { t.Fatalf("failed to extract error URL Error: %v", urlErr.Err) } if !strings.Contains(urlErr0.Err.Error(), "Client.Timeout") && !strings.Contains(urlErr0.Err.Error(), "connection refused") { t.Fatalf("the root cause is not timeout: %v", urlErr0.Err) } } func TestOCSPFailOpenResponder404(t *testing.T) { cleanup() defer cleanup() ocspCacheServerEnabled = false setenv(ocspTestResponderURLEnv, fmt.Sprintf("http://localhost:%v/404", wiremock.port)) config := &Config{ Account: "fakeaccount10", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", LoginTimeout: 5 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if isFailToConnectOrAuthErr(driverErr) { t.Fatalf("should failed to connect %v", err) } } func TestOCSPFailClosedResponder404(t *testing.T) { cleanup() defer cleanup() ocspCacheServerEnabled = false setenv(ocspTestResponderURLEnv, fmt.Sprintf("http://localhost:%v/404", wiremock.port)) config := &Config{ Account: "fakeaccount11", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", LoginTimeout: 5 * time.Second, OCSPFailOpen: OCSPFailOpenFalse, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } urlErr, ok := err.(*url.Error) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if !strings.Contains(urlErr.Err.Error(), "404 Not Found") && !strings.Contains(urlErr.Err.Error(), "connection refused") { t.Fatalf("the root cause is not 404: %v", urlErr.Err) } } func TestExpiredCertificate(t *testing.T) { cleanup() defer cleanup() config := &Config{ Account: "fakeaccount10", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", Host: "expired.badssl.com", LoginTimeout: 10 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } urlErr, ok := err.(*url.Error) if !ok { t.Fatalf("failed to extract error URL Error: %v", err) } _, ok = urlErr.Err.(x509.CertificateInvalidError) if !ok { // Go 1.20 throws tls CertificateVerification error errString := urlErr.Err.Error() // badssl sometimes times out if !strings.Contains(errString, "certificate has expired or is not yet valid") && !strings.Contains(errString, "timeout") && !strings.Contains(errString, "connection attempt failed") { t.Fatalf("failed to extract error Certificate error: %v", err) } } } /* DISABLED: sicne it appeared self-signed.badssl.com is not well maintained, this test is no longer reliable. // TestSelfSignedCertificate tests self-signed certificate func TestSelfSignedCertificate(t *testing.T) { cleanup() defer cleanup() config := &Config{ Account: "fakeaccount10", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", Host: "self-signed.badssl.com", LoginTimeout: 10 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } urlErr, ok := err.(*url.Error) if !ok { t.Fatalf("failed to extract error URL Error: %v", err) } _, ok = urlErr.Err.(x509.UnknownAuthorityError) if !ok { t.Fatalf("failed to extract error Certificate error: %v", err) } } */ func TestOCSPFailOpenNoOCSPURL(t *testing.T) { cleanup() defer cleanup() ocspCacheServerEnabled = false setenv(ocspTestNoOCSPURLEnv, "true") config := &Config{ Account: "fakeaccount10", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", LoginTimeout: 10 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } if isFailToConnectOrAuthErr(driverErr) { t.Fatalf("should failed to connect %v", err) } } func TestOCSPFailClosedNoOCSPURL(t *testing.T) { cleanup() defer cleanup() ocspCacheServerEnabled = false setenv(ocspTestNoOCSPURLEnv, "true") config := &Config{ Account: "fakeaccount11", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key User: "fakeuser", Password: "fakepassword", LoginTimeout: 20 * time.Second, OCSPFailOpen: OCSPFailOpenFalse, } var db *sql.DB var err error var testURL string testURL, err = DSN(config) assertNilF(t, err, "failed to build URL from Config") if db, err = sql.Open("snowflake", testURL); err != nil { t.Fatalf("failed to open db. %v, err: %v", testURL, err) } defer db.Close() if err = db.Ping(); err == nil { t.Fatalf("should fail to ping. %v", testURL) } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } urlErr, ok := err.(*url.Error) if !ok { t.Fatalf("failed to extract error SnowflakeError: %v", err) } driverErr, ok := urlErr.Err.(*SnowflakeError) if !ok { if !strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { t.Fatalf("failed to extract error SnowflakeError: %v", err) } } if driverErr.Number != ErrOCSPNoOCSPResponderURL { t.Fatalf("should failed to connect %v", err) } } func TestOCSPUnexpectedResponses(t *testing.T) { cleanup() defer cleanup() ocspCacheServerEnabled = false cfg := wiremockHTTPS.connectionConfig(t) countingRoundTripper := newCountingRoundTripper(http.DefaultTransport) ocspTransport := wiremockHTTPS.ocspTransporter(t, countingRoundTripper) cfg.Transporter = ocspTransport runSampleQuery := func(cfg *Config) { connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) rows, err := db.Query("SELECT 1") assertNilF(t, err) defer rows.Close() var v int assertTrueF(t, rows.Next()) err = rows.Scan(&v) assertNilF(t, err) assertEqualE(t, v, 1) } t.Run("should retry when OCSP is not reachable", func(t *testing.T) { countingRoundTripper.reset() testResponderOverride := overrideEnv(ocspTestResponderURLEnv, "http://localhost:56734") defer testResponderOverride.rollback() wiremock.registerMappings(t, wiremockMapping{filePath: "select1.json"}, wiremockMapping{filePath: "auth/password/successful_flow.json"}, ) runSampleQuery(cfg) assertTrueE(t, countingRoundTripper.postReqCount["http://localhost:56734"] > 1) assertEqualE(t, countingRoundTripper.getReqCount["http://localhost:56734"], 0) }) t.Run("should fallback to GET when POST returns malformed response", func(t *testing.T) { countingRoundTripper.reset() testResponderOverride := overrideEnv(ocspTestResponderURLEnv, wiremock.baseURL()) defer testResponderOverride.rollback() wiremock.registerMappings(t, wiremockMapping{filePath: "ocsp/malformed.json"}, wiremockMapping{filePath: "select1.json"}, wiremockMapping{filePath: "auth/password/successful_flow.json"}, ) runSampleQuery(cfg) assertEqualE(t, countingRoundTripper.postReqCount[wiremock.baseURL()], 2) assertEqualE(t, countingRoundTripper.getReqCount[wiremock.baseURL()], 2) }) t.Run("should not fallback to GET when for POST unauthorized is returned", func(t *testing.T) { countingRoundTripper.reset() assertNilF(t, os.Setenv(ocspTestResponderURLEnv, wiremock.baseURL())) testResponderOverride := overrideEnv(ocspTestResponderURLEnv, wiremock.baseURL()) defer testResponderOverride.rollback() wiremock.registerMappings(t, wiremockMapping{filePath: "ocsp/unauthorized.json"}, wiremockMapping{filePath: "select1.json"}, wiremockMapping{filePath: "auth/password/successful_flow.json"}, ) runSampleQuery(cfg) assertEqualE(t, countingRoundTripper.postReqCount[wiremock.baseURL()], 2) assertEqualE(t, countingRoundTripper.getReqCount[wiremock.baseURL()], 0) }) } func TestConnectionToMultipleConfigurations(t *testing.T) { setenv(cacheServerURLEnv, defaultCacheServerHost) wiremockHTTPS.registerMappings(t, wiremockMapping{filePath: "auth/password/successful_flow.json"}) err := RegisterTLSConfig("wiremock", &tls.Config{ RootCAs: wiremockHTTPS.certPool(t), }) assertNilF(t, err) origOcspMaxRetryCount := OcspMaxRetryCount OcspMaxRetryCount = 1 defer func() { OcspMaxRetryCount = origOcspMaxRetryCount }() cfgForFailOpen := wiremockHTTPS.connectionConfig(t) cfgForFailOpen.OCSPFailOpen = OCSPFailOpenTrue cfgForFailOpen.Transporter = nil cfgForFailOpen.TLSConfigName = "wiremock" cfgForFailOpen.MaxRetryCount = 1 cfgForFailClose := wiremockHTTPS.connectionConfig(t) cfgForFailClose.OCSPFailOpen = OCSPFailOpenFalse cfgForFailClose.Transporter = nil cfgForFailClose.TLSConfigName = "wiremock" cfgForFailClose.MaxRetryCount = 1 // we ignore closing here, since these are only wiremock connections failOpenDb := sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfgForFailOpen)) failCloseDb := sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfgForFailClose)) _, err = failOpenDb.Conn(context.Background()) assertNilF(t, err) _, err = failCloseDb.Conn(context.Background()) assertNotNilF(t, err) var se *SnowflakeError assertTrueF(t, errors.As(err, &se)) assertStringContainsE(t, se.Error(), "no OCSP server is attached to the certificate") _, err = failOpenDb.Conn(context.Background()) assertNilF(t, err) // new connections should still behave the same way failOpenDb2 := sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfgForFailOpen)) failCloseDb2 := sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfgForFailClose)) _, err = failOpenDb2.Conn(context.Background()) assertNilF(t, err) _, err = failCloseDb2.Conn(context.Background()) assertNotNilF(t, err) assertTrueF(t, errors.As(err, &se)) assertStringContainsE(t, se.Error(), "no OCSP server is attached to the certificate") // and old connections should still behave the same way _, err = failOpenDb.Conn(context.Background()) assertNilF(t, err) _, err = failCloseDb.Conn(context.Background()) assertNotNilF(t, err) assertTrueF(t, errors.As(err, &se)) assertStringContainsE(t, se.Error(), "no OCSP server is attached to the certificate") } ================================================ FILE: driver_test.go ================================================ package gosnowflake import ( "cmp" "context" "crypto/rsa" "database/sql" "database/sql/driver" "encoding/base64" "encoding/pem" "errors" "flag" "fmt" "math" "math/big" "math/rand" "net/http" "net/url" "os" "os/signal" "path/filepath" "reflect" "runtime" "strconv" "strings" "sync" "syscall" "testing" "time" ) var ( username string pass string account string dbname string schemaname string warehouse string rolename string dsn string host string port string protocol string customPrivateKey bool // Whether user has specified the private key path testPrivKey *rsa.PrivateKey // Valid private key used for all test cases debugMode bool ) const ( selectNumberSQL = "SELECT %s::NUMBER(%v, %v) AS C" selectVariousTypes = "SELECT 1.0::NUMBER(30,2) as C1, 2::NUMBER(18,0) AS C2, 22::NUMBER(38, 0) AS C2A, 't3' AS C3, 4.2::DOUBLE AS C4, 'abcd'::BINARY(8388608) AS C5, true AS C6" selectRandomGenerator = "SELECT SEQ8(), RANDSTR(1000, RANDOM()) FROM TABLE(GENERATOR(ROWCOUNT=>%v))" PSTLocation = "America/Los_Angeles" ) // The tests require the following parameters in the environment variables. // SNOWFLAKE_TEST_USER, SNOWFLAKE_TEST_PASSWORD, SNOWFLAKE_TEST_ACCOUNT, // SNOWFLAKE_TEST_DATABASE, SNOWFLAKE_TEST_SCHEMA, SNOWFLAKE_TEST_WAREHOUSE. // Optionally you may specify SNOWFLAKE_TEST_PROTOCOL, SNOWFLAKE_TEST_HOST // and SNOWFLAKE_TEST_PORT to specify the endpoint. func init() { // get environment variables env := func(key, defaultValue string) string { return cmp.Or(os.Getenv(key), defaultValue) } username = env("SNOWFLAKE_TEST_USER", "testuser") pass = env("SNOWFLAKE_TEST_PASSWORD", "testpassword") account = env("SNOWFLAKE_TEST_ACCOUNT", "testaccount") dbname = env("SNOWFLAKE_TEST_DATABASE", "testdb") schemaname = env("SNOWFLAKE_TEST_SCHEMA", "public") rolename = env("SNOWFLAKE_TEST_ROLE", "sysadmin") warehouse = env("SNOWFLAKE_TEST_WAREHOUSE", "testwarehouse") protocol = env("SNOWFLAKE_TEST_PROTOCOL", "https") host = os.Getenv("SNOWFLAKE_TEST_HOST") port = env("SNOWFLAKE_TEST_PORT", "443") if host == "" { host = fmt.Sprintf("%s.snowflakecomputing.com", account) } else { host = fmt.Sprintf("%s:%s", host, port) } setupPrivateKey() createDSN("UTC") debugMode, _ = strconv.ParseBool(os.Getenv("SNOWFLAKE_TEST_DEBUG")) if debugMode { _ = GetLogger().SetLogLevel("debug") } } func createDSN(timezone string) { // Check if we should use JWT authentication authenticator := os.Getenv("SNOWFLAKE_TEST_AUTHENTICATOR") if authenticator == "SNOWFLAKE_JWT" { // For JWT authentication, don't include password in the DSN dsn = fmt.Sprintf("%s@%s/%s/%s", username, host, dbname, schemaname) } else { // For standard password authentication dsn = fmt.Sprintf("%s:%s@%s/%s/%s", username, pass, host, dbname, schemaname) } parameters := url.Values{} parameters.Add("timezone", timezone) if protocol != "" { parameters.Add("protocol", protocol) } if account != "" { parameters.Add("account", account) } if warehouse != "" { parameters.Add("warehouse", warehouse) } if rolename != "" { parameters.Add("role", rolename) } // Add authenticator and private key for JWT authentication if authenticator == "SNOWFLAKE_JWT" { parameters.Add("authenticator", "SNOWFLAKE_JWT") parameters.Add("jwtClientTimeout", "20") privateKeyPath := os.Getenv("SNOWFLAKE_TEST_PRIVATE_KEY") if privateKeyPath != "" { // Read and encode the private key file privateKeyBytes, err := os.ReadFile(privateKeyPath) if err == nil { block, _ := pem.Decode(privateKeyBytes) if block != nil && block.Type == "PRIVATE KEY" { encodedKey := base64.URLEncoding.EncodeToString(block.Bytes) parameters.Add("privateKey", encodedKey) } else if block == nil { panic("Failed to decode PEM block from private key file") } else { panic("Expected 'PRIVATE KEY' block type") } } else { panic("Failed to read private key file") } } else { panic("SNOWFLAKE_TEST_PRIVATE_KEY environment variable is not set for JWT authentication") } } if len(parameters) > 0 { dsn += "?" + parameters.Encode() } } // setup creates a test schema so that all tests can run in the same schema func setup() (string, error) { env := func(key, defaultValue string) string { return cmp.Or(os.Getenv(key), defaultValue) } orgSchemaname := schemaname if env("GITHUB_WORKFLOW", "") != "" { githubRunnerID := env("RUNNER_TRACKING_ID", "GITHUB_RUNNER_ID") githubRunnerID = strings.ReplaceAll(githubRunnerID, "-", "_") githubSha := env("GITHUB_SHA", "GITHUB_SHA") schemaname = fmt.Sprintf("%v_%v", githubRunnerID, githubSha) } else { schemaname = fmt.Sprintf("golang_%v", time.Now().UnixNano()) } var db *sql.DB var err error if db, err = sql.Open("snowflake", dsn); err != nil { return "", fmt.Errorf("failed to open db. err: %v", err) } defer db.Close() if _, err = db.Exec(fmt.Sprintf("CREATE OR REPLACE SCHEMA %v", schemaname)); err != nil { return "", fmt.Errorf("failed to create schema. %v", err) } createDSN("UTC") return orgSchemaname, nil } // teardown drops the test schema func teardown() error { var db *sql.DB var err error if db, err = sql.Open("snowflake", dsn); err != nil { return fmt.Errorf("failed to open db. %v, err: %v", dsn, err) } defer db.Close() if _, err = db.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %v", schemaname)); err != nil { return fmt.Errorf("failed to create schema. %v", err) } return nil } func TestMain(m *testing.M) { flag.Parse() signal.Ignore(syscall.SIGQUIT) if value := os.Getenv("SKIP_SETUP"); value != "" { os.Exit(m.Run()) } if _, err := setup(); err != nil { panic(err) } ret := m.Run() if err := teardown(); err != nil { panic(err) } os.Exit(ret) } type DBTest struct { *testing.T conn *sql.Conn } func (dbt *DBTest) mustQueryT(t *testing.T, query string, args ...any) (rows *RowsExtended) { t.Helper() // handler interrupt signal ctx, cancel := context.WithCancel(context.Background()) c := make(chan os.Signal, 1) c0 := make(chan bool, 1) signal.Notify(c, os.Interrupt) defer func() { signal.Stop(c) }() go func() { select { case <-c: fmt.Println("Caught signal, canceling...") cancel() case <-ctx.Done(): fmt.Println("Done") case <-c0: } close(c) }() rs, err := dbt.conn.QueryContext(ctx, query, args...) if err != nil { t.Fatalf("query, query=%v, err=%v", query, err) } return &RowsExtended{ rows: rs, closeChan: &c0, t: t, } } func (dbt *DBTest) mustQuery(query string, args ...any) (rows *RowsExtended) { dbt.Helper() return dbt.mustQueryT(dbt.T, query, args...) } func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...any) (rows *RowsExtended) { dbt.Helper() return dbt.mustQueryContextT(ctx, dbt.T, query, args...) } func (dbt *DBTest) mustQueryContextT(ctx context.Context, t *testing.T, query string, args ...any) (rows *RowsExtended) { t.Helper() // handler interrupt signal ctx, cancel := context.WithCancel(ctx) c := make(chan os.Signal, 1) c0 := make(chan bool, 1) signal.Notify(c, os.Interrupt) defer func() { signal.Stop(c) }() go func() { select { case <-c: fmt.Println("Caught signal, canceling...") cancel() case <-ctx.Done(): fmt.Println("Done") case <-c0: } close(c) }() rs, err := dbt.conn.QueryContext(ctx, query, args...) if err != nil { t.Fatalf("query, query=%v, err=%v", query, err) } return &RowsExtended{ rows: rs, closeChan: &c0, t: t, } } func (dbt *DBTest) query(query string, args ...any) (*sql.Rows, error) { return dbt.conn.QueryContext(context.Background(), query, args...) } func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...any) { rows := dbt.mustQuery(query, args...) defer rows.Close() cnt := 0 for rows.Next() { cnt++ } if cnt != expected { dbt.Fatalf("expected %v, got %v", expected, cnt) } } func (dbt *DBTest) prepare(query string) (*sql.Stmt, error) { return dbt.conn.PrepareContext(context.Background(), query) } func (dbt *DBTest) fail(method, query string, err error) { if !debugMode && len(query) > 1000 { query = "[query too large to print]" } dbt.Fatalf("error on %s [%s]: %s", method, query, err.Error()) } func (dbt *DBTest) mustExec(query string, args ...any) (res sql.Result) { return dbt.mustExecContext(context.Background(), query, args...) } func (dbt *DBTest) mustExecT(t *testing.T, query string, args ...any) (res sql.Result) { return dbt.mustExecContextT(context.Background(), t, query, args...) } func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...any) (res sql.Result) { res, err := dbt.conn.ExecContext(ctx, query, args...) if err != nil { dbt.fail("exec context", query, err) } return res } func (dbt *DBTest) mustExecContextT(ctx context.Context, t *testing.T, query string, args ...any) (res sql.Result) { res, err := dbt.conn.ExecContext(ctx, query, args...) if err != nil { t.Fatalf("exec context: query=%v, err=%v", query, err) } return res } func (dbt *DBTest) exec(query string, args ...any) (sql.Result, error) { return dbt.conn.ExecContext(context.Background(), query, args...) } func (dbt *DBTest) mustDecimalSize(ct *sql.ColumnType) (pr int64, sc int64) { var ok bool pr, sc, ok = ct.DecimalSize() if !ok { dbt.Fatalf("failed to get decimal size. %v", ct) } return pr, sc } func (dbt *DBTest) mustFailDecimalSize(ct *sql.ColumnType) { var ok bool if _, _, ok = ct.DecimalSize(); ok { dbt.Fatalf("should not return decimal size. %v", ct) } } func (dbt *DBTest) mustLength(ct *sql.ColumnType) (cLen int64) { var ok bool cLen, ok = ct.Length() if !ok { dbt.Fatalf("failed to get length. %v", ct) } return cLen } func (dbt *DBTest) mustFailLength(ct *sql.ColumnType) { var ok bool if _, ok = ct.Length(); ok { dbt.Fatalf("should not return length. %v", ct) } } func (dbt *DBTest) mustNullable(ct *sql.ColumnType) (canNull bool) { var ok bool canNull, ok = ct.Nullable() if !ok { dbt.Fatalf("failed to get length. %v", ct) } return canNull } func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { stmt, err := dbt.conn.PrepareContext(context.Background(), query) if err != nil { dbt.fail("prepare", query, err) } return stmt } func (dbt *DBTest) forceJSON() { dbt.mustExec(forceJSON) } func (dbt *DBTest) forceArrow() { dbt.mustExec(forceARROW) dbt.mustExec("alter session set ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = false") dbt.mustExec("alter session set FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = false") } func (dbt *DBTest) forceNativeArrow() { // structured types dbt.mustExec(forceARROW) dbt.mustExec("alter session set ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true") dbt.mustExec("alter session set FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true") } func (dbt *DBTest) enableStructuredTypes() { _, err := dbt.exec("alter session set ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE = true") if err != nil { dbt.Log(err) } _, err = dbt.exec("alter session set IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE = true") if err != nil { dbt.Log(err) } _, err = dbt.exec("alter session set ENABLE_STRUCTURED_TYPES_IN_FDN_TABLES = true") if err != nil { dbt.Log(err) } } func (dbt *DBTest) enableStructuredTypesBinding() { dbt.enableStructuredTypes() _, err := dbt.exec("ALTER SESSION SET ENABLE_OBJECT_TYPED_BINDS = true") if err != nil { dbt.Log(err) } _, err = dbt.exec("ALTER SESSION SET ENABLE_STRUCTURED_TYPES_IN_BINDS = Enable") if err != nil { dbt.Log(err) } } type SCTest struct { *testing.T sc *snowflakeConn } func (sct *SCTest) fail(method, query string, err error) { if !debugMode && len(query) > 300 { query = "[query too large to print]" } sct.Fatalf("error on %s [%s]: %s", method, query, err.Error()) } func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result { result, err := sct.sc.Exec(query, args) if err != nil { sct.fail("exec", query, err) } return result } func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows { rows, err := sct.sc.Query(query, args) if err != nil { sct.fail("query", query, err) } return rows } func (sct *SCTest) mustQueryContext(ctx context.Context, query string, args []driver.NamedValue) driver.Rows { rows, err := sct.sc.QueryContext(ctx, query, args) if err != nil { sct.fail("QueryContext", query, err) } return rows } type testConfig struct { dsn string } func runDBTest(t *testing.T, test func(dbt *DBTest)) { runDBTestWithConfig(t, &testConfig{dsn}, test) } func runDBTestWithConfig(t *testing.T, testCfg *testConfig, test func(dbt *DBTest)) { db, conn := openConn(t, testCfg) defer conn.Close() defer db.Close() dbt := &DBTest{t, conn} test(dbt) } func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) { runSnowflakeConnTestWithConfig(t, &testConfig{dsn}, test) } func runSnowflakeConnTestWithConfig(t *testing.T, testCfg *testConfig, test func(sct *SCTest)) { config, err := ParseDSN(testCfg.dsn) if err != nil { t.Error(err) } sc, err := buildSnowflakeConn(context.Background(), *config) if err != nil { t.Fatal(err) } defer sc.Close() if err = authenticateWithConfig(sc); err != nil { t.Fatal(err) } sct := &SCTest{t, sc} test(sct) } func getDbHandlerFromConfig(t *testing.T, cfg *Config) *sql.DB { dsn, err := DSN(cfg) assertNilF(t, err, "failed to create DSN from Config") db, err := sql.Open("snowflake", dsn) assertNilF(t, err, "failed to open database") return db } func runningOnAWS() bool { return os.Getenv("CLOUD_PROVIDER") == "AWS" } func runningOnGCP() bool { return os.Getenv("CLOUD_PROVIDER") == "GCP" } func runningOnLinux() bool { return runtime.GOOS == "linux" } func TestKnownUserInvalidPasswordParameters(t *testing.T) { wiremock.registerMappings(t, wiremockMapping{filePath: "auth/password/invalid_password.json"}, ) cfg := wiremock.connectionConfig() cfg.User = "testUser" cfg.Password = "INVALID_PASSWORD" cfg.Authenticator = AuthTypeSnowflake // Force password auth db := sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfg)) defer db.Close() _, err := db.Exec("SELECT 1") assertNotNilF(t, err, "should cause an authentication error") var driverErr *SnowflakeError assertErrorsAsF(t, err, &driverErr) assertEqualE(t, driverErr.Number, 390100) } func TestCommentOnlyQuery(t *testing.T) { runDBTest(t, func(dbt *DBTest) { query := "--" // just a comment, no query rows, err := dbt.query(query) if err == nil { rows.Close() dbt.fail("query", query, err) } if driverErr, ok := err.(*SnowflakeError); ok { if driverErr.Number != 900 { // syntax error dbt.fail("query", query, err) } } }) } func TestEmptyQuery(t *testing.T) { runDBTest(t, func(dbt *DBTest) { query := "select 1 from dual where 1=0" // just a comment, no query rows := dbt.conn.QueryRowContext(context.Background(), query) var v1 any if err := rows.Scan(&v1); err != sql.ErrNoRows { dbt.Errorf("should fail. err: %v", err) } rows = dbt.conn.QueryRowContext(context.Background(), query) if err := rows.Scan(&v1); err != sql.ErrNoRows { dbt.Errorf("should fail. err: %v", err) } }) } func TestEmptyQueryWithRequestID(t *testing.T) { runDBTest(t, func(dbt *DBTest) { query := "select 1" ctx := WithRequestID(context.Background(), NewUUID()) rows := dbt.conn.QueryRowContext(ctx, query) var v1 any if err := rows.Scan(&v1); err != nil { dbt.Errorf("should not have failed with valid request id. err: %v", err) } }) } func TestRequestIDFromTwoDifferentSessions(t *testing.T) { db, err := sql.Open("snowflake", dsn) assertNilF(t, err) db.SetMaxOpenConns(10) conn, err := db.Conn(context.Background()) assertNilF(t, err) defer conn.Close() _, err = conn.ExecContext(context.Background(), forceJSON) assertNilF(t, err) conn2, err := db.Conn(context.Background()) assertNilF(t, err) defer conn2.Close() _, err = conn2.ExecContext(context.Background(), forceJSON) assertNilF(t, err) // creating table reqIDForCreate := NewUUID() _, err = conn.ExecContext(WithRequestID(context.Background(), reqIDForCreate), "CREATE TABLE req_id_testing (id INTEGER)") assertNilF(t, err) defer func() { _, err = db.Exec("DROP TABLE IF EXISTS req_id_testing") assertNilE(t, err) }() _, err = conn.ExecContext(WithRequestID(context.Background(), reqIDForCreate), "CREATE TABLE req_id_testing (id INTEGER)") assertNilF(t, err) defer func() { _, err = db.Exec("DROP TABLE IF EXISTS req_id_testing") assertNilE(t, err) }() // should fail as API v1 does not allow reusing requestID across sessions for DML statements _, err = conn2.ExecContext(WithRequestID(context.Background(), reqIDForCreate), "CREATE TABLE req_id_testing (id INTEGER)") assertNotNilE(t, err) assertStringContainsE(t, err.Error(), "already exists") // inserting a record reqIDForInsert := NewUUID() execResult, err := conn.ExecContext(WithRequestID(context.Background(), reqIDForInsert), "INSERT INTO req_id_testing VALUES (1)") assertNilF(t, err) rowsInserted, err := execResult.RowsAffected() assertNilF(t, err) assertEqualE(t, rowsInserted, int64(1)) _, err = conn2.ExecContext(WithRequestID(context.Background(), reqIDForInsert), "INSERT INTO req_id_testing VALUES (1)") assertNilF(t, err) rowsInserted2, err := execResult.RowsAffected() assertNilF(t, err) assertEqualE(t, rowsInserted2, int64(1)) // selecting data reqIDForSelect := NewUUID() rows, err := conn.QueryContext(WithRequestID(context.Background(), reqIDForSelect), "SELECT * FROM req_id_testing") assertNilF(t, err) defer rows.Close() var i int assertTrueE(t, rows.Next()) assertNilF(t, rows.Scan(&i)) assertEqualE(t, i, 1) i = 0 assertTrueE(t, rows.Next()) assertNilF(t, rows.Scan(&i)) assertEqualE(t, i, 1) assertFalseE(t, rows.Next()) rows2, err := conn.QueryContext(WithRequestID(context.Background(), reqIDForSelect), "SELECT * FROM req_id_testing") assertNilF(t, err) defer rows2.Close() assertTrueE(t, rows2.Next()) assertNilF(t, rows2.Scan(&i)) assertEqualE(t, i, 1) i = 0 assertTrueE(t, rows2.Next()) assertNilF(t, rows2.Scan(&i)) assertEqualE(t, i, 1) assertFalseE(t, rows2.Next()) // insert another data _, err = conn.ExecContext(context.Background(), "INSERT INTO req_id_testing VALUES (1)") assertNilF(t, err) // selecting using old request id rows3, err := conn.QueryContext(WithRequestID(context.Background(), reqIDForSelect), "SELECT * FROM req_id_testing") assertNilF(t, err) defer rows3.Close() assertTrueE(t, rows3.Next()) assertNilF(t, rows3.Scan(&i)) assertEqualE(t, i, 1) i = 0 assertTrueE(t, rows3.Next()) assertNilF(t, rows3.Scan(&i)) assertEqualE(t, i, 1) i = 0 assertFalseF(t, rows3.Next()) } func TestCRUD(t *testing.T) { runDBTest(t, func(dbt *DBTest) { // Create Table dbt.mustExec("CREATE OR REPLACE TABLE test (value BOOLEAN)") // Test for unexpected Data var out bool rows := dbt.mustQuery("SELECT * FROM test") defer rows.Close() if rows.Next() { dbt.Error("unexpected Data in empty table") } // Create Data res := dbt.mustExec("INSERT INTO test VALUES (true)") count, err := res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { dbt.Fatalf("expected 1 affected row, got %d", count) } id, err := res.LastInsertId() if err != nil { dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) } if id != -1 { dbt.Fatalf( "expected InsertId -1, got %d. Snowflake doesn't support last insert ID", id) } // Read rows = dbt.mustQuery("SELECT value FROM test") defer func(rows *RowsExtended) { assertNilF(t, rows.Close()) }(rows) if rows.Next() { assertNilF(t, rows.Scan(&out)) if !out { dbt.Errorf("%t should be true", out) } if rows.Next() { dbt.Error("unexpected Data") } } else { dbt.Error("no Data") } // Update res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { dbt.Fatalf("expected 1 affected row, got %d", count) } // Check Update rows = dbt.mustQuery("SELECT value FROM test") defer func(rows *RowsExtended) { assertNilF(t, rows.Close()) }(rows) if rows.Next() { assertNilF(t, rows.Scan(&out)) if out { dbt.Errorf("%t should be true", out) } if rows.Next() { dbt.Error("unexpected Data") } } else { dbt.Error("no Data") } // Delete res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { dbt.Fatalf("expected 1 affected row, got %d", count) } // Check for unexpected rows res = dbt.mustExec("DELETE FROM test") count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 0 { dbt.Fatalf("expected 0 affected row, got %d", count) } }) } func TestInt(t *testing.T) { testInt(t, false) } func testInt(t *testing.T, json bool) { runDBTest(t, func(dbt *DBTest) { types := []string{"INT", "INTEGER"} in := int64(42) var out int64 var rows *RowsExtended // SIGNED for _, v := range types { t.Run(v, func(t *testing.T) { if json { dbt.mustExec(forceJSON) } dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") dbt.mustExec("INSERT INTO test VALUES (?)", in) rows = dbt.mustQuery("SELECT value FROM test") defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { assertNilF(t, rows.Scan(&out)) if in != out { dbt.Errorf("%s: %d != %d", v, in, out) } } else { dbt.Errorf("%s: no data", v) } }) } dbt.mustExec("DROP TABLE IF EXISTS test") }) } func TestFloat32(t *testing.T) { testFloat32(t, false) } func testFloat32(t *testing.T, json bool) { runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} in := float32(42.23) var out float32 var rows *RowsExtended for _, v := range types { t.Run(v, func(t *testing.T) { if json { dbt.mustExec(forceJSON) } dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") dbt.mustExec("INSERT INTO test VALUES (?)", in) rows = dbt.mustQuery("SELECT value FROM test") defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { err := rows.Scan(&out) if err != nil { dbt.Errorf("failed to scan data: %v", err) } if in != out { dbt.Errorf("%s: %g != %g", v, in, out) } } else { dbt.Errorf("%s: no data", v) } }) } dbt.mustExec("DROP TABLE IF EXISTS test") }) } func TestFloat64(t *testing.T) { testFloat64(t, false) } func testFloat64(t *testing.T, json bool) { runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} expected := 42.23 var out float64 var rows *RowsExtended for _, v := range types { t.Run(v, func(t *testing.T) { if json { dbt.mustExec(forceJSON) } dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") dbt.mustExec("INSERT INTO test VALUES (42.23)") rows = dbt.mustQuery("SELECT value FROM test") defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { assertNilF(t, rows.Scan(&out)) if expected != out { dbt.Errorf("%s: %g != %g", v, expected, out) } } else { dbt.Errorf("%s: no data", v) } }) } dbt.mustExec("DROP TABLE IF EXISTS test") }) } func TestDecfloat(t *testing.T) { runDBTest(t, func(dbt *DBTest) { for _, format := range []string{"JSON", "ARROW"} { if format == "JSON" { dbt.mustExecT(t, forceJSON) } else { dbt.mustExecT(t, forceARROW) } for _, higherPrecision := range []bool{false, true} { for _, decfloatMappingEnabled := range []bool{true, false} { t.Run(fmt.Sprintf("format=%v,higherPrecision=%v,decfloatMappingEnabled=%v", format, higherPrecision, decfloatMappingEnabled), func(t *testing.T) { for _, tc := range []struct { in string standardPrecisionOutput float64 higherPrecisionOutput string decfloatDisabledOutput string }{ {in: "0", standardPrecisionOutput: 0, higherPrecisionOutput: "0", decfloatDisabledOutput: "0"}, {in: "-1", standardPrecisionOutput: -1, higherPrecisionOutput: "-1", decfloatDisabledOutput: "-1"}, {in: "-1.5", standardPrecisionOutput: -1.5, higherPrecisionOutput: "-1.5", decfloatDisabledOutput: "-1.5"}, {in: "1e1", standardPrecisionOutput: 10, higherPrecisionOutput: "10", decfloatDisabledOutput: "10"}, {in: "1e2", standardPrecisionOutput: 100, higherPrecisionOutput: "100", decfloatDisabledOutput: "100"}, {in: "-2e3", standardPrecisionOutput: -2000, higherPrecisionOutput: "-2000", decfloatDisabledOutput: "-2000"}, {in: "1e100", standardPrecisionOutput: math.Pow10(100), higherPrecisionOutput: "1e+100", decfloatDisabledOutput: "1e100"}, {in: "-1.2345e2", standardPrecisionOutput: -123.45, higherPrecisionOutput: "-123.45", decfloatDisabledOutput: "-123.45"}, {in: "1.23456e2", standardPrecisionOutput: 123.456, higherPrecisionOutput: "123.456", decfloatDisabledOutput: "123.456"}, {in: "-9.87654321E-250", standardPrecisionOutput: -9.876654321 * math.Pow10(-250), higherPrecisionOutput: "-9.87654321e-250", decfloatDisabledOutput: "-9.87654321e-250"}, {in: "1.2345678901234567890123456789012345678e37", standardPrecisionOutput: 12345678901234567525491324606797053952, higherPrecisionOutput: "12345678901234567890123456789012345678", decfloatDisabledOutput: "12345678901234567890123456789012345678"}, // pragma: allowlist secret } { t.Run(tc.in, func(t *testing.T) { ctx := context.Background() if higherPrecision { ctx = WithHigherPrecision(ctx) } if decfloatMappingEnabled { ctx = WithDecfloatMappingEnabled(ctx) } rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT '%v'::DECFLOAT UNION SELECT NULL ORDER BY 1", tc.in)) defer rows.Close() rows.mustNext() if !decfloatMappingEnabled { var s string rows.mustScan(&s) if format == "ARROW" { assertEqualF(t, s, strings.ToLower(tc.in)) } else { assertEqualE(t, s, tc.decfloatDisabledOutput) } columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[string]()) } else if higherPrecision { var bf *big.Float rows.mustScan(&bf) assertEqualE(t, bf.Text('g', 38), tc.higherPrecisionOutput) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[*big.Float]()) } else { var f float64 rows.mustScan(&f) assertEqualEpsilonE(t, f, tc.standardPrecisionOutput, 0.0001) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[float64]()) } rows.mustNext() if !decfloatMappingEnabled { var s sql.NullString rows.mustScan(&s) assertFalseE(t, s.Valid) } else if higherPrecision { var bf *big.Float rows.mustScan(&bf) assertNilE(t, bf) } else { var f sql.NullFloat64 rows.mustScan(&f) assertFalseE(t, f.Valid) } }) } }) } } } t.Run("Binding simple value", func(t *testing.T) { t.Run("As string", func(t *testing.T) { rows := dbt.mustQueryContextT(context.Background(), t, "SELECT ?::DECFLOAT", DataTypeDecfloat, "1234567890.1234567890123456789012345678") defer rows.Close() rows.mustNext() var s string rows.mustScan(&s) assertEqualE(t, s, "1.2345678901234567890123456789012345678e9") }) t.Run("As float", func(t *testing.T) { rows := dbt.mustQueryContextT(WithDecfloatMappingEnabled(context.Background()), t, "SELECT ?::DECFLOAT", DataTypeDecfloat, 123.45) defer rows.Close() rows.mustNext() var f float64 rows.mustScan(&f) assertEqualE(t, f, 123.45) }) t.Run("As *big.Float", func(t *testing.T) { bfFromString, ok := new(big.Float).SetPrec(127).SetString("1234567890.1234567890123456789012345678") assertTrueF(t, ok) println(bfFromString.Text('g', 40)) rows := dbt.mustQueryContextT(WithDecfloatMappingEnabled(WithHigherPrecision(context.Background())), t, "SELECT ?::DECFLOAT", DataTypeDecfloat, bfFromString) defer rows.Close() rows.mustNext() bf := new(big.Float).SetPrec(127) rows.mustScan(&bf) println(bf.Text('g', 40)) assertTrueE(t, bf.Cmp(bfFromString) == 0) }) }) t.Run("Binding array", func(t *testing.T) { bfFromString, ok := new(big.Float).SetPrec(127).SetString("1234567890.1234567890123456789012345678") assertTrueF(t, ok) arrays := []any{ mustArray([]string{"123.45", "1234567890.1234567890123456789012345678"}, DataTypeDecfloat), mustArray([]float64{123.45, 1234567890.1234567890123456789012345678}, DataTypeDecfloat), mustArray([]*big.Float{ new(big.Float).SetFloat64(123.45), bfFromString, }, DataTypeDecfloat), } for _, bulk := range []bool{false, true} { for idx, arr := range arrays { t.Run(fmt.Sprintf("bulk=%v, idx=%v", bulk, idx), func(t *testing.T) { if bulk { dbt.mustExecT(t, "ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1") } else { dbt.mustExecT(t, "ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 100") } dbt.mustExecT(t, "CREATE OR REPLACE TABLE test_decfloat (value DECFLOAT)") defer dbt.mustExecT(t, "DROP TABLE IF EXISTS test_decfloat") _ = dbt.mustExecT(t, "INSERT INTO test_decfloat VALUES (?)", arr) rows := dbt.mustQueryT(t, "SELECT value FROM test_decfloat ORDER BY 1") defer rows.Close() rows.mustNext() var f float64 rows.mustScan(&f) assertEqualEpsilonE(t, f, 123.45, 0.01) rows.mustNext() if idx != 1 { // float64 cannot be bound with the full precision var s string rows.mustScan(&s) assertEqualE(t, s, "1.2345678901234567890123456789012345678e9") } else { rows.mustScan(&f) assertEqualEpsilonE(t, f, 1234567890.1234567890123456789012345678, 0.01) } }) } } }) }) } func TestString(t *testing.T) { testString(t, false) } func testString(t *testing.T, json bool) { runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } types := []string{"CHAR(255)", "VARCHAR(255)", "TEXT", "STRING"} in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" var out string var rows *RowsExtended for _, v := range types { t.Run(v, func(t *testing.T) { dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") dbt.mustExec("INSERT INTO test VALUES (?)", in) rows = dbt.mustQuery("SELECT value FROM test") defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { assertNilF(t, rows.Scan(&out)) if in != out { dbt.Errorf("%s: %s != %s", v, in, out) } } else { dbt.Errorf("%s: no data", v) } }) } dbt.mustExec("DROP TABLE IF EXISTS test") // BLOB (Snowflake doesn't support BLOB type but STRING covers large text data) dbt.mustExec("CREATE OR REPLACE TABLE test (id int, value STRING)") id := 2 in = `Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.` dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) if err := dbt.conn.QueryRowContext(context.Background(), "SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil { dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) } else if out != in { dbt.Errorf("BLOB: %s != %s", in, out) } }) } /** TESTING TYPES **/ // testUUID is a wrapper around UUID for unit testing purposes and should not be used in production type testUUID struct { UUID } func newTestUUID() testUUID { return testUUID{NewUUID()} } func parseTestUUID(str string) testUUID { if str == "" { return testUUID{} } return testUUID{ParseUUID(str)} } // Scan implements sql.Scanner so UUIDs can be read from databases transparently. // Currently, database types that map to string and []byte are supported. Please // consult database-specific driver documentation for matching types. func (uuid *testUUID) Scan(src any) error { switch src := src.(type) { case nil: return nil case string: // if an empty UUID comes from a table, we return a null UUID if src == "" { return nil } // see Parse for required string format u := ParseUUID(src) *uuid = testUUID{u} case []byte: // if an empty UUID comes from a table, we return a null UUID if len(src) == 0 { return nil } // assumes a simple slice of bytes if 16 bytes // otherwise attempts to parse if len(src) != 16 { return uuid.Scan(string(src)) } copy((uuid.UUID)[:], src) default: return fmt.Errorf("Scan: unable to scan type %T into UUID", src) } return nil } // Value implements sql.Valuer so that UUIDs can be written to databases // transparently. Currently, UUIDs map to strings. Please consult // database-specific driver documentation for matching types. func (uuid testUUID) Value() (driver.Value, error) { return uuid.String(), nil } func TestUUID(t *testing.T) { t.Run("JSON", func(t *testing.T) { testUUIDWithFormat(t, true, false) }) t.Run("Arrow", func(t *testing.T) { testUUIDWithFormat(t, false, true) }) } func testUUIDWithFormat(t *testing.T, json, arrow bool) { runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } else if arrow { dbt.mustExec(forceARROW) } types := []string{"CHAR(255)", "VARCHAR(255)", "TEXT", "STRING"} in := make([]testUUID, len(types)) for i := range types { in[i] = newTestUUID() } for i, v := range types { t.Run(v, func(t *testing.T) { dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") dbt.mustExec("INSERT INTO test VALUES (?)", in[i]) rows := dbt.mustQuery("SELECT value FROM test") defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { var out testUUID assertNilF(t, rows.Scan(&out)) if in[i] != out { dbt.Errorf("%s: %s != %s", v, in, out) } } else { dbt.Errorf("%s: no data", v) } }) } dbt.mustExec("DROP TABLE IF EXISTS test") }) } type tcDateTimeTimestamp struct { dbtype string tlayout string tests []timeTest } type timeTest struct { s string // source date time string t time.Time // expected fetched data } func (tt timeTest) genQuery() string { return "SELECT '%s'::%s" } func (tt timeTest) run(t *testing.T, dbt *DBTest, dbtype, tlayout string) { var rows *RowsExtended query := fmt.Sprintf(tt.genQuery(), tt.s, dbtype) rows = dbt.mustQuery(query) defer rows.Close() var err error if !rows.Next() { err = rows.Err() if err == nil { err = fmt.Errorf("no data") } dbt.Errorf("%s: %s", dbtype, err) return } var dst any if err = rows.Scan(&dst); err != nil { dbt.Errorf("%s: %s", dbtype, err) return } switch val := dst.(type) { case []uint8: str := string(val) if str == tt.s { return } dbt.Errorf("%s to string: expected %q, got %q", dbtype, tt.s, str, ) case time.Time: if val.UnixNano() == tt.t.UnixNano() { return } t.Logf("source:%v, expected: %v, got:%v", tt.s, tt.t, val) dbt.Errorf("%s to string: expected %q, got %q", dbtype, tt.s, val.Format(tlayout), ) default: dbt.Errorf("%s: unhandled type %T (is '%v')", dbtype, val, val, ) } } func TestSimpleDateTimeTimestampFetch(t *testing.T) { testSimpleDateTimeTimestampFetch(t, false) } func testSimpleDateTimeTimestampFetch(t *testing.T, json bool) { var scan = func(rows *RowsExtended, cd any, ct any, cts any) { if err := rows.Scan(cd, ct, cts); err != nil { t.Fatal(err) } } var fetchTypes = []func(*RowsExtended){ func(rows *RowsExtended) { var cd, ct, cts time.Time scan(rows, &cd, &ct, &cts) }, func(rows *RowsExtended) { var cd, ct, cts time.Time scan(rows, &cd, &ct, &cts) }, } runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, f := range fetchTypes { rows := dbt.mustQuery("SELECT CURRENT_DATE(), CURRENT_TIME(), CURRENT_TIMESTAMP()") defer rows.Close() if rows.Next() { f(rows) } else { t.Fatal("no results") } } }) } func TestDateTime(t *testing.T) { testDateTime(t, false) } func testDateTime(t *testing.T, json bool) { afterTime := func(t time.Time, d string) time.Time { dur, err := time.ParseDuration(d) if err != nil { panic(err) } return t.Add(dur) } t0 := time.Time{} tstr0 := "0000-00-00 00:00:00.000000000" testcases := []tcDateTimeTimestamp{ {"DATE", format[:10], []timeTest{ {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)}, {t: time.Date(2, 8, 2, 0, 0, 0, 0, time.UTC), s: "0002-08-02"}, }}, {"TIME", format[11:19], []timeTest{ {t: afterTime(t0, "12345s")}, {t: t0, s: tstr0[11:19]}, }}, {"TIME(0)", format[11:19], []timeTest{ {t: afterTime(t0, "12345s")}, {t: t0, s: tstr0[11:19]}, }}, {"TIME(1)", format[11:21], []timeTest{ {t: afterTime(t0, "12345600ms")}, {t: t0, s: tstr0[11:21]}, }}, {"TIME(6)", format[11:], []timeTest{ {t: t0, s: tstr0[11:]}, }}, {"DATETIME", format[:19], []timeTest{ {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, }}, {"DATETIME(0)", format[:21], []timeTest{ {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, }}, {"DATETIME(1)", format[:21], []timeTest{ {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)}, }}, {"DATETIME(6)", format, []timeTest{ {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)}, }}, {"DATETIME(9)", format, []timeTest{ {t: time.Date(2011, 11, 20, 21, 27, 37, 123456789, time.UTC)}, }}, } runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { t.Run(setups.dbtype, func(t *testing.T) { for _, setup := range setups.tests { if setup.s == "" { // fill time string wherever Go can reliable produce it setup.s = setup.t.Format(setups.tlayout) } setup.run(t, dbt, setups.dbtype, setups.tlayout) } }) } }) } func TestTimestampLTZ(t *testing.T) { testTimestampLTZ(t, false) } func testTimestampLTZ(t *testing.T, json bool) { // Set session time zone in Los Angeles, same as machine createDSN(PSTLocation) location, err := time.LoadLocation(PSTLocation) if err != nil { t.Error(err) } testcases := []tcDateTimeTimestamp{ { dbtype: "TIMESTAMP_LTZ(9)", tlayout: format, tests: []timeTest{ { s: "2016-12-30 05:02:03", t: time.Date(2016, 12, 30, 5, 2, 3, 0, location), }, { s: "2016-12-30 05:02:03 -00:00", t: time.Date(2016, 12, 30, 5, 2, 3, 0, time.UTC), }, { s: "2017-05-12 00:51:42", t: time.Date(2017, 5, 12, 0, 51, 42, 0, location), }, { s: "2017-03-12 01:00:00", t: time.Date(2017, 3, 12, 1, 0, 0, 0, location), }, { s: "2017-03-13 04:00:00", t: time.Date(2017, 3, 13, 4, 0, 0, 0, location), }, { s: "2017-03-13 04:00:00.123456789", t: time.Date(2017, 3, 13, 4, 0, 0, 123456789, location), }, }, }, { dbtype: "TIMESTAMP_LTZ(8)", tlayout: format, tests: []timeTest{ { s: "2017-03-13 04:00:00.123456789", t: time.Date(2017, 3, 13, 4, 0, 0, 123456780, location), }, }, }, } runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { t.Run(setups.dbtype, func(t *testing.T) { for _, setup := range setups.tests { if setup.s == "" { // fill time string wherever Go can reliable produce it setup.s = setup.t.Format(setups.tlayout) } setup.run(t, dbt, setups.dbtype, setups.tlayout) } }) } }) // Revert timezone to UTC, which is default for the test suit createDSN("UTC") } func TestTimestampTZ(t *testing.T) { testTimestampTZ(t, false) } func testTimestampTZ(t *testing.T, json bool) { sflo := func(offsets string) (loc *time.Location) { r, err := LocationWithOffsetString(offsets) if err != nil { return time.UTC } return r } testcases := []tcDateTimeTimestamp{ { dbtype: "TIMESTAMP_TZ(9)", tlayout: format, tests: []timeTest{ { s: "2016-12-30 05:02:03 +07:00", t: time.Date(2016, 12, 30, 5, 2, 3, 0, sflo("+0700")), }, { s: "2017-05-23 03:56:41 -09:00", t: time.Date(2017, 5, 23, 3, 56, 41, 0, sflo("-0900")), }, }, }, } runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { t.Run(setups.dbtype, func(t *testing.T) { for _, setup := range setups.tests { if setup.s == "" { // fill time string wherever Go can reliable produce it setup.s = setup.t.Format(setups.tlayout) } setup.run(t, dbt, setups.dbtype, setups.tlayout) } }) } }) } func TestNULL(t *testing.T) { testNULL(t, false) } func testNULL(t *testing.T, json bool) { runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } nullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT NULL") if err != nil { dbt.Fatal(err) } defer nullStmt.Close() nonNullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT 1") if err != nil { dbt.Fatal(err) } defer nonNullStmt.Close() // NullBool var nb sql.NullBool // Invalid if err = nullStmt.QueryRow().Scan(&nb); err != nil { dbt.Fatal(err) } if nb.Valid { dbt.Error("valid NullBool which should be invalid") } // Valid if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { dbt.Fatal(err) } if !nb.Valid { dbt.Error("invalid NullBool which should be valid") } else if !nb.Bool { dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) } // NullFloat64 var nf sql.NullFloat64 // Invalid if err = nullStmt.QueryRow().Scan(&nf); err != nil { dbt.Fatal(err) } if nf.Valid { dbt.Error("valid NullFloat64 which should be invalid") } // Valid if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { dbt.Fatal(err) } if !nf.Valid { dbt.Error("invalid NullFloat64 which should be valid") } else if nf.Float64 != float64(1) { dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) } // NullInt64 var ni sql.NullInt64 // Invalid if err = nullStmt.QueryRow().Scan(&ni); err != nil { dbt.Fatal(err) } if ni.Valid { dbt.Error("valid NullInt64 which should be invalid") } // Valid if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { dbt.Fatal(err) } if !ni.Valid { dbt.Error("invalid NullInt64 which should be valid") } else if ni.Int64 != int64(1) { dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) } // NullString var ns sql.NullString // Invalid if err = nullStmt.QueryRow().Scan(&ns); err != nil { dbt.Fatal(err) } if ns.Valid { dbt.Error("valid NullString which should be invalid") } // Valid if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { dbt.Fatal(err) } if !ns.Valid { dbt.Error("invalid NullString which should be valid") } else if ns.String != `1` { dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") } // nil-bytes var b []byte // Read nil if err = nullStmt.QueryRow().Scan(&b); err != nil { dbt.Fatal(err) } if b != nil { dbt.Error("non-nil []byte which should be nil") } // Read non-nil if err = nonNullStmt.QueryRow().Scan(&b); err != nil { dbt.Fatal(err) } if b == nil { dbt.Error("nil []byte which should be non-nil") } // Insert nil b = nil success := false if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ? IS NULL", b).Scan(&success); err != nil { dbt.Fatal(err) } if !success { dbt.Error("inserting []byte(nil) as NULL failed") t.Fatal("stopping") } // Check input==output with input==nil b = nil if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { dbt.Fatal(err) } if b != nil { dbt.Error("non-nil echo from nil input") } // Check input==output with input!=nil b = []byte("") if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { dbt.Fatal(err) } if b == nil { dbt.Error("nil echo from non-nil input") } // Insert NULL dbt.mustExec("CREATE OR REPLACE TABLE test (dummmy1 int, value int, dummy2 int)") dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) var dummy1, out, dummy2 any rows := dbt.mustQuery("SELECT * FROM test") defer func() { assertNilF(t, rows.Close()) }() if rows.Next() { assertNilF(t, rows.Scan(&dummy1, &out, &dummy2)) if out != nil { dbt.Errorf("%v != nil", out) } } else { dbt.Error("no data") } }) } func TestVariant(t *testing.T) { testVariant(t, false) } func testVariant(t *testing.T, json bool) { runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } rows := dbt.mustQuery(`select parse_json('[{"id":1, "name":"test1"},{"id":2, "name":"test2"}]')`) defer rows.Close() var v string if rows.Next() { if err := rows.Scan(&v); err != nil { t.Fatal(err) } } else { t.Fatal("no rows") } }) } func TestArray(t *testing.T) { testArray(t, false) } func testArray(t *testing.T, json bool) { runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } rows := dbt.mustQuery(`select as_array(parse_json('[{"id":1, "name":"test1"},{"id":2, "name":"test2"}]'))`) defer rows.Close() var v string if rows.Next() { if err := rows.Scan(&v); err != nil { t.Fatal(err) } } else { t.Fatal("no rows") } }) } func TestLargeSetResult(t *testing.T) { customJSONDecoderEnabled = false testLargeSetResult(t, 100000, false) } func testLargeSetResult(t *testing.T, numrows int, json bool) { runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } rows := dbt.mustQuery(fmt.Sprintf(selectRandomGenerator, numrows)) defer rows.Close() cnt := 0 var idx int var v string for rows.Next() { if err := rows.Scan(&idx, &v); err != nil { t.Fatal(err) } cnt++ } logger.Infof("NextResultSet: %v", rows.NextResultSet()) if cnt != numrows { dbt.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) } }) } // TestPingpongQuery validates that the driver's ping-pong keepalive protocol // maintains the connection during long-running queries. TIMELIMIT=>60 must be // long enough to trigger the ping-pong mechanism. Do not reduce this value. func TestPingpongQuery(t *testing.T) { runDBTest(t, func(dbt *DBTest) { numrows := 1 rows := dbt.mustQuery("SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 60))") defer rows.Close() cnt := 0 for rows.Next() { cnt++ } if cnt != numrows { dbt.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) } }) } func TestDML(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE OR REPLACE TABLE test(c1 int, c2 string)") if err := insertData(dbt, false); err != nil { dbt.Fatalf("failed to insert data: %v", err) } results, err := queryTest(dbt) if err != nil { dbt.Fatalf("failed to query test table: %v", err) } if len(*results) != 0 { dbt.Fatalf("number of returned data didn't match. expected 0, got: %v", len(*results)) } if err = insertData(dbt, true); err != nil { dbt.Fatalf("failed to insert data: %v", err) } results, err = queryTest(dbt) if err != nil { dbt.Fatalf("failed to query test table: %v", err) } if len(*results) != 2 { dbt.Fatalf("number of returned data didn't match. expected 2, got: %v", len(*results)) } }) } func insertData(dbt *DBTest, commit bool) error { tx, err := dbt.conn.BeginTx(context.Background(), nil) if err != nil { dbt.Fatalf("failed to begin transaction: %v", err) } res, err := tx.Exec("INSERT INTO test VALUES(1, 'test1'), (2, 'test2')") if err != nil { dbt.Fatalf("failed to insert value into test: %v", err) } n, err := res.RowsAffected() if err != nil { dbt.Fatalf("failed to rows affected: %v", err) } if n != 2 { dbt.Fatalf("failed to insert value into test. expected: 2, got: %v", n) } results, err := queryTestTx(tx) if err != nil { dbt.Fatalf("failed to query test table: %v", err) } if len(*results) != 2 { dbt.Fatalf("number of returned data didn't match. expected 2, got: %v", len(*results)) } if commit { if err = tx.Commit(); err != nil { return err } } else { if err = tx.Rollback(); err != nil { return err } } return err } func queryTestTx(tx *sql.Tx) (*map[int]string, error) { var c1 int var c2 string rows, err := tx.Query("SELECT c1, c2 FROM test") if err != nil { return nil, err } defer rows.Close() results := make(map[int]string, 2) for rows.Next() { if err = rows.Scan(&c1, &c2); err != nil { return nil, err } results[c1] = c2 } return &results, nil } func queryTest(dbt *DBTest) (*map[int]string, error) { var c1 int var c2 string rows, err := dbt.query("SELECT c1, c2 FROM test") if err != nil { return nil, err } defer rows.Close() results := make(map[int]string, 2) for rows.Next() { if err = rows.Scan(&c1, &c2); err != nil { return nil, err } results[c1] = c2 } return &results, nil } func TestCancelQuery(t *testing.T) { runDBTest(t, func(dbt *DBTest) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() _, err := dbt.conn.QueryContext(ctx, "CALL SYSTEM$WAIT(10, 'SECONDS')") if err == nil { dbt.Fatal("No timeout error returned") } if !errors.Is(err, context.DeadlineExceeded) { dbt.Fatalf("Timeout error mismatch: expect %v, receive %v", context.DeadlineExceeded, err.Error()) } }) } func TestCancelQueryWithConnectionContext(t *testing.T) { testCases := []struct { name string setupConnection func(ctx context.Context, db *sql.DB) error }{ { name: "explicit connection", setupConnection: func(ctx context.Context, db *sql.DB) error { _, err := db.Conn(ctx) return err }, }, { name: "implicit connection", setupConnection: func(ctx context.Context, db *sql.DB) error { _, err := db.ExecContext(ctx, "SELECT 1") return err }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { db := openDB(t) defer db.Close() ctx, cancelConnectionContext := context.WithCancel(context.Background()) err := tc.setupConnection(ctx, db) assertNilF(t, err, "connection setup should succeed") cancelConnectionContext() _, err = db.ExecContext(context.Background(), "SELECT 1") assertNilF(t, err, "subsequent SELECT should work after cancelled connection context") cwd, err := os.Getwd() assertNilF(t, err, "Failed to get current working directory") filePath := filepath.Join(cwd, "test_data", "put_get_1.txt") putQuery := fmt.Sprintf("PUT file://%v @~/%v", filePath, "test_cancel_query_with_connection_context.txt") _, err = db.ExecContext(context.Background(), putQuery) assertNilF(t, err, "PUT statement should work after cancelled connection context") }) } } func TestPing(t *testing.T) { runDBTest(t, func(dbt *DBTest) { if err := dbt.conn.PingContext(context.Background()); err != nil { t.Fatalf("failed to ping. err: %v", err) } if err := dbt.conn.PingContext(context.Background()); err != nil { t.Fatalf("failed to ping with context. err: %v", err) } if err := dbt.conn.Close(); err != nil { t.Fatalf("failed to close db. err: %v", err) } if err := dbt.conn.PingContext(context.Background()); err == nil { t.Fatal("should have failed to ping") } if err := dbt.conn.PingContext(context.Background()); err == nil { t.Fatal("should have failed to ping with context") } }) } func TestDoubleDollar(t *testing.T) { // no escape is required for dollar signs runDBTest(t, func(dbt *DBTest) { sql := `create or replace function dateErr(I double) returns date language javascript strict as $$ var x = [ 0, "1400000000000", "2013-04-05", [], [1400000000000], "x1234", Number.NaN, null, undefined, {}, [1400000000000,1500000000000] ]; return x[I]; $$ ;` dbt.mustExec(sql) }) } func TestTimezoneSessionParameter(t *testing.T) { createDSN(PSTLocation) runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryT(t, "SHOW PARAMETERS LIKE 'TIMEZONE'") defer rows.Close() if !rows.Next() { t.Fatal("failed to get timezone.") } p, err := ScanSnowflakeParameter(rows.rows) if err != nil { t.Errorf("failed to run get timezone value. err: %v", err) } if p.Value != PSTLocation { t.Errorf("failed to get an expected timezone. got: %v", p.Value) } }) createDSN("UTC") } func TestLargeSetResultCancel(t *testing.T) { runDBTest(t, func(dbt *DBTest) { c := make(chan error) ctx, cancel := context.WithCancel(context.Background()) go func() { // attempt to run a 100 seconds query, but it should be canceled in 1 second timelimit := 100 rows, err := dbt.conn.QueryContext( ctx, fmt.Sprintf("SELECT COUNT(*) FROM TABLE(GENERATOR(timelimit=>%v))", timelimit)) if err != nil { c <- err return } defer rows.Close() c <- nil }() // cancel after 1 second time.Sleep(time.Second) cancel() ret := <-c if !errors.Is(ret, context.Canceled) { t.Fatalf("failed to cancel. err: %v", ret) } close(c) }) } func TestValidateDatabaseParameter(t *testing.T) { // Parse the global DSN to get base configuration with proper authentication cfg, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse global dsn") } testcases := []struct { description string dbname string schemaname string params map[string]string errorCode int }{ { description: "invalid_database_and_schema", dbname: "NOT_EXISTS", schemaname: "NOT_EXISTS", errorCode: ErrObjectNotExistOrAuthorized, }, { description: "invalid_schema", dbname: cfg.Database, schemaname: "NOT_EXISTS", errorCode: ErrObjectNotExistOrAuthorized, }, { description: "invalid_warehouse", dbname: cfg.Database, schemaname: cfg.Schema, params: map[string]string{ "warehouse": "NOT_EXIST", }, errorCode: ErrObjectNotExistOrAuthorized, }, { description: "invalid_role", dbname: cfg.Database, schemaname: cfg.Schema, params: map[string]string{ "role": "NOT_EXIST", }, errorCode: ErrRoleNotExist, }, } for idx, tc := range testcases { t.Run(tc.description, func(t *testing.T) { // Create a new config based on the global config (which already has proper authentication) testCfg := *cfg // Copy the config with proper authentication from global DSN testCfg.Database = tc.dbname testCfg.Schema = tc.schemaname // Override with test-specific parameters testCfg.Warehouse = tc.params["warehouse"] testCfg.Role = tc.params["role"] db := sql.OpenDB(NewConnector(SnowflakeDriver{}, testCfg)) defer db.Close() if _, err = db.Exec("SELECT 1"); err == nil { t.Fatal("should cause an error.") } if driverErr, ok := err.(*SnowflakeError); ok { if driverErr.Number != tc.errorCode { maskedErr := maskSecrets(err.Error()) t.Errorf("got unexpected error: %s in test case %d", maskedErr, idx) } } }) } } func TestSpecifyWarehouseDatabase(t *testing.T) { // Parse the global DSN to get base configuration with proper authentication cfg, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse global dsn") } // Override with test-specific settings cfg.Warehouse = warehouse db := sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfg)) defer db.Close() if _, err = db.Exec("SELECT 1"); err != nil { maskedErr := maskSecrets(err.Error()) t.Fatalf("failed to execute a select 1: %s", maskedErr) } } func TestFetchNil(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery("SELECT * FROM values(3,4),(null, 5) order by 2") defer rows.Close() var c1 sql.NullInt64 var c2 sql.NullInt64 var results []sql.NullInt64 for rows.Next() { if err := rows.Scan(&c1, &c2); err != nil { dbt.Fatal(err) } results = append(results, c1) } if results[1].Valid { t.Errorf("First element of second row must be nil (NULL). %v", results) } }) } func TestPingInvalidHost(t *testing.T) { config := Config{ Account: "NOT_EXISTS", User: "BOGUS_USER", Password: "barbar", LoginTimeout: 10 * time.Second, } testURL, err := DSN(&config) if err != nil { t.Fatalf("failed to parse config. config: %v, err: %v", config, err) } db, err := sql.Open("snowflake", testURL) assertNilF(t, err, "failed to initialize the connection") if err = db.PingContext(context.Background()); err == nil { t.Fatal("should cause an error") } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } if driverErr, ok := err.(*SnowflakeError); !ok || ok && isFailToConnectOrAuthErr(driverErr) { // Failed to connect error t.Fatalf("error didn't match") } } func TestOpenWithConfig(t *testing.T) { config := Config{ Account: "testaccount", User: "testuser", Password: "testpassword", Authenticator: AuthTypeSnowflake, // Force password authentication PrivateKey: nil, // Ensure no private key } testURL, err := DSN(&config) if err != nil { t.Fatalf("failed to parse config. config: %v, err: %v", config, err) } db, err := sql.Open("snowflake", testURL) assertNilF(t, err, "failed to initialize the connection") if err = db.PingContext(context.Background()); err == nil { t.Fatal("should cause an error") } if strings.Contains(err.Error(), "HTTP Status: 513. Hanging?") { return } if driverErr, ok := err.(*SnowflakeError); !ok || ok && isFailToConnectOrAuthErr(driverErr) { // Failed to connect error t.Fatalf("error didn't match") } } func TestOpenWithConfigCancel(t *testing.T) { wiremock.registerMappings(t, wiremockMapping{filePath: "auth/password/successful_flow_with_telemetry.json", params: map[string]string{"%CLIENT_TELEMETRY_ENABLED%": "true"}}, ) driver := SnowflakeDriver{} config := wiremock.connectionConfig() blockingRoundTripper := newBlockingRoundTripper(createTestNoRevocationTransport(), 0) countingRoundTripper := newCountingRoundTripper(blockingRoundTripper) config.Transporter = countingRoundTripper t.Run("canceled during request:login-request", func(t *testing.T) { blockingRoundTripper.setPathBlockTime("/session/v1/login-request", 50*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() _, err := driver.OpenWithConfig(ctx, *config) assertErrIsE(t, err, context.DeadlineExceeded) assertEqualE(t, countingRoundTripper.totalRequestsByPath("/session/v1/login-request"), 1) assertEqualE(t, countingRoundTripper.totalRequestsByPath("/telemetry/send"), 0) }) t.Run("canceled during request:telemetry/send", func(t *testing.T) { blockingRoundTripper.reset() countingRoundTripper.reset() blockingRoundTripper.setPathBlockTime("/telemetry/send", 400*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() _, err := driver.OpenWithConfig(ctx, *config) assertErrIsE(t, err, context.DeadlineExceeded) assertEqualE(t, countingRoundTripper.totalRequestsByPath("/session/v1/login-request"), 1) assertEqualE(t, countingRoundTripper.totalRequestsByPath("/telemetry/send"), 1) }) } func TestOpenWithInvalidConfig(t *testing.T) { config, err := ParseDSN("u:p@h?tmpDirPath=%2Fnon-existing") if err != nil { t.Fatalf("failed to parse dsn. err: %v", err) } config.Authenticator = AuthTypeSnowflake config.PrivateKey = nil driver := SnowflakeDriver{} _, err = driver.OpenWithConfig(context.Background(), *config) if err == nil || !strings.Contains(err.Error(), "/non-existing") { t.Fatalf("should fail on missing directory") } } func TestOpenWithTransport(t *testing.T) { config, err := ParseDSN(dsn) if err != nil { t.Fatalf("failed to parse dsn. err: %v", err) } countingTransport := newCountingRoundTripper(createTestNoRevocationTransport()) var transport http.RoundTripper = countingTransport config.Transporter = transport driver := SnowflakeDriver{} db, err := driver.OpenWithConfig(context.Background(), *config) assertNilF(t, err, fmt.Sprintf("failed to open with config. config: %v", config)) conn := db.(*snowflakeConn) if conn.rest.Client.Transport != transport { t.Fatal("transport doesn't match") } db.Close() if countingTransport.totalRequests() == 0 { t.Fatal("transport did not receive any requests") } // Test that transport override also works in OCSP checks disabled. countingTransport.reset() config.DisableOCSPChecks = true db, err = driver.OpenWithConfig(context.Background(), *config) assertNilF(t, err, fmt.Sprintf("failed to open with config. config: %v", config)) conn = db.(*snowflakeConn) if conn.rest.Client.Transport != transport { t.Fatal("transport doesn't match") } db.Close() if countingTransport.totalRequests() == 0 { t.Fatal("transport did not receive any requests") } } func TestClientSessionKeepAliveParameter(t *testing.T) { // This test doesn't really validate the CLIENT_SESSION_KEEP_ALIVE functionality but simply checks // the session parameter. customDsn := dsn + "&client_session_keep_alive=true" runDBTestWithConfig(t, &testConfig{dsn: customDsn}, func(dbt *DBTest) { rows := dbt.mustQuery("SHOW PARAMETERS LIKE 'CLIENT_SESSION_KEEP_ALIVE'") defer rows.Close() if !rows.Next() { t.Fatal("failed to get timezone.") } p, err := ScanSnowflakeParameter(rows.rows) assertNilF(t, err, "failed to run get client_session_keep_alive value") if p.Value != "true" { t.Fatalf("failed to get an expected client_session_keep_alive. got: %v", maskSecrets(p.Value)) } rows2 := dbt.mustQuery("select count(*) from table(generator(timelimit=>30))") defer rows2.Close() }) } func TestTimePrecision(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table z3 (t1 time(5))") rows := dbt.mustQuery("select * from z3") defer rows.Close() cols, err := rows.ColumnTypes() assertNilE(t, err, "failed to get column types") if pres, _, ok := cols[0].DecimalSize(); pres != 5 || !ok { t.Fatalf("Wrong value returned. Got %v instead of 5.", pres) } }) } func initPoolWithSize(t *testing.T, db *sql.DB, poolSize int) { wg := sync.WaitGroup{} wg.Add(poolSize) for range poolSize { go func(wg *sync.WaitGroup) { defer wg.Done() time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) runSmokeQuery(t, db) }(&wg) } wg.Wait() } func initPoolWithSizeAndReturnErrors(db *sql.DB, poolSize int) []error { wg := sync.WaitGroup{} wg.Add(poolSize) errMu := sync.Mutex{} var errs []error for i := range poolSize { go func(wg *sync.WaitGroup) { defer wg.Done() // Wiremock handles incoming request in parallel, in non-atomic way. // If two requests start at the same time, they both see the same scenario state, // even if it should be changed after the request is matched to a particular scenario state. time.Sleep(time.Duration(i * 5 * int(time.Millisecond))) err := runSmokeQueryAndReturnErrors(db) if err != nil { errMu.Lock() errs = append(errs, err) errMu.Unlock() } }(&wg) } wg.Wait() return errs } func runSelectCurrentUser(t *testing.T, db *sql.DB) string { rows, err := db.Query("SELECT current_user()") assertNilF(t, err) defer rows.Close() assertTrueF(t, rows.Next()) var v string err = rows.Scan(&v) assertNilF(t, err) return v } func runSmokeQuery(t *testing.T, db *sql.DB) { rows, err := db.Query("SELECT 1") assertNilF(t, err) defer rows.Close() assertTrueF(t, rows.Next()) var v int err = rows.Scan(&v) assertNilF(t, err) assertEqualE(t, v, 1) } func runSmokeQueryAndReturnErrors(db *sql.DB) error { rows, err := db.Query("SELECT 1") if err != nil { return err } defer rows.Close() if !rows.Next() { return fmt.Errorf("no rows") } var v int err = rows.Scan(&v) if err != nil { return err } if v != 1 { return fmt.Errorf("value mismatch. expected 1, got %v", v) } return nil } func runSmokeQueryWithConn(t *testing.T, conn *sql.Conn) { rows, err := conn.QueryContext(context.Background(), "SELECT 1") assertNilF(t, err) defer rows.Close() assertTrueF(t, rows.Next()) var v int err = rows.Scan(&v) assertNilF(t, err) assertEqualE(t, v, 1) } ================================================ FILE: dsn.go ================================================ package gosnowflake import ( sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" ) // Type aliases — re-exported from internal/config for backward compatibility. type ( // Config is a set of configuration parameters Config = sfconfig.Config // ConfigBool is a type to represent true or false in the Config ConfigBool = sfconfig.Bool // ConfigParam is used to bind the name of the Config field with the environment variable and set the requirement for it ConfigParam = sfconfig.Param ) // ConfigBool constants — re-exported from internal/config. const ( // configBoolNotSet represents the default value for the config field which is not set configBoolNotSet = sfconfig.BoolNotSet // ConfigBoolTrue represents true for the config field ConfigBoolTrue = sfconfig.BoolTrue // ConfigBoolFalse represents false for the config field ConfigBoolFalse = sfconfig.BoolFalse ) // DSN constructs a DSN for Snowflake db. func DSN(cfg *Config) (string, error) { return sfconfig.DSN(cfg) } // ParseDSN parses the DSN string to a Config. func ParseDSN(dsn string) (*Config, error) { return sfconfig.ParseDSN(dsn) } // GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) { return sfconfig.GetConfigFromEnv(properties) } func transportConfigFor(tt transportType) *transportConfig { return defaultTransportConfigs.forTransportType(tt) } ================================================ FILE: easy_logging.go ================================================ package gosnowflake import ( "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "io" "os" "path" "runtime" "strings" "sync" loggerinternal "github.com/snowflakedb/gosnowflake/v2/internal/logger" ) type initTrials struct { everTriedToInitialize bool clientConfigFileInput string configureCounter int mu sync.Mutex } var easyLoggingInitTrials = initTrials{ everTriedToInitialize: false, clientConfigFileInput: "", configureCounter: 0, mu: sync.Mutex{}, } func (i *initTrials) setInitTrial(clientConfigFileInput string) { i.everTriedToInitialize = true i.clientConfigFileInput = clientConfigFileInput } func (i *initTrials) increaseReconfigureCounter() { i.configureCounter++ } func initEasyLogging(clientConfigFileInput string) error { easyLoggingInitTrials.mu.Lock() defer easyLoggingInitTrials.mu.Unlock() if !allowedToInitialize(clientConfigFileInput) { logger.Info("Skipping Easy Logging initialization as it is not allowed to initialize") return nil } logger.Infof("Trying to initialize Easy Logging") config, configPath, err := getClientConfig(clientConfigFileInput) if err != nil { logger.Errorf("Failed to initialize Easy Logging, err: %s", err) return easyLoggingInitError(err) } if config == nil { logger.Info("Easy Logging is disabled as no config has been found") easyLoggingInitTrials.setInitTrial(clientConfigFileInput) return nil } var logLevel string logLevel, err = getLogLevel(config.Common.LogLevel) if err != nil { logger.Errorf("Failed to initialize Easy Logging, err: %s", err) return easyLoggingInitError(err) } var logPath string logPath, err = getLogPath(config.Common.LogPath) if err != nil { logger.Errorf("Failed to initialize Easy Logging, err: %s", err) return easyLoggingInitError(err) } logger.Infof("Initializing Easy Logging with logPath=%s and logLevel=%s from file: %s", logPath, logLevel, configPath) err = reconfigureEasyLogging(logLevel, logPath) if err != nil { logger.Errorf("Failed to initialize Easy Logging, err: %s", err) } easyLoggingInitTrials.setInitTrial(clientConfigFileInput) easyLoggingInitTrials.increaseReconfigureCounter() return err } func easyLoggingInitError(err error) error { return &SnowflakeError{ Number: ErrCodeClientConfigFailed, Message: errors2.ErrMsgClientConfigFailed, MessageArgs: []any{err.Error()}, } } func reconfigureEasyLogging(logLevel string, logPath string) error { // don't allow any change if a non-default logger is already being used. currentLogger := GetLogger() if !loggerinternal.IsEasyLoggingLogger(currentLogger) { logger.Warnf("Cannot reconfigure easy logging: custom logger is in use") return nil // cannot replace custom logger } newLogger := CreateDefaultLogger() err := newLogger.SetLogLevel(logLevel) if err != nil { return err } var output io.Writer var file *os.File output, file, err = createLogWriter(logPath) if err != nil { return err } newLogger.SetOutput(output) err = loggerinternal.CloseFileOnLoggerReplace(newLogger, file) if err != nil { logger.Errorf("%s", err) } // Actually set the new logger as the global logger if err := SetLogger(newLogger); err != nil { logger.Errorf("Failed to set new logger: %s", err) return err } return nil } func createLogWriter(logPath string) (io.Writer, *os.File, error) { if strings.EqualFold(logPath, "STDOUT") { return os.Stdout, nil, nil } logFileName := path.Join(logPath, "snowflake.log") file, err := os.OpenFile(logFileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0640) if err != nil { return nil, nil, err } return file, file, nil } func allowedToInitialize(clientConfigFileInput string) bool { triedToInitializeWithoutConfigFile := easyLoggingInitTrials.everTriedToInitialize && easyLoggingInitTrials.clientConfigFileInput == "" isAllowedToInitialize := !easyLoggingInitTrials.everTriedToInitialize || (triedToInitializeWithoutConfigFile && clientConfigFileInput != "") if !isAllowedToInitialize && easyLoggingInitTrials.clientConfigFileInput != clientConfigFileInput { logger.Warnf("Easy logging will not be configured for CLIENT_CONFIG_FILE=%s because it was previously configured for a different client config", clientConfigFileInput) } return isAllowedToInitialize } func getLogLevel(logLevel string) (string, error) { if logLevel == "" { logger.Warn("LogLevel in client config not found. Using default value: OFF") return levelOff, nil } return toLogLevel(logLevel) } func getLogPath(logPath string) (string, error) { logPathOrDefault := logPath if logPath == "" { homeDir, err := os.UserHomeDir() if err != nil { return "", fmt.Errorf("user home directory is not accessible, err: %w", err) } logPathOrDefault = homeDir logger.Warnf("LogPath in client config not found. Using user home directory as a default value: %s", logPathOrDefault) } pathWithGoSubdir := path.Join(logPathOrDefault, "go") exists, err := dirExists(pathWithGoSubdir) if err != nil { return "", err } if !exists { err = os.MkdirAll(pathWithGoSubdir, 0700) if err != nil { return "", err } } logDirPermValid, perm, err := isDirAccessCorrect(pathWithGoSubdir) if err != nil { return "", err } if !logDirPermValid { logger.Warnf("Log directory: %s could potentially be accessed by others. Directory chmod: 0%o", pathWithGoSubdir, *perm) } return pathWithGoSubdir, nil } func isDirAccessCorrect(dirPath string) (bool, *os.FileMode, error) { if runtime.GOOS == "windows" { return true, nil, nil } dirStat, err := os.Stat(dirPath) if err != nil { return false, nil, err } perm := dirStat.Mode().Perm() if perm != 0700 { return false, &perm, nil } return true, &perm, nil } func dirExists(dirPath string) (bool, error) { stat, err := os.Stat(dirPath) if err == nil { return stat.IsDir(), nil } if errors.Is(err, os.ErrNotExist) { return false, nil } return false, err } ================================================ FILE: easy_logging_test.go ================================================ package gosnowflake import ( "context" "fmt" "os" "path" "path/filepath" "strings" "sync" "testing" loggerinternal "github.com/snowflakedb/gosnowflake/v2/internal/logger" ) func TestInitializeEasyLoggingOnlyOnceWhenConfigGivenAsAParameter(t *testing.T) { skipOnWindows(t, "Doesn't work on Windows") defer cleanUp() origLogLevel := logger.GetLogLevel() defer logger.SetLogLevel(origLogLevel) logger.SetLogLevel("error") logDir := t.TempDir() logLevel := levelError contents := createClientConfigContent(logLevel, logDir) configFilePath := createFile(t, "config.json", contents, logDir) easyLoggingInitTrials.reset() err := openWithClientConfigFile(t, configFilePath) assertNilF(t, err, "open config error") assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), logLevel, "error log level check") assertEqualE(t, easyLoggingInitTrials.configureCounter, 1) err = openWithClientConfigFile(t, "") assertNilF(t, err, "open config error") err = openWithClientConfigFile(t, configFilePath) assertNilF(t, err, "open config error") err = openWithClientConfigFile(t, "/another-config.json") assertNilF(t, err, "open config error") assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), logLevel, "error log level check") assertEqualE(t, easyLoggingInitTrials.configureCounter, 1) } func TestConfigureEasyLoggingOnlyOnceWhenInitializedWithoutConfigFilePath(t *testing.T) { skipOnWindows(t, "Doesn't work on Windows") skipOnMissingHome(t) origLogLevel := logger.GetLogLevel() defer logger.SetLogLevel(origLogLevel) logger.SetLogLevel("error") appExe, err := os.Executable() assertNilF(t, err, "application exe not accessible") userHome, err := os.UserHomeDir() assertNilF(t, err, "user home directory not accessible") testcases := []struct { name string dir string }{ { name: "user home directory", dir: userHome, }, { name: "application directory", dir: filepath.Dir(appExe), }, } for _, test := range testcases { t.Run(test.name, func(t *testing.T) { defer cleanUp() logDir := t.TempDir() assertNilF(t, err, "user home directory error") logLevel := levelError contents := createClientConfigContent(logLevel, logDir) configFilePath := createFile(t, defaultConfigName, contents, test.dir) defer os.Remove(configFilePath) easyLoggingInitTrials.reset() err = openWithClientConfigFile(t, "") assertNilF(t, err, "open config error") err = openWithClientConfigFile(t, "") assertNilF(t, err, "open config error") assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), logLevel, "error log level check") assertEqualE(t, easyLoggingInitTrials.configureCounter, 1) }) } } func TestReconfigureEasyLoggingIfConfigPathWasNotGivenForTheFirstTime(t *testing.T) { skipOnWindows(t, "Doesn't work on Windows") skipOnMissingHome(t) defer cleanUp() origLogLevel := logger.GetLogLevel() defer logger.SetLogLevel(origLogLevel) logger.SetLogLevel("error") configDir, err := os.UserHomeDir() logDir := t.TempDir() assertNilF(t, err, "user home directory error") homeConfigLogLevel := levelError homeConfigContent := createClientConfigContent(homeConfigLogLevel, logDir) homeConfigFilePath := createFile(t, defaultConfigName, homeConfigContent, configDir) defer os.Remove(homeConfigFilePath) customLogLevel := levelWarn customFileContent := createClientConfigContent(customLogLevel, logDir) customConfigFilePath := createFile(t, "config.json", customFileContent, configDir) easyLoggingInitTrials.reset() err = openWithClientConfigFile(t, "") logger.Error("Error message") assertNilF(t, err, "open config error") assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), homeConfigLogLevel, "tmp dir log level check") assertEqualE(t, easyLoggingInitTrials.configureCounter, 1) err = openWithClientConfigFile(t, customConfigFilePath) logger.Error("Warning message") assertNilF(t, err, "open config error") assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), customLogLevel, "custom dir log level check") assertEqualE(t, easyLoggingInitTrials.configureCounter, 2) var logContents []byte logContents, err = os.ReadFile(path.Join(logDir, "go", "snowflake.log")) assertNilF(t, err, "read file error") logs := notEmptyLines(string(logContents)) assertEqualE(t, len(logs), 2, "number of logs") } func TestEasyLoggingFailOnUnknownLevel(t *testing.T) { defer cleanUp() dir := t.TempDir() easyLoggingInitTrials.reset() configContent := createClientConfigContent("something_unknown", dir) configFilePath := createFile(t, "config.json", configContent, dir) err := openWithClientConfigFile(t, configFilePath) assertNotNilF(t, err, "open config error") assertStringContainsE(t, err.Error(), fmt.Sprint(ErrCodeClientConfigFailed), "error code") assertStringContainsE(t, err.Error(), "parsing client config failed", "error message") } func TestEasyLoggingFailOnNotExistingConfigFile(t *testing.T) { defer cleanUp() easyLoggingInitTrials.reset() err := openWithClientConfigFile(t, "/not-existing-file.json") assertNotNilF(t, err, "open config error") assertStringContainsE(t, err.Error(), fmt.Sprint(ErrCodeClientConfigFailed), "error code") assertStringContainsE(t, err.Error(), "parsing client config failed", "error message") } func TestLogToConfiguredFile(t *testing.T) { skipOnWindows(t, "Doesn't work on Windows") defer cleanUp() origLogLevel := logger.GetLogLevel() defer logger.SetLogLevel(origLogLevel) logger.SetLogLevel("error") dir := t.TempDir() easyLoggingInitTrials.reset() configContent := createClientConfigContent(levelWarn, dir) configFilePath := createFile(t, "config.json", configContent, dir) logFilePath := path.Join(dir, "go", "snowflake.log") err := openWithClientConfigFile(t, configFilePath) assertNilF(t, err, "open config error") logger.Error("Error message") logger.Warn("Warning message") logger.Info("Info message") logger.Trace("Trace message") var logContents []byte logContents, err = os.ReadFile(logFilePath) assertNilF(t, err, "read file error") logs := notEmptyLines(string(logContents)) assertEqualE(t, len(logs), 2, "number of logs") errorLogs := filterStrings(logs, func(val string) bool { return strings.Contains(val, "level=ERROR") }) assertEqualE(t, len(errorLogs), 1, "error logs count") warningLogs := filterStrings(logs, func(val string) bool { return strings.Contains(val, "level=WARN") }) assertEqualE(t, len(warningLogs), 1, "warning logs count") } func TestDataRace(t *testing.T) { n := 10 wg := sync.WaitGroup{} wg.Add(n) for range make([]int, n) { go func() { defer wg.Done() err := initEasyLogging("") assertNilF(t, err, "no error from db") }() } wg.Wait() } func notEmptyLines(lines string) []string { notEmptyFunc := func(val string) bool { return val != "" } return filterStrings(strings.Split(strings.ReplaceAll(lines, "\r\n", "\n"), "\n"), notEmptyFunc) } func cleanUp() { newLogger := CreateDefaultLogger() if _, ok := logger.(loggerinternal.EasyLoggingSupport); ok { SetLogger(newLogger) } easyLoggingInitTrials.reset() } func toClientConfigLevel(logLevel string) string { logLevelUpperCase := strings.ToUpper(logLevel) switch strings.ToUpper(logLevel) { case "WARNING": return levelWarn case levelOff, levelError, levelWarn, levelInfo, levelDebug, levelTrace: return logLevelUpperCase default: return "" } } func filterStrings(values []string, keep func(string) bool) []string { var filteredStrings []string for _, val := range values { if keep(val) { filteredStrings = append(filteredStrings, val) } } return filteredStrings } func defaultConfig(t *testing.T) *Config { config, err := ParseDSN(dsn) assertNilF(t, err, "parse dsn error") return config } func openWithClientConfigFile(t *testing.T, clientConfigFile string) error { driver := SnowflakeDriver{} config := defaultConfig(t) config.ClientConfigFile = clientConfigFile _, err := driver.OpenWithConfig(context.Background(), *config) return err } func (i *initTrials) reset() { i.mu.Lock() defer i.mu.Unlock() i.everTriedToInitialize = false i.clientConfigFileInput = "" i.configureCounter = 0 } ================================================ FILE: encrypt_util.go ================================================ package gosnowflake import ( "bytes" "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "encoding/json" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "io" "os" "strconv" ) const gcmIvLengthInBytes = 12 var ( defaultKeyAad = make([]byte, 0) defaultDataAad = make([]byte, 0) ) // override default behavior for wrapper func (ew *encryptionWrapper) UnmarshalJSON(data []byte) error { // if GET, unmarshal slice of encryptionMaterial if err := json.Unmarshal(data, &ew.EncryptionMaterials); err == nil { return err } // else (if PUT), unmarshal the encryptionMaterial itself return json.Unmarshal(data, &ew.snowflakeFileEncryption) } // encryptStreamCBC encrypts a stream buffer using AES128 block cipher in CBC mode // with PKCS5 padding func encryptStreamCBC( sfe *snowflakeFileEncryption, src io.Reader, out io.Writer, chunkSize int) (*encryptMetadata, error) { if chunkSize == 0 { chunkSize = aes.BlockSize * 4 * 1024 } kek, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey) if err != nil { return nil, err } keySize := len(kek) fileKey := getSecureRandom(keySize) block, err := aes.NewCipher(fileKey) if err != nil { return nil, err } dataIv := getSecureRandom(block.BlockSize()) mode := cipher.NewCBCEncrypter(block, dataIv) cipherText := make([]byte, chunkSize) chunk := make([]byte, chunkSize) // encrypt file with CBC padded := false for { // read the stream buffer up to len(chunk) bytes into chunk // note that all spaces in chunk may be used even if Read() returns n < len(chunk) n, err := src.Read(chunk) if err != nil && err != io.EOF { return nil, fmt.Errorf("reading: %w", err) } if n == 0 { break } if n%aes.BlockSize != 0 { // add padding to the end of the chunk and update the length n chunk = padBytesLength(chunk[:n], aes.BlockSize) n = len(chunk) padded = true } // make sure only n bytes of chunk is used mode.CryptBlocks(cipherText, chunk[:n]) if _, err := out.Write(cipherText[:n]); err != nil { return nil, err } } // add padding if not yet added if !padded { padding := bytes.Repeat([]byte(string(rune(aes.BlockSize))), aes.BlockSize) mode.CryptBlocks(cipherText, padding) if _, err := out.Write(cipherText[:len(padding)]); err != nil { return nil, err } } // encrypt key with ECB fileKey = padBytesLength(fileKey, block.BlockSize()) encryptedFileKey := make([]byte, len(fileKey)) if err = encryptECB(encryptedFileKey, fileKey, kek); err != nil { return nil, err } matDesc := materialDescriptor{ fmt.Sprintf("%v", sfe.SMKID), sfe.QueryID, strconv.Itoa(keySize * 8), } matDescUnicode, err := matdescToUnicode(matDesc) if err != nil { return nil, err } return &encryptMetadata{ base64.StdEncoding.EncodeToString(encryptedFileKey), base64.StdEncoding.EncodeToString(dataIv), matDescUnicode, }, nil } func encryptECB(encrypted []byte, fileKey []byte, decodedKey []byte) error { block, err := aes.NewCipher(decodedKey) if err != nil { return err } if len(fileKey)%block.BlockSize() != 0 { return fmt.Errorf("input not full of blocks") } if len(encrypted) < len(fileKey) { return fmt.Errorf("output length is smaller than input length") } for len(fileKey) > 0 { block.Encrypt(encrypted, fileKey[:block.BlockSize()]) encrypted = encrypted[block.BlockSize():] fileKey = fileKey[block.BlockSize():] } return nil } func decryptECB(decrypted []byte, keyBytes []byte, decodedKey []byte) error { block, err := aes.NewCipher(decodedKey) if err != nil { return err } if len(keyBytes)%block.BlockSize() != 0 { return fmt.Errorf("input not full of blocks") } if len(decrypted) < len(keyBytes) { return fmt.Errorf("output length is smaller than input length") } for len(keyBytes) > 0 { block.Decrypt(decrypted, keyBytes[:block.BlockSize()]) keyBytes = keyBytes[block.BlockSize():] decrypted = decrypted[block.BlockSize():] } return nil } func encryptFileCBC( sfe *snowflakeFileEncryption, filename string, chunkSize int, tmpDir string) ( meta *encryptMetadata, fileName string, err error) { if chunkSize == 0 { chunkSize = aes.BlockSize * 4 * 1024 } tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#") if err != nil { return nil, "", err } defer func() { if tmpErr := tmpOutputFile.Close(); tmpErr != nil && err == nil { err = tmpErr } }() infile, err := os.OpenFile(filename, os.O_CREATE|os.O_RDONLY, readWriteFileMode) if err != nil { return nil, "", err } defer func() { if tmpErr := infile.Close(); tmpErr != nil && err == nil { err = tmpErr } }() meta, err = encryptStreamCBC(sfe, infile, tmpOutputFile, chunkSize) if err != nil { return nil, "", err } return meta, tmpOutputFile.Name(), err } func decryptFileKeyECB( metadata *encryptMetadata, sfe *snowflakeFileEncryption) ([]byte, []byte, error) { decodedKey, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey) if err != nil { return nil, nil, err } keyBytes, err := base64.StdEncoding.DecodeString(metadata.key) // encrypted file key if err != nil { return nil, nil, err } ivBytes, err := base64.StdEncoding.DecodeString(metadata.iv) if err != nil { return nil, nil, err } // decrypt file key decryptedKey := make([]byte, len(keyBytes)) if err = decryptECB(decryptedKey, keyBytes, decodedKey); err != nil { return nil, nil, err } decryptedKey, err = paddingTrim(decryptedKey) if err != nil { return nil, nil, err } return decryptedKey, ivBytes, err } func initCBC(decryptedKey []byte, ivBytes []byte) (cipher.BlockMode, error) { block, err := aes.NewCipher(decryptedKey) if err != nil { return nil, err } mode := cipher.NewCBCDecrypter(block, ivBytes) return mode, err } func decryptFileCBC( metadata *encryptMetadata, sfe *snowflakeFileEncryption, filename string, chunkSize int, tmpDir string) (outputFileName string, err error) { tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#") if err != nil { return "", err } defer func() { if tmpErr := tmpOutputFile.Close(); tmpErr != nil && err == nil { err = tmpErr } }() infile, err := os.Open(filename) if err != nil { return "", err } defer func() { if tmpErr := infile.Close(); tmpErr != nil && err == nil { err = tmpErr } }() totalFileSize, err := decryptStreamCBC(metadata, sfe, chunkSize, infile, tmpOutputFile) if err != nil { return "", err } err = tmpOutputFile.Truncate(int64(totalFileSize)) return tmpOutputFile.Name(), err } // Returns decrypted file size and any error that happened during decryption. func decryptStreamCBC( metadata *encryptMetadata, sfe *snowflakeFileEncryption, chunkSize int, src io.Reader, out io.Writer) (int, error) { if chunkSize == 0 { chunkSize = aes.BlockSize * 4 * 1024 } decryptedKey, ivBytes, err := decryptFileKeyECB(metadata, sfe) if err != nil { return 0, err } mode, err := initCBC(decryptedKey, ivBytes) if err != nil { return 0, err } var totalFileSize int var prevChunk []byte for { chunk := make([]byte, chunkSize) n, err := src.Read(chunk) if err != nil && err != io.EOF { return 0, fmt.Errorf("reading: %w", err) } if n == 0 { break } if n%aes.BlockSize != 0 { // add padding to the end of the chunk and update the length n chunk = padBytesLength(chunk[:n], aes.BlockSize) n = len(chunk) } totalFileSize += n chunk = chunk[:n] mode.CryptBlocks(chunk, chunk) if _, err := out.Write(chunk); err != nil { return 0, err } prevChunk = chunk } if prevChunk != nil { totalFileSize -= paddingOffset(prevChunk) } return totalFileSize, nil } func encryptGCM(iv []byte, plaintext []byte, encryptionKey []byte, aad []byte) ([]byte, error) { aead, err := initGcm(encryptionKey) if err != nil { return nil, err } return aead.Seal(nil, iv, plaintext, aad), nil } func decryptGCM(iv []byte, ciphertext []byte, encryptionKey []byte, aad []byte) ([]byte, error) { aead, err := initGcm(encryptionKey) if err != nil { return nil, err } return aead.Open(nil, iv, ciphertext, aad) } func initGcm(encryptionKey []byte) (cipher.AEAD, error) { block, err := aes.NewCipher(encryptionKey) if err != nil { return nil, err } return cipher.NewGCM(block) } func encryptFileGCM( sfe *snowflakeFileEncryption, filename string, tmpDir string) ( meta *gcmEncryptMetadata, outputFileName string, err error) { tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#") if err != nil { return nil, "", err } defer func() { if tmpErr := tmpOutputFile.Close(); tmpErr != nil && err == nil { err = tmpErr } }() infile, err := os.OpenFile(filename, os.O_CREATE|os.O_RDONLY, readWriteFileMode) if err != nil { return nil, "", err } defer func() { if tmpErr := infile.Close(); tmpErr != nil && err == nil { err = tmpErr } }() plaintext, err := os.ReadFile(filename) if err != nil { return nil, "", err } kek, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey) if err != nil { return nil, "", err } keySize := len(kek) fileKey := getSecureRandom(keySize) keyIv := getSecureRandom(gcmIvLengthInBytes) encryptedFileKey, err := encryptGCM(keyIv, fileKey, kek, defaultKeyAad) if err != nil { return nil, "", err } dataIv := getSecureRandom(gcmIvLengthInBytes) encryptedData, err := encryptGCM(dataIv, plaintext, fileKey, defaultDataAad) if err != nil { return nil, "", err } _, err = tmpOutputFile.Write(encryptedData) if err != nil { return nil, "", err } matDesc := materialDescriptor{ fmt.Sprintf("%v", sfe.SMKID), sfe.QueryID, strconv.Itoa(keySize * 8), } matDescUnicode, err := matdescToUnicode(matDesc) if err != nil { return nil, "", err } meta = &gcmEncryptMetadata{ key: base64.StdEncoding.EncodeToString(encryptedFileKey), keyIv: base64.StdEncoding.EncodeToString(keyIv), dataIv: base64.StdEncoding.EncodeToString(dataIv), keyAad: base64.StdEncoding.EncodeToString(defaultKeyAad), dataAad: base64.StdEncoding.EncodeToString(defaultDataAad), matdesc: matDescUnicode, } return meta, tmpOutputFile.Name(), nil } func decryptFileGCM( metadata *gcmEncryptMetadata, sfe *snowflakeFileEncryption, filename string, tmpDir string) ( string, error) { kek, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey) if err != nil { return "", err } encryptedFileKey, err := base64.StdEncoding.DecodeString(metadata.key) if err != nil { return "", err } keyIv, err := base64.StdEncoding.DecodeString(metadata.keyIv) if err != nil { return "", err } keyAad, err := base64.StdEncoding.DecodeString(metadata.keyAad) if err != nil { return "", err } dataIv, err := base64.StdEncoding.DecodeString(metadata.dataIv) if err != nil { return "", err } dataAad, err := base64.StdEncoding.DecodeString(metadata.dataAad) if err != nil { return "", err } fileKey, err := decryptGCM(keyIv, encryptedFileKey, kek, keyAad) if err != nil { return "", err } ciphertext, err := os.ReadFile(filename) if err != nil { return "", err } plaintext, err := decryptGCM(dataIv, ciphertext, fileKey, dataAad) if err != nil { return "", err } tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#") if err != nil { return "", err } _, err = tmpOutputFile.Write(plaintext) if err != nil { return "", err } return tmpOutputFile.Name(), nil } type materialDescriptor struct { SmkID string `json:"smkId"` QueryID string `json:"queryId"` KeySize string `json:"keySize"` } func matdescToUnicode(matdesc materialDescriptor) (string, error) { s, err := json.Marshal(&matdesc) if err != nil { return "", err } return string(s), nil } func getSecureRandom(byteLength int) []byte { token := make([]byte, byteLength) _, err := rand.Read(token) if err != nil { logger.Errorf("cannot init secure random. %v", err) } return token } func padBytesLength(src []byte, blockSize int) []byte { padLength := blockSize - len(src)%blockSize padText := bytes.Repeat([]byte{byte(padLength)}, padLength) return append(src, padText...) } func paddingTrim(src []byte) ([]byte, error) { if len(src) == 0 { logger.Errorf("padding trim failed - data length is 0") return nil, &SnowflakeError{ Number: ErrInvalidPadding, Message: "padding validation failed", } } unpadding := src[len(src)-1] n := int(unpadding) if n == 0 || n > len(src) { logger.Errorf("padding validation failed - invalid padding detected. data length: %d, padding value: %d", len(src), n) return nil, &SnowflakeError{ Number: ErrInvalidPadding, Message: errors.ErrMsgInvalidPadding, } } return src[:len(src)-n], nil } func paddingOffset(src []byte) int { length := len(src) return int(src[length-1]) } type contentKey struct { KeyID string `json:"KeyId,omitempty"` EncryptionKey string `json:"EncryptedKey,omitempty"` Algorithm string `json:"Algorithm,omitempty"` } type encryptionAgent struct { Protocol string `json:"Protocol,omitempty"` EncryptionAlgorithm string `json:"EncryptionAlgorithm,omitempty"` } type keyMetadata struct { EncryptionLibrary string `json:"EncryptionLibrary,omitempty"` } type encryptionData struct { EncryptionMode string `json:"EncryptionMode,omitempty"` WrappedContentKey contentKey `json:"WrappedContentKey"` EncryptionAgent encryptionAgent `json:"EncryptionAgent"` ContentEncryptionIV string `json:"ContentEncryptionIV,omitempty"` KeyWrappingMetadata keyMetadata `json:"KeyWrappingMetadata"` } type snowflakeFileEncryption struct { QueryStageMasterKey string `json:"queryStageMasterKey,omitempty"` QueryID string `json:"queryId,omitempty"` SMKID int64 `json:"smkId,omitempty"` } // PUT requests return a single encryptionMaterial object whereas GET requests // return a slice (array) of encryptionMaterial objects, both under the field // 'encryptionMaterial' type encryptionWrapper struct { snowflakeFileEncryption EncryptionMaterials []snowflakeFileEncryption } type encryptMetadata struct { key string iv string matdesc string } type gcmEncryptMetadata struct { key string keyIv string dataIv string keyAad string dataAad string matdesc string } ================================================ FILE: encrypt_util_test.go ================================================ package gosnowflake import ( "bufio" "compress/gzip" "encoding/base64" "errors" "fmt" "io" "math/rand" "os" "os/exec" "path" "path/filepath" "strconv" "testing" "testing/iotest" "time" ) const timeFormat = "2006-01-02T15:04:05" type encryptDecryptTestFile struct { numberOfBytesInEachRow int numberOfLines int } func TestEncryptDecryptFileCBC(t *testing.T) { encMat := snowflakeFileEncryption{ "ztke8tIdVt1zmlQIZm0BMA==", "123873c7-3a66-40c4-ab89-e3722fbccce1", 9223372036854775807, } data := "test data" inputFile := "test_encrypt_decrypt_file" fd, err := os.Create(inputFile) if err != nil { t.Error(err) } defer fd.Close() defer os.Remove(inputFile) if _, err = fd.Write([]byte(data)); err != nil { t.Error(err) } metadata, encryptedFile, err := encryptFileCBC(&encMat, inputFile, 0, "") if err != nil { t.Error(err) } defer os.Remove(encryptedFile) assertStringContainsE(t, metadata.matdesc, "9223372036854775807") decryptedFile, err := decryptFileCBC(metadata, &encMat, encryptedFile, 0, "") if err != nil { t.Error(err) } defer os.Remove(decryptedFile) fd, err = os.Open(decryptedFile) if err != nil { t.Error(err) } defer fd.Close() content, err := io.ReadAll(fd) if err != nil { t.Error(err) } if string(content) != data { t.Fatalf("data did not match content. expected: %v, got: %v", data, string(content)) } } func TestEncryptDecryptFilePadding(t *testing.T) { encMat := snowflakeFileEncryption{ "ztke8tIdVt1zmlQIZm0BMA==", "123873c7-3a66-40c4-ab89-e3722fbccce1", 3112, } testcases := []encryptDecryptTestFile{ // File size is a multiple of 65536 bytes (chunkSize) {numberOfBytesInEachRow: 8, numberOfLines: 16384}, {numberOfBytesInEachRow: 16, numberOfLines: 4096}, // File size is not a multiple of 65536 bytes (chunkSize) {numberOfBytesInEachRow: 8, numberOfLines: 10240}, {numberOfBytesInEachRow: 16, numberOfLines: 6144}, // The second chunk's size is a multiple of 16 bytes (aes.BlockSize) {numberOfBytesInEachRow: 16, numberOfLines: 4097}, // The second chunk's size is not a multiple of 16 bytes (aes.BlockSize) {numberOfBytesInEachRow: 12, numberOfLines: 5462}, {numberOfBytesInEachRow: 10, numberOfLines: 6556}, } for _, test := range testcases { t.Run(fmt.Sprintf("%v_%v", test.numberOfBytesInEachRow, test.numberOfLines), func(t *testing.T) { tmpDir, err := generateKLinesOfNByteRows(test.numberOfLines, test.numberOfBytesInEachRow, t.TempDir()) if err != nil { t.Error(err) } encryptDecryptFile(t, encMat, test.numberOfLines, tmpDir) }) } } func TestEncryptDecryptLargeFileCBC(t *testing.T) { encMat := snowflakeFileEncryption{ "ztke8tIdVt1zmlQIZm0BMA==", "123873c7-3a66-40c4-ab89-e3722fbccce1", 3112, } numberOfFiles := 1 numberOfLines := 10000 tmpDir, err := generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, t.TempDir()) if err != nil { t.Error(err) } encryptDecryptFile(t, encMat, numberOfLines, tmpDir) } func TestEncryptStreamCBCReadError(t *testing.T) { sfe := snowflakeFileEncryption{ QueryStageMasterKey: "YWJjZGVmMTIzNDU2Nzg5MA==", QueryID: "unused", SMKID: 9223372036854775807, } wantErr := errors.New("test error") r := iotest.ErrReader(wantErr) n, err := encryptStreamCBC(&sfe, r, nil, 0) assertTrueF(t, errors.Is(err, wantErr), fmt.Sprintf("expected error: %v, got: %v", wantErr, err)) assertNilE(t, n, "expected no metadata on error") } func TestDecryptStreamCBCReadError(t *testing.T) { tmpDir := t.TempDir() tempFile, err := os.CreateTemp(tmpDir, "gcm") assertNilF(t, err) _, err = tempFile.Write([]byte("abc")) assertNilF(t, err) err = tempFile.Close() assertNilF(t, err) sfe := snowflakeFileEncryption{ QueryStageMasterKey: "YWJjZGVmMTIzNDU2Nzg5MA==", QueryID: "unused", SMKID: 9223372036854775807, } meta, _, err := encryptFileCBC(&sfe, tempFile.Name(), 0, tmpDir) assertNilF(t, err) assertStringContainsF(t, meta.matdesc, "9223372036854775807") wantErr := errors.New("test error") r := iotest.ErrReader(wantErr) n, err := decryptStreamCBC(meta, &sfe, 0, r, nil) assertTrueF(t, errors.Is(err, wantErr), fmt.Sprintf("expected error: %v, got: %v", wantErr, err)) assertEqualE(t, n, 0, "expected 0 bytes written") } func encryptDecryptFile(t *testing.T, encMat snowflakeFileEncryption, expected int, tmpDir string) { files, err := filepath.Glob(filepath.Join(tmpDir, "file*")) if err != nil { t.Error(err) } inputFile := files[0] metadata, encryptedFile, err := encryptFileCBC(&encMat, inputFile, 0, tmpDir) if err != nil { t.Error(err) } defer os.Remove(encryptedFile) decryptedFile, err := decryptFileCBC(metadata, &encMat, encryptedFile, 0, tmpDir) if err != nil { t.Error(err) } defer os.Remove(decryptedFile) cnt := 0 fd, err := os.Open(decryptedFile) if err != nil { t.Error(err) } defer fd.Close() scanner := bufio.NewScanner(fd) for scanner.Scan() { cnt++ } if err = scanner.Err(); err != nil { t.Error(err) } if cnt != expected { t.Fatalf("incorrect number of lines. expected: %v, got: %v", expected, cnt) } } func generateKLinesOfNByteRows(numLines int, numBytes int, tmpDir string) (string, error) { fname := path.Join(tmpDir, "file"+strconv.FormatInt(int64(numLines*numBytes), 10)) f, err := os.Create(fname) if err != nil { return "", err } for range numLines { str := randomString(numBytes - 1) // \n is the last character rec := fmt.Sprintf("%v\n", str) if _, err = f.Write([]byte(rec)); err != nil { return "", err } } err = f.Close() return tmpDir, err } func generateKLinesOfNFiles(k int, n int, compress bool, tmpDir string) (string, error) { for i := range n { fname := path.Join(tmpDir, "file"+strconv.FormatInt(int64(i), 10)) f, err := os.Create(fname) if err != nil { return "", err } for range k { num := rand.Float64() * 10000 min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix() max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix() delta := max - min sec := rand.Int63n(delta) + min tm := time.Unix(sec, 0) dt := tm.Format("2021-03-01") sec = rand.Int63n(delta) + min ts := time.Unix(sec, 0).Format(timeFormat) sec = rand.Int63n(delta) + min tsltz := time.Unix(sec, 0).Format(timeFormat) sec = rand.Int63n(delta) + min tsntz := time.Unix(sec, 0).Format(timeFormat) sec = rand.Int63n(delta) + min tstz := time.Unix(sec, 0).Format(timeFormat) pct := rand.Float64() * 1000 ratio := fmt.Sprintf("%.2f", rand.Float64()*1000) rec := fmt.Sprintf("%v,%v,%v,%v,%v,%v,%v,%v\n", num, dt, ts, tsltz, tsntz, tstz, pct, ratio) if _, err = f.Write([]byte(rec)); err != nil { return "", err } } if err = f.Close(); err != nil { return "", err } if compress { if !isWindows { gzipCmd := exec.Command("gzip", filepath.Join(tmpDir, "file"+strconv.FormatInt(int64(i), 10))) gzipOut, err := gzipCmd.StdoutPipe() if err != nil { return "", err } gzipErr, err := gzipCmd.StderrPipe() if err != nil { return "", err } if err = gzipCmd.Start(); err != nil { return "", err } if _, err = io.ReadAll(gzipOut); err != nil { return "", err } if _, err = io.ReadAll(gzipErr); err != nil { return "", err } if err = gzipCmd.Wait(); err != nil { return "", err } } else { fOut, err := os.Create(fname + ".gz") if err != nil { return "", err } w := gzip.NewWriter(fOut) fIn, err := os.Open(fname) if err != nil { return "", err } if _, err = io.Copy(w, fIn); err != nil { return "", err } w.Close() fOut.Close() fIn.Close() } } } return tmpDir, nil } func TestEncryptDecryptGCM(t *testing.T) { input := []byte("abc") iv := []byte("ab1234567890") // pragma: allowlist secret key := []byte("1234567890abcdef") // pragma: allowlist secret encrypted, err := encryptGCM(iv, input, key, nil) assertNilF(t, err) assertEqualE(t, base64.StdEncoding.EncodeToString(encrypted), "iG+lT4o27hkzj3kblYRzQikLVQ==") decrypted, err := decryptGCM(iv, encrypted, key, nil) assertNilF(t, err) assertDeepEqualE(t, decrypted, input) } func TestEncryptDecryptFileGCM(t *testing.T) { tmpDir := os.TempDir() tempFile, err := os.CreateTemp(tmpDir, "gcm") assertNilF(t, err) _, err = tempFile.Write([]byte("abc")) assertNilF(t, err) sfe := &snowflakeFileEncryption{ QueryStageMasterKey: "YWJjZGVmMTIzNDU2Nzg5MA==", QueryID: "unused", SMKID: 9223372036854775807, } meta, encryptedFileName, err := encryptFileGCM(sfe, tempFile.Name(), tmpDir) assertNilF(t, err) assertStringContainsE(t, meta.matdesc, "9223372036854775807") decryptedFileName, err := decryptFileGCM(meta, sfe, encryptedFileName, tmpDir) assertNilF(t, err) fileContent, err := os.ReadFile(decryptedFileName) assertNilF(t, err) assertEqualE(t, string(fileContent), "abc") } ================================================ FILE: errors.go ================================================ package gosnowflake import ( "fmt" "runtime/debug" "strconv" "time" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" ) // SnowflakeError is a error type including various Snowflake specific information. type SnowflakeError = sferrors.SnowflakeError func generateTelemetryExceptionData(se *SnowflakeError) *telemetryData { data := &telemetryData{ Message: map[string]string{ typeKey: sqlException, sourceKey: telemetrySource, driverTypeKey: "Go", driverVersionKey: SnowflakeGoDriverVersion, stacktraceKey: maskSecrets(string(debug.Stack())), }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), } if se.QueryID != "" { data.Message[queryIDKey] = se.QueryID } if se.SQLState != "" { data.Message[sqlStateKey] = se.SQLState } if se.Message != "" { data.Message[reasonKey] = se.Message } if len(se.MessageArgs) > 0 { data.Message[reasonKey] = fmt.Sprintf(se.Message, se.MessageArgs...) } if se.Number != 0 { data.Message[errorNumberKey] = strconv.Itoa(se.Number) } return data } // exceptionTelemetry generates telemetry data from the error and adds it to the telemetry queue. func exceptionTelemetry(se *SnowflakeError, sc *snowflakeConn) *SnowflakeError { if sc == nil || sc.telemetry == nil || !sc.telemetry.enabled { return se // skip expensive stacktrace generation below if telemetry is disabled } data := generateTelemetryExceptionData(se) if err := sc.telemetry.addLog(data); err != nil { logger.WithContext(sc.ctx).Debugf("failed to log to telemetry: %v", data) } return se } // return populated error fields replacing the default response func populateErrorFields(code int, data *execResponse) *SnowflakeError { err := sferrors.ErrUnknownError() if code != -1 { err.Number = code } if data.Data.SQLState != "" { err.SQLState = data.Data.SQLState } if data.Message != "" { err.Message = data.Message } if data.Data.QueryID != "" { err.QueryID = data.Data.QueryID } return err } // Snowflake Server Error code const ( queryNotExecutingCode = "000605" queryInProgressCode = "333333" queryInProgressAsyncCode = "333334" sessionExpiredCode = "390112" invalidOAuthAccessTokenCode = "390303" expiredOAuthAccessTokenCode = "390318" ) // Driver return errors — re-exported from internal/errors const ( /* connection */ // ErrCodeEmptyAccountCode is an error code for the case where a DSN doesn't include account parameter ErrCodeEmptyAccountCode = sferrors.ErrCodeEmptyAccountCode // ErrCodeEmptyUsernameCode is an error code for the case where a DSN doesn't include user parameter ErrCodeEmptyUsernameCode = sferrors.ErrCodeEmptyUsernameCode // ErrCodeEmptyPasswordCode is an error code for the case where a DSN doesn't include password parameter ErrCodeEmptyPasswordCode = sferrors.ErrCodeEmptyPasswordCode // ErrCodeFailedToParseHost is an error code for the case where a DSN includes an invalid host name ErrCodeFailedToParseHost = sferrors.ErrCodeFailedToParseHost // ErrCodeFailedToParsePort is an error code for the case where a DSN includes an invalid port number ErrCodeFailedToParsePort = sferrors.ErrCodeFailedToParsePort // ErrCodeIdpConnectionError is an error code for the case where a IDP connection failed ErrCodeIdpConnectionError = sferrors.ErrCodeIdpConnectionError // ErrCodeSSOURLNotMatch is an error code for the case where a SSO URL doesn't match ErrCodeSSOURLNotMatch = sferrors.ErrCodeSSOURLNotMatch // ErrCodeServiceUnavailable is an error code for the case where service is unavailable. ErrCodeServiceUnavailable = sferrors.ErrCodeServiceUnavailable // ErrCodeFailedToConnect is an error code for the case where a DB connection failed due to wrong account name ErrCodeFailedToConnect = sferrors.ErrCodeFailedToConnect // ErrCodeRegionOverlap is an error code for the case where a region is specified despite an account region present ErrCodeRegionOverlap = sferrors.ErrCodeRegionOverlap // ErrCodePrivateKeyParseError is an error code for the case where the private key is not parsed correctly ErrCodePrivateKeyParseError = sferrors.ErrCodePrivateKeyParseError // ErrCodeFailedToParseAuthenticator is an error code for the case where a DNS includes an invalid authenticator ErrCodeFailedToParseAuthenticator = sferrors.ErrCodeFailedToParseAuthenticator // ErrCodeClientConfigFailed is an error code for the case where clientConfigFile is invalid or applying client configuration fails ErrCodeClientConfigFailed = sferrors.ErrCodeClientConfigFailed // ErrCodeTomlFileParsingFailed is an error code for the case where parsing the toml file is failed because of invalid value. ErrCodeTomlFileParsingFailed = sferrors.ErrCodeTomlFileParsingFailed // ErrCodeFailedToFindDSNInToml is an error code for the case where the DSN does not exist in the toml file. ErrCodeFailedToFindDSNInToml = sferrors.ErrCodeFailedToFindDSNInToml // ErrCodeInvalidFilePermission is an error code for the case where the user does not have 0600 permission to the toml file. ErrCodeInvalidFilePermission = sferrors.ErrCodeInvalidFilePermission // ErrCodeEmptyPasswordAndToken is an error code for the case where a DSN do includes neither password nor token ErrCodeEmptyPasswordAndToken = sferrors.ErrCodeEmptyPasswordAndToken // ErrCodeEmptyOAuthParameters is an error code for the case where the client ID or client secret are not provided for OAuth flows. ErrCodeEmptyOAuthParameters = sferrors.ErrCodeEmptyOAuthParameters // ErrMissingAccessATokenButRefreshTokenPresent is an error code for the case when access token is not found in cache, but the refresh token is present. ErrMissingAccessATokenButRefreshTokenPresent = sferrors.ErrMissingAccessATokenButRefreshTokenPresent // ErrCodeMissingTLSConfig is an error code for the case where the TLS config is missing. ErrCodeMissingTLSConfig = sferrors.ErrCodeMissingTLSConfig /* network */ // ErrFailedToPostQuery is an error code for the case where HTTP POST failed. ErrFailedToPostQuery = sferrors.ErrFailedToPostQuery // ErrFailedToRenewSession is an error code for the case where session renewal failed. ErrFailedToRenewSession = sferrors.ErrFailedToRenewSession // ErrFailedToCancelQuery is an error code for the case where cancel query failed. ErrFailedToCancelQuery = sferrors.ErrFailedToCancelQuery // ErrFailedToCloseSession is an error code for the case where close session failed. ErrFailedToCloseSession = sferrors.ErrFailedToCloseSession // ErrFailedToAuth is an error code for the case where authentication failed for unknown reason. ErrFailedToAuth = sferrors.ErrFailedToAuth // ErrFailedToAuthSAML is an error code for the case where authentication via SAML failed for unknown reason. ErrFailedToAuthSAML = sferrors.ErrFailedToAuthSAML // ErrFailedToAuthOKTA is an error code for the case where authentication via OKTA failed for unknown reason. ErrFailedToAuthOKTA = sferrors.ErrFailedToAuthOKTA // ErrFailedToGetSSO is an error code for the case where authentication via OKTA failed for unknown reason. ErrFailedToGetSSO = sferrors.ErrFailedToGetSSO // ErrFailedToParseResponse is an error code for when we cannot parse an external browser response from Snowflake. ErrFailedToParseResponse = sferrors.ErrFailedToParseResponse // ErrFailedToGetExternalBrowserResponse is an error code for when there's an error reading from the open socket. ErrFailedToGetExternalBrowserResponse = sferrors.ErrFailedToGetExternalBrowserResponse // ErrFailedToHeartbeat is an error code when a heartbeat fails. ErrFailedToHeartbeat = sferrors.ErrFailedToHeartbeat /* rows */ // ErrFailedToGetChunk is an error code for the case where it failed to get chunk of result set ErrFailedToGetChunk = sferrors.ErrFailedToGetChunk // ErrNonArrowResponseInArrowBatches is an error code for case where ArrowBatches mode is enabled, but response is not Arrow-based ErrNonArrowResponseInArrowBatches = sferrors.ErrNonArrowResponseInArrowBatches /* transaction*/ // ErrNoReadOnlyTransaction is an error code for the case where readonly mode is specified. ErrNoReadOnlyTransaction = sferrors.ErrNoReadOnlyTransaction // ErrNoDefaultTransactionIsolationLevel is an error code for the case where non default isolation level is specified. ErrNoDefaultTransactionIsolationLevel = sferrors.ErrNoDefaultTransactionIsolationLevel /* file transfer */ // ErrInvalidStageFs is an error code denoting an invalid stage in the file system ErrInvalidStageFs = sferrors.ErrInvalidStageFs // ErrFailedToDownloadFromStage is an error code denoting the failure to download a file from the stage ErrFailedToDownloadFromStage = sferrors.ErrFailedToDownloadFromStage // ErrFailedToUploadToStage is an error code denoting the failure to upload a file to the stage ErrFailedToUploadToStage = sferrors.ErrFailedToUploadToStage // ErrInvalidStageLocation is an error code denoting an invalid stage location ErrInvalidStageLocation = sferrors.ErrInvalidStageLocation // ErrLocalPathNotDirectory is an error code denoting a local path that is not a directory ErrLocalPathNotDirectory = sferrors.ErrLocalPathNotDirectory // ErrFileNotExists is an error code denoting the file to be transferred does not exist ErrFileNotExists = sferrors.ErrFileNotExists // ErrCompressionNotSupported is an error code denoting the user specified compression type is not supported ErrCompressionNotSupported = sferrors.ErrCompressionNotSupported // ErrInternalNotMatchEncryptMaterial is an error code denoting the encryption material specified does not match ErrInternalNotMatchEncryptMaterial = sferrors.ErrInternalNotMatchEncryptMaterial // ErrCommandNotRecognized is an error code denoting the PUT/GET command was not recognized ErrCommandNotRecognized = sferrors.ErrCommandNotRecognized // ErrFailedToConvertToS3Client is an error code denoting the failure of an interface to s3.Client conversion ErrFailedToConvertToS3Client = sferrors.ErrFailedToConvertToS3Client // ErrNotImplemented is an error code denoting the file transfer feature is not implemented ErrNotImplemented = sferrors.ErrNotImplemented // ErrInvalidPadding is an error code denoting the invalid padding of decryption key ErrInvalidPadding = sferrors.ErrInvalidPadding /* binding */ // ErrBindSerialization is an error code for a failed serialization of bind variables ErrBindSerialization = sferrors.ErrBindSerialization // ErrBindUpload is an error code for the uploading process of bind elements to the stage ErrBindUpload = sferrors.ErrBindUpload /* async */ // ErrAsync is an error code for an unknown async error ErrAsync = sferrors.ErrAsync /* multi-statement */ // ErrNoResultIDs is an error code for empty result IDs for multi statement queries ErrNoResultIDs = sferrors.ErrNoResultIDs /* converter */ // ErrInvalidTimestampTz is an error code for the case where a returned TIMESTAMP_TZ internal value is invalid ErrInvalidTimestampTz = sferrors.ErrInvalidTimestampTz // ErrInvalidOffsetStr is an error code for the case where an offset string is invalid. The input string must // consist of sHHMI where one sign character '+'/'-' followed by zero filled hours and minutes ErrInvalidOffsetStr = sferrors.ErrInvalidOffsetStr // ErrInvalidBinaryHexForm is an error code for the case where a binary data in hex form is invalid. ErrInvalidBinaryHexForm = sferrors.ErrInvalidBinaryHexForm // ErrTooHighTimestampPrecision is an error code for the case where cannot convert Snowflake timestamp to arrow.Timestamp ErrTooHighTimestampPrecision = sferrors.ErrTooHighTimestampPrecision // ErrNullValueInArray is an error code for the case where there are null values in an array without arrayValuesNullable set to true ErrNullValueInArray = sferrors.ErrNullValueInArray // ErrNullValueInMap is an error code for the case where there are null values in a map without mapValuesNullable set to true ErrNullValueInMap = sferrors.ErrNullValueInMap /* OCSP */ // ErrOCSPStatusRevoked is an error code for the case where the certificate is revoked. ErrOCSPStatusRevoked = sferrors.ErrOCSPStatusRevoked // ErrOCSPStatusUnknown is an error code for the case where the certificate revocation status is unknown. ErrOCSPStatusUnknown = sferrors.ErrOCSPStatusUnknown // ErrOCSPInvalidValidity is an error code for the case where the OCSP response validity is invalid. ErrOCSPInvalidValidity = sferrors.ErrOCSPInvalidValidity // ErrOCSPNoOCSPResponderURL is an error code for the case where the OCSP responder URL is not attached. ErrOCSPNoOCSPResponderURL = sferrors.ErrOCSPNoOCSPResponderURL /* query Status*/ // ErrQueryStatus when check the status of a query, receive error or no status ErrQueryStatus = sferrors.ErrQueryStatus // ErrQueryIDFormat the query ID given to fetch its result is not valid ErrQueryIDFormat = sferrors.ErrQueryIDFormat // ErrQueryReportedError server side reports the query failed with error ErrQueryReportedError = sferrors.ErrQueryReportedError // ErrQueryIsRunning the query is still running ErrQueryIsRunning = sferrors.ErrQueryIsRunning /* GS error code */ // ErrSessionGone is an GS error code for the case that session is already closed ErrSessionGone = sferrors.ErrSessionGone // ErrRoleNotExist is a GS error code for the case that the role specified does not exist ErrRoleNotExist = sferrors.ErrRoleNotExist // ErrObjectNotExistOrAuthorized is a GS error code for the case that the server-side object specified does not exist ErrObjectNotExistOrAuthorized = sferrors.ErrObjectNotExistOrAuthorized ) ================================================ FILE: errors_test.go ================================================ package gosnowflake import ( "strings" "testing" ) func TestErrorMessage(t *testing.T) { e := &SnowflakeError{ Number: 1, Message: "test message", } if !strings.Contains(e.Error(), "000001") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "test message") { t.Errorf("failed to format error. %v", e) } e = &SnowflakeError{ Number: 1, Message: "test message: %v, %v", MessageArgs: []any{"C1", "C2"}, } if !strings.Contains(e.Error(), "000001") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "test message") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "C1") { t.Errorf("failed to format error. %v", e) } e = &SnowflakeError{ Number: 1, Message: "test message: %v, %v", MessageArgs: []any{"C1", "C2"}, SQLState: "01112", } if !strings.Contains(e.Error(), "000001") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "test message") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "C1") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "01112") { t.Errorf("failed to format error. %v", e) } e = &SnowflakeError{ Number: 1, Message: "test message: %v, %v", MessageArgs: []any{"C1", "C2"}, SQLState: "01112", QueryID: "abcdef-abcdef-abcdef", } if !strings.Contains(e.Error(), "000001") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "test message") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "C1") { t.Errorf("failed to format error. %v", e) } if !strings.Contains(e.Error(), "01112") { t.Errorf("failed to format error. %v", e) } if strings.Contains(e.Error(), "abcdef-abcdef-abcdef") { // no quid t.Errorf("failed to format error. %v", e) } e.IncludeQueryID = true if !strings.Contains(e.Error(), "abcdef-abcdef-abcdef") { // no quid t.Errorf("failed to format error. %v", e) } } ================================================ FILE: file_compression_type.go ================================================ package gosnowflake import ( "bytes" "strings" "github.com/gabriel-vasile/mimetype" ) type compressionType struct { name string fileExtension string mimeSubtypes []string isSupported bool } var compressionTypes = map[string]*compressionType{ "GZIP": { "GZIP", ".gz", []string{"gzip", "x-gzip"}, true, }, "DEFLATE": { "DEFLATE", ".deflate", []string{"zlib", "deflate"}, true, }, "RAW_DEFLATE": { "RAW_DEFLATE", ".raw_deflate", []string{"raw_deflate"}, true, }, "BZIP2": { "BZIP2", ".bz2", []string{"bzip2", "x-bzip2", "x-bz2", "x-bzip", "bz2"}, true, }, "LZIP": { "LZIP", ".lz", []string{"lzip", "x-lzip"}, false, }, "LZMA": { "LZMA", ".lzma", []string{"lzma", "x-lzma"}, false, }, "LZO": { "LZO", ".lzo", []string{"lzo", "x-lzo"}, false, }, "XZ": { "XZ", ".xz", []string{"xz", "x-xz"}, false, }, "COMPRESS": { "COMPRESS", ".Z", []string{"compress", "x-compress"}, false, }, "PARQUET": { "PARQUET", ".parquet", []string{"parquet"}, true, }, "ZSTD": { "ZSTD", ".zst", []string{"zstd", "x-zstd"}, true, }, "BROTLI": { "BROTLI", ".br", []string{"br", "x-br"}, true, }, "ORC": { "ORC", ".orc", []string{"orc"}, true, }, } var mimeSubTypeToCompression map[string]*compressionType var extensionToCompression map[string]*compressionType func init() { mimeSubTypeToCompression = make(map[string]*compressionType) extensionToCompression = make(map[string]*compressionType) for _, meta := range compressionTypes { extensionToCompression[meta.fileExtension] = meta for _, subtype := range meta.mimeSubtypes { mimeSubTypeToCompression[subtype] = meta } } mimetype.Extend(func(raw []byte, limit uint32) bool { return bytes.HasPrefix(raw, []byte("PAR1")) }, "snowflake/parquet", ".parquet") mimetype.Extend(func(raw []byte, limit uint32) bool { return bytes.HasPrefix(raw, []byte("ORC")) }, "snowflake/orc", ".orc") } func lookupByMimeSubType(mimeSubType string) *compressionType { if val, ok := mimeSubTypeToCompression[strings.ToLower(mimeSubType)]; ok { return val } return nil } func lookupByExtension(extension string) *compressionType { if val, ok := extensionToCompression[strings.ToLower(extension)]; ok { return val } return nil } ================================================ FILE: file_transfer_agent.go ================================================ package gosnowflake //lint:file-ignore U1000 Ignore all unused code import ( "bytes" "cmp" "context" "database/sql/driver" "encoding/json" "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "github.com/snowflakedb/gosnowflake/v2/internal/query" "io" "math" "net/url" "os" "path/filepath" "regexp" "runtime" "sort" "strings" "sync" "time" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/gabriel-vasile/mimetype" ) type ( cloudType string commandType string ) const ( fileProtocol = "file://" multiPartThreshold int64 = 64 * 1024 * 1024 streamingMultiPartThreshold int64 = 8 * 1024 * 1024 isWindows = runtime.GOOS == "windows" mb float64 = 1024.0 * 1024.0 ) const ( uploadCommand commandType = "UPLOAD" downloadCommand commandType = "DOWNLOAD" unknownCommand commandType = "UNKNOWN" putRegexp string = `(?i)^(?:/\*.*\*/\s*)*\s*put\s+` getRegexp string = `(?i)^(?:/\*.*\*/\s*)*\s*get\s+` ) const ( s3Client cloudType = "S3" azureClient cloudType = "AZURE" gcsClient cloudType = "GCS" local cloudType = "LOCAL_FS" ) type resultStatus int const ( errStatus resultStatus = iota uploaded downloaded skipped renewToken renewPresignedURL notFoundFile needRetry needRetryWithLowerConcurrency ) func (rs resultStatus) String() string { return [...]string{"ERROR", "UPLOADED", "DOWNLOADED", "SKIPPED", "RENEW_TOKEN", "RENEW_PRESIGNED_URL", "NOT_FOUND_FILE", "NEED_RETRY", "NEED_RETRY_WITH_LOWER_CONCURRENCY"}[rs] } func (rs resultStatus) isSet() bool { return uploaded <= rs && rs <= needRetryWithLowerConcurrency } // SnowflakeFileTransferOptions enables users to specify options regarding // files transfers such as PUT/GET type SnowflakeFileTransferOptions struct { showProgressBar bool MultiPartThreshold int64 /* streaming PUT */ compressSourceFromStream bool /* PUT */ putCallback *snowflakeProgressPercentage putAzureCallback *snowflakeProgressPercentage putCallbackOutputStream *io.Writer /* GET */ getCallback *snowflakeProgressPercentage getAzureCallback *snowflakeProgressPercentage getCallbackOutputStream *io.Writer } type snowflakeFileTransferAgent struct { ctx context.Context sc *snowflakeConn data *execResponseData command string commandType commandType stageLocationType cloudType fileMetadata []*fileMetadata encryptionMaterial []*snowflakeFileEncryption stageInfo *execResponseStageInfo results []*fileMetadata sourceStream io.Reader srcLocations []string autoCompress bool srcCompression string parallel int64 overwrite bool srcFiles []string localLocation string srcFileToEncryptionMaterial map[string]*snowflakeFileEncryption useAccelerateEndpoint bool presignedURLs []string options *SnowflakeFileTransferOptions streamBuffer *bytes.Buffer } func (sfa *snowflakeFileTransferAgent) execute() error { var err error if err = sfa.parseCommand(); err != nil { return err } if err = sfa.initFileMetadata(); err != nil { return err } if sfa.commandType == uploadCommand { if err = sfa.processFileCompressionType(); err != nil { return err } } if err = sfa.transferAccelerateConfig(); err != nil { return err } if sfa.commandType == downloadCommand { if _, err = os.Stat(sfa.localLocation); os.IsNotExist(err) { if err = os.MkdirAll(sfa.localLocation, os.ModePerm); err != nil { return err } } } if sfa.stageLocationType == local { if _, err = os.Stat(sfa.stageInfo.Location); os.IsNotExist(err) { if err = os.MkdirAll(sfa.stageInfo.Location, os.ModePerm); err != nil { return err } } } if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { return err } smallFileMetas := make([]*fileMetadata, 0) largeFileMetas := make([]*fileMetadata, 0) for _, meta := range sfa.fileMetadata { meta.overwrite = sfa.overwrite meta.sfa = sfa meta.options = sfa.options if sfa.stageLocationType != local { sizeThreshold := sfa.options.MultiPartThreshold meta.options.MultiPartThreshold = sizeThreshold if sfa.commandType == uploadCommand { if meta.srcFileSize > sizeThreshold { meta.parallel = sfa.parallel largeFileMetas = append(largeFileMetas, meta) } else { meta.parallel = 1 smallFileMetas = append(smallFileMetas, meta) } } else { // Enable multi-part download for all files to improve performance. // The MultiPartThreshold will be passed to the Cloud Storage Provider to determine the part size. meta.parallel = sfa.parallel largeFileMetas = append(largeFileMetas, meta) } } else { meta.parallel = 1 smallFileMetas = append(smallFileMetas, meta) } } if sfa.commandType == uploadCommand { if err = sfa.upload(largeFileMetas, smallFileMetas); err != nil { return err } } else { if err = sfa.download(largeFileMetas); err != nil { return err } } return nil } func (sfa *snowflakeFileTransferAgent) parseCommand() error { var err error if sfa.data.Command != "" { sfa.commandType = commandType(sfa.data.Command) } else { sfa.commandType = unknownCommand } sfa.initEncryptionMaterial() if len(sfa.data.SrcLocations) == 0 { return exceptionTelemetry(&SnowflakeError{ Number: ErrInvalidStageLocation, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: "failed to parse location", }, sfa.sc) } sfa.srcLocations = sfa.data.SrcLocations if sfa.commandType == uploadCommand { if sfa.sourceStream != nil { sfa.srcFiles = sfa.srcLocations // streaming PUT } else { sfa.srcFiles, err = sfa.expandFilenames(sfa.srcLocations) if err != nil { return err } } sfa.autoCompress = sfa.data.AutoCompress sfa.srcCompression = strings.ToLower(sfa.data.SourceCompression) } else { sfa.srcFiles = sfa.srcLocations sfa.srcFileToEncryptionMaterial = make(map[string]*snowflakeFileEncryption) if len(sfa.data.SrcLocations) == len(sfa.encryptionMaterial) { for i, srcFile := range sfa.srcFiles { sfa.srcFileToEncryptionMaterial[srcFile] = sfa.encryptionMaterial[i] } } else if len(sfa.encryptionMaterial) != 0 { return exceptionTelemetry(&SnowflakeError{ Number: ErrInternalNotMatchEncryptMaterial, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgInternalNotMatchEncryptMaterial, MessageArgs: []any{len(sfa.data.SrcLocations), len(sfa.encryptionMaterial)}, }, sfa.sc) } sfa.localLocation, err = expandUser(sfa.data.LocalLocation) if err != nil { return err } if fi, err := os.Stat(sfa.localLocation); err != nil || !fi.IsDir() { return exceptionTelemetry(&SnowflakeError{ Number: ErrLocalPathNotDirectory, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgLocalPathNotDirectory, MessageArgs: []any{sfa.localLocation}, }, sfa.sc) } } sfa.parallel = 1 if sfa.data.Parallel != 0 { sfa.parallel = sfa.data.Parallel } sfa.overwrite = sfa.data.Overwrite sfa.stageLocationType = cloudType(strings.ToUpper(sfa.data.StageInfo.LocationType)) sfa.stageInfo = &sfa.data.StageInfo sfa.presignedURLs = make([]string, 0) if len(sfa.data.PresignedURLs) != 0 { sfa.presignedURLs = sfa.data.PresignedURLs } if sfa.getStorageClient(sfa.stageLocationType) == nil { return exceptionTelemetry(&SnowflakeError{ Number: ErrInvalidStageFs, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgInvalidStageFs, MessageArgs: []any{sfa.stageLocationType}, }, sfa.sc) } return nil } func (sfa *snowflakeFileTransferAgent) initEncryptionMaterial() { sfa.encryptionMaterial = make([]*snowflakeFileEncryption, 0) wrapper := sfa.data.EncryptionMaterial if sfa.commandType == uploadCommand { if wrapper.QueryID != "" { sfa.encryptionMaterial = append(sfa.encryptionMaterial, &wrapper.snowflakeFileEncryption) } } else { for _, encmat := range wrapper.EncryptionMaterials { if encmat.QueryID != "" { sfa.encryptionMaterial = append(sfa.encryptionMaterial, &encmat) } } } } func (sfa *snowflakeFileTransferAgent) expandFilenames(locations []string) ([]string, error) { canonicalLocations := make([]string, 0) for _, fileName := range locations { if sfa.commandType == uploadCommand { var err error fileName, err = expandUser(fileName) if err != nil { return []string{}, err } if !filepath.IsAbs(fileName) { cwd, err := getDirectory() if err != nil { return []string{}, err } fileName = filepath.Join(cwd, fileName) } if isWindows && len(fileName) > 2 && fileName[0] == '/' && fileName[2] == ':' { // Windows path: /C:/data/file1.txt where it starts with slash // followed by a drive letter and colon. fileName = fileName[1:] } files, err := filepath.Glob(fileName) if err != nil { return []string{}, err } canonicalLocations = append(canonicalLocations, files...) } else { canonicalLocations = append(canonicalLocations, fileName) } } return canonicalLocations, nil } func (sfa *snowflakeFileTransferAgent) initFileMetadata() error { sfa.fileMetadata = []*fileMetadata{} switch sfa.commandType { case uploadCommand: logger.Debugf("upload command initiated - file count: %d, query ID: %s, encryption materials: %d", len(sfa.srcFiles), sfa.data.QueryID, len(sfa.encryptionMaterial)) if len(sfa.srcFiles) == 0 { fileName := sfa.data.SrcLocations return exceptionTelemetry(&SnowflakeError{ Number: ErrFileNotExists, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgFileNotExists, MessageArgs: []any{fileName}, }, sfa.sc) } // Handles bulk inserts by checking if sourceStream exists. // - If the file exists locally (PUT command), it saves the stream without loading it into memory. // - If not, treats it as an INSERT converted to PUT for bulk upload. if sfa.sourceStream != nil { //Bulk insert case fileName := sfa.srcFiles[0] fileInfo, err := os.Stat(fileName) if err != nil { buf := new(bytes.Buffer) _, err := buf.ReadFrom(sfa.sourceStream) if err != nil { return exceptionTelemetry(&SnowflakeError{ Number: ErrFileNotExists, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgFailToReadDataFromBuffer, MessageArgs: []any{fileName}, }, sfa.sc) } sfa.fileMetadata = append(sfa.fileMetadata, &fileMetadata{ name: baseName(fileName), srcFileName: fileName, srcStream: buf, fileStream: sfa.sourceStream, srcFileSize: int64(buf.Len()), stageLocationType: sfa.stageLocationType, stageInfo: sfa.stageInfo, }) } else { //PUT command with existing file sfa.fileMetadata = append(sfa.fileMetadata, &fileMetadata{ name: baseName(fileName), srcFileName: fileName, fileStream: sfa.sourceStream, srcFileSize: fileInfo.Size(), stageLocationType: sfa.stageLocationType, stageInfo: sfa.stageInfo, }) } } else { for i, fileName := range sfa.srcFiles { fi, err := os.Stat(fileName) if os.IsNotExist(err) { return exceptionTelemetry(&SnowflakeError{ Number: ErrFileNotExists, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgFileNotExists, MessageArgs: []any{fileName}, }, sfa.sc) } else if fi.IsDir() { return exceptionTelemetry(&SnowflakeError{ Number: ErrFileNotExists, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgFileNotExists, MessageArgs: []any{fileName}, }, sfa.sc) } sfa.fileMetadata = append(sfa.fileMetadata, &fileMetadata{ name: baseName(fileName), srcFileName: fileName, srcFileSize: fi.Size(), stageLocationType: sfa.stageLocationType, stageInfo: sfa.stageInfo, }) if len(sfa.encryptionMaterial) > 0 { sfa.fileMetadata[i].encryptionMaterial = sfa.encryptionMaterial[0] } } } if len(sfa.encryptionMaterial) > 0 { for _, meta := range sfa.fileMetadata { meta.encryptionMaterial = sfa.encryptionMaterial[0] } } case downloadCommand: logger.Debugf("download command initiated - file count: %d, query ID: %s", len(sfa.srcFiles), sfa.data.QueryID) for _, fileName := range sfa.srcFiles { if len(fileName) > 0 { _, after, ok := strings.Cut(fileName, "/") dstFileName := fileName if ok { dstFileName = after } sfa.fileMetadata = append(sfa.fileMetadata, &fileMetadata{ name: baseName(fileName), srcFileName: fileName, dstFileName: dstFileName, dstStream: new(bytes.Buffer), stageLocationType: sfa.stageLocationType, stageInfo: sfa.stageInfo, localLocation: sfa.localLocation, }) } } for _, meta := range sfa.fileMetadata { fileName := meta.srcFileName if val, ok := sfa.srcFileToEncryptionMaterial[fileName]; ok { meta.encryptionMaterial = val } } } return nil } func (sfa *snowflakeFileTransferAgent) processFileCompressionType() error { var userSpecifiedSourceCompression *compressionType var autoDetect bool switch sfa.srcCompression { case "auto_detect": autoDetect = true case "none": autoDetect = false default: userSpecifiedSourceCompression = lookupByMimeSubType(sfa.srcCompression) if userSpecifiedSourceCompression == nil || !userSpecifiedSourceCompression.isSupported { return exceptionTelemetry(&SnowflakeError{ Number: ErrCompressionNotSupported, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgFeatureNotSupported, MessageArgs: []any{userSpecifiedSourceCompression}, }, sfa.sc) } autoDetect = false } gzipCompression := compressionTypes["GZIP"] for _, meta := range sfa.fileMetadata { fileName := meta.srcFileName var currentFileCompressionType *compressionType if autoDetect { currentFileCompressionType = lookupByExtension(filepath.Ext(fileName)) if currentFileCompressionType == nil { var mtype *mimetype.MIME var err error if meta.srcStream != nil { r := getReaderFromBuffer(&meta.srcStream) mtype, err = mimetype.DetectReader(r) if err != nil { return err } if _, err = io.ReadAll(r); err != nil { // flush out tee buffer return err } } else { mtype, err = mimetype.DetectFile(fileName) if err != nil { return err } } currentFileCompressionType = lookupByExtension(mtype.Extension()) } if currentFileCompressionType != nil && !currentFileCompressionType.isSupported { return exceptionTelemetry(&SnowflakeError{ Number: ErrCompressionNotSupported, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgFeatureNotSupported, MessageArgs: []any{userSpecifiedSourceCompression}, }, sfa.sc) } } else { currentFileCompressionType = userSpecifiedSourceCompression } if currentFileCompressionType != nil { meta.srcCompressionType = currentFileCompressionType if currentFileCompressionType.isSupported { meta.dstCompressionType = currentFileCompressionType meta.requireCompress = false meta.dstFileName = meta.name } else { return exceptionTelemetry(&SnowflakeError{ Number: ErrCompressionNotSupported, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgFeatureNotSupported, MessageArgs: []any{userSpecifiedSourceCompression}, }, sfa.sc) } } else { meta.requireCompress = sfa.autoCompress meta.srcCompressionType = nil if sfa.autoCompress { dstFileName := meta.name + compressionTypes["GZIP"].fileExtension meta.dstFileName = dstFileName meta.dstCompressionType = gzipCompression } else { meta.dstFileName = meta.name meta.dstCompressionType = nil } } } return nil } func (sfa *snowflakeFileTransferAgent) updateFileMetadataWithPresignedURL() error { // presigned URL only applies to GCS if sfa.stageLocationType == gcsClient { switch sfa.commandType { case uploadCommand: // SNOW-3309225 - When a downscoped token is available, the token already covers the entire stage prefix so per-file // re-querying is unnecessary. Skipping the extra round-trip also avoids a path mismatch on versioned stages. if sfa.stageInfo != nil && sfa.stageInfo.Creds.GcsAccessToken != "" { return nil } filePathToBeReplaced := sfa.getLocalFilePathFromCommand(sfa.command) for _, meta := range sfa.fileMetadata { filePathToBeReplacedWith := strings.TrimRight(filePathToBeReplaced, meta.dstFileName) + meta.dstFileName commandWithSingleFile := strings.ReplaceAll(sfa.command, filePathToBeReplaced, filePathToBeReplacedWith) req := execRequest{ SQLText: commandWithSingleFile, } headers := getHeaders() headers[httpHeaderAccept] = headerContentTypeApplicationJSON jsonBody, err := json.Marshal(req) if err != nil { return err } data, err := sfa.sc.rest.FuncPostQuery( sfa.ctx, sfa.sc.rest, &url.Values{}, headers, jsonBody, sfa.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sfa.ctx), sfa.sc.cfg) if err != nil { return err } if data.Data.StageInfo != (execResponseStageInfo{}) { meta.stageInfo = &data.Data.StageInfo meta.presignedURL = nil if meta.stageInfo.PresignedURL != "" { meta.presignedURL, err = url.Parse(meta.stageInfo.PresignedURL) if err != nil { return err } } } } case downloadCommand: for i, meta := range sfa.fileMetadata { if len(sfa.presignedURLs) > 0 { var err error meta.presignedURL, err = url.Parse(sfa.presignedURLs[i]) if err != nil { return err } } else { meta.presignedURL = nil } } default: return exceptionTelemetry(&SnowflakeError{ Number: ErrCommandNotRecognized, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgCommandNotRecognized, MessageArgs: []any{sfa.commandType}, }, sfa.sc) } } return nil } type s3BucketAccelerateConfigGetter interface { GetBucketAccelerateConfiguration(ctx context.Context, params *s3.GetBucketAccelerateConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketAccelerateConfigurationOutput, error) } type s3ClientCreator interface { extractBucketNameAndPath(location string) (*s3Location, error) createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config, telemetry *snowflakeTelemetry) (cloudClient, error) } func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s3ClientCreator) error { s3Loc, err := s3Util.extractBucketNameAndPath(sfa.stageInfo.Location) if err != nil { return err } s3Cli, err := s3Util.createClientWithConfig(sfa.stageInfo, false, sfa.sc.cfg, sfa.sc.telemetry) if err != nil { return err } client, ok := s3Cli.(s3BucketAccelerateConfigGetter) if !ok { return exceptionTelemetry(&SnowflakeError{ Number: ErrFailedToConvertToS3Client, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgFailedToConvertToS3Client, }, sfa.sc) } ret, err := withCloudStorageTimeout(sfa.ctx, sfa.sc.cfg, func(ctx context.Context) (*s3.GetBucketAccelerateConfigurationOutput, error) { return client.GetBucketAccelerateConfiguration(ctx, &s3.GetBucketAccelerateConfigurationInput{ Bucket: &s3Loc.bucketName, }) }) sfa.useAccelerateEndpoint = ret != nil && ret.Status == "Enabled" if err != nil { logger.WithContext(sfa.sc.ctx).Warnf("An error occurred when getting accelerate config: %v", err) } return nil } func withCloudStorageTimeout[T any](ctx context.Context, cfg *Config, f func(ctx context.Context) (T, error)) (T, error) { if cfg.CloudStorageTimeout > 0 { cancelCtx, cancelFunc := context.WithTimeout(ctx, cfg.CloudStorageTimeout) defer cancelFunc() return f(cancelCtx) } return f(ctx) } func (sfa *snowflakeFileTransferAgent) transferAccelerateConfig() error { if sfa.stageLocationType == s3Client { s3Util := new(snowflakeS3Client) return sfa.transferAccelerateConfigWithUtil(s3Util) } return nil } func (sfa *snowflakeFileTransferAgent) getLocalFilePathFromCommand(command string) string { if len(command) == 0 || !strings.Contains(command, fileProtocol) { return "" } if !regexp.MustCompile(putRegexp).Match([]byte(command)) { return "" } filePathBeginIdx := strings.Index(command, fileProtocol) isFilePathQuoted := command[filePathBeginIdx-1] == '\'' filePathBeginIdx += len(fileProtocol) var filePathEndIdx int filePath := "" if isFilePathQuoted { filePathEndIdx = filePathBeginIdx + strings.Index(command[filePathBeginIdx:], "'") if filePathEndIdx > filePathBeginIdx { filePath = command[filePathBeginIdx:filePathEndIdx] } } else { indexList := make([]int, 0) delims := []rune{' ', '\n', ';'} for _, delim := range delims { index := strings.Index(command[filePathBeginIdx:], string(delim)) if index != -1 { indexList = append(indexList, index) } } filePathEndIdx = -1 if getMin(indexList) != -1 { filePathEndIdx = filePathBeginIdx + getMin(indexList) } if filePathEndIdx > filePathBeginIdx { filePath = command[filePathBeginIdx:filePathEndIdx] } else { filePath = command[filePathBeginIdx:] } } return filePath } func (sfa *snowflakeFileTransferAgent) upload( largeFileMetadata []*fileMetadata, smallFileMetadata []*fileMetadata) error { client, err := sfa.getStorageClient(sfa.stageLocationType). createClient(sfa.stageInfo, sfa.useAccelerateEndpoint, sfa.sc.cfg, sfa.sc.telemetry) if err != nil { return err } for _, meta := range smallFileMetadata { meta.client = client } for _, meta := range largeFileMetadata { meta.client = client } if len(smallFileMetadata) > 0 { logger.WithContext(sfa.sc.ctx).Infof("uploading %v small files", len(smallFileMetadata)) if err = sfa.uploadFilesParallel(smallFileMetadata); err != nil { return err } } if len(largeFileMetadata) > 0 { logger.WithContext(sfa.sc.ctx).Infof("uploading %v large files", len(largeFileMetadata)) if err = sfa.uploadFilesSequential(largeFileMetadata); err != nil { return err } } return nil } func (sfa *snowflakeFileTransferAgent) download( fileMetadata []*fileMetadata) error { client, err := sfa.getStorageClient(sfa.stageLocationType). createClient(sfa.stageInfo, sfa.useAccelerateEndpoint, sfa.sc.cfg, nil) if err != nil { return err } for _, meta := range fileMetadata { meta.client = client } logger.WithContext(sfa.sc.ctx).Infof("downloading %v files", len(fileMetadata)) if err = sfa.downloadFilesParallel(fileMetadata); err != nil { return err } return nil } func (sfa *snowflakeFileTransferAgent) uploadFilesParallel(fileMetas []*fileMetadata) error { idx := 0 fileMetaLen := len(fileMetas) var err error for idx < fileMetaLen { endOfIdx := intMin(fileMetaLen, idx+int(sfa.parallel)) targetMeta := fileMetas[idx:endOfIdx] for { var wg sync.WaitGroup results := make([]*fileMetadata, len(targetMeta)) errors := make([]error, len(targetMeta)) for i, meta := range targetMeta { wg.Add(1) go func(k int, m *fileMetadata) { defer wg.Done() defer func() { if r := recover(); r != nil { errors[k] = fmt.Errorf("panic during file upload: %v", r) results[k] = nil } }() results[k], errors[k] = sfa.uploadOneFile(m) }(i, meta) } wg.Wait() // append errors with no result associated to separate array var errorMessages []string for i, result := range results { if result == nil { if errors[i] == nil { errorMessages = append(errorMessages, "unknown error") } else { errorMessages = append(errorMessages, errors[i].Error()) } } } if errorMessages != nil { // sort the error messages to be more deterministic as the goroutines may finish in different order each time sort.Strings(errorMessages) return fmt.Errorf("errors during file upload:\n%v", strings.Join(errorMessages, "\n")) } retryMeta := make([]*fileMetadata, 0) for i, result := range results { result.errorDetails = errors[i] if result.resStatus == renewToken || result.resStatus == renewPresignedURL { retryMeta = append(retryMeta, result) } else { sfa.results = append(sfa.results, result) } } if len(retryMeta) == 0 { break } needRenewToken := false for _, result := range retryMeta { if result.resStatus == renewToken { needRenewToken = true } } if needRenewToken { client, err := sfa.renewExpiredClient() if err != nil { return err } for _, result := range retryMeta { result.client = client } if endOfIdx < fileMetaLen { for i := idx + int(sfa.parallel); i < fileMetaLen; i++ { fileMetas[i].client = client } } } for _, result := range retryMeta { if result.resStatus == renewPresignedURL { if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { return err } break } } targetMeta = retryMeta } if endOfIdx == fileMetaLen { break } idx += int(sfa.parallel) } return err } func (sfa *snowflakeFileTransferAgent) uploadFilesSequential(fileMetas []*fileMetadata) error { idx := 0 fileMetaLen := len(fileMetas) for idx < fileMetaLen { res, err := sfa.uploadOneFile(fileMetas[idx]) if err != nil { return err } if res.resStatus == renewToken { client, err := sfa.renewExpiredClient() if err != nil { return err } for i := idx; i < fileMetaLen; i++ { fileMetas[i].client = client } continue } else if res.resStatus == renewPresignedURL { if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { return err } continue } sfa.results = append(sfa.results, res) idx++ } return nil } func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileMetadata, error) { meta.realSrcFileName = meta.srcFileName tmpDir := "" if meta.fileStream == nil { var err error tmpDir, err = os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "") if err != nil { return nil, err } meta.tmpDir = tmpDir } defer func() { if err := os.RemoveAll(tmpDir); err != nil { logger.WithContext(sfa.sc.ctx).Warnf("failed to remove temp dir %v: %v", tmpDir, err) } }() fileUtil := new(snowflakeFileUtil) err := compressDataIfRequired(meta, fileUtil, tmpDir) if err != nil { return meta, err } err = updateUploadSize(meta, fileUtil) if err != nil { return meta, err } err = encryptDataIfRequired(meta, sfa.stageLocationType) if err != nil { return meta, err } client := sfa.getStorageClient(sfa.stageLocationType) if err = client.uploadOneFileWithRetry(sfa.ctx, meta); err != nil { return meta, err } return meta, nil } func (sfa *snowflakeFileTransferAgent) downloadFilesParallel(fileMetas []*fileMetadata) error { idx := 0 fileMetaLen := len(fileMetas) var err error for idx < fileMetaLen { endOfIdx := intMin(fileMetaLen, idx+int(sfa.parallel)) targetMeta := fileMetas[idx:endOfIdx] for { var wg sync.WaitGroup results := make([]*fileMetadata, len(targetMeta)) errors := make([]error, len(targetMeta)) for i, meta := range targetMeta { wg.Add(1) go func(k int, m *fileMetadata) { defer wg.Done() defer func() { if r := recover(); r != nil { errors[k] = fmt.Errorf("panic during file download: %v", r) results[k] = nil } }() results[k], errors[k] = sfa.downloadOneFile(sfa.ctx, m) }(i, meta) } wg.Wait() retryMeta := make([]*fileMetadata, 0) for i, result := range results { result.errorDetails = errors[i] if result.resStatus == renewToken || result.resStatus == renewPresignedURL { retryMeta = append(retryMeta, result) } else { sfa.results = append(sfa.results, result) } } if len(retryMeta) == 0 { break } logger.WithContext(sfa.sc.ctx).Infof("%v retries found", len(retryMeta)) needRenewToken := false for _, result := range retryMeta { if result.resStatus == renewToken { needRenewToken = true } logger.WithContext(sfa.sc.ctx).Infof( "retying download file %v with status %v", result.name, result.resStatus) } if needRenewToken { client, err := sfa.renewExpiredClient() if err != nil { return err } for _, result := range retryMeta { result.client = client } if endOfIdx < fileMetaLen { for i := idx + int(sfa.parallel); i < fileMetaLen; i++ { fileMetas[i].client = client } } } for _, result := range retryMeta { if result.resStatus == renewPresignedURL { if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { return err } break } } targetMeta = retryMeta } if endOfIdx == fileMetaLen { break } idx += int(sfa.parallel) } return err } func (sfa *snowflakeFileTransferAgent) downloadOneFile(ctx context.Context, meta *fileMetadata) (*fileMetadata, error) { if !isFileGetStream(ctx) { tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "") if err != nil { return meta, err } meta.tmpDir = tmpDir defer func() { if err = os.RemoveAll(tmpDir); err != nil { logger.WithContext(sfa.sc.ctx).Warnf("failed to remove temp dir %v: %v", tmpDir, err) } }() } client := sfa.getStorageClient(sfa.stageLocationType) if err := client.downloadOneFile(ctx, meta); err != nil { meta.dstFileSize = -1 if !meta.resStatus.isSet() { meta.resStatus = errStatus } meta.errorDetails = errors.New(err.Error() + ", file=" + meta.dstFileName) return meta, err } return meta, nil } func (sfa *snowflakeFileTransferAgent) getStorageClient(stageLocationType cloudType) storageUtil { switch stageLocationType { case local: return &localUtil{} case s3Client, azureClient, gcsClient: return &remoteStorageUtil{ cfg: sfa.sc.cfg, telemetry: sfa.sc.telemetry, } default: return nil } } func (sfa *snowflakeFileTransferAgent) renewExpiredClient() (cloudClient, error) { data, err := sfa.sc.exec( sfa.ctx, sfa.command, false, false, false, []driver.NamedValue{}) if err != nil { return nil, err } storageClient := sfa.getStorageClient(sfa.stageLocationType) return storageClient.createClient(&data.Data.StageInfo, sfa.useAccelerateEndpoint, sfa.sc.cfg, nil) } func (sfa *snowflakeFileTransferAgent) result() (*execResponse, error) { // inherit old response data data := sfa.data rowset := make([]fileTransferResultType, 0) if sfa.commandType == uploadCommand { if len(sfa.results) > 0 { for _, meta := range sfa.results { var srcCompressionType, dstCompressionType *compressionType if meta.srcCompressionType != nil { srcCompressionType = meta.srcCompressionType } else { srcCompressionType = &compressionType{ name: "NONE", } } if meta.dstCompressionType != nil { dstCompressionType = meta.dstCompressionType } else { dstCompressionType = &compressionType{ name: "NONE", } } errorDetails := meta.errorDetails srcFileSize := meta.srcFileSize dstFileSize := meta.dstFileSize if errorDetails != nil { return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrFailedToUploadToStage, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errorDetails.Error(), }, sfa.sc) } rowset = append(rowset, fileTransferResultType{ meta.name, meta.srcFileName, meta.dstFileName, srcFileSize, dstFileSize, srcCompressionType, dstCompressionType, meta.resStatus, meta.errorDetails, }) } sort.Slice(rowset, func(i, j int) bool { return rowset[i].srcFileName < rowset[j].srcFileName }) ccrs := make([][]*string, 0, len(rowset)) for _, rs := range rowset { srcFileSize := fmt.Sprintf("%v", rs.srcFileSize) dstFileSize := fmt.Sprintf("%v", rs.dstFileSize) resStatus := rs.resStatus.String() errorStr := "" if rs.errorDetails != nil { errorStr = rs.errorDetails.Error() } ccrs = append(ccrs, []*string{ &rs.srcFileName, &rs.dstFileName, &srcFileSize, &dstFileSize, &rs.srcCompressionType.name, &rs.dstCompressionType.name, &resStatus, &errorStr, }) } data.RowSet = ccrs cc := make([]chunkRowType, len(ccrs)) populateJSONRowSet(cc, ccrs) data.QueryResultFormat = "json" rt := []query.ExecResponseRowType{ {Name: "source", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, {Name: "target", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, {Name: "source_size", ByteLength: 64, Length: 64, Type: "FIXED", Scale: 0, Nullable: false}, {Name: "target_size", ByteLength: 64, Length: 64, Type: "FIXED", Scale: 0, Nullable: false}, {Name: "source_compression", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, {Name: "target_compression", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, {Name: "status", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, {Name: "message", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, } data.RowType = rt return &execResponse{Data: *data, Success: true}, nil } } else { // DOWNLOAD if len(sfa.results) > 0 { for _, meta := range sfa.results { dstFileSize := meta.dstFileSize errorDetails := meta.errorDetails if errorDetails != nil { return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrFailedToDownloadFromStage, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errorDetails.Error(), }, sfa.sc) } rowset = append(rowset, fileTransferResultType{ "", "", meta.dstFileName, 0, dstFileSize, nil, nil, meta.resStatus, meta.errorDetails, }) } sort.Slice(rowset, func(i, j int) bool { return rowset[i].srcFileName < rowset[j].srcFileName }) ccrs := make([][]*string, 0, len(rowset)) for _, rs := range rowset { dstFileSize := fmt.Sprintf("%v", rs.dstFileSize) resStatus := rs.resStatus.String() errorStr := "" if rs.errorDetails != nil { errorStr = rs.errorDetails.Error() } ccrs = append(ccrs, []*string{ &rs.dstFileName, &dstFileSize, &resStatus, &errorStr, }) } data.RowSet = ccrs cc := make([]chunkRowType, len(ccrs)) populateJSONRowSet(cc, ccrs) data.QueryResultFormat = "json" rt := []query.ExecResponseRowType{ {Name: "file", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, {Name: "size", ByteLength: 64, Length: 64, Type: "FIXED", Scale: 0, Nullable: false}, {Name: "status", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, {Name: "message", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, } data.RowType = rt return &execResponse{Data: *data, Success: true}, nil } } return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrNotImplemented, SQLState: sfa.data.SQLState, QueryID: sfa.data.QueryID, Message: errors2.ErrMsgNotImplemented, }, sfa.sc) } func isFileTransfer(query string) bool { putRe := regexp.MustCompile(putRegexp) getRe := regexp.MustCompile(getRegexp) return putRe.Match([]byte(query)) || getRe.Match([]byte(query)) } type snowflakeProgressPercentage struct { filename string fileSize float64 outputStream *io.Writer showProgressBar bool seenSoFar int64 done bool startTime time.Time } func (spp *snowflakeProgressPercentage) call(bytesAmount int64) { if spp.outputStream != nil { spp.seenSoFar += bytesAmount percentage := spp.percent(spp.seenSoFar, spp.fileSize) if !spp.done { spp.done = spp.updateProgress(spp.filename, spp.startTime, spp.fileSize, percentage, spp.outputStream, spp.showProgressBar) } } } func (spp *snowflakeProgressPercentage) percent(seenSoFar int64, size float64) float64 { if float64(seenSoFar) >= size || size <= 0 { return 1.0 } return float64(seenSoFar) / size } func (spp *snowflakeProgressPercentage) updateProgress(filename string, startTime time.Time, totalSize float64, progress float64, outputStream *io.Writer, showProgressBar bool) bool { barLength := 10 totalSize /= mb status := "" elapsedTime := time.Since(startTime) var throughput float64 if elapsedTime != 0.0 { throughput = totalSize / elapsedTime.Seconds() } if progress < 0 { progress = 0 status = "Halt...\r\n" } if progress >= 1 { status = fmt.Sprintf("Done (%.3fs, %.2fMB/s)", elapsedTime.Seconds(), throughput) } if status == "" && showProgressBar { status = fmt.Sprintf("(%.3fsm %.2fMB/s)", elapsedTime.Seconds(), throughput) } if status != "" { block := int(math.Round(float64(barLength) * progress)) text := fmt.Sprintf("\r%v(%.2fMB): [%v] %.2f%% %v ", filename, totalSize, strings.Repeat("#", block)+strings.Repeat("-", barLength-block), progress*100, status) _, err := (*outputStream).Write([]byte(text)) if err != nil { logger.Warnf("cannot write status of progress. %v", err) } } return progress == 1.0 } func compressDataIfRequired(meta *fileMetadata, fileUtil *snowflakeFileUtil, tmpDir string) error { var err error if meta.requireCompress { if meta.srcStream != nil { meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(&meta.srcStream) } else { meta.realSrcFileName, _, err = fileUtil.compressFileWithGzip(meta.srcFileName, tmpDir) } } return err } func updateUploadSize(meta *fileMetadata, fileUtil *snowflakeFileUtil) error { var err error if meta.fileStream != nil { meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(meta.fileStream) } else { meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForFile(meta.realSrcFileName) } return err } func encryptDataIfRequired(meta *fileMetadata, ct cloudType) error { if ct != local && meta.encryptionMaterial != nil { var err error if meta.srcStream != nil { var encryptedStream bytes.Buffer srcStream := cmp.Or(meta.realSrcStream, meta.srcStream) meta.encryptMeta, err = encryptStreamCBC(meta.encryptionMaterial, srcStream, &encryptedStream, 0) if err != nil { return err } meta.realSrcStream = &encryptedStream } else { var dataFile string meta.encryptMeta, dataFile, err = encryptFileCBC(meta.encryptionMaterial, meta.realSrcFileName, 0, meta.tmpDir) if err != nil { return err } meta.realSrcFileName = dataFile } } return nil } ================================================ FILE: file_transfer_agent_test.go ================================================ package gosnowflake import ( "bytes" "context" "errors" "fmt" "io" "net/url" "os" "path" "path/filepath" "regexp" "strconv" "strings" "testing" "time" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/smithy-go" ) type tcFilePath struct { command string path string } func TestGetBucketAccelerateConfiguration(t *testing.T) { if runningOnGithubAction() { t.Skip("Should be run against an account in AWS EU North1 region.") } runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: make([]string, 0), }, } if err := sfa.transferAccelerateConfig(); err != nil { var ae smithy.APIError if errors.As(err, &ae) { if ae.ErrorCode() == "MethodNotAllowed" { t.Fatalf("should have ignored 405 error: %v", err) } } } }) } type s3ClientCreatorMock struct { extract func(string) (*s3Location, error) create func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config, telemetry *snowflakeTelemetry) (cloudClient, error) } func (mock *s3ClientCreatorMock) extractBucketNameAndPath(location string) (*s3Location, error) { return mock.extract(location) } func (mock *s3ClientCreatorMock) createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config, telemetry *snowflakeTelemetry) (cloudClient, error) { return mock.create(info, useAccelerateEndpoint, cfg, telemetry) } type s3BucketAccelerateConfigGetterMock struct { err error } func (mock *s3BucketAccelerateConfigGetterMock) GetBucketAccelerateConfiguration(ctx context.Context, params *s3.GetBucketAccelerateConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketAccelerateConfigurationOutput, error) { return nil, mock.err } func TestGetBucketAccelerateConfigurationTooManyRetries(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { buf := &bytes.Buffer{} logger.SetOutput(buf) err := logger.SetLogLevel("warn") if err != nil { return } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: make([]string, 0), }, stageInfo: &execResponseStageInfo{ Location: "test", }, } err = sfa.transferAccelerateConfigWithUtil(&s3ClientCreatorMock{ extract: func(s string) (*s3Location, error) { return &s3Location{bucketName: "test", s3Path: "test"}, nil }, create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config, _ *snowflakeTelemetry) (cloudClient, error) { return &s3BucketAccelerateConfigGetterMock{err: errors.New("testing")}, nil }, }) assertNilE(t, err) assertStringContainsE(t, buf.String(), "msg=\"An error occurred when getting accelerate config: testing\"") }) } func TestGetBucketAccelerateConfigurationFailedExtractBucketNameAndPath(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: make([]string, 0), }, stageInfo: &execResponseStageInfo{ Location: "test", }, } err := sfa.transferAccelerateConfigWithUtil(&s3ClientCreatorMock{ extract: func(s string) (*s3Location, error) { return nil, errors.New("failed extraction") }, }) assertNotNilE(t, err) }) } func TestGetBucketAccelerateConfigurationFailedCreateClient(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: make([]string, 0), }, stageInfo: &execResponseStageInfo{ Location: "test", }, } err := sfa.transferAccelerateConfigWithUtil(&s3ClientCreatorMock{ extract: func(s string) (*s3Location, error) { return &s3Location{bucketName: "test", s3Path: "test"}, nil }, create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config, _ *snowflakeTelemetry) (cloudClient, error) { return nil, errors.New("failed creation") }, }) assertNotNilE(t, err) }) } func TestGetBucketAccelerateConfigurationInvalidClient(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: make([]string, 0), }, stageInfo: &execResponseStageInfo{ Location: "test", }, } err := sfa.transferAccelerateConfigWithUtil(&s3ClientCreatorMock{ extract: func(s string) (*s3Location, error) { return &s3Location{bucketName: "test", s3Path: "test"}, nil }, create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config, _ *snowflakeTelemetry) (cloudClient, error) { return 1, nil }, }) assertNotNilE(t, err) }) } func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { tmpDir, err := os.MkdirTemp("", "data") if err != nil { t.Error(err) } defer func() { assertNilF(t, os.RemoveAll(tmpDir)) }() testData := filepath.Join(tmpDir, "data.txt") f, err := os.Create(testData) if err != nil { t.Error(err) } _, err = f.WriteString("test1,test2\ntest3,test4\n") assertNilF(t, err) assertNilF(t, f.Close()) runDBTest(t, func(dbt *DBTest) { if _, err = dbt.exec("use role sysadmin"); err != nil { t.Skip("snowflake admin account not accessible") } dbt.mustExec("rm @~/test_get") sqlText := fmt.Sprintf("put file://%v @~/test_get", testData) sqlText = strings.ReplaceAll(sqlText, "\\", "\\\\") dbt.mustExec(sqlText) sqlText = fmt.Sprintf("get @~/test_get/data.txt file://%v\\get", tmpDir) if _, err = dbt.query(sqlText); err == nil { t.Fatalf("should return local path not directory error.") } dbt.mustExec("rm @~/test_get") }) } func TestUnitGetLocalFilePathFromCommand(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: make([]string, 0), }, } testcases := []tcFilePath{ {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, {"PUT 'file:///tmp/my_data_file.txt' @~ overwrite=true", "/tmp/my_data_file.txt"}, {"PUT file:///tmp/sub_dir/my_data_file.txt\n @~ overwrite=true", "/tmp/sub_dir/my_data_file.txt"}, {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, {"", ""}, {"PUT 'file2:///tmp/my_data_file.txt' @~ overwrite=true", ""}, } for _, test := range testcases { t.Run(test.command, func(t *testing.T) { path := sfa.getLocalFilePathFromCommand(test.command) if path != test.path { t.Fatalf("unexpected file path. expected: %v, but got: %v", test.path, path) } }) } }) } func TestUnitProcessFileCompressionType(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), } testcases := []struct { srcCompression string }{ {"none"}, {"auto_detect"}, {"gzip"}, } for _, test := range testcases { t.Run(test.srcCompression, func(t *testing.T) { sfa.srcCompression = test.srcCompression err := sfa.processFileCompressionType() if err != nil { t.Fatalf("failed to process file compression") } }) } // test invalid compression type error sfa.srcCompression = "gz" data := &execResponseData{ SQLState: "S00087", QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", } sfa.data = data err := sfa.processFileCompressionType() if err == nil { t.Fatal("should have failed") } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } if driverErr.Number != ErrCompressionNotSupported { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCompressionNotSupported, driverErr.Number) } }) } func TestParseCommandWithInvalidStageLocation(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: make([]string, 0), }, } err := sfa.parseCommand() if err == nil { t.Fatal("should have raised an error") } driverErr, ok := err.(*SnowflakeError) if !ok || driverErr.Number != ErrInvalidStageLocation { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageLocation, driverErr.Number) } }) } func TestParseCommandEncryptionMaterialMismatchError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { mockEncMaterial1 := snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } mockEncMaterial2 := snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: []string{"/tmp/uploads"}, EncryptionMaterial: encryptionWrapper{ snowflakeFileEncryption: mockEncMaterial1, EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1, mockEncMaterial2}, }, }, } err := sfa.parseCommand() if err == nil { t.Fatal("should have raised an error") } driverErr, ok := err.(*SnowflakeError) if !ok || driverErr.Number != ErrInternalNotMatchEncryptMaterial { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInternalNotMatchEncryptMaterial, driverErr.Number) } }) } func TestParseCommandInvalidStorageClientException(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { tmpDir, err := os.MkdirTemp("", "abc") if err != nil { t.Error(err) } mockEncMaterial1 := snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: []string{"/tmp/uploads"}, LocalLocation: tmpDir, EncryptionMaterial: encryptionWrapper{ snowflakeFileEncryption: mockEncMaterial1, EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1}, }, }, } err = sfa.parseCommand() if err == nil { t.Fatal("should have raised an error") } driverErr, ok := err.(*SnowflakeError) if !ok || driverErr.Number != ErrInvalidStageFs { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageFs, driverErr.Number) } }) } func TestInitFileMetadataError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: []string{"fileDoesNotExist.txt"}, data: &execResponseData{ SQLState: "123456", QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", }, } err := sfa.initFileMetadata() if err == nil { t.Fatal("should have raised an error") } driverErr, ok := err.(*SnowflakeError) if !ok || driverErr.Number != ErrFileNotExists { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) } tmpDir, err := os.MkdirTemp("", "data") if err != nil { t.Error(err) } defer os.RemoveAll(tmpDir) sfa.srcFiles = []string{tmpDir} err = sfa.initFileMetadata() if err == nil { t.Fatal("should have raised an error") } driverErr, ok = err.(*SnowflakeError) if !ok || driverErr.Number != ErrFileNotExists { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) } }) } func TestUpdateMetadataWithPresignedUrl(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } dir, err := os.Getwd() if err != nil { t.Error(err) } testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" presignedURLMock := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, requestID UUID, _ *Config) (*execResponse, error) { // ensure the same requestID from context is used if len(requestID) == 0 { t.Fatal("requestID is empty") } dd := &execResponseData{ QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", Command: string(uploadCommand), StageInfo: execResponseStageInfo{ LocationType: "GCS", Location: "gcspuscentral1-4506459564-stage/users/456", Path: "users/456", Region: "US_CENTRAL1", PresignedURL: testURL, }, } return &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, }, nil } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/data1.txt"), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } sct.sc.rest.FuncPostQuery = presignedURLMock sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, command: "put file:///tmp/test_data/data1.txt @~", stageLocationType: gcsClient, fileMetadata: []*fileMetadata{&uploadMeta}, } err = sfa.updateFileMetadataWithPresignedURL() if err != nil { t.Error(err) } if testURL != sfa.fileMetadata[0].presignedURL.String() { t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) } }) } func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } dir, err := os.Getwd() if err != nil { t.Error(err) } testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: downloadCommand, command: "get @~/data1.txt.gz file:///tmp/testData", stageLocationType: gcsClient, fileMetadata: []*fileMetadata{&downloadMeta}, presignedURLs: []string{testURL}, } err = sfa.updateFileMetadataWithPresignedURL() if err != nil { t.Error(err) } if testURL != sfa.fileMetadata[0].presignedURL.String() { t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) } }) } func TestUpdateMetadataWithPresignedUrlError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, command: "get @~/data1.txt.gz file:///tmp/testData", stageLocationType: gcsClient, data: &execResponseData{ SQLState: "123456", QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", }, } err := sfa.updateFileMetadataWithPresignedURL() if err == nil { t.Fatal("should have raised an error") } driverErr, ok := err.(*SnowflakeError) if !ok || driverErr.Number != ErrCommandNotRecognized { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCommandNotRecognized, driverErr.Number) } }) } func TestUpdateMetadataSkipsSecondQueryWithGcsDownscopedToken(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "ya29.downscoped-token-test", }, } dir, err := os.Getwd() assertNilF(t, err, fmt.Sprintf("os.Getwd was unsuccessful, error: %v", err)) postQueryCalled := false presignedURLMock := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { postQueryCalled = true t.Fatal("FuncPostQuery should not be called when a downscoped token is present") return nil, nil } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) assertNilF(t, err, fmt.Sprintf("could not create gcsCli, error: %v", err)) uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: filepath.Join(dir, "test_data", "data1.txt"), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } sr := &snowflakeRestful{ FuncPostQuery: presignedURLMock, } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{}, rest: sr, }, commandType: uploadCommand, command: "put file:///tmp/test_data/data1.txt @~", stageLocationType: gcsClient, stageInfo: &info, fileMetadata: []*fileMetadata{&uploadMeta}, } err = sfa.updateFileMetadataWithPresignedURL() assertNilF(t, err, fmt.Sprintf("unexpected error in updateFileMetadataWithPresignedURL, error: %v", err)) assertFalseF(t, postQueryCalled, "should not have issued a second query when downscoped token is available") assertEqualF(t, uploadMeta.stageInfo, &info, "stageInfo on metadata should remain unchanged") } func TestUpdateMetadataStillQueriesWithPresignedUrlOnGcs(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } dir, err := os.Getwd() assertNilF(t, err, fmt.Sprintf("os.Getwd was unsuccessful, error: %v", err)) testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature456" postQueryCalled := false presignedURLMock := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { postQueryCalled = true dd := &execResponseData{ QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", Command: string(uploadCommand), StageInfo: execResponseStageInfo{ LocationType: "GCS", Location: "gcspuscentral1-4506459564-stage/users/456", Path: "users/456", Region: "US_CENTRAL1", PresignedURL: testURL, }, } return &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, }, nil } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) assertNilF(t, err, fmt.Sprintf("could not create gcsCli, error: %v", err)) uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: filepath.Join(dir, "test_data", "data1.txt"), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } sr := &snowflakeRestful{ FuncPostQuery: presignedURLMock, } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{}, rest: sr, }, commandType: uploadCommand, command: "put file:///tmp/test_data/data1.txt @~", stageLocationType: gcsClient, stageInfo: &execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", }, fileMetadata: []*fileMetadata{&uploadMeta}, } err = sfa.updateFileMetadataWithPresignedURL() assertNilF(t, err, fmt.Sprintf("unexpected error in updateFileMetadataWithPresignedURL: %v", err)) assertTrueF(t, postQueryCalled, "FuncPostQuery should have been called for presigned URL flow (no downscoped token)") assertNotNilF(t, uploadMeta.presignedURL, "presignedURL should have been set on metadata") assertEqualF(t, testURL, uploadMeta.presignedURL.String(), fmt.Sprintf("presigned URL %v does not match testUrl %v", uploadMeta.presignedURL.String(), testURL)) } func TestUploadWhenFilesystemReadOnlyError(t *testing.T) { if isWindows { t.Skip("permission model is different") } roPath := t.TempDir() // Set the temp directory to read only err := os.Chmod(roPath, 0444) if err != nil { t.Fatal(err) } info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } dir, err := os.Getwd() if err != nil { t.Error(err) } // Make sure that the test uses read only directory t.Setenv("TMPDIR", roPath) uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsClient, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/data1.txt"), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{}, }, commandType: uploadCommand, command: "put file:///tmp/test_data/data1.txt @~", stageLocationType: gcsClient, fileMetadata: []*fileMetadata{&uploadMeta}, parallel: 1, } err = sfa.uploadFilesParallel([]*fileMetadata{&uploadMeta}) if err == nil { t.Fatal("should error when the filesystem is read only") } if !strings.Contains(err.Error(), "errors during file upload:\nmkdir") { t.Fatalf("should error when creating the temporary directory. Instead errored with: %v", err) } } func TestUploadWhenErrorWithResultIsReturned(t *testing.T) { if isWindows { t.Skip("permission model is different") } var err error dir, err := os.Getwd() assertNilF(t, err) err = createWriteonlyFile(path.Join(dir, "test_data"), "writeonly.csv") assertNilF(t, err) uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: local, sha256Digest: "123456789abcdef", stageInfo: &execResponseStageInfo{ Location: dir, LocationType: "local", }, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "test_data/writeonly.csv"), overwrite: true, } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{ TmpDirPath: dir, }, }, data: &execResponseData{ SrcLocations: []string{path.Join(dir, "/test_data/writeonly.csv")}, Command: "UPLOAD", SourceCompression: "none", StageInfo: execResponseStageInfo{ LocationType: "LOCAL_FS", Location: dir, }, }, commandType: uploadCommand, command: fmt.Sprintf("put file://%v/test_data/data1.txt @~", dir), stageLocationType: local, fileMetadata: []*fileMetadata{&uploadMeta}, parallel: 1, } err = sfa.execute() assertNilF(t, err) // execute should not propagate errors, it should be returned by sfa.result only _, err = sfa.result() assertNotNilE(t, err) } func createWriteonlyFile(dir, filename string) error { path := path.Join(dir, filename) if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { if _, err := os.Create(path); err != nil { return err } } if err := os.Chmod(path, 0222); err != nil { return err } return nil } func TestUnitUpdateProgress(t *testing.T) { var b bytes.Buffer buf := io.Writer(&b) _, err := buf.Write([]byte("testing")) assertNilF(t, err) spp := &snowflakeProgressPercentage{ filename: "test.txt", fileSize: float64(1500), outputStream: &buf, showProgressBar: true, done: false, } spp.call(0) if spp.done != false { t.Fatal("should not be done.") } if spp.seenSoFar != 0 { t.Fatalf("expected seenSoFar to be 0 but was %v", spp.seenSoFar) } spp.call(1516) if spp.done != true { t.Fatal("should be done after updating progess") } } func TestCustomTmpDirPath(t *testing.T) { tmpDir, err := os.MkdirTemp("", "") if err != nil { t.Fatalf("cannot create temp directory: %v", err) } defer func() { assertNilF(t, os.RemoveAll(tmpDir)) }() uploadFile := filepath.Join(tmpDir, "data.txt") f, err := os.Create(uploadFile) if err != nil { t.Error(err) } _, err = f.WriteString("test1,test2\ntest3,test4\n") assertNilF(t, err) assertNilF(t, f.Close()) uploadMeta := &fileMetadata{ name: "data.txt.gz", stageLocationType: "local", noSleepingTime: true, client: local, sha256Digest: "123456789abcdef", stageInfo: &execResponseStageInfo{ Location: tmpDir, LocationType: "local", }, dstFileName: "data.txt.gz", srcFileName: uploadFile, overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } downloadFile := filepath.Join(tmpDir, "download.txt") downloadMeta := &fileMetadata{ name: "data.txt.gz", stageLocationType: "local", noSleepingTime: true, client: local, sha256Digest: "123456789abcdef", stageInfo: &execResponseStageInfo{ Location: tmpDir, LocationType: "local", }, srcFileName: "data.txt.gz", dstFileName: downloadFile, overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } sfa := snowflakeFileTransferAgent{ ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{ TmpDirPath: tmpDir, }, }, stageLocationType: local, } _, err = sfa.uploadOneFile(uploadMeta) if err != nil { t.Fatal(err) } _, err = sfa.downloadOneFile(context.Background(), downloadMeta) if err != nil { t.Fatal(err) } defer os.Remove("download.txt") } func TestReadonlyTmpDirPathShouldFail(t *testing.T) { if isWindows { t.Skip("permission model is different") } tmpDir, err := os.MkdirTemp("", "") if err != nil { t.Fatalf("cannot create temp directory: %v", err) } defer func() { assertNilF(t, os.RemoveAll(tmpDir)) }() uploadFile := filepath.Join(tmpDir, "data.txt") f, err := os.Create(uploadFile) if err != nil { t.Error(err) } _, err = f.WriteString("test1,test2\ntest3,test4\n") assertNilF(t, err) assertNilF(t, f.Close()) err = os.Chmod(tmpDir, 0500) if err != nil { t.Fatalf("cannot mark directory as readonly: %v", err) } defer func() { assertNilF(t, os.Chmod(tmpDir, 0700)) }() uploadMeta := &fileMetadata{ name: "data.txt.gz", stageLocationType: "local", noSleepingTime: true, client: local, sha256Digest: "123456789abcdef", stageInfo: &execResponseStageInfo{ Location: tmpDir, LocationType: "local", }, dstFileName: "data.txt.gz", srcFileName: uploadFile, overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } sfa := snowflakeFileTransferAgent{ ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{ TmpDirPath: tmpDir, }, }, stageLocationType: local, } _, err = sfa.uploadOneFile(uploadMeta) if err == nil { t.Fatalf("should not upload file as temporary directory is not readable") } } func TestUploadDownloadOneFileRequireCompress(t *testing.T) { testUploadDownloadOneFile(t, false) } func TestUploadDownloadOneFileRequireCompressStream(t *testing.T) { testUploadDownloadOneFile(t, true) } func testUploadDownloadOneFile(t *testing.T, isStream bool) { tmpDir, err := os.MkdirTemp("", "data") if err != nil { t.Fatalf("cannot create temp directory: %v", err) } defer os.RemoveAll(tmpDir) uploadFile := filepath.Join(tmpDir, "data.txt") f, err := os.Create(uploadFile) if err != nil { t.Error(err) } _, err = f.WriteString("test1,test2\ntest3,test4\n") assertNilF(t, err) assertNilF(t, f.Close()) uploadMeta := &fileMetadata{ name: "data.txt.gz", stageLocationType: "local", noSleepingTime: true, client: local, sha256Digest: "123456789abcdef", stageInfo: &execResponseStageInfo{ Location: tmpDir, LocationType: "local", }, dstFileName: "data.txt.gz", srcFileName: uploadFile, overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, requireCompress: true, } downloadFile := filepath.Join(tmpDir, "download.txt") downloadMeta := &fileMetadata{ name: "data.txt.gz", stageLocationType: "local", noSleepingTime: true, client: local, sha256Digest: "123456789abcdef", stageInfo: &execResponseStageInfo{ Location: tmpDir, LocationType: "local", }, srcFileName: "data.txt.gz", dstFileName: downloadFile, overwrite: true, parallel: int64(10), options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } sfa := snowflakeFileTransferAgent{ ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{ TmpDirPath: tmpDir, }, }, stageLocationType: local, } if isStream { fileStream, _ := os.Open(uploadFile) ctx := WithFilePutStream(context.Background(), fileStream) uploadMeta.fileStream, err = getFileStream(ctx) assertNilF(t, err) } _, err = sfa.uploadOneFile(uploadMeta) if err != nil { t.Fatal(err) } if uploadMeta.resStatus != uploaded { t.Fatalf("failed to upload file") } _, err = sfa.downloadOneFile(context.Background(), downloadMeta) if err != nil { t.Fatal(err) } defer func() { assertNilF(t, os.Remove("download.txt")) }() if downloadMeta.resStatus != downloaded { t.Fatalf("failed to download file") } } func TestPutGetRegexShouldIgnoreWhitespaceAtTheBeginning(t *testing.T) { for _, test := range []struct { regex string query string }{ { regex: putRegexp, query: "PUT abc", }, { regex: putRegexp, query: " PUT abc", }, { regex: putRegexp, query: "\tPUT abc", }, { regex: putRegexp, query: "\nPUT abc", }, { regex: putRegexp, query: "\r\nPUT abc", }, { regex: getRegexp, query: "GET abc", }, { regex: getRegexp, query: " GET abc", }, { regex: getRegexp, query: "\tGET abc", }, { regex: getRegexp, query: "\nGET abc", }, { regex: getRegexp, query: "\r\nGET abc", }, } { { t.Run(test.regex+" "+test.query, func(t *testing.T) { regex := regexp.MustCompile(test.regex) assertTrueE(t, regex.Match([]byte(test.query))) assertFalseE(t, regex.Match([]byte("prefix "+test.query))) }) } } } func TestEncryptStream(t *testing.T) { srcBytes := []byte{63, 64, 65} initStr := bytes.NewBuffer(srcBytes) for _, tc := range []struct { ct cloudType encrypt bool realSrcStream bool encryptMat bool }{ { ct: s3Client, encrypt: true, realSrcStream: true, encryptMat: true, }, { ct: s3Client, encrypt: true, realSrcStream: false, encryptMat: true, }, { ct: s3Client, encrypt: false, realSrcStream: false, encryptMat: false, }, { ct: azureClient, encrypt: true, realSrcStream: true, encryptMat: true, }, { ct: azureClient, encrypt: true, realSrcStream: false, encryptMat: true, }, { ct: azureClient, encrypt: false, realSrcStream: false, encryptMat: false, }, { ct: gcsClient, encrypt: true, realSrcStream: true, encryptMat: true, }, { ct: gcsClient, encrypt: true, realSrcStream: false, encryptMat: true, }, { ct: gcsClient, encrypt: false, realSrcStream: false, encryptMat: false, }, { ct: local, encrypt: false, realSrcStream: true, encryptMat: true, }, { ct: local, encrypt: false, realSrcStream: true, encryptMat: false, }, { ct: local, encrypt: false, realSrcStream: false, encryptMat: true, }, { ct: local, encrypt: false, realSrcStream: false, encryptMat: false, }, } { { var encMat *snowflakeFileEncryption = nil if tc.encryptMat { encMat = &snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } } var realSrcStr *bytes.Buffer = nil if tc.realSrcStream { realSrcStr = initStr } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: tc.ct, noSleepingTime: true, parallel: int64(100), client: nil, sha256Digest: "123456789abcdef", stageInfo: nil, dstFileName: "data1.txt.gz", srcStream: initStr, realSrcStream: realSrcStr, overwrite: true, options: nil, encryptionMaterial: encMat, mockUploader: nil, sfa: nil, } t.Run(string(tc.ct)+" encrypt "+strconv.FormatBool(tc.encrypt)+" realSrcStream "+strconv.FormatBool(tc.realSrcStream)+" encryptMat "+strconv.FormatBool(tc.encryptMat), func(t *testing.T) { err := encryptDataIfRequired(&uploadMeta, tc.ct) assertNilF(t, err) if tc.encrypt { assertNotNilF(t, uploadMeta.encryptMeta, "encryption metadata should be present") if tc.realSrcStream { assertNotEqualF(t, uploadMeta.realSrcStream, realSrcStr, "stream should be encrypted") } else { assertNotEqualF(t, uploadMeta.realSrcStream, initStr, "stream should not be encrypted") } } else { assertNilF(t, uploadMeta.encryptMeta, "encryption metadata should be empty") assertEqualF(t, uploadMeta.realSrcStream, realSrcStr, "stream should not be encrypted") } }) } } } func TestEncryptFile(t *testing.T) { for _, tc := range []struct { ct cloudType encrypt bool encryptMat bool }{ { ct: s3Client, encrypt: true, encryptMat: true, }, { ct: s3Client, encrypt: false, encryptMat: false, }, { ct: azureClient, encrypt: true, encryptMat: true, }, { ct: azureClient, encrypt: false, encryptMat: false, }, { ct: gcsClient, encrypt: true, encryptMat: true, }, { ct: gcsClient, encrypt: false, encryptMat: false, }, { ct: local, encrypt: false, encryptMat: true, }, { ct: local, encrypt: false, encryptMat: false, }, } { dir, err := os.Getwd() srcF := path.Join(dir, "/test_data/put_get_1.txt") assertNilF(t, err, "error getting current directory") var encMat *snowflakeFileEncryption = nil if tc.encryptMat { encMat = &snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: tc.ct, noSleepingTime: true, parallel: int64(100), client: nil, sha256Digest: "123456789abcdef", stageInfo: nil, dstFileName: "data1.txt.gz", srcFileName: srcF, realSrcFileName: srcF, overwrite: true, options: nil, encryptionMaterial: encMat, mockUploader: nil, sfa: nil, } t.Run(string(tc.ct)+" encrypt "+strconv.FormatBool(tc.encrypt)+" encryptMat "+strconv.FormatBool(tc.encryptMat), func(t *testing.T) { err := encryptDataIfRequired(&uploadMeta, tc.ct) assertNilF(t, err) if tc.encrypt { assertNotNilF(t, uploadMeta.encryptMeta, "encryption metadata should be present") assertNotEqualF(t, uploadMeta.realSrcFileName, srcF, "file should be encrypted") srcBytes, err := os.ReadFile(srcF) assertNilF(t, err) encBytes, err := os.ReadFile(uploadMeta.realSrcFileName) assertNilF(t, err) assertFalseF(t, bytes.Equal(srcBytes, encBytes), "file contents should differ") } else { assertNilF(t, uploadMeta.encryptMeta, "encryption metadata should be empty") assertEqualF(t, uploadMeta.realSrcFileName, srcF, "file should not be encrypted") } }) } } ================================================ FILE: file_util.go ================================================ package gosnowflake import ( "bytes" "compress/gzip" "crypto/sha256" "encoding/base64" "io" "net/url" "os" "path/filepath" "strings" ) type snowflakeFileUtil struct { } const ( fileChunkSize = 16 * 4 * 1024 readWriteFileMode os.FileMode = 0666 ) func (util *snowflakeFileUtil) compressFileWithGzipFromStream(srcStream **bytes.Buffer) (*bytes.Buffer, int, error) { r := getReaderFromBuffer(srcStream) buf, err := io.ReadAll(r) if err != nil { return nil, -1, err } var c bytes.Buffer w := gzip.NewWriter(&c) if _, err := w.Write(buf); err != nil { // write buf to gzip writer return nil, -1, err } if err := w.Close(); err != nil { return nil, -1, err } return &c, c.Len(), nil } func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir string) (gzipFileName string, size int64, err error) { basename := baseName(fileName) gzipFileName = filepath.Join(tmpDir, basename+"_c.gz") fr, err := os.Open(fileName) if err != nil { return "", -1, err } defer func() { if tmpErr := fr.Close(); tmpErr != nil { err = tmpErr } }() fw, err := os.OpenFile(gzipFileName, os.O_WRONLY|os.O_CREATE, readWriteFileMode) if err != nil { return "", -1, err } gzw := gzip.NewWriter(fw) defer func() { if tmpErr := gzw.Close(); tmpErr != nil { err = tmpErr } }() _, err = io.Copy(gzw, fr) if err != nil { return "", -1, err } stat, err := os.Stat(gzipFileName) if err != nil { return "", -1, err } return gzipFileName, stat.Size(), err } func (util *snowflakeFileUtil) getDigestAndSizeForStream(stream io.Reader) (string, int64, error) { m := sha256.New() chunk := make([]byte, fileChunkSize) var total int64 for { n, err := stream.Read(chunk) if err == io.EOF { break } else if err != nil { return "", 0, err } total += int64(n) m.Write(chunk[:n]) } return base64.StdEncoding.EncodeToString(m.Sum(nil)), total, nil } func (util *snowflakeFileUtil) getDigestAndSizeForFile(fileName string) (digest string, size int64, err error) { f, err := os.Open(fileName) if err != nil { return "", 0, err } defer func() { if tmpErr := f.Close(); tmpErr != nil { err = tmpErr } }() var total int64 m := sha256.New() chunk := make([]byte, fileChunkSize) for { n, err := f.Read(chunk) if err == io.EOF { break } else if err != nil { return "", 0, err } total += int64(n) m.Write(chunk[:n]) } if _, err = f.Seek(0, io.SeekStart); err != nil { return "", -1, err } return base64.StdEncoding.EncodeToString(m.Sum(nil)), total, err } // file metadata for PUT/GET type fileMetadata struct { name string sfa *snowflakeFileTransferAgent stageLocationType cloudType resStatus resultStatus stageInfo *execResponseStageInfo encryptionMaterial *snowflakeFileEncryption encryptMeta *encryptMetadata srcFileName string realSrcFileName string srcFileSize int64 srcCompressionType *compressionType uploadSize int64 dstFileSize int64 dstFileName string dstCompressionType *compressionType client cloudClient // *s3.Client (S3), *azblob.ContainerURL (Azure), string (GCS) requireCompress bool parallel int64 sha256Digest string overwrite bool tmpDir string errorDetails error lastError error noSleepingTime bool lastMaxConcurrency int localLocation string options *SnowflakeFileTransferOptions /* streaming PUT */ fileStream io.Reader srcStream *bytes.Buffer realSrcStream *bytes.Buffer /* streaming GET */ dstStream *bytes.Buffer /* GCS */ presignedURL *url.URL gcsFileHeaderDigest string gcsFileHeaderContentLength int64 gcsFileHeaderEncryptionMeta *encryptMetadata /* mock */ mockUploader s3UploadAPI mockDownloader s3DownloadAPI mockHeader s3HeaderAPI mockGcsClient gcsAPI mockAzureClient azureAPI } type fileTransferResultType struct { name string srcFileName string dstFileName string srcFileSize int64 dstFileSize int64 srcCompressionType *compressionType dstCompressionType *compressionType resStatus resultStatus errorDetails error } type fileHeader struct { digest string contentLength int64 encryptionMetadata *encryptMetadata } func getReaderFromBuffer(src **bytes.Buffer) io.Reader { var b bytes.Buffer tee := io.TeeReader(*src, &b) // read src to buf *src = &b // revert pointer back return tee } // baseName returns the pathname of the path provided func baseName(path string) string { base := filepath.Base(path) if base == "." || base == "/" { return "" } if len(base) > 1 && (path[len(path)-1:] == "." || path[len(path)-1:] == "/") { return "" } return base } // expandUser returns the argument with an initial component of ~ func expandUser(path string) (string, error) { if !strings.HasPrefix(path, "~") { return path, nil } homeDir, err := os.UserHomeDir() if err != nil { return "", err } if path == "~" { path = homeDir } else if strings.HasPrefix(path, "~/") { path = filepath.Join(homeDir, path[2:]) } return path, nil } // getDirectory retrieves the current working directory func getDirectory() (string, error) { ex, err := os.Executable() if err != nil { return "", err } return filepath.Dir(ex), nil } ================================================ FILE: file_util_test.go ================================================ package gosnowflake import ( "os/user" "path/filepath" "testing" ) func TestGetDigestAndSizeForInvalidDir(t *testing.T) { fileUtil := new(snowflakeFileUtil) digest, size, err := fileUtil.getDigestAndSizeForFile("/home/file.txt") if digest != "" { t.Fatal("should be empty") } if size != 0 { t.Fatal("should be 0") } if err == nil { t.Fatal("should have failed") } } type tcBaseName struct { in string out string } func TestBaseName(t *testing.T) { testcases := []tcBaseName{ {"/tmp", "tmp"}, {"/home/desktop/.", ""}, {"/home/desktop/..", ""}, } for _, test := range testcases { t.Run(test.in, func(t *testing.T) { base := baseName(test.in) if test.out != base { t.Errorf("Failed to get base, input %v, expected: %v, got: %v", test.in, test.out, base) } }) } } func TestExpandUser(t *testing.T) { skipOnMissingHome(t) usr, err := user.Current() if err != nil { t.Fatal(err) } homeDir := usr.HomeDir user, err := expandUser("~") if err != nil { t.Fatal(err) } if homeDir != user { t.Fatalf("failed to expand user, expected: %v, got: %v", homeDir, user) } user, err = expandUser("~/storage") if err != nil { t.Fatal(err) } expectedPath := filepath.Join(homeDir, "storage") if expectedPath != user { t.Fatalf("failed to expand user, expected: %v, got: %v", expectedPath, user) } } ================================================ FILE: function_wrapper_test.go ================================================ package gosnowflake import ( "context" "sync" "testing" ) func TestGoWrapper(t *testing.T) { var ( goWrapperCalled = false testGoRoutineWrapperLock sync.Mutex ) setGoWrapperCalled := func(value bool) { testGoRoutineWrapperLock.Lock() defer testGoRoutineWrapperLock.Unlock() goWrapperCalled = value } getGoWrapperCalled := func() bool { testGoRoutineWrapperLock.Lock() defer testGoRoutineWrapperLock.Unlock() return goWrapperCalled } // this is the go wrapper function we are going to pass into GoroutineWrapper. // we will know that this has been called if the channel is closed var closeGoWrapperCalledChannel = func(ctx context.Context, f func()) { setGoWrapperCalled(true) f() } runDBTest(t, func(dbt *DBTest) { oldGoroutineWrapper := GoroutineWrapper t.Cleanup(func() { GoroutineWrapper = oldGoroutineWrapper }) GoroutineWrapper = closeGoWrapperCalledChannel ctx := WithAsyncMode(context.Background()) rows := dbt.mustQueryContext(ctx, "SELECT 1") assertTrueE(t, rows.Next()) var i int assertNilF(t, rows.Scan(&i)) rows.Close() assertTrueF(t, getGoWrapperCalled(), "channel should be closed, indicating our wrapper worked") }) } ================================================ FILE: function_wrappers.go ================================================ package gosnowflake import "context" // GoroutineWrapperFunc is used to wrap goroutines. This is useful if the caller wants // to recover panics, rather than letting panics cause a system crash. A suggestion would be to // use use the recover functionality, and log the panic as is most useful to you type GoroutineWrapperFunc func(ctx context.Context, f func()) // The default GoroutineWrapperFunc; this does nothing. With this default wrapper // panics will take down binary as expected var noopGoroutineWrapper = func(_ context.Context, f func()) { f() } // GoroutineWrapper is used to hold the GoroutineWrapperFunc set by the client, or to // store the default goroutine wrapper which does nothing var GoroutineWrapper GoroutineWrapperFunc = noopGoroutineWrapper ================================================ FILE: gcs_storage_client.go ================================================ package gosnowflake import ( "cmp" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "strconv" "strings" ) const ( gcsMetadataPrefix = "x-goog-meta-" gcsMetadataSfcDigest = gcsMetadataPrefix + sfcDigest gcsMetadataMatdescKey = gcsMetadataPrefix + "matdesc" gcsMetadataEncryptionDataProp = gcsMetadataPrefix + "encryptiondata" gcsFileHeaderDigest = "gcs-file-header-digest" gcsRegionMeCentral2 = "me-central2" minimumDownloadPartSize = 1024 * 1024 * 5 // 5MB ) type snowflakeGcsClient struct { cfg *Config telemetry *snowflakeTelemetry } type gcsLocation struct { bucketName string path string } func (util *snowflakeGcsClient) createClient(info *execResponseStageInfo, _ bool, telemetry *snowflakeTelemetry) (cloudClient, error) { if info.Creds.GcsAccessToken != "" { logger.Debug("Using GCS downscoped token") return info.Creds.GcsAccessToken, nil } logger.Debugf("No access token received from GS, using presigned url: %s", info.PresignedURL) return "", nil } // cloudUtil implementation func (util *snowflakeGcsClient) getFileHeader(ctx context.Context, meta *fileMetadata, filename string) (*fileHeader, error) { if meta.resStatus == uploaded || meta.resStatus == downloaded { return &fileHeader{ digest: meta.gcsFileHeaderDigest, contentLength: meta.gcsFileHeaderContentLength, encryptionMetadata: meta.gcsFileHeaderEncryptionMeta, }, nil } if meta.presignedURL != nil { meta.resStatus = notFoundFile } else { URL, err := util.generateFileURL(meta.stageInfo, strings.TrimLeft(filename, "/")) if err != nil { return nil, err } accessToken, ok := meta.client.(string) if !ok { return nil, fmt.Errorf("interface convertion. expected type string but got %T", meta.client) } gcsHeaders := map[string]string{ "Authorization": "Bearer " + accessToken, } resp, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "HEAD", URL.String(), nil) if err != nil { return nil, err } for k, v := range gcsHeaders { req.Header.Add(k, v) } client, err := newGcsClient(util.cfg, util.telemetry) if err != nil { return nil, err } // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient } resp, err := client.Do(req) if err != nil && strings.HasSuffix(err.Error(), "EOF") { logger.Debug("Retrying HEAD request because of EOF") resp, err = client.Do(req) } return resp, err }) if err != nil { return nil, err } defer func() { if resp.Body != nil { if err := resp.Body.Close(); err != nil { logger.Warnf("failed to close response body: %v", err) } } }() if resp.StatusCode != http.StatusOK { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = errStatus if resp.StatusCode == 403 || resp.StatusCode == 408 || resp.StatusCode == 429 || resp.StatusCode == 500 || resp.StatusCode == 503 { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = needRetry return nil, meta.lastError } if resp.StatusCode == 404 { meta.resStatus = notFoundFile } else if util.isTokenExpired(resp) { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewToken } return nil, meta.lastError } digest := resp.Header.Get(gcsMetadataSfcDigest) contentLength, err := strconv.Atoi(resp.Header.Get("content-length")) if err != nil { return nil, err } var encryptionMeta *encryptMetadata if resp.Header.Get(gcsMetadataEncryptionDataProp) != "" { var encryptData *encryptionData err := json.Unmarshal([]byte(resp.Header.Get(gcsMetadataEncryptionDataProp)), &encryptData) if err != nil { return nil, fmt.Errorf("cannot unmarshal encryption data: %v", err) } if encryptData != nil { encryptionMeta = &encryptMetadata{ key: encryptData.WrappedContentKey.EncryptionKey, iv: encryptData.ContentEncryptionIV, } if resp.Header.Get(gcsMetadataMatdescKey) != "" { encryptionMeta.matdesc = resp.Header.Get(gcsMetadataMatdescKey) } } } meta.resStatus = uploaded return &fileHeader{ digest: digest, contentLength: int64(contentLength), encryptionMetadata: encryptionMeta, }, nil } return nil, nil } type gcsAPI interface { Do(req *http.Request) (*http.Response, error) } // cloudUtil implementation func (util *snowflakeGcsClient) uploadFile( ctx context.Context, dataFile string, meta *fileMetadata, maxConcurrency int, multiPartThreshold int64) error { uploadURL := meta.presignedURL var accessToken string var err error if uploadURL == nil { uploadURL, err = util.generateFileURL(meta.stageInfo, strings.TrimLeft(meta.dstFileName, "/")) if err != nil { return err } var ok bool accessToken, ok = meta.client.(string) if !ok { return fmt.Errorf("interface convertion. expected type string but got %T", meta.client) } } var contentEncoding string if meta.dstCompressionType != nil { contentEncoding = strings.ToLower(meta.dstCompressionType.name) } if contentEncoding == "gzip" { contentEncoding = "" } gcsHeaders := make(map[string]string) gcsHeaders[httpHeaderContentEncoding] = contentEncoding gcsHeaders[gcsMetadataSfcDigest] = meta.sha256Digest if accessToken != "" { gcsHeaders["Authorization"] = "Bearer " + accessToken } if meta.encryptMeta != nil { encryptData := encryptionData{ "FullBlob", contentKey{ "symmKey1", meta.encryptMeta.key, "AES_CBC_256", }, encryptionAgent{ "1.0", "AES_CBC_256", }, meta.encryptMeta.iv, keyMetadata{ "Java 5.3.0", }, } b, err := json.Marshal(&encryptData) if err != nil { return err } gcsHeaders[gcsMetadataEncryptionDataProp] = string(b) gcsHeaders[gcsMetadataMatdescKey] = meta.encryptMeta.matdesc } var uploadSrc io.Reader if meta.srcStream != nil { uploadSrc = meta.srcStream if meta.realSrcStream != nil { uploadSrc = meta.realSrcStream } } else { var err error uploadSrc, err = os.Open(dataFile) if err != nil { return err } defer func(src io.Closer) { if err := src.Close(); err != nil { logger.Warnf("failed to close %v file: %v", dataFile, err) } }(uploadSrc.(io.Closer)) } resp, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "PUT", uploadURL.String(), uploadSrc) if err != nil { return nil, err } for k, v := range gcsHeaders { req.Header.Add(k, v) } client, err := newGcsClient(util.cfg, util.telemetry) if err != nil { return nil, err } // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient } return client.Do(req) }) if err != nil { return err } defer func() { if resp.Body != nil { if err := resp.Body.Close(); err != nil { logger.Warnf("failed to close response body: %v", err) } } }() if resp.StatusCode != http.StatusOK { if resp.StatusCode == 403 || resp.StatusCode == 408 || resp.StatusCode == 429 || resp.StatusCode == 500 || resp.StatusCode == 503 { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = needRetry } else if accessToken == "" && resp.StatusCode == 400 && meta.lastError == nil { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewPresignedURL } else if accessToken != "" && util.isTokenExpired(resp) { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewToken } else { meta.lastError = fmt.Errorf("%v", resp.Status) } return meta.lastError } if meta.options.putCallback != nil { meta.options.putCallback = &snowflakeProgressPercentage{ filename: dataFile, fileSize: float64(meta.srcFileSize), outputStream: meta.options.putCallbackOutputStream, showProgressBar: meta.options.showProgressBar, } } meta.dstFileSize = meta.uploadSize meta.resStatus = uploaded meta.gcsFileHeaderDigest = gcsHeaders[gcsFileHeaderDigest] meta.gcsFileHeaderContentLength = meta.uploadSize if err = json.Unmarshal([]byte(gcsHeaders[gcsMetadataEncryptionDataProp]), &meta.encryptMeta); err != nil { return err } meta.gcsFileHeaderEncryptionMeta = meta.encryptMeta return nil } // cloudUtil implementation func (util *snowflakeGcsClient) nativeDownloadFile( ctx context.Context, meta *fileMetadata, fullDstFileName string, maxConcurrency int64, partSize int64) error { partSize = int64Max(partSize, minimumDownloadPartSize) downloadURL := meta.presignedURL var accessToken string var err error gcsHeaders := make(map[string]string) if downloadURL == nil || downloadURL.String() == "" { downloadURL, err = util.generateFileURL(meta.stageInfo, strings.TrimLeft(meta.srcFileName, "/")) if err != nil { return err } var ok bool accessToken, ok = meta.client.(string) if !ok { return fmt.Errorf("interface convertion. expected type string but got %T", meta.client) } if accessToken != "" { gcsHeaders["Authorization"] = "Bearer " + accessToken } } logger.Debugf("GCS Client: Send Get Request to %v", downloadURL.String()) // First, get file size with a HEAD request to determine if multi-part download is needed // Also extract metadata during this request fileHeader, err := util.getFileHeaderForDownload(ctx, downloadURL, gcsHeaders, accessToken, meta) if err != nil { return err } fileSize := fileHeader.ContentLength // Use multi-part download for files larger than partSize or when maxConcurrency > 1 if fileSize > partSize && maxConcurrency > 1 { err = util.downloadFileInParts(ctx, downloadURL, gcsHeaders, accessToken, meta, fullDstFileName, fileSize, maxConcurrency, partSize) } else { // Fall back to single-part download for smaller files err = util.downloadFileSinglePart(ctx, downloadURL, gcsHeaders, accessToken, meta, fullDstFileName) } if err != nil { return err } var encryptMeta encryptMetadata if fileHeader.Header.Get(gcsMetadataEncryptionDataProp) != "" { var encryptData *encryptionData if err = json.Unmarshal([]byte(fileHeader.Header.Get(gcsMetadataEncryptionDataProp)), &encryptData); err != nil { return err } if encryptData != nil { encryptMeta = encryptMetadata{ encryptData.WrappedContentKey.EncryptionKey, encryptData.ContentEncryptionIV, "", } if key := fileHeader.Header.Get(gcsMetadataMatdescKey); key != "" { encryptMeta.matdesc = key } } } meta.resStatus = downloaded meta.gcsFileHeaderDigest = fileHeader.Header.Get(gcsMetadataSfcDigest) meta.gcsFileHeaderContentLength = fileSize meta.gcsFileHeaderEncryptionMeta = &encryptMeta return nil } // getFileHeaderForDownload gets the file header using a HEAD request func (util *snowflakeGcsClient) getFileHeaderForDownload(ctx context.Context, downloadURL *url.URL, gcsHeaders map[string]string, accessToken string, meta *fileMetadata) (*http.Response, error) { resp, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "HEAD", downloadURL.String(), nil) if err != nil { return nil, err } for k, v := range gcsHeaders { req.Header.Add(k, v) } client, err := newGcsClient(util.cfg, util.telemetry) if err != nil { return nil, err } // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient } return client.Do(req) }) if err != nil { return nil, err } defer func() { if resp.Body != nil { if err := resp.Body.Close(); err != nil { logger.Warnf("Failed to close response body: %v", err) } } }() if resp.StatusCode != http.StatusOK { return nil, util.handleHTTPError(resp, meta, accessToken) } return resp, nil } // downloadPart is a struct for downloading a part of a file in memory type downloadPart struct { data []byte index int64 err error } // downloadPartStream is a struct for downloading a part of a file in a stream type downloadPartStream struct { stream io.ReadCloser index int64 err error } type downloadJob struct { index int64 start int64 end int64 } func (util *snowflakeGcsClient) downloadFileInParts( ctx context.Context, downloadURL *url.URL, gcsHeaders map[string]string, accessToken string, meta *fileMetadata, fullDstFileName string, fileSize int64, maxConcurrency int64, partSize int64) error { // Calculate number of parts based on desired part size numParts := (fileSize + partSize - 1) / partSize // For streaming, use batched approach to avoid buffering all parts in memory if isFileGetStream(ctx) { return util.downloadInPartsForStream(ctx, downloadURL, gcsHeaders, accessToken, meta, fileSize, numParts, maxConcurrency, partSize) } return util.downloadInPartsForFile(ctx, downloadURL, gcsHeaders, accessToken, meta, fullDstFileName, fileSize, numParts, maxConcurrency, partSize) } // downloadInPartsForStream downloads file in batches, streaming parts sequentially func (util *snowflakeGcsClient) downloadInPartsForStream( ctx context.Context, downloadURL *url.URL, gcsHeaders map[string]string, accessToken string, meta *fileMetadata, fileSize, numParts, maxConcurrency, partSize int64) error { // Create a single HTTP client for all downloads to reuse connections client, err := newGcsClient(util.cfg, util.telemetry) if err != nil { return err } // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient } // The first part's index for each batch var nextPartIndex int64 = 0 for nextPartIndex < numParts { // Calculate this batch size batchSize := maxConcurrency if nextPartIndex+batchSize > numParts { batchSize = numParts - nextPartIndex } // Download this batch jobs := make(chan downloadJob, batchSize) results := make(chan downloadPartStream, batchSize) // Start workers for this batch for i := int64(0); i < batchSize; i++ { go func() { for job := range jobs { stream, err := util.downloadRangeStream(ctx, downloadURL, gcsHeaders, accessToken, meta, client, job.start, job.end) results <- downloadPartStream{stream: stream, index: job.index, err: err} } }() } // Send jobs for this batch for i := int64(0); i < batchSize; i++ { partIndex := nextPartIndex + i start := partIndex * partSize end := start + partSize - 1 if end >= fileSize { end = fileSize - 1 } jobs <- downloadJob{index: i, start: start, end: end} } close(jobs) // Signal no more jobs // Collect results for this batch batchResults := make([]downloadPartStream, batchSize) for i := int64(0); i < batchSize; i++ { result := <-results if result.err != nil { // Close any successful streams before returning error for j := int64(0); j < i; j++ { if batchResults[j].stream != nil { if closeErr := batchResults[j].stream.Close(); closeErr != nil { logger.Warnf("Failed to close stream: %v", closeErr) } } } return result.err } batchResults[result.index] = result } // Stream parts sequentially in order, closing streams as we go for i := int64(0); i < batchSize; i++ { part := batchResults[i] if part.stream != nil { // Stream directly from HTTP response to destination stream _, err := io.Copy(meta.dstStream, part.stream) // Close the stream immediately after copying if closeErr := part.stream.Close(); closeErr != nil { logger.Warnf("Failed to close stream: %v", closeErr) } if err != nil { // Close remaining streams before returning error for j := i + 1; j < batchSize; j++ { if batchResults[j].stream != nil { if closeErr := batchResults[j].stream.Close(); closeErr != nil { logger.Warnf("Failed to close stream: %v", closeErr) } } } return err } } } nextPartIndex += batchSize } return nil } // downloadInPartsForFile downloads all parts and writes to file func (util *snowflakeGcsClient) downloadInPartsForFile( ctx context.Context, downloadURL *url.URL, gcsHeaders map[string]string, accessToken string, meta *fileMetadata, fullDstFileName string, fileSize, numParts, maxConcurrency, partSize int64) error { // Create a single HTTP client for all downloads to reuse connections client, err := newGcsClient(util.cfg, util.telemetry) if err != nil { return err } // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient } // Start all workers and download all parts jobs := make(chan downloadJob, numParts) results := make(chan downloadPart, numParts) // Start worker pool with maxConcurrency workers for range maxConcurrency { go func() { for job := range jobs { data, err := util.downloadRangeBytes(ctx, downloadURL, gcsHeaders, accessToken, meta, client, job.start, job.end) results <- downloadPart{data: data, index: job.index, err: err} } }() } // Send all jobs to workers for i := range numParts { start := i * partSize end := start + partSize - 1 if end >= fileSize { end = fileSize - 1 } jobs <- downloadJob{index: i, start: start, end: end} } close(jobs) // Signal no more jobs // Collect results and store in order parts := make([][]byte, numParts) for range numParts { result := <-results if result.err != nil { return result.err } parts[result.index] = result.data } f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode) if err != nil { return err } defer func() { if err := f.Close(); err != nil { logger.Warnf("Failed to close file: %v", err) } }() for _, part := range parts { if _, err := f.Write(part); err != nil { return err } } fi, err := os.Stat(fullDstFileName) if err != nil { return err } meta.srcFileSize = fi.Size() return nil } // downloadRangeStream downloads a specific byte range and returns the response stream func (util *snowflakeGcsClient) downloadRangeStream( ctx context.Context, downloadURL *url.URL, gcsHeaders map[string]string, accessToken string, meta *fileMetadata, client gcsAPI, start, end int64) (io.ReadCloser, error) { resp, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "GET", downloadURL.String(), nil) if err != nil { return nil, err } // Add range header for partial content req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) for k, v := range gcsHeaders { req.Header.Add(k, v) } return client.Do(req) }) if err != nil { return nil, err } // Accept both 200 (full content) and 206 (partial content) status codes if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { _ = resp.Body.Close() return nil, util.handleHTTPError(resp, meta, accessToken) } // Return the response body stream directly - caller is responsible for closing return resp.Body, nil } // downloadRangeBytes downloads a specific byte range and returns the bytes func (util *snowflakeGcsClient) downloadRangeBytes( ctx context.Context, downloadURL *url.URL, gcsHeaders map[string]string, accessToken string, meta *fileMetadata, client gcsAPI, start, end int64) ([]byte, error) { stream, err := util.downloadRangeStream(ctx, downloadURL, gcsHeaders, accessToken, meta, client, start, end) if err != nil { return nil, err } defer func() { if err := stream.Close(); err != nil { logger.Warnf("Failed to close stream: %v", err) } }() // Download the data into memory data, err := io.ReadAll(stream) if err != nil { return nil, err } return data, nil } // downloadFileSinglePart downloads a file using a single request (original implementation) func (util *snowflakeGcsClient) downloadFileSinglePart( ctx context.Context, downloadURL *url.URL, gcsHeaders map[string]string, accessToken string, meta *fileMetadata, fullDstFileName string) error { resp, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "GET", downloadURL.String(), nil) if err != nil { return nil, err } for k, v := range gcsHeaders { req.Header.Add(k, v) } client, err := newGcsClient(util.cfg, util.telemetry) if err != nil { return nil, err } // for testing only if meta.mockGcsClient != nil { client = meta.mockGcsClient } return client.Do(req) }) if err != nil { return err } defer func() { if resp.Body != nil { if err := resp.Body.Close(); err != nil { logger.Warnf("Failed to close response body: %v", err) } } }() if resp.StatusCode != http.StatusOK { return util.handleHTTPError(resp, meta, accessToken) } if isFileGetStream(ctx) { if _, err := io.Copy(meta.dstStream, resp.Body); err != nil { return err } } else { f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode) if err != nil { return err } defer func() { if err = f.Close(); err != nil { logger.Warnf("Failed to close the file: %v", err) } }() if _, err = io.Copy(f, resp.Body); err != nil { return err } fi, err := os.Stat(fullDstFileName) if err != nil { return err } meta.srcFileSize = fi.Size() } return nil } // handleHTTPError handles HTTP error responses consistently func (util *snowflakeGcsClient) handleHTTPError(resp *http.Response, meta *fileMetadata, accessToken string) error { if resp.StatusCode == 403 || resp.StatusCode == 408 || resp.StatusCode == 429 || resp.StatusCode == 500 || resp.StatusCode == 503 { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = needRetry } else if resp.StatusCode == 404 { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = notFoundFile } else if accessToken == "" && resp.StatusCode == 400 && meta.lastError == nil { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewPresignedURL } else if accessToken != "" && util.isTokenExpired(resp) { meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewToken } else { meta.lastError = fmt.Errorf("%v", resp.Status) } return meta.lastError } func (util *snowflakeGcsClient) extractBucketNameAndPath(location string) *gcsLocation { containerName := location var path string if strings.Contains(location, "/") { containerName = location[:strings.Index(location, "/")] path = location[strings.Index(location, "/")+1:] if path != "" && !strings.HasSuffix(path, "/") { path += "/" } } return &gcsLocation{containerName, path} } func (util *snowflakeGcsClient) generateFileURL(stageInfo *execResponseStageInfo, filename string) (result *url.URL, err error) { gcsLoc := util.extractBucketNameAndPath(stageInfo.Location) fullFilePath := gcsLoc.path + filename endPoint := "https://storage.googleapis.com" // TODO: SNOW-1789759 hardcoded region will be replaced in the future isRegionalURLEnabled := (strings.ToLower(stageInfo.Region) == gcsRegionMeCentral2) || stageInfo.UseRegionalURL if stageInfo.EndPoint != "" { endPoint = fmt.Sprintf("https://%s", stageInfo.EndPoint) } else if stageInfo.UseVirtualURL { endPoint = fmt.Sprintf("https://%s.storage.googleapis.com", gcsLoc.bucketName) } else if stageInfo.Region != "" && isRegionalURLEnabled { endPoint = fmt.Sprintf("https://storage.%s.rep.googleapis.com", strings.ToLower(stageInfo.Region)) } if stageInfo.UseVirtualURL { result, err = url.Parse(endPoint + "/" + url.PathEscape(fullFilePath)) } else { result, err = url.Parse(endPoint + "/" + gcsLoc.bucketName + "/" + url.PathEscape(fullFilePath)) } logger.Debugf("generated file URL from location=%v, path=%v, fileName=%v, endpoint=%v, useVirtualUrl=%v, result=%v, err=%v", stageInfo.Location, gcsLoc.path, filename, stageInfo.EndPoint, stageInfo.UseVirtualURL, cmp.Or(result, &url.URL{}).String(), err) return result, err } func (util *snowflakeGcsClient) isTokenExpired(resp *http.Response) bool { return resp.StatusCode == 401 } func newGcsClient(cfg *Config, telemetry *snowflakeTelemetry) (gcsAPI, error) { transport, err := newTransportFactory(cfg, telemetry).createTransport(transportConfigFor(transportTypeCloudProvider)) if err != nil { return nil, err } return &http.Client{ Transport: transport, }, nil } ================================================ FILE: gcs_storage_client_test.go ================================================ package gosnowflake import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "path" "strings" "testing" ) type tcFileURL struct { location string fname string bucket string filepath string } func TestExtractBucketAndPath(t *testing.T) { gcsUtil := new(snowflakeGcsClient) testcases := []tcBucketPath{ {"sfc-eng-regression/test_sub_dir/", "sfc-eng-regression", "test_sub_dir/"}, {"sfc-eng-regression/dir/test_stg/test_sub_dir/", "sfc-eng-regression", "dir/test_stg/test_sub_dir/"}, {"sfc-eng-regression/", "sfc-eng-regression", ""}, {"sfc-eng-regression//", "sfc-eng-regression", "/"}, {"sfc-eng-regression///", "sfc-eng-regression", "//"}, } for _, test := range testcases { t.Run(test.in, func(t *testing.T) { gcsLoc := gcsUtil.extractBucketNameAndPath(test.in) if gcsLoc.bucketName != test.bucket { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.bucket, gcsLoc.bucketName) } if gcsLoc.path != test.path { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.path, gcsLoc.path) } }) } } func TestIsTokenExpiredWith401(t *testing.T) { gcsUtil := new(snowflakeGcsClient) dd := &execResponseData{} execResp := &execResponse{ Data: *dd, Message: "token expired", Code: "401", Success: true, } ba, err := json.Marshal(execResp) if err != nil { panic(err) } resp := &http.Response{StatusCode: http.StatusUnauthorized, Body: &fakeResponseBody{body: ba}} if !gcsUtil.isTokenExpired(resp) { t.Fatalf("expected true for token expired") } } func TestIsTokenExpiredWith404(t *testing.T) { gcsUtil := new(snowflakeGcsClient) dd := &execResponseData{} execResp := &execResponse{ Data: *dd, Message: "file not found", Code: "404", Success: true, } ba, err := json.Marshal(execResp) if err != nil { panic(err) } resp := &http.Response{StatusCode: http.StatusNotFound, Body: &fakeResponseBody{body: ba}} if gcsUtil.isTokenExpired(resp) { t.Fatalf("should be false") } resp = &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}} if gcsUtil.isTokenExpired(resp) { t.Fatalf("should be false") } resp = &http.Response{ StatusCode: http.StatusUnauthorized, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}} if !gcsUtil.isTokenExpired(resp) { t.Fatalf("should be true") } } func TestGenerateFileURL(t *testing.T) { gcsUtil := new(snowflakeGcsClient) testcases := []tcFileURL{ {"sfc-eng-regression/test_sub_dir/", "file1", "sfc-eng-regression", "test_sub_dir/file1"}, {"sfc-eng-regression/dir/test_stg/test_sub_dir/", "file2", "sfc-eng-regression", "dir/test_stg/test_sub_dir/file2"}, {"sfc-eng-regression/dir/test_stg/test sub dir/", "file2", "sfc-eng-regression", "dir/test_stg/test sub dir/file2"}, {"sfc-eng-regression/", "file3", "sfc-eng-regression", "file3"}, {"sfc-eng-regression//", "file4", "sfc-eng-regression", "/file4"}, {"sfc-eng-regression///", "file5", "sfc-eng-regression", "//file5"}, } for _, test := range testcases { t.Run(test.location, func(t *testing.T) { stageInfo := &execResponseStageInfo{} stageInfo.Location = test.location gcsURL, err := gcsUtil.generateFileURL(stageInfo, test.fname) assertNilF(t, err, "error should be nil") expectedURL, err := url.Parse("https://storage.googleapis.com/" + test.bucket + "/" + url.PathEscape(test.filepath)) assertNilF(t, err, "error should be nil") assertEqualE(t, gcsURL.String(), expectedURL.String(), "failed. expected: %v but got: %v", expectedURL.String(), gcsURL.String()) }) } for _, test := range testcases { t.Run(test.location, func(t *testing.T) { stageInfo := &execResponseStageInfo{} stageInfo.Location = test.location gcsURL, err := gcsUtil.generateFileURL(stageInfo, test.fname) assertNilF(t, err, "error should be nil") expectedURL, err := url.Parse("https://storage.googleapis.com/" + test.bucket + "/" + url.PathEscape(test.filepath)) assertNilF(t, err, "error should be nil") assertEqualE(t, gcsURL.String(), expectedURL.String(), "failed. expected: %v but got: %v", expectedURL.String(), gcsURL.String()) }) } for _, test := range testcases { t.Run(test.location, func(t *testing.T) { stageInfo := &execResponseStageInfo{} stageInfo.Location = test.location stageInfo.UseVirtualURL = true gcsURL, err := gcsUtil.generateFileURL(stageInfo, test.fname) assertNilF(t, err, "error should be nil") expectedURL, err := url.Parse("https://sfc-eng-regression.storage.googleapis.com/" + url.PathEscape(test.filepath)) assertNilF(t, err, "error should be nil") assertEqualE(t, gcsURL.String(), expectedURL.String(), "failed. expected: %v but got: %v", expectedURL.String(), gcsURL.String()) }) } } type clientMock struct { DoFunc func(req *http.Request) (*http.Response, error) } func (c *clientMock) Do(req *http.Request) (*http.Response, error) { return c.DoFunc(req) } func TestUploadFileWithGcsUploadFailedError(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, parallel: initialParallel, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), overwrite: true, dstCompressionType: compressionTypes["GZIP"], options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return nil, errors.New("unexpected error uploading file") }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Fatal("should have failed") } } func TestUploadFileWithGcsUploadFailedWithRetry(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } encMat := snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, parallel: initialParallel, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), overwrite: true, dstCompressionType: compressionTypes["GZIP"], encryptionMaterial: &encMat, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "403 Forbidden", StatusCode: 403, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Error("should have raised an error") } if uploadMeta.resStatus != needRetry { t.Fatalf("expected %v result status, got: %v", needRetry, uploadMeta.resStatus) } } func TestUploadFileWithGcsUploadFailedWithTokenExpired(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, parallel: initialParallel, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "401 Unauthorized", StatusCode: 401, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != renewToken { t.Fatalf("expected %v result status, got: %v", renewToken, uploadMeta.resStatus) } } func TestDownloadOneFileFromGcsFailed(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", } dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return nil, errors.New("unexpected error downloading file") }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } } func TestDownloadOneFileFromGcsFailedWithRetry(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", } dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "403 Forbidden", StatusCode: 403, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } if downloadMeta.resStatus != needRetry { t.Fatalf("expected %v result status, got: %v", needRetry, downloadMeta.resStatus) } } func TestDownloadOneFileFromGcsFailedWithTokenExpired(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "401 Unauthorized", StatusCode: 401, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } if downloadMeta.resStatus != renewToken { t.Fatalf("expected %v result status, got: %v", renewToken, downloadMeta.resStatus) } } func TestDownloadOneFileFromGcsFailedWithFileNotFound(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "404 Not Found", StatusCode: 404, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } if downloadMeta.resStatus != notFoundFile { t.Fatalf("expected %v result status, got: %v", notFoundFile, downloadMeta.resStatus) } } func TestGetHeaderTokenExpiredError(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: info.Creds.GcsAccessToken, stageInfo: &info, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "401 Unauthorized", StatusCode: 401, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != renewToken { t.Fatalf("expected %v result status, got: %v", renewToken, meta.resStatus) } } func TestGetHeaderFileNotFound(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: info.Creds.GcsAccessToken, stageInfo: &info, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "404 Not Found", StatusCode: 404, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != notFoundFile { t.Fatalf("expected %v result status, got: %v", notFoundFile, meta.resStatus) } } func TestGetHeaderPresignedUrlReturns404(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } presignedURL, err := url.Parse("https://google-cloud.test.com") if err != nil { t.Error(err) } meta := fileMetadata{ client: info.Creds.GcsAccessToken, stageInfo: &info, presignedURL: presignedURL, } header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt") if header != nil { t.Fatalf("expected null header, got: %v", header) } if err != nil { t.Error(err) } if meta.resStatus != notFoundFile { t.Fatalf("expected %v result status, got: %v", notFoundFile, meta.resStatus) } } func TestGetHeaderReturnsError(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: info.Creds.GcsAccessToken, stageInfo: &info, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return nil, errors.New("unexpected exception getting file header") }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } } func TestGetHeaderBadRequest(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: info.Creds.GcsAccessToken, stageInfo: &info, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "400 Bad Request", StatusCode: 400, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != errStatus { t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus) } } func TestGetHeaderRetryableError(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: info.Creds.GcsAccessToken, stageInfo: &info, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "403 Forbidden", StatusCode: 403, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != needRetry { t.Fatalf("expected %v result status, got: %v", needRetry, meta.resStatus) } } func TestUploadStreamFailed(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } initialParallel := int64(100) src := []byte{65, 66, 67} gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, parallel: initialParallel, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcStream: bytes.NewBuffer(src), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return nil, errors.New("unexpected error uploading file") }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcStream = uploadMeta.srcStream err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Fatal("should have failed") } } func TestUploadFileWithBadRequest(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, parallel: initialParallel, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), overwrite: true, lastError: nil, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: 400, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != renewPresignedURL { t.Fatalf("expected %v result status, got: %v", renewPresignedURL, uploadMeta.resStatus) } } func TestGetFileHeaderEncryptionData(t *testing.T) { mockEncDataResp := "{\"EncryptionMode\":\"FullBlob\",\"WrappedContentKey\": {\"KeyId\":\"symmKey1\",\"EncryptedKey\":\"testencryptedkey12345678910==\",\"Algorithm\":\"AES_CBC_256\"},\"EncryptionAgent\": {\"Protocol\":\"1.0\",\"EncryptionAlgorithm\":\"AES_CBC_256\"},\"ContentEncryptionIV\":\"testIVkey12345678910==\",\"KeyWrappingMetadata\":{\"EncryptionLibrary\":\"Java 5.3.0\"}}" mockMatDesc := "{\"queryid\":\"01abc874-0406-1bf0-0000-53b10668e056\",\"smkid\":\"92019681909886\",\"key\":\"128\"}" info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: info.Creds.GcsAccessToken, stageInfo: &info, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "200 OK", StatusCode: 200, Header: http.Header{ "X-Goog-Meta-Encryptiondata": []string{mockEncDataResp}, "Content-Length": []string{"4256"}, "X-Goog-Meta-Sfc-Digest": []string{"123456789abcdef"}, "X-Goog-Meta-Matdesc": []string{mockMatDesc}, }, }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt") if err != nil { t.Fatal(err) } expectedFileHeader := &fileHeader{ digest: "123456789abcdef", contentLength: 4256, encryptionMetadata: &encryptMetadata{ key: "testencryptedkey12345678910==", iv: "testIVkey12345678910==", matdesc: mockMatDesc, }, } if header.contentLength != expectedFileHeader.contentLength || header.digest != expectedFileHeader.digest || header.encryptionMetadata.iv != expectedFileHeader.encryptionMetadata.iv || header.encryptionMetadata.key != expectedFileHeader.encryptionMetadata.key || header.encryptionMetadata.matdesc != expectedFileHeader.encryptionMetadata.matdesc { t.Fatalf("unexpected file header. expected: %v, got: %v", expectedFileHeader, header) } } func TestGetFileHeaderEncryptionDataInterfaceConversionError(t *testing.T) { mockEncDataResp := "{\"EncryptionMode\":\"FullBlob\",\"WrappedContentKey\": {\"KeyId\":\"symmKey1\",\"EncryptedKey\":\"testencryptedkey12345678910==\",\"Algorithm\":\"AES_CBC_256\"},\"EncryptionAgent\": {\"Protocol\":\"1.0\",\"EncryptionAlgorithm\":\"AES_CBC_256\"},\"ContentEncryptionIV\":\"testIVkey12345678910==\",\"KeyWrappingMetadata\":{\"EncryptionLibrary\":\"Java 5.3.0\"}}" mockMatDesc := "{\"queryid\":\"01abc874-0406-1bf0-0000-53b10668e056\",\"smkid\":\"92019681909886\",\"key\":\"128\"}" info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: 1, stageInfo: &info, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "200 OK", StatusCode: 200, Header: http.Header{ "X-Goog-Meta-Encryptiondata": []string{mockEncDataResp}, "Content-Length": []string{"4256"}, "X-Goog-Meta-Sfc-Digest": []string{"123456789abcdef"}, "X-Goog-Meta-Matdesc": []string{mockMatDesc}, }, }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } _, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt") if err == nil { t.Error("should have raised an error") } } func TestUploadFileToGcsNoStatus(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", LocationType: "GCS", } encMat := snowflakeFileEncryption{ QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", QueryID: "01abc874-0406-1bf0-0000-53b10668e056", SMKID: 92019681909886, } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, parallel: initialParallel, client: gcsCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), overwrite: true, dstCompressionType: compressionTypes["GZIP"], encryptionMaterial: &encMat, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "401 Unauthorized", StatusCode: 401, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Error("should have raised an error") } } func TestDownloadFileFromGcsError(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", } dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "403 Unauthorized", StatusCode: 401, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } } func TestDownloadFileWithBadRequest(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", } dir, err := os.Getwd() if err != nil { t.Error(err) } gcsCli, err := new(snowflakeGcsClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "GCS", noSleepingTime: true, client: gcsCli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockGcsClient: &clientMock{ DoFunc: func(req *http.Request) (*http.Response, error) { return &http.Response{ Status: "400 Bad Request", StatusCode: 400, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), }, nil }, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } if downloadMeta.resStatus != renewPresignedURL { t.Fatalf("expected %v result status, got: %v", renewPresignedURL, downloadMeta.resStatus) } } func Test_snowflakeGcsClient_uploadFile(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: 1, stageInfo: &info, } err := new(snowflakeGcsClient).uploadFile(context.Background(), "somedata", &meta, 1, 1) if err == nil { t.Error("should have raised an error") } } func Test_snowflakeGcsClient_nativeDownloadFile(t *testing.T) { info := execResponseStageInfo{ Location: "gcs/teststage/users/34/", LocationType: "GCS", Creds: execResponseCredentials{ GcsAccessToken: "test-token-124456577", }, } meta := fileMetadata{ client: 1, stageInfo: &info, } err := new(snowflakeGcsClient).nativeDownloadFile(context.Background(), &meta, "dummy data", 1, multiPartThreshold) if err == nil { t.Error("should have raised an error") } } func TestGetGcsCustomEndpoint(t *testing.T) { testcases := []struct { desc string in execResponseStageInfo expectedFileURL string }{ { desc: "when the endPoint is not specified and UseRegionalURL is false", in: execResponseStageInfo{ UseRegionalURL: false, Location: "my-travel-maps/mock_directory/mock_path/", EndPoint: "", Region: "WEST-1", UseVirtualURL: false, }, expectedFileURL: "https://storage.googleapis.com/my-travel-maps", }, { desc: "when the useRegionalURL is only enabled", in: execResponseStageInfo{ UseRegionalURL: true, Location: "my-travel-maps/mock_directory/mock_path/", EndPoint: "", Region: "mockLocation", UseVirtualURL: false, }, expectedFileURL: "https://storage.mocklocation.rep.googleapis.com/my-travel-maps", }, { desc: "when the region is me-central2", in: execResponseStageInfo{ UseRegionalURL: false, Location: "my-travel-maps/mock_directory/mock_path/", EndPoint: "", Region: "me-central2", UseVirtualURL: false, }, expectedFileURL: "https://storage.me-central2.rep.googleapis.com/my-travel-maps", }, { desc: "when the region is me-central2 (mixed case)", in: execResponseStageInfo{ UseRegionalURL: false, Location: "my-travel-maps/mock_directory/mock_path/", EndPoint: "", Region: "ME-cEntRal2", UseVirtualURL: false, }, expectedFileURL: "https://storage.me-central2.rep.googleapis.com/my-travel-maps", }, { desc: "when the region is me-central2 (uppercase)", in: execResponseStageInfo{ UseRegionalURL: false, Location: "my-travel-maps/mock_directory/mock_path/", EndPoint: "", Region: "ME-CENTRAL2", UseVirtualURL: false, }, expectedFileURL: "https://storage.me-central2.rep.googleapis.com/my-travel-maps", }, { desc: "when the endPoint is specified", in: execResponseStageInfo{ UseRegionalURL: false, Location: "my-travel-maps/mock_directory/mock_path/", EndPoint: "storage.specialEndPoint.rep.googleapis.com", Region: "ME-cEntRal1", UseVirtualURL: false, }, expectedFileURL: "https://storage.specialEndPoint.rep.googleapis.com/my-travel-maps", }, { desc: "when both the endPoint and the useRegionalUrl are specified", in: execResponseStageInfo{ UseRegionalURL: true, Location: "my-travel-maps/mock_directory/mock_path/", EndPoint: "storage.specialEndPoint.rep.googleapis.com", Region: "ME-cEntRal1", UseVirtualURL: false, }, expectedFileURL: "https://storage.specialEndPoint.rep.googleapis.com/my-travel-maps", }, { desc: "when both the endPoint is specified and the region is me-central2", in: execResponseStageInfo{ UseRegionalURL: true, Location: "my-travel-maps/mock_directory/mock_path/", EndPoint: "storage.specialEndPoint.rep.googleapis.com", Region: "ME-CENTRAL2", UseVirtualURL: false, }, expectedFileURL: "https://storage.specialEndPoint.rep.googleapis.com/my-travel-maps", }, { desc: "when only the useVirtualUrl is enabled", in: execResponseStageInfo{ Location: "my-travel-maps/mock_directory/mock_path/", UseRegionalURL: false, EndPoint: "", Region: "WEST-1", UseVirtualURL: true, }, expectedFileURL: "https://my-travel-maps.storage.googleapis.com", }, { desc: "when both the useRegionalURL and useVirtualUrl are enabled", in: execResponseStageInfo{ Location: "my-travel-maps/mock_directory/mock_path/", UseRegionalURL: true, EndPoint: "", Region: "ME-CENTRAL2", UseVirtualURL: true, }, expectedFileURL: "https://my-travel-maps.storage.googleapis.com", }, { desc: "when all the options are enabled", in: execResponseStageInfo{ Location: "my-travel-maps/mock_directory/mock_path/", UseRegionalURL: true, EndPoint: "storage.specialEndPoint.rep.googleapis.com", Region: "ME-CENTRAL2", UseVirtualURL: true, }, expectedFileURL: "https://storage.specialEndPoint.rep.googleapis.com", }, } for _, test := range testcases { t.Run(test.desc, func(t *testing.T) { gcs := new(snowflakeGcsClient) fileURL, err := gcs.generateFileURL(&test.in, "mock_file") assertNilF(t, err, "Should not fail") expectedURL, err := url.Parse(test.expectedFileURL + "/" + url.QueryEscape("mock_directory/mock_path/mock_file")) assertNilF(t, err, "Should not fail") assertEqualF(t, fileURL.String(), expectedURL.String(), "failed. in: %v, expected: %v, got: %v", fmt.Sprintf("%v", test.in), expectedURL.String(), fileURL.String()) }) } } ================================================ FILE: go.mod ================================================ module github.com/snowflakedb/gosnowflake/v2 go 1.24.0 require ( github.com/99designs/keyring v1.2.2 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 github.com/BurntSushi/toml v1.4.0 github.com/apache/arrow-go/v18 v18.4.0 github.com/aws/aws-sdk-go-v2 v1.38.1 github.com/aws/aws-sdk-go-v2/config v1.27.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.16.15 github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1 github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 github.com/aws/smithy-go v1.22.5 github.com/gabriel-vasile/mimetype v1.4.7 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 go.opentelemetry.io/otel v1.40.0 go.opentelemetry.io/otel/sdk v1.40.0 golang.org/x/crypto v0.46.0 golang.org/x/net v0.48.0 golang.org/x/oauth2 v0.34.0 golang.org/x/sys v0.40.0 ) require ( github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apache/thrift v0.22.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/danieljoos/wincred v1.2.2 // indirect github.com/dvsekhvalnov/jose2go v1.7.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect github.com/golang/snappy v1.0.0 // indirect github.com/google/flatbuffers v25.2.10+incompatible // indirect github.com/google/uuid v1.6.0 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.11 // indirect github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect github.com/mtibben/percent v0.2.1 // indirect github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/metric v1.40.0 // indirect go.opentelemetry.io/otel/trace v1.40.0 // indirect golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect golang.org/x/mod v0.30.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 // indirect golang.org/x/term v0.38.0 // indirect golang.org/x/text v0.32.0 // indirect golang.org/x/tools v0.39.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/grpc v1.79.3 // indirect google.golang.org/protobuf v1.36.10 // indirect ) ================================================ FILE: go.sum ================================================ github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0= github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTBqhFkHUrPk= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 h1:rTnT/Jrcm+figWlYz4Ixzt0SJVR2cMC8lvZcimipiEY= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0 h1:QkAcEIAKbNL4KoFr4SathZPhDhF4mVwpBMFlYjyAqy8= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0/go.mod h1:bhXu1AjYL+wutSL/kpSq6s7733q2Rb0yuot9Zgfqa/0= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 h1:u/LLAOFgsMv7HmNL4Qufg58y+qElGOt5qv0z1mURkRY= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0/go.mod h1:2e8rMJtl2+2j+HXbTBwnyGpm5Nou7KhvSfxOq8JpTag= github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE= github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1/go.mod h1:Vt9sXTKwMyGcOxSmLDMnGPgqsUg7m8pe215qMLrDXw4= github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apache/arrow-go/v18 v18.4.0 h1:/RvkGqH517iY8bZKc4FD5/kkdwXJGjxf28JIXbJ/oB0= github.com/apache/arrow-go/v18 v18.4.0/go.mod h1:Aawvwhj8x2jURIzD9Moy72cF0FyJXOpkYpdmGRHcw14= github.com/apache/thrift v0.22.0 h1:r7mTJdj51TMDe6RtcmNdQxgn9XcyfGDOzegMDRg47uc= github.com/apache/thrift v0.22.0/go.mod h1:1e7J/O1Ae6ZQMTYdy9xa3w9k+XHWPfRvdPyJeynQ+/g= github.com/aws/aws-sdk-go-v2 v1.38.1 h1:j7sc33amE74Rz0M/PoCpsZQ6OunLqys/m5antM0J+Z8= github.com/aws/aws-sdk-go-v2 v1.38.1/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA= github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE= github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.16.15 h1:7Zwtt/lP3KNRkeZre7soMELMGNoBrutx8nobg1jKWmo= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.16.15/go.mod h1:436h2adoHb57yd+8W+gYPrrA9U/R/SuAuOO42Ushzhw= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5 h1:81KE7vaZzrl7yHBYHVEzYB8sypz11NMOZ40YlWvPxsU= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5/go.mod h1:LIt2rg7Mcgn09Ygbdh/RdIm0rQ+3BNkbP1gyVMFtRK0= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7 h1:ZMeFZ5yk+Ek+jNr1+uwCd2tG89t6oTS5yVWpa6yy2es= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7/go.mod h1:mxV05U+4JiHqIpGqqYXOHLPKUC6bDXC44bsUhNjOEwY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5 h1:f9RyWNtS8oH7cZlbn+/JNPpjUk5+5fLd5lM9M0i49Ys= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5/go.mod h1:h5CoMZV2VF297/VLhRhO1WF+XYWOzXo+4HsObA4HjBQ= github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1 h1:6cnno47Me9bRykw9AEv9zkXE+5or7jz8TsskTTccbgc= github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1/go.mod h1:qmdkIIAC+GCLASF7R2whgNrJADz0QZPX+Seiw/i4S3o= github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU= github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/danieljoos/wincred v1.2.2 h1:774zMFJrqaeYCK2W57BgAem/MLi6mtSE47MB6BOJ0i0= github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7nd/ogr0Uh8= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dnaeon/go-vcr v1.1.0 h1:ReYa/UBrRyQdant9B4fNHGoCNKw6qh6P0fsdGmZpR7c= github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= github.com/dvsekhvalnov/jose2go v1.7.0 h1:bnQc8+GMnidJZA8zc6lLEAb4xNrIqHwO+9TzqvtQZPo= github.com/dvsekhvalnov/jose2go v1.7.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/gabriel-vasile/mimetype v1.4.7 h1:SKFKl7kD0RiPdbht0s7hFtjl489WcQ1VyPW8ZzUMYCA= github.com/gabriel-vasile/mimetype v1.4.7/go.mod h1:GDlAgAyIRT27BhFl53XNAFtfjzOkLaF35JdEG0P7LtU= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0= github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs= github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g= go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g= go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc= go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8= go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE= go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw= go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg= go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw= go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 h1:E2/AqCUMZGgd73TQkxUMcMla25GB9i/5HOdLr+uH7Vo= golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: gosnowflake.mak ================================================ ## Setup SHELL := /bin/bash SRC = $(shell find . -type f -name '*.go' -not -path "./vendor/*") setup: @which golint &> /dev/null || go install golang.org/x/lint/golint@latest @which make2help &> /dev/null || go install github.com/Songmu/make2help/cmd/make2help@latest ## Install dependencies deps: setup go mod tidy ## Show help help: @make2help $(MAKEFILE_LIST) # Format source codes (internally used) cfmt: setup @gofmt -l -w $(SRC) # Lint (internally used) clint: deps @echo "Running go vet and lint" @for pkg in $$(go list ./... | grep -v /vendor/); do \ echo "Verifying $$pkg"; \ go vet $$pkg; \ golint -set_exit_status $$pkg || exit $$?; \ done # Install (internally used) cinstall: @export GOBIN=$$GOPATH/bin; \ go install -tags=sfdebug $(CMD_TARGET).go # Run (internally used) crun: install $(CMD_TARGET) .PHONY: setup help cfmt clint cinstall crun ================================================ FILE: heartbeat.go ================================================ package gosnowflake import ( "context" "encoding/json" "fmt" "io" "net/http" "net/url" "time" ) const ( minHeartBeatInterval = 900 * time.Second maxHeartBeatInterval = 3600 * time.Second defaultHeartBeatInterval = 3600 * time.Second ) func newDefaultHeartBeat(restful *snowflakeRestful) *heartbeat { return newHeartBeat(restful, defaultHeartBeatInterval) } func newHeartBeat(restful *snowflakeRestful, heartbeatInterval time.Duration) *heartbeat { logger.Debugf("Using heartbeat with custom interval: %v", heartbeatInterval) if heartbeatInterval < minHeartBeatInterval { logger.Warnf("Heartbeat interval %v is less than minimum %v, using minimum", heartbeatInterval, minHeartBeatInterval) heartbeatInterval = minHeartBeatInterval } else if heartbeatInterval > maxHeartBeatInterval { logger.Warnf("Heartbeat interval %v is greater than maximum %v, using maximum", heartbeatInterval, maxHeartBeatInterval) heartbeatInterval = maxHeartBeatInterval } return &heartbeat{ restful: restful, heartbeatInterval: heartbeatInterval, } } type heartbeat struct { restful *snowflakeRestful shutdownChan chan bool heartbeatInterval time.Duration } func (hc *heartbeat) run() { _, _, sessionID := safeGetTokens(hc.restful) ctx := context.WithValue(context.Background(), SFSessionIDKey, sessionID) hbTicker := time.NewTicker(hc.heartbeatInterval) defer hbTicker.Stop() for { select { case <-hbTicker.C: err := hc.heartbeatMain() if err != nil { logger.WithContext(ctx).Errorf("failed to heartbeat: %v", err) } case <-hc.shutdownChan: logger.WithContext(ctx).Info("stopping heartbeat") return } } } func (hc *heartbeat) start() { _, _, sessionID := safeGetTokens(hc.restful) ctx := context.WithValue(context.Background(), SFSessionIDKey, sessionID) hc.shutdownChan = make(chan bool) go hc.run() logger.WithContext(ctx).Info("heartbeat started") } func (hc *heartbeat) stop() { _, _, sessionID := safeGetTokens(hc.restful) ctx := context.WithValue(context.Background(), SFSessionIDKey, sessionID) hc.shutdownChan <- true close(hc.shutdownChan) logger.WithContext(ctx).Info("heartbeat stopped") } func (hc *heartbeat) heartbeatMain() error { params := &url.Values{} params.Set(requestIDKey, NewUUID().String()) params.Set(requestGUIDKey, NewUUID().String()) headers := getHeaders() token, _, sessionID := safeGetTokens(hc.restful) ctx := context.WithValue(context.Background(), SFSessionIDKey, sessionID) logger.WithContext(ctx).Info("Heartbeating!") headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) fullURL := hc.restful.getFullURL(heartBeatPath, params) timeout := hc.restful.RequestTimeout resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, defaultTimeProvider, nil) if err != nil { return err } defer func() { if err = resp.Body.Close(); err != nil { logger.WithContext(ctx).Warnf("failed to close response body for %v. err: %v", fullURL, err) } }() if resp.StatusCode == http.StatusOK { logger.WithContext(ctx).Debugf("heartbeatMain: resp: %v", resp) var respd execResponse err = json.NewDecoder(resp.Body).Decode(&respd) if err != nil { logger.WithContext(ctx).Errorf("failed to decode heartbeat response JSON. err: %v", err) return err } if respd.Code == sessionExpiredCode { logger.WithContext(ctx).Info("Snowflake returned 'session expired', trying to renew expired token.") err = hc.restful.renewExpiredSessionToken(context.Background(), timeout, token) if err != nil { return err } } return nil } b, err := io.ReadAll(resp.Body) if err != nil { logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) return err } logger.WithContext(ctx).Debugf("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) logger.WithContext(ctx).Debugf("Header: %v", resp.Header) return &SnowflakeError{ Number: ErrFailedToHeartbeat, SQLState: SQLStateConnectionFailure, Message: "Failed to heartbeat.", } } ================================================ FILE: heartbeat_test.go ================================================ package gosnowflake import ( "context" "testing" "time" ) func TestUnitPostHeartbeat(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { // send heartbeat call and renew expired session sr := &snowflakeRestful{ FuncPost: postTestRenew, FuncRenewSession: renewSessionTest, TokenAccessor: getSimpleTokenAccessor(), RequestTimeout: 0, } heartbeat := newDefaultHeartBeat(sr) err := heartbeat.heartbeatMain() assertNilF(t, err, "failed to heartbeat and renew session") heartbeat.restful.FuncPost = postTestError err = heartbeat.heartbeatMain() assertNotNilF(t, err, "should have failed to start heartbeat") assertEqualE(t, err.Error(), "failed to run post method") heartbeat.restful.FuncPost = postTestSuccessButInvalidJSON err = heartbeat.heartbeatMain() assertNotNilF(t, err, "should have failed to start heartbeat") assertHasPrefixE(t, err.Error(), "invalid character") heartbeat.restful.FuncPost = postTestAppForbiddenError err = heartbeat.heartbeatMain() assertNotNilF(t, err, "should have failed to start heartbeat") driverErr, ok := err.(*SnowflakeError) assertTrueF(t, ok, "connection should be snowflakeConn") assertEqualE(t, driverErr.Number, ErrFailedToHeartbeat) }) } func TestHeartbeatStartAndStop(t *testing.T) { customDsn := dsn + "&client_session_keep_alive=true" config, err := ParseDSN(customDsn) assertNilF(t, err, "failed to parse dsn") driver := SnowflakeDriver{} db, err := driver.OpenWithConfig(context.Background(), *config) assertNilF(t, err, "failed to open with config") conn, ok := db.(*snowflakeConn) assertTrueF(t, ok, "connection should be snowflakeConn") assertNotNilF(t, conn.rest, "heartbeat should not be nil") assertNotNilF(t, conn.rest.HeartBeat, "heartbeat should not be nil") err = db.Close() assertNilF(t, err, "should not cause error in Close") assertNilF(t, conn.rest.HeartBeat, "heartbeat should be nil") } func TestHeartbeatIntervalLowerThanMin(t *testing.T) { sr := &snowflakeRestful{ FuncPost: postTestRenew, FuncRenewSession: renewSessionTest, TokenAccessor: getSimpleTokenAccessor(), RequestTimeout: 0, } heartbeat := newHeartBeat(sr, minHeartBeatInterval-1*time.Second) assertEqualF(t, heartbeat.heartbeatInterval, minHeartBeatInterval, "heartbeat interval should be set to min") } func TestHeartbeatIntervalHigherThanMax(t *testing.T) { sr := &snowflakeRestful{ FuncPost: postTestRenew, FuncRenewSession: renewSessionTest, TokenAccessor: getSimpleTokenAccessor(), RequestTimeout: 0, } heartbeat := newHeartBeat(sr, maxHeartBeatInterval+1*time.Second) assertEqualF(t, heartbeat.heartbeatInterval, maxHeartBeatInterval, "heartbeat interval should be set to max") } ================================================ FILE: htap.go ================================================ package gosnowflake import ( "sort" "strconv" "sync" ) const ( queryContextCacheSizeParamName = "QUERY_CONTEXT_CACHE_SIZE" defaultQueryContextCacheSize = 5 ) type queryContext struct { Entries []queryContextEntry `json:"entries,omitempty"` } type queryContextEntry struct { ID int `json:"id"` Timestamp int64 `json:"timestamp"` Priority int `json:"priority"` Context string `json:"context,omitempty"` } type queryContextCache struct { mutex sync.Mutex entries []queryContextEntry } func (qcc *queryContextCache) add(sc *snowflakeConn, qces ...queryContextEntry) { qcc.mutex.Lock() defer qcc.mutex.Unlock() if len(qces) == 0 { qcc.prune(0) } else { for _, newQce := range qces { logger.Debugf("adding query context: %v", newQce) newQceProcessed := false for existingQceIdx, existingQce := range qcc.entries { if newQce.ID == existingQce.ID { newQceProcessed = true if newQce.Timestamp > existingQce.Timestamp { qcc.entries[existingQceIdx] = newQce } else if newQce.Timestamp == existingQce.Timestamp { if newQce.Priority != existingQce.Priority { qcc.entries[existingQceIdx] = newQce } } } } if !newQceProcessed { for existingQceIdx, existingQce := range qcc.entries { if newQce.Priority == existingQce.Priority { qcc.entries[existingQceIdx] = newQce newQceProcessed = true } } } if !newQceProcessed { qcc.entries = append(qcc.entries, newQce) } } sort.Slice(qcc.entries, func(idx1, idx2 int) bool { return qcc.entries[idx1].Priority < qcc.entries[idx2].Priority }) qcc.prune(qcc.getQueryContextCacheSize(sc)) } } func (qcc *queryContextCache) prune(size int) { if len(qcc.entries) > size { qcc.entries = qcc.entries[0:size] } } func (qcc *queryContextCache) getQueryContextCacheSize(sc *snowflakeConn) int { sizeStr, ok := sc.syncParams.get(queryContextCacheSizeParamName) if ok { size, err := strconv.Atoi(*sizeStr) if err != nil { logger.Warnf("cannot parse %v as int as query context cache size: %v", sizeStr, err) } else { return size } } return defaultQueryContextCacheSize } ================================================ FILE: htap_test.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "encoding/json" "fmt" "net/url" "reflect" "strconv" "strings" "testing" "time" ) func TestSortingByPriority(t *testing.T) { qcc := queryContextCache{} sc := htapTestSnowflakeConn() qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} qceC := queryContextEntry{ID: 14, Timestamp: 125, Priority: 6, Context: "c"} qceD := queryContextEntry{ID: 15, Timestamp: 126, Priority: 8, Context: "d"} t.Run("Add to empty cache", func(t *testing.T) { qcc.add(sc, qceA) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA}) { t.Fatalf("no entries added to cache. %v", qcc.entries) } }) t.Run("Add another entry with different id, timestamp and priority - greater priority", func(t *testing.T) { qcc.add(sc, qceB) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) t.Run("Add another entry with different id, timestamp and priority - lesser priority", func(t *testing.T) { qcc.add(sc, qceC) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceA, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) t.Run("Add another entry with different id, timestamp and priority - priority in the middle", func(t *testing.T) { qcc.add(sc, qceD) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceA, qceD, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) } func TestAddingQcesWithTheSameIdAndLaterTimestamp(t *testing.T) { qcc := queryContextCache{} sc := htapTestSnowflakeConn() qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} qceC := queryContextEntry{ID: 12, Timestamp: 125, Priority: 6, Context: "c"} qceD := queryContextEntry{ID: 12, Timestamp: 126, Priority: 6, Context: "d"} t.Run("Add to empty cache", func(t *testing.T) { qcc.add(sc, qceA) qcc.add(sc, qceB) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { t.Fatalf("no entries added to cache. %v", qcc.entries) } }) t.Run("Add another entry with different priority", func(t *testing.T) { qcc.add(sc, qceC) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) t.Run("Add another entry with same priority", func(t *testing.T) { qcc.add(sc, qceD) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceD, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) } func TestAddingQcesWithTheSameIdAndSameTimestamp(t *testing.T) { qcc := queryContextCache{} sc := htapTestSnowflakeConn() qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} qceC := queryContextEntry{ID: 12, Timestamp: 123, Priority: 6, Context: "c"} qceD := queryContextEntry{ID: 12, Timestamp: 123, Priority: 6, Context: "d"} t.Run("Add to empty cache", func(t *testing.T) { qcc.add(sc, qceA) qcc.add(sc, qceB) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { t.Fatalf("no entries added to cache. %v", qcc.entries) } }) t.Run("Add another entry with different priority", func(t *testing.T) { qcc.add(sc, qceC) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) t.Run("Add another entry with same priority", func(t *testing.T) { qcc.add(sc, qceD) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) } func TestAddingQcesWithTheSameIdAndEarlierTimestamp(t *testing.T) { qcc := queryContextCache{} sc := htapTestSnowflakeConn() qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} qceC := queryContextEntry{ID: 12, Timestamp: 122, Priority: 6, Context: "c"} qceD := queryContextEntry{ID: 12, Timestamp: 122, Priority: 7, Context: "d"} t.Run("Add to empty cache", func(t *testing.T) { qcc.add(sc, qceA) qcc.add(sc, qceB) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) t.Run("Add another entry with different priority", func(t *testing.T) { qcc.add(sc, qceC) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) t.Run("Add another entry with same priority", func(t *testing.T) { qcc.add(sc, qceD) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) } func TestAddingQcesWithDifferentId(t *testing.T) { qcc := queryContextCache{} sc := htapTestSnowflakeConn() qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} qceC := queryContextEntry{ID: 14, Timestamp: 122, Priority: 7, Context: "c"} qceD := queryContextEntry{ID: 15, Timestamp: 122, Priority: 6, Context: "d"} t.Run("Add to empty cache", func(t *testing.T) { qcc.add(sc, qceA) qcc.add(sc, qceB) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) t.Run("Add another entry with same priority", func(t *testing.T) { qcc.add(sc, qceC) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) t.Run("Add another entry with different priority", func(t *testing.T) { qcc.add(sc, qceD) if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceD, qceC, qceB}) { t.Fatalf("unexpected qcc entries. %v", qcc.entries) } }) } func TestAddingQueryContextCacheEntry(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { t.Run("First query (may be on empty cache)", func(t *testing.T) { entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) copy(entriesBefore, sct.sc.queryContextCache.entries) sct.mustQuery("SELECT 1", nil) entriesAfter := sct.sc.queryContextCache.entries if !containsNewEntries(entriesAfter, entriesBefore) { t.Error("no new entries added to the query context cache") } }) t.Run("Second query (cache should not be empty)", func(t *testing.T) { entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) copy(entriesBefore, sct.sc.queryContextCache.entries) if len(entriesBefore) == 0 { t.Fatalf("cache should not be empty after first query") } sct.mustQuery("SELECT 2", nil) entriesAfter := sct.sc.queryContextCache.entries if !containsNewEntries(entriesAfter, entriesBefore) { t.Error("no new entries added to the query context cache") } }) }) } func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryContextEntry) bool { if len(entriesAfter) > len(entriesBefore) { return true } for _, entryAfter := range entriesAfter { for _, entryBefore := range entriesBefore { if !reflect.DeepEqual(entryBefore, entryAfter) { return true } } } return false } func TestPruneBySessionValue(t *testing.T) { qce1 := queryContextEntry{1, 1, 1, ""} qce2 := queryContextEntry{2, 2, 2, ""} qce3 := queryContextEntry{3, 3, 3, ""} testcases := []struct { size string expected []queryContextEntry }{ { size: "1", expected: []queryContextEntry{qce1}, }, { size: "2", expected: []queryContextEntry{qce1, qce2}, }, { size: "3", expected: []queryContextEntry{qce1, qce2, qce3}, }, { size: "4", expected: []queryContextEntry{qce1, qce2, qce3}, }, } for _, tc := range testcases { t.Run(tc.size, func(t *testing.T) { params := map[string]*string{ queryContextCacheSizeParamName: &tc.size, } sc := &snowflakeConn{ cfg: &Config{}, syncParams: syncParams{params: params}, } qcc := queryContextCache{} qcc.add(sc, qce1) qcc.add(sc, qce2) qcc.add(sc, qce3) if !reflect.DeepEqual(qcc.entries, tc.expected) { t.Errorf("unexpected cache entries. expected: %v, got: %v", tc.expected, qcc.entries) } }) } } func TestPruneByDefaultValue(t *testing.T) { qce1 := queryContextEntry{1, 1, 1, ""} qce2 := queryContextEntry{2, 2, 2, ""} qce3 := queryContextEntry{3, 3, 3, ""} qce4 := queryContextEntry{4, 4, 4, ""} qce5 := queryContextEntry{5, 5, 5, ""} qce6 := queryContextEntry{6, 6, 6, ""} sc := &snowflakeConn{ cfg: &Config{}, } qcc := queryContextCache{} qcc.add(sc, qce1) qcc.add(sc, qce2) qcc.add(sc, qce3) qcc.add(sc, qce4) qcc.add(sc, qce5) if len(qcc.entries) != 5 { t.Fatalf("Expected 5 elements, got: %v", len(qcc.entries)) } qcc.add(sc, qce6) if len(qcc.entries) != 5 { t.Fatalf("Expected 5 elements, got: %v", len(qcc.entries)) } } func TestNoQcesClearsCache(t *testing.T) { qce1 := queryContextEntry{1, 1, 1, ""} sc := &snowflakeConn{ cfg: &Config{}, } qcc := queryContextCache{} qcc.add(sc, qce1) if len(qcc.entries) != 1 { t.Fatalf("improperly inited cache") } qcc.add(sc) if len(qcc.entries) != 0 { t.Errorf("after adding empty context list cache should be cleared") } } func TestQCCUpdatedAfterQueryResponse(t *testing.T) { // Create initial QCC entry initialEntry := queryContextEntry{ID: 1, Timestamp: 100, Priority: 1, Context: "initial"} // Create query context that would be returned in the response newEntry := queryContextEntry{ID: 2, Timestamp: 200, Priority: 2, Context: "new"} queryContextJSON := fmt.Sprintf(`{"entries":[{"id":%d,"timestamp":%d,"priority":%d,"context":"%s"}]}`, newEntry.ID, newEntry.Timestamp, newEntry.Priority, newEntry.Context) testCases := []bool{true, false} for _, success := range testCases { t.Run(fmt.Sprintf("success=%v", success), func(t *testing.T) { // Mock response with query context postQueryMock := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { code := "0" message := "" if !success { code = "1234" message = "Query failed" } return &execResponse{ Data: execResponseData{ QueryContext: json.RawMessage(queryContextJSON), }, Message: message, Code: code, Success: success, }, nil } sr := &snowflakeRestful{ FuncPostQuery: postQueryMock, } sc := &snowflakeConn{ cfg: &Config{}, rest: sr, } sc.queryContextCache.add(sc, initialEntry) // Execute query _, err := sc.ExecContext(context.Background(), "SELECT 1", nil) if !success { assertNotNilF(t, err, "expected error for failed query") } else { assertNilF(t, err, "unexpected error for successful query") } // Verify QCC WAS updated in both cases - should now contain both entries assertEqualE(t, len(sc.queryContextCache.entries), 2, "expected 2 entries in QCC") // Verify new entry was added (entries are sorted by priority) found := false for _, entry := range sc.queryContextCache.entries { if entry.ID == newEntry.ID { found = true break } } assertTrueE(t, found, "new QCC entry not found after query") }) } } func htapTestSnowflakeConn() *snowflakeConn { return &snowflakeConn{ cfg: &Config{}, } } func TestQueryContextCacheDisabled(t *testing.T) { customDsn := dsn + "&disableQueryContextCache=true" runSnowflakeConnTestWithConfig(t, &testConfig{dsn: customDsn}, func(sct *SCTest) { sct.mustExec("SELECT 1", nil) if len(sct.sc.queryContextCache.entries) > 0 { t.Error("should not contain any entries") } }) } func TestHybridTablesE2E(t *testing.T) { skipOnJenkins(t, "HTAP is not enabled on environment") if runningOnGithubAction() && !runningOnAWS() { t.Skip("HTAP is enabled only on AWS") } runID := time.Now().UnixMilli() testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID) testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID) runSnowflakeConnTest(t, func(sct *SCTest) { dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil) defer func() { assertNilF(t, dbQuery.Close()) }() currentDb := make([]driver.Value, 1) assertNilF(t, dbQuery.Next(currentDb)) defer func() { sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil) sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil) sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil) }() t.Run("Run tests on first database", func(t *testing.T) { sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil) sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil) sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil) rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) defer func() { assertNilF(t, rows.Close()) }() row := make([]driver.Value, 2) assertNilF(t, rows.Next(row)) if row[0] != "1" || row[1] != "a" { t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) } sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil) rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) defer func() { assertNilF(t, rows2.Close()) }() assertNilF(t, rows2.Next(row)) if row[0] != "1" || row[1] != "a" { t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) } assertNilF(t, rows2.Next(row)) if row[0] != "2" || row[1] != "b" { t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1]) } if len(sct.sc.queryContextCache.entries) != 2 { t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries) } }) t.Run("Run tests on second database", func(t *testing.T) { sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil) sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil) sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil) rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil) defer func() { assertNilF(t, rows.Close()) }() row := make([]driver.Value, 2) assertNilF(t, rows.Next(row)) if row[0] != "3" || row[1] != "c" { t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1]) } if len(sct.sc.queryContextCache.entries) != 3 { t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) } }) t.Run("Run tests on first database again", func(t *testing.T) { sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil) sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil) rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) defer func() { assertNilF(t, rows.Close()) }() if len(sct.sc.queryContextCache.entries) != 3 { t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) } }) }) } func TestHTAPOptimizations(t *testing.T) { if runningOnGithubAction() { t.Skip("insufficient permissions") } for _, useHtapOptimizations := range []bool{true, false} { runSnowflakeConnTest(t, func(sct *SCTest) { t.Run("useHtapOptimizations="+strconv.FormatBool(useHtapOptimizations), func(t *testing.T) { if useHtapOptimizations { sct.mustExec("ALTER SESSION SET ENABLE_SNOW_654741_FOR_TESTING = true", nil) } runID := time.Now().UnixMilli() t.Run("Schema", func(t *testing.T) { newSchema := fmt.Sprintf("test_schema_%v", runID) if strings.EqualFold(sct.sc.cfg.Schema, newSchema) { t.Errorf("schema should not be switched") } sct.mustExec(fmt.Sprintf("CREATE SCHEMA %v", newSchema), nil) defer sct.mustExec(fmt.Sprintf("DROP SCHEMA %v", newSchema), nil) if !strings.EqualFold(sct.sc.cfg.Schema, newSchema) { t.Errorf("schema should be switched, expected %v, got %v", newSchema, sct.sc.cfg.Schema) } query := sct.mustQuery("SELECT 1", nil) query.Close() if !strings.EqualFold(sct.sc.cfg.Schema, newSchema) { t.Errorf("schema should be switched, expected %v, got %v", newSchema, sct.sc.cfg.Schema) } }) t.Run("Database", func(t *testing.T) { newDatabase := fmt.Sprintf("test_database_%v", runID) if strings.EqualFold(sct.sc.cfg.Database, newDatabase) { t.Errorf("database should not be switched") } sct.mustExec(fmt.Sprintf("CREATE DATABASE %v", newDatabase), nil) defer sct.mustExec(fmt.Sprintf("DROP DATABASE %v", newDatabase), nil) if !strings.EqualFold(sct.sc.cfg.Database, newDatabase) { t.Errorf("database should be switched, expected %v, got %v", newDatabase, sct.sc.cfg.Database) } query := sct.mustQuery("SELECT 1", nil) query.Close() if !strings.EqualFold(sct.sc.cfg.Database, newDatabase) { t.Errorf("database should be switched, expected %v, got %v", newDatabase, sct.sc.cfg.Database) } }) t.Run("Warehouse", func(t *testing.T) { newWarehouse := fmt.Sprintf("test_warehouse_%v", runID) if strings.EqualFold(sct.sc.cfg.Warehouse, newWarehouse) { t.Errorf("warehouse should not be switched") } sct.mustExec(fmt.Sprintf("CREATE WAREHOUSE %v", newWarehouse), nil) defer sct.mustExec(fmt.Sprintf("DROP WAREHOUSE %v", newWarehouse), nil) if !strings.EqualFold(sct.sc.cfg.Warehouse, newWarehouse) { t.Errorf("warehouse should be switched, expected %v, got %v", newWarehouse, sct.sc.cfg.Warehouse) } query := sct.mustQuery("SELECT 1", nil) query.Close() if !strings.EqualFold(sct.sc.cfg.Warehouse, newWarehouse) { t.Errorf("warehouse should be switched, expected %v, got %v", newWarehouse, sct.sc.cfg.Warehouse) } }) t.Run("Role", func(t *testing.T) { if strings.EqualFold(sct.sc.cfg.Role, "PUBLIC") { t.Errorf("role should not be public for this test") } sct.mustExec("USE ROLE public", nil) if !strings.EqualFold(sct.sc.cfg.Role, "PUBLIC") { t.Errorf("role should be switched, expected public, got %v", sct.sc.cfg.Role) } query := sct.mustQuery("SELECT 1", nil) query.Close() if !strings.EqualFold(sct.sc.cfg.Role, "PUBLIC") { t.Errorf("role should be switched, expected public, got %v", sct.sc.cfg.Role) } }) t.Run("Session param - DATE_OUTPUT_FORMAT", func(t *testing.T) { dateFormat, _ := sct.sc.syncParams.get("date_output_format") if !strings.EqualFold(*dateFormat, "YYYY-MM-DD") { t.Errorf("should use default date_output_format, but got: %v", *dateFormat) } sct.mustExec("ALTER SESSION SET DATE_OUTPUT_FORMAT = 'DD-MM-YYYY'", nil) defer sct.mustExec("ALTER SESSION SET DATE_OUTPUT_FORMAT = 'YYYY-MM-DD'", nil) dateFormat, _ = sct.sc.syncParams.get("date_output_format") if !strings.EqualFold(*dateFormat, "DD-MM-YYYY") { t.Errorf("date output format should be switched, expected DD-MM-YYYY, got %v", *dateFormat) } query := sct.mustQuery("SELECT 1", nil) query.Close() dateFormat, _ = sct.sc.syncParams.get("date_output_format") if !strings.EqualFold(*dateFormat, "DD-MM-YYYY") { t.Errorf("date output format should be switched, expected DD-MM-YYYY, got %v", *dateFormat) } }) }) }) } } func TestConnIsCleanAfterClose(t *testing.T) { // We create a new db here to not use the default pool as we can leave it in dirty state. t.Skip("Fails, because connection is returned to a pool dirty") ctx := context.Background() runID := time.Now().UnixMilli() db := openDB(t) defer db.Close() db.SetMaxOpenConns(1) conn, err := db.Conn(ctx) if err != nil { t.Fatal(err) } defer conn.Close() dbt := DBTest{t, conn} dbt.mustExec(forceJSON) var dbName string rows1 := dbt.mustQuery("SELECT CURRENT_DATABASE()") rows1.Next() assertNilF(t, rows1.Scan(&dbName)) newDbName := fmt.Sprintf("test_database_%v", runID) dbt.mustExec("CREATE DATABASE " + newDbName) assertNilF(t, rows1.Close()) assertNilF(t, conn.Close()) conn2, err := db.Conn(ctx) if err != nil { t.Fatal(err) } dbt2 := DBTest{t, conn2} var dbName2 string rows2 := dbt2.mustQuery("SELECT CURRENT_DATABASE()") defer func() { assertNilF(t, rows2.Close()) }() rows2.Next() assertNilF(t, rows2.Scan(&dbName2)) if !strings.EqualFold(dbName, dbName2) { t.Errorf("fresh connection from pool should have original database") } } ================================================ FILE: internal/arrow/arrow.go ================================================ package arrow import ( "context" "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/snowflakedb/gosnowflake/v2/internal/query" ) // contextKey is a private type for context keys used by this package. type contextKey string // Context keys for arrow batches configuration. const ( ctxArrowBatches contextKey = "ARROW_BATCHES" ctxArrowBatchesTimestampOpt contextKey = "ARROW_BATCHES_TIMESTAMP_OPTION" ctxArrowBatchesUtf8Validate contextKey = "ENABLE_ARROW_BATCHES_UTF8_VALIDATION" ctxHigherPrecision contextKey = "ENABLE_HIGHER_PRECISION" ) // --- Timestamp option --- // TimestampOption controls how Snowflake timestamps are converted in arrow batches. type TimestampOption int const ( // UseNanosecondTimestamp converts Snowflake timestamps to arrow timestamps with nanosecond precision. UseNanosecondTimestamp TimestampOption = iota // UseMicrosecondTimestamp converts Snowflake timestamps to arrow timestamps with microsecond precision. UseMicrosecondTimestamp // UseMillisecondTimestamp converts Snowflake timestamps to arrow timestamps with millisecond precision. UseMillisecondTimestamp // UseSecondTimestamp converts Snowflake timestamps to arrow timestamps with second precision. UseSecondTimestamp // UseOriginalTimestamp leaves Snowflake timestamps in their original format without conversion. UseOriginalTimestamp ) // --- Context accessors --- // EnableArrowBatches sets the arrow batches mode flag in the context. func EnableArrowBatches(ctx context.Context) context.Context { return context.WithValue(ctx, ctxArrowBatches, true) } // BatchesEnabled checks if arrow batches mode is enabled. func BatchesEnabled(ctx context.Context) bool { v := ctx.Value(ctxArrowBatches) if v == nil { return false } d, ok := v.(bool) return ok && d } // WithTimestampOption sets the arrow batches timestamp option in the context. func WithTimestampOption(ctx context.Context, option TimestampOption) context.Context { return context.WithValue(ctx, ctxArrowBatchesTimestampOpt, option) } // GetTimestampOption returns the timestamp option from the context. func GetTimestampOption(ctx context.Context) TimestampOption { v := ctx.Value(ctxArrowBatchesTimestampOpt) if v == nil { return UseNanosecondTimestamp } o, ok := v.(TimestampOption) if !ok { return UseNanosecondTimestamp } return o } // EnableUtf8Validation enables UTF-8 validation for arrow batch string columns. func EnableUtf8Validation(ctx context.Context) context.Context { return context.WithValue(ctx, ctxArrowBatchesUtf8Validate, true) } // Utf8ValidationEnabled checks if UTF-8 validation is enabled. func Utf8ValidationEnabled(ctx context.Context) bool { v := ctx.Value(ctxArrowBatchesUtf8Validate) if v == nil { return false } d, ok := v.(bool) return ok && d } // WithHigherPrecision enables higher precision mode in the context. func WithHigherPrecision(ctx context.Context) context.Context { return context.WithValue(ctx, ctxHigherPrecision, true) } // HigherPrecisionEnabled checks if higher precision is enabled. func HigherPrecisionEnabled(ctx context.Context) bool { v := ctx.Value(ctxHigherPrecision) if v == nil { return false } d, ok := v.(bool) return ok && d } // BatchRaw holds raw (untransformed) arrow records for a single batch. type BatchRaw struct { Records *[]arrow.Record Index int RowCount int Location *time.Location Download func(ctx context.Context) (*[]arrow.Record, int, error) } // BatchDataInfo contains all information needed to build arrow batches. type BatchDataInfo struct { Batches []BatchRaw RowTypes []query.ExecResponseRowType Allocator memory.Allocator Ctx context.Context QueryID string } // BatchDataProvider is implemented by SnowflakeRows to expose raw arrow batch data. type BatchDataProvider interface { GetArrowBatches() (*BatchDataInfo, error) } ================================================ FILE: internal/compilation/cgo_disabled.go ================================================ //go:build !cgo package compilation // CgoEnabled is set to false if CGO is disabled. var CgoEnabled = false ================================================ FILE: internal/compilation/cgo_enabled.go ================================================ //go:build cgo package compilation // CgoEnabled is set to true if CGO is enabled. var CgoEnabled = true ================================================ FILE: internal/compilation/linking_mode.go ================================================ package compilation import ( "debug/elf" "fmt" "runtime" "sync" ) // LinkingMode describes what linking mode was detected for the current binary. type LinkingMode int const ( // StaticLinking means the static linking. StaticLinking LinkingMode = iota // DynamicLinking means the dynamic linking. DynamicLinking // UnknownLinking means driver couldn't determine linking or it is not relevant (it is relevant on Linux only). UnknownLinking ) func (lm *LinkingMode) String() string { switch *lm { case StaticLinking: return "static" case DynamicLinking: return "dynamic" default: return "unknown" } } // CheckDynamicLinking checks whether the current binary has a dynamic linker (PT_INTERP). // A statically linked glibc binary will crash with SIGFPE if dlopen is called, // so this check allows us to skip minicore loading gracefully. // The result is cached so the ELF parsing only happens once. func CheckDynamicLinking() (LinkingMode, error) { linkingModeOnce.Do(func() { if runtime.GOOS != "linux" { linkingModeCached = UnknownLinking return } f, err := elf.Open("/proc/self/exe") if err != nil { linkingModeErr = fmt.Errorf("cannot open /proc/self/exe: %v", err) return } defer func() { _ = f.Close() }() for _, p := range f.Progs { if p.Type == elf.PT_INTERP { linkingModeCached = DynamicLinking return } } linkingModeCached = StaticLinking }) return linkingModeCached, linkingModeErr } var ( linkingModeOnce sync.Once linkingModeCached LinkingMode linkingModeErr error ) ================================================ FILE: internal/compilation/minicore_disabled.go ================================================ //go:build minicore_disabled package compilation // MinicoreEnabled is set to false when building with -tags minicore_disabled. // This disables minicore at compile time, which is useful for statically linked binaries // that cannot use dynamic library loading (dlopen). // // Example: go build -tags minicore_disabled ./... var MinicoreEnabled = false ================================================ FILE: internal/compilation/minicore_enabled.go ================================================ //go:build !minicore_disabled package compilation // MinicoreEnabled is set to true by default. Build with -tags minicore_disabled to disable // minicore at compile time. This is useful when building statically linked binaries, // as minicore requires dynamic library loading (dlopen) which is incompatible with static linking. // // Example: go build -tags minicore_disabled ./... var MinicoreEnabled = true ================================================ FILE: internal/config/assert_test.go ================================================ package config import ( "fmt" "reflect" "slices" "strings" "testing" "time" sflogger "github.com/snowflakedb/gosnowflake/v2/internal/logger" ) // TODO temporary - move this to a common test utils package when we have one func maskSecrets(text string) string { return sflogger.MaskSecrets(text) } func assertNilE(t *testing.T, actual any, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateNil(actual, descriptions...)) } func assertNilF(t *testing.T, actual any, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateNil(actual, descriptions...)) } func assertNotNilF(t *testing.T, actual any, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateNotNil(actual, descriptions...)) } func assertEqualE(t *testing.T, actual any, expected any, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) } func assertEqualF(t *testing.T, actual any, expected any, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) } func assertTrueE(t *testing.T, actual bool, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEqual(actual, true, descriptions...)) } func assertTrueF(t *testing.T, actual bool, descriptions ...string) { t.Helper() fatalOnNonEmpty(t, validateEqual(actual, true, descriptions...)) } func assertFalseE(t *testing.T, actual bool, descriptions ...string) { t.Helper() errorOnNonEmpty(t, validateEqual(actual, false, descriptions...)) } func fatalOnNonEmpty(t *testing.T, errMsg string) { if errMsg != "" { t.Helper() t.Fatal(formatErrorMessage(errMsg)) } } func errorOnNonEmpty(t *testing.T, errMsg string) { if errMsg != "" { t.Helper() t.Error(formatErrorMessage(errMsg)) } } func formatErrorMessage(errMsg string) string { return fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339Nano), maskSecrets(errMsg)) } func validateNil(actual any, descriptions ...string) string { if isNil(actual) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to be nil but was not. %s", maskSecrets(fmt.Sprintf("%v", actual)), desc) } func validateNotNil(actual any, descriptions ...string) string { if !isNil(actual) { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected to be not nil but was not. %s", desc) } func validateEqual(actual any, expected any, descriptions ...string) string { if expected == actual { return "" } desc := joinDescriptions(descriptions...) return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s", maskSecrets(fmt.Sprintf("%v", actual)), maskSecrets(fmt.Sprintf("%v", expected)), desc) } func joinDescriptions(descriptions ...string) string { return strings.Join(descriptions, " ") } func isNil(value any) bool { if value == nil { return true } val := reflect.ValueOf(value) return slices.Contains([]reflect.Kind{reflect.Pointer, reflect.Slice, reflect.Map, reflect.Interface, reflect.Func}, val.Kind()) && val.IsNil() } ================================================ FILE: internal/config/auth_type.go ================================================ package config import ( "net/url" "strings" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" ) // AuthType indicates the type of authentication in Snowflake type AuthType int const ( // AuthTypeSnowflake is the general username password authentication AuthTypeSnowflake AuthType = iota // AuthTypeOAuth is the OAuth authentication AuthTypeOAuth // AuthTypeExternalBrowser is to use a browser to access an Fed and perform SSO authentication AuthTypeExternalBrowser // AuthTypeOkta is to use a native okta URL to perform SSO authentication on Okta AuthTypeOkta // AuthTypeJwt is to use Jwt to perform authentication AuthTypeJwt // AuthTypeTokenAccessor is to use the provided token accessor and bypass authentication AuthTypeTokenAccessor // AuthTypeUsernamePasswordMFA is to use username and password with mfa AuthTypeUsernamePasswordMFA // AuthTypePat is to use programmatic access token AuthTypePat // AuthTypeOAuthAuthorizationCode is to use browser-based OAuth2 flow AuthTypeOAuthAuthorizationCode // AuthTypeOAuthClientCredentials is to use non-interactive OAuth2 flow AuthTypeOAuthClientCredentials // AuthTypeWorkloadIdentityFederation is to use CSP identity for authentication AuthTypeWorkloadIdentityFederation ) func (authType AuthType) String() string { switch authType { case AuthTypeSnowflake: return "SNOWFLAKE" case AuthTypeOAuth: return "OAUTH" case AuthTypeExternalBrowser: return "EXTERNALBROWSER" case AuthTypeOkta: return "OKTA" case AuthTypeJwt: return "SNOWFLAKE_JWT" case AuthTypeTokenAccessor: return "TOKENACCESSOR" case AuthTypeUsernamePasswordMFA: return "USERNAME_PASSWORD_MFA" case AuthTypePat: return "PROGRAMMATIC_ACCESS_TOKEN" case AuthTypeOAuthAuthorizationCode: return "OAUTH_AUTHORIZATION_CODE" case AuthTypeOAuthClientCredentials: return "OAUTH_CLIENT_CREDENTIALS" case AuthTypeWorkloadIdentityFederation: return "WORKLOAD_IDENTITY" default: return "UNKNOWN" } } // DetermineAuthenticatorType parses the authenticator string and sets the Config.Authenticator field. func DetermineAuthenticatorType(cfg *Config, value string) error { upperCaseValue := strings.ToUpper(value) lowerCaseValue := strings.ToLower(value) if strings.Trim(value, " ") == "" || upperCaseValue == AuthTypeSnowflake.String() { cfg.Authenticator = AuthTypeSnowflake return nil } else if upperCaseValue == AuthTypeOAuth.String() { cfg.Authenticator = AuthTypeOAuth return nil } else if upperCaseValue == AuthTypeJwt.String() { cfg.Authenticator = AuthTypeJwt return nil } else if upperCaseValue == AuthTypeExternalBrowser.String() { cfg.Authenticator = AuthTypeExternalBrowser return nil } else if upperCaseValue == AuthTypeUsernamePasswordMFA.String() { cfg.Authenticator = AuthTypeUsernamePasswordMFA return nil } else if upperCaseValue == AuthTypeTokenAccessor.String() { cfg.Authenticator = AuthTypeTokenAccessor return nil } else if upperCaseValue == AuthTypePat.String() { cfg.Authenticator = AuthTypePat return nil } else if upperCaseValue == AuthTypeOAuthAuthorizationCode.String() { cfg.Authenticator = AuthTypeOAuthAuthorizationCode return nil } else if upperCaseValue == AuthTypeOAuthClientCredentials.String() { cfg.Authenticator = AuthTypeOAuthClientCredentials return nil } else if upperCaseValue == AuthTypeWorkloadIdentityFederation.String() { cfg.Authenticator = AuthTypeWorkloadIdentityFederation return nil } else { // possibly Okta case oktaURLString, err := url.QueryUnescape(lowerCaseValue) if err != nil { return &sferrors.SnowflakeError{ Number: sferrors.ErrCodeFailedToParseAuthenticator, Message: sferrors.ErrMsgFailedToParseAuthenticator, MessageArgs: []any{lowerCaseValue}, } } oktaURL, err := url.Parse(oktaURLString) if err != nil { return &sferrors.SnowflakeError{ Number: sferrors.ErrCodeFailedToParseAuthenticator, Message: sferrors.ErrMsgFailedToParseAuthenticator, MessageArgs: []any{oktaURLString}, } } if oktaURL.Scheme != "https" { return &sferrors.SnowflakeError{ Number: sferrors.ErrCodeFailedToParseAuthenticator, Message: sferrors.ErrMsgFailedToParseAuthenticator, MessageArgs: []any{oktaURLString}, } } cfg.OktaURL = oktaURL cfg.Authenticator = AuthTypeOkta } return nil } ================================================ FILE: internal/config/config.go ================================================ // Package config provides the Config struct which contains all configuration parameters for the driver and a Validate method to check if the configuration is correct. package config import ( "crypto/rsa" "errors" "net/http" "net/url" "os" "strings" "time" ) // Config is a set of configuration parameters type Config struct { Account string // Account name User string // Username Password string // Password (requires User) Database string // Database name Schema string // Schema Warehouse string // Warehouse Role string // Role Region string // Region OauthClientID string // Client id for OAuth2 external IdP OauthClientSecret string // Client secret for OAuth2 external IdP OauthAuthorizationURL string // Authorization URL of Auth2 external IdP OauthTokenRequestURL string // Token request URL of Auth2 external IdP OauthRedirectURI string // Redirect URI registered in IdP. The default is http://127.0.0.1: OauthScope string // Comma separated list of scopes. If empty it is derived from role. EnableSingleUseRefreshTokens bool // Enables single use refresh tokens for Snowflake IdP // ValidateDefaultParameters disable the validation checks for Database, Schema, Warehouse and Role // at the time a connection is established ValidateDefaultParameters Bool Params map[string]*string // other connection parameters Protocol string // http or https (optional) Host string // hostname (optional) Port int // port (optional) Authenticator AuthType // The authenticator type SingleAuthenticationPrompt Bool // If enabled prompting for authentication will only occur for the first authentication challenge Passcode string PasscodeInPassword bool OktaURL *url.URL // Deprecated: timeouts may be reorganized in a future release. LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response // Deprecated: timeouts may be reorganized in a future release. RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response // Deprecated: timeouts may be reorganized in a future release. JWTExpireTimeout time.Duration // JWT expire after timeout // Deprecated: timeouts may be reorganized in a future release. ClientTimeout time.Duration // Timeout for network round trip + read out http response // Deprecated: timeouts may be reorganized in a future release. JWTClientTimeout time.Duration // Timeout for network round trip + read out http response used when JWT token auth is taking place // Deprecated: timeouts may be reorganized in a future release. ExternalBrowserTimeout time.Duration // Timeout for external browser login // Deprecated: timeouts may be reorganized in a future release. CloudStorageTimeout time.Duration // Timeout for a single call to a cloud storage provider MaxRetryCount int // Specifies how many times non-periodic HTTP request can be retried Application string // application name. DisableOCSPChecks bool // driver doesn't check certificate revocation status OCSPFailOpen OCSPFailOpenMode // OCSP Fail Open Token string // Token to use for OAuth other forms of token based auth TokenFilePath string // TokenFilePath defines a file where to read token from TokenAccessor TokenAccessor // TokenAccessor Optional token accessor to use ServerSessionKeepAlive bool // ServerSessionKeepAlive enables the session to persist even after the driver connection is closed PrivateKey *rsa.PrivateKey // Private key used to sign JWT Transporter http.RoundTripper // RoundTripper to intercept HTTP requests and responses TLSConfigName string // Name of the TLS config to use // Deprecated: may be removed in a future release with logging reorganization. Tracing string // sets logging level LogQueryText bool // indicates whether query text should be logged. LogQueryParameters bool // indicates whether query parameters should be logged. TmpDirPath string // sets temporary directory used by a driver for operations like encrypting, compressing etc ClientRequestMfaToken Bool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux. ClientStoreTemporaryCredential Bool // When true the ID token is cached in the credential manager. True by default in Windows/OSX. False for Linux. DisableQueryContextCache bool // Should HTAP query context cache be disabled IncludeRetryReason Bool // Should retried request contain retry reason ClientConfigFile string // File path to the client configuration json file DisableConsoleLogin Bool // Indicates whether console login should be disabled DisableSamlURLCheck Bool // Indicates whether the SAML URL check should be disabled WorkloadIdentityProvider string // The workload identity provider to use for WIF authentication WorkloadIdentityEntraResource string // The resource to use for WIF authentication on Azure environment WorkloadIdentityImpersonationPath []string // The components to use for WIF impersonation. CertRevocationCheckMode CertRevocationCheckMode // revocation check mode for CRLs CrlAllowCertificatesWithoutCrlURL Bool // Allow certificates (not short-lived) without CRL DP included to be treated as correct ones CrlInMemoryCacheDisabled bool // Should the in-memory cache be disabled CrlOnDiskCacheDisabled bool // Should the on-disk cache be disabled CrlDownloadMaxSize int // Max size in bytes of CRL to download. 0 means use default (20MB). CrlHTTPClientTimeout time.Duration // Timeout for HTTP client used to download CRL ConnectionDiagnosticsEnabled bool // Indicates whether connection diagnostics should be enabled ConnectionDiagnosticsAllowlistFile string // File path to the allowlist file for connection diagnostics. If not specified, the allowlist.json file in the current directory will be used. ProxyHost string // Proxy host ProxyPort int // Proxy port ProxyUser string // Proxy user ProxyPassword string // Proxy password ProxyProtocol string // Proxy protocol (http or https) NoProxy string // No proxy for this host list } var errTokenConfigConflict = errors.New("token and tokenFilePath cannot be specified at the same time") // Validate enables testing if config is correct. // A driver client may call it manually, but it is also called during opening first connection. func (c *Config) Validate() error { if c.TmpDirPath != "" { if _, err := os.Stat(c.TmpDirPath); err != nil { return err } } if strings.EqualFold(c.WorkloadIdentityProvider, "azure") && len(c.WorkloadIdentityImpersonationPath) > 0 { return errors.New("WorkloadIdentityImpersonationPath is not supported for Azure") } if c.Token != "" && c.TokenFilePath != "" { return errTokenConfigConflict } return nil } // Param binds Config field names to environment variable names. type Param struct { Name string EnvName string FailOnMissing bool } ================================================ FILE: internal/config/config_bool.go ================================================ package config // Bool is a type to represent true or false in the Config type Bool uint8 const ( // BoolNotSet represents the default value for the config field which is not set BoolNotSet Bool = iota // Reserved for unset to let default value fall into this category // BoolTrue represents true for the config field BoolTrue // BoolFalse represents false for the config field BoolFalse ) func (cb Bool) String() string { switch cb { case BoolTrue: return "true" case BoolFalse: return "false" default: return "not set" } } ================================================ FILE: internal/config/connection_configuration.go ================================================ package config import ( "encoding/base64" "errors" "os" path "path/filepath" "runtime" "strconv" "strings" "time" "github.com/BurntSushi/toml" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" ) const ( snowflakeConnectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME" snowflakeHome = "SNOWFLAKE_HOME" defaultTokenPath = "/snowflake/session/token" othersCanReadFilePermission = os.FileMode(0044) othersCanWriteFilePermission = os.FileMode(0022) executableFilePermission = os.FileMode(0111) skipWarningForReadPermissionsEnv = "SF_SKIP_WARNING_FOR_READ_PERMISSIONS_ON_CONFIG_FILE" ) // LoadConnectionConfig returns connection configs loaded from the toml file. // By default, SNOWFLAKE_HOME(toml file path) is os.snowflakeHome/.snowflake // and SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN) is 'default' func LoadConnectionConfig() (*Config, error) { logger.Trace("Loading connection configuration from the local files.") cfg := &Config{ Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Default to snowflake } dsn := getConnectionDSN(os.Getenv(snowflakeConnectionName)) snowflakeConfigDir, err := GetTomlFilePath(os.Getenv(snowflakeHome)) if err != nil { return nil, err } logger.Debugf("Looking for connection file in directory %v", snowflakeConfigDir) tomlFilePath := path.Join(snowflakeConfigDir, "connections.toml") err = ValidateFilePermission(tomlFilePath) if err != nil { return nil, err } tomlInfo := make(map[string]any) _, err = toml.DecodeFile(tomlFilePath, &tomlInfo) if err != nil { return nil, err } dsnMap, exist := tomlInfo[dsn] if !exist { return nil, &sferrors.SnowflakeError{ Number: sferrors.ErrCodeFailedToFindDSNInToml, Message: sferrors.ErrMsgFailedToFindDSNInTomlFile, } } connectionConfig, ok := dsnMap.(map[string]any) if !ok { return nil, err } logger.Trace("Trying to parse the config file") err = ParseToml(cfg, connectionConfig) if err != nil { return nil, err } err = FillMissingConfigParameters(cfg) if err != nil { return nil, err } return cfg, err } // ParseToml parses a TOML connection map into a Config. func ParseToml(cfg *Config, connectionMap map[string]any) error { for key, value := range connectionMap { if err := HandleSingleParam(cfg, key, value); err != nil { return err } } return nil } // HandleSingleParam processes a single TOML parameter into a Config. func HandleSingleParam(cfg *Config, key string, value any) error { var err error // We normalize the key to handle both snake_case and camelCase. normalizedKey := strings.ReplaceAll(strings.ToLower(key), "_", "") // the cases in switch statement should be in lower case and no _ switch normalizedKey { case "user", "username": cfg.User, err = parseString(value) case "password": cfg.Password, err = parseString(value) case "host": cfg.Host, err = parseString(value) case "account": cfg.Account, err = parseString(value) case "warehouse": cfg.Warehouse, err = parseString(value) case "database": cfg.Database, err = parseString(value) case "schema": cfg.Schema, err = parseString(value) case "role": cfg.Role, err = parseString(value) case "region": cfg.Region, err = parseString(value) case "protocol": cfg.Protocol, err = parseString(value) case "passcode": cfg.Passcode, err = parseString(value) case "port": cfg.Port, err = ParseInt(value) case "passcodeinpassword": cfg.PasscodeInPassword, err = ParseBool(value) case "clienttimeout": cfg.ClientTimeout, err = ParseDuration(value) case "jwtclienttimeout": cfg.JWTClientTimeout, err = ParseDuration(value) case "logintimeout": cfg.LoginTimeout, err = ParseDuration(value) case "requesttimeout": cfg.RequestTimeout, err = ParseDuration(value) case "jwttimeout": cfg.JWTExpireTimeout, err = ParseDuration(value) case "externalbrowsertimeout": cfg.ExternalBrowserTimeout, err = ParseDuration(value) case "maxretrycount": cfg.MaxRetryCount, err = ParseInt(value) case "application": cfg.Application, err = parseString(value) case "authenticator": var v string v, err = parseString(value) if err = checkParsingError(err, key, value); err != nil { return err } err = DetermineAuthenticatorType(cfg, v) case "disableocspchecks": cfg.DisableOCSPChecks, err = ParseBool(value) case "ocspfailopen": var vv Bool vv, err = parseConfigBool(value) if err := checkParsingError(err, key, value); err != nil { return err } cfg.OCSPFailOpen = OCSPFailOpenMode(vv) case "token": cfg.Token, err = parseString(value) case "privatekey": var v string v, err = parseString(value) if err = checkParsingError(err, key, value); err != nil { return err } block, decodeErr := base64.URLEncoding.DecodeString(v) if decodeErr != nil { return &sferrors.SnowflakeError{ Number: sferrors.ErrCodePrivateKeyParseError, Message: "Base64 decode failed", } } cfg.PrivateKey, err = ParsePKCS8PrivateKey(block) case "validatedefaultparameters": cfg.ValidateDefaultParameters, err = parseConfigBool(value) case "clientrequestmfatoken": cfg.ClientRequestMfaToken, err = parseConfigBool(value) case "clientstoretemporarycredential": cfg.ClientStoreTemporaryCredential, err = parseConfigBool(value) case "tracing": cfg.Tracing, err = parseString(value) case "logquerytext": cfg.LogQueryText, err = ParseBool(value) case "logqueryparameters": cfg.LogQueryParameters, err = ParseBool(value) case "tmpdirpath": cfg.TmpDirPath, err = parseString(value) case "disablequerycontextcache": cfg.DisableQueryContextCache, err = ParseBool(value) case "includeretryreason": cfg.IncludeRetryReason, err = parseConfigBool(value) case "clientconfigfile": cfg.ClientConfigFile, err = parseString(value) case "disableconsolelogin": cfg.DisableConsoleLogin, err = parseConfigBool(value) case "disablesamlurlcheck": cfg.DisableSamlURLCheck, err = parseConfigBool(value) case "oauthauthorizationurl": cfg.OauthAuthorizationURL, err = parseString(value) case "oauthclientid": cfg.OauthClientID, err = parseString(value) case "oauthclientsecret": cfg.OauthClientSecret, err = parseString(value) case "oauthtokenrequesturl": cfg.OauthTokenRequestURL, err = parseString(value) case "oauthredirecturi": cfg.OauthRedirectURI, err = parseString(value) case "oauthscope": cfg.OauthScope, err = parseString(value) case "workloadidentityprovider": cfg.WorkloadIdentityProvider, err = parseString(value) case "workloadidentityentraresource": cfg.WorkloadIdentityEntraResource, err = parseString(value) case "workloadidentityimpersonatinpath": cfg.WorkloadIdentityImpersonationPath, err = parseStrings(value) case "tokenfilepath": cfg.TokenFilePath, err = parseString(value) if err = checkParsingError(err, key, value); err != nil { return err } case "connectiondiagnosticsenabled": cfg.ConnectionDiagnosticsEnabled, err = ParseBool(value) case "connectiondiagnosticsallowlistfile": cfg.ConnectionDiagnosticsAllowlistFile, err = parseString(value) case "proxyhost": cfg.ProxyHost, err = parseString(value) case "proxyport": cfg.ProxyPort, err = ParseInt(value) case "proxyuser": cfg.ProxyUser, err = parseString(value) case "proxypassword": cfg.ProxyPassword, err = parseString(value) case "proxyprotocol": cfg.ProxyProtocol, err = parseString(value) case "noproxy": cfg.NoProxy, err = parseString(value) default: param, err := parseString(value) if err = checkParsingError(err, key, value); err != nil { return err } cfg.Params[urlDecodeIfNeeded(key)] = ¶m } return checkParsingError(err, key, value) } func checkParsingError(err error, key string, value any) error { if err != nil { err = &sferrors.SnowflakeError{ Number: sferrors.ErrCodeTomlFileParsingFailed, Message: sferrors.ErrMsgFailedToParseTomlFile, MessageArgs: []any{key, value}, } logger.Errorf("Parsed key: %s, value: %v is not an option for the connection config", key, value) return err } logger.Warnf("Parsed key: %s, value: %v — cannot be parsed as string", key, value) return nil } // ParseInt parses an interface value to int. func ParseInt(i any) (int, error) { v, ok := i.(string) if !ok { num, ok := i.(int) if !ok { return 0, errors.New("failed to parse the value to integer") } return num, nil } return strconv.Atoi(v) } // ParseBool parses an interface value to bool. func ParseBool(i any) (bool, error) { v, ok := i.(string) if !ok { vv, ok := i.(bool) if !ok { return false, errors.New("failed to parse the value to boolean") } return vv, nil } return strconv.ParseBool(v) } func parseConfigBool(i any) (Bool, error) { vv, err := ParseBool(i) if err != nil { return BoolFalse, err } if vv { return BoolTrue, nil } return BoolFalse, nil } // ParseDuration parses an interface value to time.Duration. func ParseDuration(i any) (time.Duration, error) { v, ok := i.(string) if !ok { num, err := ParseInt(i) if err != nil { return time.Duration(0), err } t := int64(num) return time.Duration(t * int64(time.Second)), nil } return parseTimeout(v) } // ReadToken reads a token from the given path (or default path if empty). func ReadToken(tokenPath string) (string, error) { if tokenPath == "" { tokenPath = defaultTokenPath } if !path.IsAbs(tokenPath) { var err error tokenPath, err = path.Abs(tokenPath) if err != nil { return "", err } } err := ValidateFilePermission(tokenPath) if err != nil { return "", err } token, err := os.ReadFile(tokenPath) if err != nil { return "", err } return string(token), nil } func parseString(i any) (string, error) { v, ok := i.(string) if !ok { return "", errors.New("failed to convert the value to string") } return v, nil } func parseStrings(i any) ([]string, error) { s, ok := i.(string) if !ok { return nil, errors.New("failed to convert the value to string") } return strings.Split(s, ","), nil } // GetTomlFilePath returns the path to the TOML file directory. func GetTomlFilePath(filePath string) (string, error) { if len(filePath) == 0 { homeDir, err := os.UserHomeDir() if err != nil { return "", err } filePath = path.Join(homeDir, ".snowflake") } absDir, err := path.Abs(filePath) if err != nil { return "", err } return absDir, nil } func getConnectionDSN(dsn string) string { if len(dsn) != 0 { return dsn } return "default" } // ValidateFilePermission checks that a file does not have overly permissive permissions. func ValidateFilePermission(filePath string) error { if runtime.GOOS == "windows" { return nil } fileInfo, err := os.Stat(filePath) if err != nil { return err } permission := fileInfo.Mode().Perm() if !shouldSkipWarningForReadPermissions() && permission&othersCanReadFilePermission != 0 { logger.Warnf("file '%v' is readable by someone other than the owner. Your Permission: %v. If you want "+ "to disable this warning, either remove read permissions from group and others or set the environment "+ "variable %v to true", filePath, permission, skipWarningForReadPermissionsEnv) } if permission&executableFilePermission != 0 { return &sferrors.SnowflakeError{ Number: sferrors.ErrCodeInvalidFilePermission, Message: sferrors.ErrMsgInvalidExecutablePermissionToFile, MessageArgs: []any{filePath, permission}, } } if permission&othersCanWriteFilePermission != 0 { return &sferrors.SnowflakeError{ Number: sferrors.ErrCodeInvalidFilePermission, Message: sferrors.ErrMsgInvalidWritablePermissionToFile, MessageArgs: []any{filePath, permission}, } } return nil } func shouldSkipWarningForReadPermissions() bool { return os.Getenv(skipWarningForReadPermissionsEnv) != "" } ================================================ FILE: internal/config/connection_configuration_test.go ================================================ package config import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/base64" "fmt" "os" path "path/filepath" "runtime" "strings" "testing" "time" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" sflogger "github.com/snowflakedb/gosnowflake/v2/internal/logger" ) func TestTokenFilePermission(t *testing.T) { if runtime.GOOS == "windows" { return } os.Setenv(snowflakeHome, "../../test_data") connectionsStat, err := os.Stat("../../test_data/connections.toml") if err != nil { t.Fatalf("Failed to stat connections.toml file: %v", err) } tokenStat, err := os.Stat("../../test_data/snowflake/session/token") if err != nil { t.Fatalf("Failed to stat token file: %v", err) } defer func() { err = os.Chmod("../../test_data/connections.toml", connectionsStat.Mode()) if err != nil { t.Errorf("Failed to restore connections.toml file permission: %v", err) } err = os.Chmod("../../test_data/snowflake/session/token", tokenStat.Mode()) if err != nil { t.Errorf("Failed to restore token file permission: %v", err) } }() t.Run("test warning logger for readable outside owner", func(t *testing.T) { originalGlobalLogger := sflogger.GetLogger() newLogger := sflogger.CreateDefaultLogger() sflogger.SetLogger(newLogger) buf := &bytes.Buffer{} sflogger.GetLogger().SetOutput(buf) defer func() { sflogger.SetLogger(originalGlobalLogger) }() err = os.Chmod("../../test_data/connections.toml", 0644) if err != nil { t.Fatalf("Failed to change connections.toml file permission: %v", err) } _, err = LoadConnectionConfig() if err != nil { t.Fatalf("Failed to load connection config: %v", err) } connectionsAbsolutePath, err := path.Abs("../../test_data/connections.toml") if err != nil { t.Fatalf("Failed to get absolute path of connections.toml file: %v", err) } expectedWarn := fmt.Sprintf("msg=\"file '%v' is readable by someone other than the owner. "+ "Your Permission: -rw-r--r--. If you want to disable this warning, either remove read permissions from group "+ "and others or set the environment variable SF_SKIP_WARNING_FOR_READ_PERMISSIONS_ON_CONFIG_FILE to true\"", connectionsAbsolutePath) if !strings.Contains(buf.String(), expectedWarn) { t.Errorf("Expected warning message not found in logs.\nGot: %v\nWant substring: %v", buf.String(), expectedWarn) } }) t.Run("test warning skipped logger for readable outside owner", func(t *testing.T) { os.Setenv(skipWarningForReadPermissionsEnv, "true") defer func() { os.Unsetenv(skipWarningForReadPermissionsEnv) }() originalGlobalLogger := sflogger.GetLogger() newLogger := sflogger.CreateDefaultLogger() sflogger.SetLogger(newLogger) buf := &bytes.Buffer{} sflogger.GetLogger().SetOutput(buf) defer func() { sflogger.SetLogger(originalGlobalLogger) }() err = os.Chmod("../../test_data/connections.toml", 0644) if err != nil { t.Fatalf("Failed to change connections.toml file permission: %v", err) } _, err = LoadConnectionConfig() if err != nil { t.Fatalf("Failed to load connection config: %v", err) } }) t.Run("test writable connection file other than owner", func(t *testing.T) { err = os.Chmod("../../test_data/connections.toml", 0666) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } _, err := LoadConnectionConfig() if err == nil { t.Fatal("The error should occur because the file is writable by anyone but the owner") } driverErr, ok := err.(*sferrors.SnowflakeError) if !ok { t.Fatalf("This should be a Snowflake Error, got: %T", err) } if driverErr.Number != sferrors.ErrCodeInvalidFilePermission { t.Fatalf("Expected error code %d, got %d", sferrors.ErrCodeInvalidFilePermission, driverErr.Number) } }) t.Run("test writable token file other than owner", func(t *testing.T) { err = os.Chmod("../../test_data/snowflake/session/token", 0666) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } _, err := ReadToken("../../test_data/snowflake/session/token") if err == nil { t.Fatal("The error should occur because the file is writable by anyone but the owner") } driverErr, ok := err.(*sferrors.SnowflakeError) if !ok { t.Fatalf("This should be a Snowflake Error, got: %T", err) } if driverErr.Number != sferrors.ErrCodeInvalidFilePermission { t.Fatalf("Expected error code %d, got %d", sferrors.ErrCodeInvalidFilePermission, driverErr.Number) } }) t.Run("test executable connection file", func(t *testing.T) { err = os.Chmod("../../test_data/connections.toml", 0100) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } _, err := LoadConnectionConfig() if err == nil { t.Fatal("The error should occur because the file is executable") } driverErr, ok := err.(*sferrors.SnowflakeError) if !ok { t.Fatalf("This should be a Snowflake Error, got: %T", err) } if driverErr.Number != sferrors.ErrCodeInvalidFilePermission { t.Fatalf("Expected error code %d, got %d", sferrors.ErrCodeInvalidFilePermission, driverErr.Number) } }) t.Run("test executable token file", func(t *testing.T) { err = os.Chmod("../../test_data/snowflake/session/token", 0010) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } _, err := ReadToken("../../test_data/snowflake/session/token") if err == nil { t.Fatal("The error should occur because the file is executable") } driverErr, ok := err.(*sferrors.SnowflakeError) if !ok { t.Fatalf("This should be a Snowflake Error, got: %T", err) } if driverErr.Number != sferrors.ErrCodeInvalidFilePermission { t.Fatalf("Expected error code %d, got %d", sferrors.ErrCodeInvalidFilePermission, driverErr.Number) } }) t.Run("test valid file permission for connection config and token file", func(t *testing.T) { err = os.Chmod("../../test_data/connections.toml", 0600) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } err = os.Chmod("../../test_data/snowflake/session/token", 0600) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } _, err := LoadConnectionConfig() if err != nil { t.Fatalf("The error occurred because the permission is not 0600: %v", err) } _, err = ReadToken("../../test_data/snowflake/session/token") if err != nil { t.Fatalf("The error occurred because the permission is not 0600: %v", err) } }) } func TestLoadConnectionConfigForStandardAuth(t *testing.T) { err := os.Chmod("../../test_data/connections.toml", 0600) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } os.Setenv(snowflakeHome, "../../test_data") cfg, err := LoadConnectionConfig() if err != nil { t.Fatalf("The error should not occur: %v", err) } assertEqual(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqual(t, cfg.User, "test_default_user") assertEqual(t, cfg.Password, "test_default_pass") assertEqual(t, cfg.Warehouse, "testw_default") assertEqual(t, cfg.Database, "test_default_db") assertEqual(t, cfg.Schema, "test_default_go") assertEqual(t, cfg.Protocol, "https") if cfg.Port != 300 { t.Fatalf("Expected port 300, got %d", cfg.Port) } } func TestLoadConnectionConfigForOAuth(t *testing.T) { err := os.Chmod("../../test_data/connections.toml", 0600) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } os.Setenv(snowflakeHome, "../../test_data") os.Setenv(snowflakeConnectionName, "aws-oauth") cfg, err := LoadConnectionConfig() if err != nil { t.Fatalf("The error should not occur: %v", err) } assertEqual(t, cfg.Account, "snowdriverswarsaw.us-west-2.aws") assertEqual(t, cfg.User, "test_oauth_user") assertEqual(t, cfg.Password, "test_oauth_pass") assertEqual(t, cfg.Warehouse, "testw_oauth") assertEqual(t, cfg.Database, "test_oauth_db") assertEqual(t, cfg.Schema, "test_oauth_go") assertEqual(t, cfg.Protocol, "https") if cfg.Authenticator != AuthTypeOAuth { t.Fatalf("Expected authenticator %v, got %v", AuthTypeOAuth, cfg.Authenticator) } assertEqual(t, cfg.Token, "token_value") if cfg.Port != 443 { t.Fatalf("Expected port 443, got %d", cfg.Port) } if cfg.DisableOCSPChecks != true { t.Fatalf("Expected DisableOCSPChecks true, got %v", cfg.DisableOCSPChecks) } } func TestLoadConnectionConfigForSnakeCaseConfiguration(t *testing.T) { err := os.Chmod("../../test_data/connections.toml", 0600) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } os.Setenv(snowflakeHome, "../../test_data") os.Setenv(snowflakeConnectionName, "snake-case") cfg, err := LoadConnectionConfig() if err != nil { t.Fatalf("The error should not occur: %v", err) } if cfg.OCSPFailOpen != OCSPFailOpenTrue { t.Fatalf("Expected OCSPFailOpen %v, got %v", OCSPFailOpenTrue, cfg.OCSPFailOpen) } } func TestReadTokenValueWithTokenFilePath(t *testing.T) { err := os.Chmod("../../test_data/connections.toml", 0600) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } err = os.Chmod("../../test_data/snowflake/session/token", 0600) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } os.Setenv(snowflakeHome, "../../test_data") os.Setenv(snowflakeConnectionName, "read-token") cfg, err := LoadConnectionConfig() if err != nil { t.Fatalf("The error should not occur: %v", err) } if cfg.Authenticator != AuthTypeOAuth { t.Fatalf("Expected authenticator %v, got %v", AuthTypeOAuth, cfg.Authenticator) } // The token_file_path in the TOML is relative ("./test_data/snowflake/session/token"), // so GetToken resolves it relative to CWD. Use an absolute path instead. absTokenPath, err := path.Abs("../../test_data/snowflake/session/token") if err != nil { t.Fatalf("Failed to get absolute path: %v", err) } cfg.TokenFilePath = absTokenPath token, err := GetToken(cfg) if err != nil { t.Fatalf("Failed to get token: %v", err) } assertEqual(t, token, "mock_token123456") if cfg.DisableOCSPChecks != true { t.Fatalf("Expected DisableOCSPChecks true, got %v", cfg.DisableOCSPChecks) } } func TestLoadConnectionConfigWitNonExistingDSN(t *testing.T) { err := os.Chmod("../../test_data/connections.toml", 0600) if err != nil { t.Fatalf("The error occurred because you cannot change the file permission: %v", err) } os.Setenv(snowflakeHome, "../../test_data") os.Setenv(snowflakeConnectionName, "unavailableDSN") _, err = LoadConnectionConfig() if err == nil { t.Fatal("The error should occur") } driverErr, ok := err.(*sferrors.SnowflakeError) if !ok { t.Fatalf("This should be a Snowflake Error, got: %T", err) } if driverErr.Number != sferrors.ErrCodeFailedToFindDSNInToml { t.Fatalf("Expected error code %d, got %d", sferrors.ErrCodeFailedToFindDSNInToml, driverErr.Number) } } func TestParseInt(t *testing.T) { var i any i = 20 num, err := ParseInt(i) if err != nil { t.Fatalf("This value should be parsed: %v", err) } if num != 20 { t.Fatalf("Expected 20, got %d", num) } i = "40" num, err = ParseInt(i) if err != nil { t.Fatalf("This value should be parsed: %v", err) } if num != 40 { t.Fatalf("Expected 40, got %d", num) } i = "wrong_num" _, err = ParseInt(i) if err == nil { t.Fatal("should have failed") } } func TestParseBool(t *testing.T) { var i any i = true b, err := ParseBool(i) if err != nil { t.Fatalf("This value should be parsed: %v", err) } if b != true { t.Fatalf("Expected true, got %v", b) } i = "false" b, err = ParseBool(i) if err != nil { t.Fatalf("This value should be parsed: %v", err) } if b != false { t.Fatalf("Expected false, got %v", b) } i = "wrong_bool" _, err = ParseBool(i) if err == nil { t.Fatal("should have failed") } } func TestParseDuration(t *testing.T) { var i any i = 300 dur, err := ParseDuration(i) if err != nil { t.Fatalf("This value should be parsed: %v", err) } if dur != time.Duration(300*int64(time.Second)) { t.Fatalf("Expected %v, got %v", time.Duration(300*int64(time.Second)), dur) } i = "30" dur, err = ParseDuration(i) if err != nil { t.Fatalf("This value should be parsed: %v", err) } if dur != time.Duration(int64(time.Minute)/2) { t.Fatalf("Expected %v, got %v", time.Duration(int64(time.Minute)/2), dur) } i = false _, err = ParseDuration(i) if err == nil { t.Fatal("should have failed") } } type paramList struct { testParams []string values []any } func testGeneratePKCS8String(key *rsa.PrivateKey) string { tmpBytes, _ := x509.MarshalPKCS8PrivateKey(key) return base64.URLEncoding.EncodeToString(tmpBytes) } func TestParseToml(t *testing.T) { localTestKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate test private key: %s", err.Error()) } testCases := []paramList{ { testParams: []string{"user", "password", "host", "account", "warehouse", "database", "schema", "role", "region", "protocol", "passcode", "application", "token", "tracing", "tmpDirPath", "tmp_dir_path", "clientConfigFile", "client_config_file", "oauth_authorization_url", "oauth_client_id", "oauth_client_secret", "oauth_token_request_url", "oauth_redirect_uri", "oauth_scope", "workload_identity_provider", "workload_identity_entra_resource", "proxyHost", "noProxy", "proxyUser", "proxyPassword", "proxyProtocol"}, values: []any{"value"}, }, { testParams: []string{"privatekey", "private_key"}, values: []any{testGeneratePKCS8String(localTestKey)}, }, { testParams: []string{"port", "maxRetryCount", "max_retry_count", "clientTimeout", "client_timeout", "jwtClientTimeout", "jwt_client_timeout", "loginTimeout", "login_timeout", "requestTimeout", "request_timeout", "jwtTimeout", "jwt_timeout", "externalBrowserTimeout", "external_browser_timeout", "proxyPort"}, values: []any{"300", 500}, }, { testParams: []string{"ocspFailOpen", "ocsp_fail_open", "PasscodeInPassword", "passcode_in_password", "validateDEFAULTParameters", "validate_default_parameters", "clientRequestMFAtoken", "client_request_mfa_token", "clientStoreTemporaryCredential", "client_store_temporary_credential", "disableQueryContextCache", "disable_query_context_cache", "disable_ocsp_checks", "includeRetryReason", "include_retry_reason", "disableConsoleLogin", "disable_console_login", "disableSamlUrlCheck", "disable_saml_url_check"}, values: []any{true, "true", false, "false"}, }, { testParams: []string{"connectionDiagnosticsEnabled", "connection_diagnostics_enabled"}, values: []any{true, false}, }, { testParams: []string{"connectionDiagnosticsAllowlistFile", "connection_diagnostics_allowlist_file"}, values: []any{"myallowlist.json"}, }, } for _, testCase := range testCases { for _, param := range testCase.testParams { for _, value := range testCase.values { t.Run(param, func(t *testing.T) { cfg := &Config{} connectionMap := make(map[string]any) connectionMap[param] = value err := ParseToml(cfg, connectionMap) if err != nil { t.Fatalf("The value should be parsed: %v", err) } }) } } } } func TestParseTomlWithWrongValue(t *testing.T) { testCases := []paramList{ { testParams: []string{"user", "password", "host", "account", "warehouse", "database", "schema", "role", "region", "protocol", "passcode", "application", "token", "privateKey", "tracing", "tmpDirPath", "clientConfigFile", "wrongParams", "token_file_path", "proxyhost", "noproxy", "proxyUser", "proxyPassword", "proxyProtocol"}, values: []any{1, false}, }, { testParams: []string{"port", "maxRetryCount", "clientTimeout", "jwtClientTimeout", "loginTimeout", "requestTimeout", "jwtTimeout", "externalBrowserTimeout", "authenticator"}, values: []any{"wrong_value", false}, }, { testParams: []string{"ocspFailOpen", "PasscodeInPassword", "validateDEFAULTParameters", "clientRequestMFAtoken", "clientStoreTemporaryCredential", "disableQueryContextCache", "includeRetryReason", "disableConsoleLogin", "disableSamlUrlCheck"}, values: []any{"wrong_value", 1}, }, } for _, testCase := range testCases { for _, param := range testCase.testParams { for _, value := range testCase.values { t.Run(param, func(t *testing.T) { cfg := &Config{} connectionMap := make(map[string]any) connectionMap[param] = value err := ParseToml(cfg, connectionMap) if err == nil { t.Fatal("should have failed") } driverErr, ok := err.(*sferrors.SnowflakeError) if !ok { t.Fatalf("This should be a Snowflake Error, got: %T", err) } if driverErr.Number != sferrors.ErrCodeTomlFileParsingFailed { t.Fatalf("Expected error code %d, got %d", sferrors.ErrCodeTomlFileParsingFailed, driverErr.Number) } }) } } } } func TestGetTomlFilePath(t *testing.T) { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && os.Getenv("HOME") == "" { t.Skip("skipping on missing HOME environment variable") } dir, err := GetTomlFilePath("") if err != nil { t.Fatalf("should not have failed: %v", err) } homeDir, err := os.UserHomeDir() if err != nil { t.Fatalf("The connection cannot find the user home directory: %v", err) } assertEqual(t, dir, path.Join(homeDir, ".snowflake")) location := "../user//somelocation///b" dir, err = GetTomlFilePath(location) if err != nil { t.Fatalf("should not have failed: %v", err) } result, err := path.Abs(location) if err != nil { t.Fatalf("should not have failed: %v", err) } assertEqual(t, dir, result) //Absolute path for windows can be varied depend on which disk the driver is located. // As a result, this test is available on non-Window machines. if !(runtime.GOOS == "windows") { result = "/user/somelocation/b" location = "/user//somelocation///b" dir, err = GetTomlFilePath(location) if err != nil { t.Fatalf("should not have failed: %v", err) } assertEqual(t, dir, result) } } // assertEqual is a simple test helper for string comparison. func assertEqual[T comparable](t *testing.T, got, want T) { t.Helper() if got != want { t.Fatalf("got %v, want %v", got, want) } } ================================================ FILE: internal/config/crl_mode.go ================================================ package config import ( "fmt" "strings" ) // CertRevocationCheckMode defines the modes for certificate revocation checks. type CertRevocationCheckMode int const ( // CertRevocationCheckDisabled means that certificate revocation checks are disabled. CertRevocationCheckDisabled CertRevocationCheckMode = iota // CertRevocationCheckAdvisory means that certificate revocation checks are advisory, and the driver will not fail if the checks end with error (cannot verify revocation status). // Driver will fail only if a certicate is revoked. CertRevocationCheckAdvisory // CertRevocationCheckEnabled means that every certificate revocation check must pass, otherwise the driver will fail. CertRevocationCheckEnabled ) func (m CertRevocationCheckMode) String() string { switch m { case CertRevocationCheckDisabled: return "DISABLED" case CertRevocationCheckAdvisory: return "ADVISORY" case CertRevocationCheckEnabled: return "ENABLED" default: return fmt.Sprintf("unknown CertRevocationCheckMode: %d", m) } } // ParseCertRevocationCheckMode parses a string into a CertRevocationCheckMode. func ParseCertRevocationCheckMode(s string) (CertRevocationCheckMode, error) { switch strings.ToLower(s) { case "disabled": return CertRevocationCheckDisabled, nil case "advisory": return CertRevocationCheckAdvisory, nil case "enabled": return CertRevocationCheckEnabled, nil } return 0, fmt.Errorf("unknown CertRevocationCheckMode: %s", s) } ================================================ FILE: internal/config/dsn.go ================================================ package config import ( "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/pem" "errors" "fmt" "net/url" "os" "strconv" "strings" "time" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" loggerinternal "github.com/snowflakedb/gosnowflake/v2/internal/logger" ) var logger = loggerinternal.NewLoggerProxy() const ( // DefaultClientTimeout is the timeout for network round trip + read out http response DefaultClientTimeout = 900 * time.Second // DefaultJWTClientTimeout is the timeout for network round trip + read out http response but used for JWT auth DefaultJWTClientTimeout = 10 * time.Second // DefaultLoginTimeout is the timeout for retry for login EXCLUDING clientTimeout DefaultLoginTimeout = 300 * time.Second // DefaultRequestTimeout is the timeout for retry for request EXCLUDING clientTimeout DefaultRequestTimeout = 0 * time.Second // DefaultJWTTimeout is the timeout for JWT token expiration DefaultJWTTimeout = 60 * time.Second // DefaultExternalBrowserTimeout is the timeout for external browser login DefaultExternalBrowserTimeout = 120 * time.Second defaultCloudStorageTimeout = -1 // Timeout for calling cloud storage. defaultMaxRetryCount = 7 // specifies maximum number of subsequent retries // DefaultDomain is the default domain for Snowflake accounts DefaultDomain = ".snowflakecomputing.com" // CnDomain is the domain for Snowflake accounts in China CnDomain = ".snowflakecomputing.cn" topLevelDomainPrefix = ".snowflakecomputing." // used to extract the domain from host ) const clientType = "Go" // GetFromEnv retrieves the value of an environment variable. // If failOnMissing is true and the variable is not set, an error is returned. func GetFromEnv(name string, failOnMissing bool) (string, error) { if value := os.Getenv(name); value != "" { return value, nil } if failOnMissing { return "", fmt.Errorf("%v environment variable is not set", name) } return "", nil } // DSN constructs a DSN for Snowflake db. func DSN(cfg *Config) (dsn string, err error) { if strings.ToLower(cfg.Region) == "us-west-2" { cfg.Region = "" } // in case account includes region region, posDot := extractRegionFromAccount(cfg.Account) if strings.ToLower(region) == "us-west-2" { region = "" cfg.Account = cfg.Account[:posDot] logger.Info("Ignoring default region .us-west-2 in DSN from Account configuration.") } if region != "" { if cfg.Region != "" { return "", sferrors.ErrRegionConflict() } cfg.Region = region cfg.Account = cfg.Account[:posDot] } hasHost := true if cfg.Host == "" { hasHost = false if cfg.Region == "" { cfg.Host = cfg.Account + DefaultDomain } else { cfg.Host = buildHostFromAccountAndRegion(cfg.Account, cfg.Region) } } err = FillMissingConfigParameters(cfg) if err != nil { return "", err } params := &url.Values{} if hasHost && cfg.Account != "" { // account may not be included in a Host string params.Add("account", cfg.Account) } if cfg.Database != "" { params.Add("database", cfg.Database) } if cfg.Schema != "" { params.Add("schema", cfg.Schema) } if cfg.Warehouse != "" { params.Add("warehouse", cfg.Warehouse) } if cfg.Role != "" { params.Add("role", cfg.Role) } if cfg.Region != "" { params.Add("region", cfg.Region) } if cfg.OauthClientID != "" { params.Add("oauthClientId", cfg.OauthClientID) } if cfg.OauthClientSecret != "" { params.Add("oauthClientSecret", cfg.OauthClientSecret) } if cfg.OauthAuthorizationURL != "" { params.Add("oauthAuthorizationUrl", cfg.OauthAuthorizationURL) } if cfg.OauthTokenRequestURL != "" { params.Add("oauthTokenRequestUrl", cfg.OauthTokenRequestURL) } if cfg.OauthRedirectURI != "" { params.Add("oauthRedirectUri", cfg.OauthRedirectURI) } if cfg.OauthScope != "" { params.Add("oauthScope", cfg.OauthScope) } if cfg.EnableSingleUseRefreshTokens { params.Add("enableSingleUseRefreshTokens", strconv.FormatBool(cfg.EnableSingleUseRefreshTokens)) } if cfg.WorkloadIdentityProvider != "" { params.Add("workloadIdentityProvider", cfg.WorkloadIdentityProvider) } if cfg.WorkloadIdentityEntraResource != "" { params.Add("workloadIdentityEntraResource", cfg.WorkloadIdentityEntraResource) } if len(cfg.WorkloadIdentityImpersonationPath) > 0 { params.Add("workloadIdentityImpersonationPath", strings.Join(cfg.WorkloadIdentityImpersonationPath, ",")) } if cfg.Authenticator != AuthTypeSnowflake { if cfg.Authenticator == AuthTypeOkta { params.Add("authenticator", strings.ToLower(cfg.OktaURL.String())) } else { params.Add("authenticator", strings.ToLower(cfg.Authenticator.String())) } } if cfg.SingleAuthenticationPrompt != BoolNotSet { if cfg.SingleAuthenticationPrompt == BoolTrue { params.Add("singleAuthenticationPrompt", "true") } else { params.Add("singleAuthenticationPrompt", "false") } } if cfg.Passcode != "" { params.Add("passcode", cfg.Passcode) } if cfg.PasscodeInPassword { params.Add("passcodeInPassword", strconv.FormatBool(cfg.PasscodeInPassword)) } if cfg.ClientTimeout != DefaultClientTimeout { params.Add("clientTimeout", strconv.FormatInt(int64(cfg.ClientTimeout/time.Second), 10)) } if cfg.JWTClientTimeout != DefaultJWTClientTimeout { params.Add("jwtClientTimeout", strconv.FormatInt(int64(cfg.JWTClientTimeout/time.Second), 10)) } if cfg.LoginTimeout != DefaultLoginTimeout { params.Add("loginTimeout", strconv.FormatInt(int64(cfg.LoginTimeout/time.Second), 10)) } if cfg.RequestTimeout != DefaultRequestTimeout { params.Add("requestTimeout", strconv.FormatInt(int64(cfg.RequestTimeout/time.Second), 10)) } if cfg.JWTExpireTimeout != DefaultJWTTimeout { params.Add("jwtTimeout", strconv.FormatInt(int64(cfg.JWTExpireTimeout/time.Second), 10)) } if cfg.ExternalBrowserTimeout != DefaultExternalBrowserTimeout { params.Add("externalBrowserTimeout", strconv.FormatInt(int64(cfg.ExternalBrowserTimeout/time.Second), 10)) } if cfg.CloudStorageTimeout != defaultCloudStorageTimeout { params.Add("cloudStorageTimeout", strconv.FormatInt(int64(cfg.CloudStorageTimeout/time.Second), 10)) } if cfg.MaxRetryCount != defaultMaxRetryCount { params.Add("maxRetryCount", strconv.Itoa(cfg.MaxRetryCount)) } if cfg.Application != clientType { params.Add("application", cfg.Application) } if cfg.Protocol != "" && cfg.Protocol != "https" { params.Add("protocol", cfg.Protocol) } if cfg.Token != "" { params.Add("token", cfg.Token) } if cfg.TokenFilePath != "" { params.Add("tokenFilePath", cfg.TokenFilePath) } if cfg.CertRevocationCheckMode != CertRevocationCheckDisabled { params.Add("certRevocationCheckMode", cfg.CertRevocationCheckMode.String()) } if cfg.CrlAllowCertificatesWithoutCrlURL == BoolTrue { params.Add("crlAllowCertificatesWithoutCrlURL", "true") } if cfg.CrlInMemoryCacheDisabled { params.Add("crlInMemoryCacheDisabled", "true") } if cfg.CrlOnDiskCacheDisabled { params.Add("crlOnDiskCacheDisabled", "true") } if cfg.CrlDownloadMaxSize != 0 { params.Add("crlDownloadMaxSize", strconv.Itoa(cfg.CrlDownloadMaxSize)) } if cfg.CrlHTTPClientTimeout != 0 { params.Add("crlHttpClientTimeout", strconv.FormatInt(int64(cfg.CrlHTTPClientTimeout/time.Second), 10)) } if cfg.Params != nil { for k, v := range cfg.Params { params.Add(k, *v) } } if cfg.PrivateKey != nil { privateKeyInBytes, err := MarshalPKCS8PrivateKey(cfg.PrivateKey) if err != nil { return "", err } keyBase64 := base64.URLEncoding.EncodeToString(privateKeyInBytes) params.Add("privateKey", keyBase64) } if cfg.DisableOCSPChecks { params.Add("disableOCSPChecks", strconv.FormatBool(cfg.DisableOCSPChecks)) } if cfg.Tracing != "" { params.Add("tracing", cfg.Tracing) } if cfg.LogQueryText { params.Add("logQueryText", strconv.FormatBool(cfg.LogQueryText)) } if cfg.LogQueryParameters { params.Add("logQueryParameters", strconv.FormatBool(cfg.LogQueryParameters)) } if cfg.TmpDirPath != "" { params.Add("tmpDirPath", cfg.TmpDirPath) } if cfg.DisableQueryContextCache { params.Add("disableQueryContextCache", "true") } if cfg.IncludeRetryReason == BoolFalse { params.Add("includeRetryReason", "false") } if cfg.ServerSessionKeepAlive { params.Add("serverSessionKeepAlive", "true") } params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse)) params.Add("validateDefaultParameters", strconv.FormatBool(cfg.ValidateDefaultParameters != BoolFalse)) if cfg.ClientRequestMfaToken != BoolNotSet { params.Add("clientRequestMfaToken", strconv.FormatBool(cfg.ClientRequestMfaToken != BoolFalse)) } if cfg.ClientStoreTemporaryCredential != BoolNotSet { params.Add("clientStoreTemporaryCredential", strconv.FormatBool(cfg.ClientStoreTemporaryCredential != BoolFalse)) } if cfg.ClientConfigFile != "" { params.Add("clientConfigFile", cfg.ClientConfigFile) } if cfg.DisableConsoleLogin != BoolNotSet { params.Add("disableConsoleLogin", strconv.FormatBool(cfg.DisableConsoleLogin != BoolFalse)) } if cfg.DisableSamlURLCheck != BoolNotSet { params.Add("disableSamlURLCheck", strconv.FormatBool(cfg.DisableSamlURLCheck != BoolFalse)) } if cfg.ConnectionDiagnosticsEnabled { params.Add("connectionDiagnosticsEnabled", strconv.FormatBool(cfg.ConnectionDiagnosticsEnabled)) } if cfg.ConnectionDiagnosticsAllowlistFile != "" { params.Add("connectionDiagnosticsAllowlistFile", cfg.ConnectionDiagnosticsAllowlistFile) } if cfg.TLSConfigName != "" { params.Add("tlsConfigName", cfg.TLSConfigName) } if cfg.ProxyHost != "" { params.Add("proxyHost", cfg.ProxyHost) } if cfg.ProxyPort != 0 { params.Add("proxyPort", strconv.Itoa(cfg.ProxyPort)) } if cfg.ProxyProtocol != "" { params.Add("proxyProtocol", cfg.ProxyProtocol) } if cfg.ProxyUser != "" { params.Add("proxyUser", cfg.ProxyUser) } if cfg.ProxyPassword != "" { params.Add("proxyPassword", cfg.ProxyPassword) } if cfg.NoProxy != "" { params.Add("noProxy", cfg.NoProxy) } dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port) if params.Encode() != "" { dsn += "?" + params.Encode() } return } // ParseDSN parses the DSN string to a Config. func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values cfg = &Config{ Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Default to snowflake } // user[:password]@account/database/schema[?param1=value1¶mN=valueN] // or // user[:password]@account/database[?param1=value1¶mN=valueN] // or // user[:password]@host:port/database/schema?account=user_account[?param1=value1¶mN=valueN] // or // host:port/database/schema?account=user_account[?param1=value1¶mN=valueN] foundSlash := false secondSlash := false done := false var i int posQuestion := len(dsn) for i = len(dsn) - 1; i >= 0; i-- { switch dsn[i] { case '/': foundSlash = true // left part is empty if i <= 0 var j int posSecondSlash := i if i > 0 { for j = i - 1; j >= 0; j-- { switch dsn[j] { case '/': // second slash secondSlash = true posSecondSlash = j case '@': // username[:password]@... cfg.User, cfg.Password = parseUserPassword(j, dsn) } if dsn[j] == '@' { break } } // account or host:port err = parseAccountHostPort(cfg, j, posSecondSlash, dsn) if err != nil { return nil, err } } // [?param1=value1&...¶mN=valueN] // Find the first '?' in dsn[i+1:] err = parseParams(cfg, i, dsn) if err != nil { return } if secondSlash { cfg.Database = dsn[posSecondSlash+1 : i] cfg.Schema = dsn[i+1 : posQuestion] } else { cfg.Database = dsn[posSecondSlash+1 : posQuestion] } done = true case '?': posQuestion = i } if done { break } } if !foundSlash { // no db or schema is specified var j int for j = len(dsn) - 1; j >= 0; j-- { switch dsn[j] { case '@': cfg.User, cfg.Password = parseUserPassword(j, dsn) case '?': posQuestion = j } if dsn[j] == '@' { break } } err = parseAccountHostPort(cfg, j, posQuestion, dsn) if err != nil { return nil, err } err = parseParams(cfg, posQuestion-1, dsn) if err != nil { return } } if posDot := strings.Index(cfg.Account, "."); posDot >= 0 { cfg.Account = cfg.Account[:posDot] } err = FillMissingConfigParameters(cfg) if err != nil { return nil, err } // unescape parameters var s string s, err = url.QueryUnescape(cfg.User) if err != nil { return nil, err } cfg.User = s s, err = url.QueryUnescape(cfg.Password) if err != nil { return nil, err } cfg.Password = s s, err = url.QueryUnescape(cfg.Database) if err != nil { return nil, err } cfg.Database = s s, err = url.QueryUnescape(cfg.Schema) if err != nil { return nil, err } cfg.Schema = s s, err = url.QueryUnescape(cfg.Role) if err != nil { return nil, err } cfg.Role = s s, err = url.QueryUnescape(cfg.Warehouse) if err != nil { return nil, err } cfg.Warehouse = s return cfg, nil } // applyAccountFromHostIfMissing sets Account to the first DNS label of Host when Account is empty // and Host matches the Snowflake hostname heuristic (hostIncludesTopLevelDomain). FillMissingConfigParameters // invokes this so programmatic Config (e.g. database/sql.Connector) matches behavior that DSN users // already got via ParseDSN plus FillMissingConfigParameters. ParseDSN still truncates dotted account // values from parameters before FillMissingConfigParameters; that step does not apply to non-empty // Account set directly on a programmatic Config. func applyAccountFromHostIfMissing(cfg *Config) { if strings.TrimSpace(cfg.Account) != "" { return } if !hostIncludesTopLevelDomain(cfg.Host) { return } posDot := strings.Index(cfg.Host, ".") if posDot <= 0 { return } cfg.Account = cfg.Host[:posDot] } // FillMissingConfigParameters fills in default values for missing config parameters. func FillMissingConfigParameters(cfg *Config) error { applyAccountFromHostIfMissing(cfg) posDash := strings.LastIndex(cfg.Account, "-") if posDash > 0 { if strings.Contains(strings.ToLower(cfg.Host), ".global.") { cfg.Account = cfg.Account[:posDash] } } if strings.Trim(cfg.Account, " ") == "" { return sferrors.ErrEmptyAccount() } if authRequiresUser(cfg) && strings.TrimSpace(cfg.User) == "" { return sferrors.ErrEmptyUsername() } if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" { return sferrors.ErrEmptyPassword() } if authRequiresEitherPasswordOrToken(cfg) && strings.TrimSpace(cfg.Password) == "" && strings.TrimSpace(cfg.Token) == "" { return sferrors.ErrEmptyPasswordAndToken() } if authRequiresClientIDAndSecret(cfg) && (strings.TrimSpace(cfg.OauthClientID) == "" || strings.TrimSpace(cfg.OauthClientSecret) == "") { return sferrors.ErrEmptyOAuthParameters() } if strings.Trim(cfg.Protocol, " ") == "" { cfg.Protocol = "https" } if cfg.Port == 0 { cfg.Port = 443 } cfg.Region = strings.Trim(cfg.Region, " ") if cfg.Region != "" { // region is specified but not included in Host domain, i := extractDomainFromHost(cfg.Host) if i >= 1 { hostPrefix := cfg.Host[0:i] if !strings.HasSuffix(hostPrefix, cfg.Region) { cfg.Host = fmt.Sprintf("%v.%v%v", hostPrefix, cfg.Region, domain) } } } if cfg.Host == "" { if cfg.Region != "" { cfg.Host = cfg.Account + "." + cfg.Region + getDomainBasedOnRegion(cfg.Region) } else { region, _ := extractRegionFromAccount(cfg.Account) if region != "" { cfg.Host = cfg.Account + getDomainBasedOnRegion(region) } else { cfg.Host = cfg.Account + DefaultDomain } } } if cfg.LoginTimeout == 0 { cfg.LoginTimeout = DefaultLoginTimeout } if cfg.RequestTimeout == 0 { cfg.RequestTimeout = DefaultRequestTimeout } if cfg.JWTExpireTimeout == 0 { cfg.JWTExpireTimeout = DefaultJWTTimeout } if cfg.ClientTimeout == 0 { cfg.ClientTimeout = DefaultClientTimeout } if cfg.JWTClientTimeout == 0 { cfg.JWTClientTimeout = DefaultJWTClientTimeout } if cfg.ExternalBrowserTimeout == 0 { cfg.ExternalBrowserTimeout = DefaultExternalBrowserTimeout } if cfg.CloudStorageTimeout == 0 { cfg.CloudStorageTimeout = defaultCloudStorageTimeout } if cfg.MaxRetryCount == 0 { cfg.MaxRetryCount = defaultMaxRetryCount } if strings.Trim(cfg.Application, " ") == "" { cfg.Application = clientType } if cfg.OCSPFailOpen == OCSPFailOpenNotSet { cfg.OCSPFailOpen = OCSPFailOpenTrue } if cfg.ValidateDefaultParameters == BoolNotSet { cfg.ValidateDefaultParameters = BoolTrue } if cfg.IncludeRetryReason == BoolNotSet { cfg.IncludeRetryReason = BoolTrue } if cfg.ProxyHost != "" && cfg.ProxyProtocol == "" { cfg.ProxyProtocol = "http" // Default to http if not specified } domain, _ := extractDomainFromHost(cfg.Host) if len(cfg.Host) == len(domain) { return &sferrors.SnowflakeError{ Number: sferrors.ErrCodeFailedToParseHost, Message: sferrors.ErrMsgFailedToParseHost, MessageArgs: []any{cfg.Host}, } } if cfg.TLSConfigName != "" { if _, ok := GetTLSConfig(cfg.TLSConfigName); !ok { return &sferrors.SnowflakeError{ Number: sferrors.ErrCodeMissingTLSConfig, Message: fmt.Sprintf(sferrors.ErrMsgMissingTLSConfig, cfg.TLSConfigName), } } } return nil } func extractDomainFromHost(host string) (domain string, index int) { i := strings.LastIndex(strings.ToLower(host), topLevelDomainPrefix) if i >= 1 { domain = host[i:] return domain, i } return "", i } func getDomainBasedOnRegion(region string) string { if strings.HasPrefix(strings.ToLower(region), "cn-") { return CnDomain } return DefaultDomain } func extractRegionFromAccount(account string) (region string, posDot int) { posDot = strings.Index(strings.ToLower(account), ".") if posDot > 0 { return account[posDot+1:], posDot } return "", posDot } func hostIncludesTopLevelDomain(host string) bool { return strings.Contains(strings.ToLower(host), topLevelDomainPrefix) } func buildHostFromAccountAndRegion(account, region string) string { return account + "." + region + getDomainBasedOnRegion(region) } func authRequiresUser(cfg *Config) bool { return cfg.Authenticator != AuthTypeOAuth && cfg.Authenticator != AuthTypeTokenAccessor && cfg.Authenticator != AuthTypeExternalBrowser && cfg.Authenticator != AuthTypePat && cfg.Authenticator != AuthTypeOAuthAuthorizationCode && cfg.Authenticator != AuthTypeOAuthClientCredentials && cfg.Authenticator != AuthTypeWorkloadIdentityFederation } func authRequiresPassword(cfg *Config) bool { return cfg.Authenticator != AuthTypeOAuth && cfg.Authenticator != AuthTypeTokenAccessor && cfg.Authenticator != AuthTypeExternalBrowser && cfg.Authenticator != AuthTypeJwt && cfg.Authenticator != AuthTypePat && cfg.Authenticator != AuthTypeOAuthAuthorizationCode && cfg.Authenticator != AuthTypeOAuthClientCredentials && cfg.Authenticator != AuthTypeWorkloadIdentityFederation } func authRequiresEitherPasswordOrToken(cfg *Config) bool { return cfg.Authenticator == AuthTypePat } func authRequiresClientIDAndSecret(cfg *Config) bool { return cfg.Authenticator == AuthTypeOAuthAuthorizationCode } // transformAccountToHost transforms account to host func transformAccountToHost(cfg *Config) (err error) { if cfg.Port == 0 && cfg.Host != "" && !hostIncludesTopLevelDomain(cfg.Host) { // account name is specified instead of host:port cfg.Account = cfg.Host region, posDot := extractRegionFromAccount(cfg.Account) if strings.ToLower(region) == "us-west-2" { region = "" cfg.Account = cfg.Account[:posDot] logger.Info("Ignoring default region .us-west-2 from Account configuration.") } if region != "" { cfg.Region = region cfg.Account = cfg.Account[:posDot] cfg.Host = buildHostFromAccountAndRegion(cfg.Account, cfg.Region) } else { cfg.Host = cfg.Account + DefaultDomain } cfg.Port = 443 } return nil } // parseAccountHostPort parses the DSN string to attempt to get account or host and port. func parseAccountHostPort(cfg *Config, posAt, posSlash int, dsn string) (err error) { // account or host:port var k int for k = posAt + 1; k < posSlash; k++ { if dsn[k] == ':' { cfg.Port, err = strconv.Atoi(dsn[k+1 : posSlash]) if err != nil { err = &sferrors.SnowflakeError{ Number: sferrors.ErrCodeFailedToParsePort, Message: sferrors.ErrMsgFailedToParsePort, MessageArgs: []any{dsn[k+1 : posSlash]}, } return } break } } cfg.Host = dsn[posAt+1 : k] return transformAccountToHost(cfg) } // parseUserPassword parses the DSN string for username and password func parseUserPassword(posAt int, dsn string) (user, password string) { var k int for k = 0; k < posAt; k++ { if dsn[k] == ':' { password = dsn[k+1 : posAt] break } } user = dsn[:k] return } // parseParams parse parameters func parseParams(cfg *Config, posQuestion int, dsn string) (err error) { for j := posQuestion + 1; j < len(dsn); j++ { if dsn[j] == '?' { if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { return } break } } return } // parseDSNParams parses the DSN "query string". Values must be url.QueryEscape'ed func parseDSNParams(cfg *Config, params string) (err error) { logger.Infof("Query String: %v\n", params) paramsSlice := strings.SplitSeq(params, "&") for v := range paramsSlice { param := strings.SplitN(v, "=", 2) if len(param) != 2 { continue } var value string value, err = url.QueryUnescape(param[1]) if err != nil { return err } switch param[0] { // Disable INFILE whitelist / enable all files case "account": cfg.Account = value case "warehouse": cfg.Warehouse = value case "database": cfg.Database = value case "schema": cfg.Schema = value case "role": cfg.Role = value case "region": cfg.Region = value case "protocol": cfg.Protocol = value case "singleAuthenticationPrompt": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.SingleAuthenticationPrompt = BoolTrue } else { cfg.SingleAuthenticationPrompt = BoolFalse } case "passcode": cfg.Passcode = value case "oauthClientId": cfg.OauthClientID = value case "oauthClientSecret": cfg.OauthClientSecret = value case "oauthAuthorizationUrl": cfg.OauthAuthorizationURL = value case "oauthTokenRequestUrl": cfg.OauthTokenRequestURL = value case "oauthRedirectUri": cfg.OauthRedirectURI = value case "oauthScope": cfg.OauthScope = value case "enableSingleUseRefreshTokens": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } cfg.EnableSingleUseRefreshTokens = vv case "passcodeInPassword": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } cfg.PasscodeInPassword = vv case "clientTimeout": cfg.ClientTimeout, err = parseTimeout(value) if err != nil { return } case "jwtClientTimeout": cfg.JWTClientTimeout, err = parseTimeout(value) if err != nil { return } case "loginTimeout": cfg.LoginTimeout, err = parseTimeout(value) if err != nil { return } case "requestTimeout": cfg.RequestTimeout, err = parseTimeout(value) if err != nil { return } case "jwtTimeout": cfg.JWTExpireTimeout, err = parseTimeout(value) if err != nil { return err } case "externalBrowserTimeout": cfg.ExternalBrowserTimeout, err = parseTimeout(value) if err != nil { return err } case "cloudStorageTimeout": cfg.CloudStorageTimeout, err = parseTimeout(value) if err != nil { return err } case "maxRetryCount": cfg.MaxRetryCount, err = strconv.Atoi(value) if err != nil { return err } case "serverSessionKeepAlive": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } cfg.ServerSessionKeepAlive = vv case "application": cfg.Application = value case "authenticator": err := DetermineAuthenticatorType(cfg, value) if err != nil { return err } case "disableOCSPChecks": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } cfg.DisableOCSPChecks = vv case "ocspFailOpen": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.OCSPFailOpen = OCSPFailOpenTrue } else { cfg.OCSPFailOpen = OCSPFailOpenFalse } case "token": cfg.Token = value case "tokenFilePath": cfg.TokenFilePath = value case "tlsConfigName": cfg.TLSConfigName = value case "workloadIdentityProvider": cfg.WorkloadIdentityProvider = value case "workloadIdentityEntraResource": cfg.WorkloadIdentityEntraResource = value case "workloadIdentityImpersonationPath": cfg.WorkloadIdentityImpersonationPath = strings.Split(value, ",") case "privateKey": var decodeErr error block, decodeErr := base64.URLEncoding.DecodeString(value) if decodeErr != nil { err = &sferrors.SnowflakeError{ Number: sferrors.ErrCodePrivateKeyParseError, Message: "Base64 decode failed", } return } cfg.PrivateKey, err = ParsePKCS8PrivateKey(block) if err != nil { return err } case "validateDefaultParameters": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.ValidateDefaultParameters = BoolTrue } else { cfg.ValidateDefaultParameters = BoolFalse } case "clientRequestMfaToken": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.ClientRequestMfaToken = BoolTrue } else { cfg.ClientRequestMfaToken = BoolFalse } case "clientStoreTemporaryCredential": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.ClientStoreTemporaryCredential = BoolTrue } else { cfg.ClientStoreTemporaryCredential = BoolFalse } case "tracing": cfg.Tracing = value case "logQueryText": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } cfg.LogQueryText = vv case "logQueryParameters": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } cfg.LogQueryParameters = vv case "tmpDirPath": cfg.TmpDirPath = value case "disableQueryContextCache": var b bool b, err = strconv.ParseBool(value) if err != nil { return } cfg.DisableQueryContextCache = b case "includeRetryReason": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.IncludeRetryReason = BoolTrue } else { cfg.IncludeRetryReason = BoolFalse } case "clientConfigFile": cfg.ClientConfigFile = value case "disableConsoleLogin": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.DisableConsoleLogin = BoolTrue } else { cfg.DisableConsoleLogin = BoolFalse } case "disableSamlURLCheck": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.DisableSamlURLCheck = BoolTrue } else { cfg.DisableSamlURLCheck = BoolFalse } case "certRevocationCheckMode": var certRevocationCheckMode CertRevocationCheckMode certRevocationCheckMode, err = ParseCertRevocationCheckMode(value) if err != nil { return } cfg.CertRevocationCheckMode = certRevocationCheckMode case "crlAllowCertificatesWithoutCrlURL": var vv bool vv, err = strconv.ParseBool(value) if vv { cfg.CrlAllowCertificatesWithoutCrlURL = BoolTrue } else { cfg.CrlAllowCertificatesWithoutCrlURL = BoolFalse } case "crlInMemoryCacheDisabled": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.CrlInMemoryCacheDisabled = true } else { cfg.CrlInMemoryCacheDisabled = false } case "crlOnDiskCacheDisabled": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } if vv { cfg.CrlOnDiskCacheDisabled = true } else { cfg.CrlOnDiskCacheDisabled = false } case "crlDownloadMaxSize": cfg.CrlDownloadMaxSize, err = strconv.Atoi(value) if err != nil { return } case "crlHttpClientTimeout": var vv int64 vv, err = strconv.ParseInt(value, 10, 64) if err != nil { return } cfg.CrlHTTPClientTimeout = time.Duration(vv * int64(time.Second)) case "connectionDiagnosticsEnabled": var vv bool vv, err = strconv.ParseBool(value) if err != nil { return } cfg.ConnectionDiagnosticsEnabled = vv case "connectionDiagnosticsAllowlistFile": cfg.ConnectionDiagnosticsAllowlistFile = value case "proxyHost": cfg.ProxyHost, err = parseString(value) case "proxyPort": cfg.ProxyPort, err = ParseInt(value) case "proxyUser": cfg.ProxyUser, err = parseString(value) case "proxyPassword": cfg.ProxyPassword, err = parseString(value) case "noProxy": cfg.NoProxy, err = parseString(value) case "proxyProtocol": cfg.ProxyProtocol, err = parseString(value) default: if cfg.Params == nil { cfg.Params = make(map[string]*string) } // handle session variables $variable=value cfg.Params[urlDecodeIfNeeded(param[0])] = &value } } return } func parseTimeout(value string) (time.Duration, error) { var vv int64 var err error vv, err = strconv.ParseInt(value, 10, 64) if err != nil { return time.Duration(0), err } return time.Duration(vv * int64(time.Second)), nil } // GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config func GetConfigFromEnv(properties []*Param) (*Config, error) { var account, user, password, token, tokenFilePath, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string var oauthClientID, oauthClientSecret, oauthAuthorizationURL, oauthTokenRequestURL, oauthRedirectURI, oauthScope string var privateKey *rsa.PrivateKey var err error if len(properties) == 0 || properties == nil { return nil, errors.New("missing configuration parameters for the connection") } for _, prop := range properties { value, err := GetFromEnv(prop.EnvName, prop.FailOnMissing) if err != nil { return nil, err } switch prop.Name { case "Account": account = value case "User": user = value case "Password": password = value case "Token": token = value case "TokenFilePath": tokenFilePath = value case "Role": role = value case "Host": host = value case "Port": portStr = value case "Protocol": protocol = value case "Warehouse": warehouse = value case "Database": database = value case "Region": region = value case "Passcode": passcode = value case "Schema": schema = value case "Application": application = value case "PrivateKey": privateKey, err = parsePrivateKeyFromFile(value) if err != nil { return nil, err } case "OAuthClientId": oauthClientID = value case "OAuthClientSecret": oauthClientSecret = value case "OAuthAuthorizationURL": oauthAuthorizationURL = value case "OAuthTokenRequestURL": oauthTokenRequestURL = value case "OAuthRedirectURI": oauthRedirectURI = value case "OAuthScope": oauthScope = value default: return nil, errors.New("unknown property: " + prop.Name) } } port := 443 // snowflake default port if len(portStr) > 0 { port, err = strconv.Atoi(portStr) if err != nil { return nil, err } } cfg := &Config{ Account: account, User: user, Password: password, Token: token, TokenFilePath: tokenFilePath, Role: role, Host: host, Port: port, Protocol: protocol, Warehouse: warehouse, Database: database, Schema: schema, PrivateKey: privateKey, Region: region, Passcode: passcode, Application: application, OauthClientID: oauthClientID, OauthClientSecret: oauthClientSecret, OauthAuthorizationURL: oauthAuthorizationURL, OauthTokenRequestURL: oauthTokenRequestURL, OauthRedirectURI: oauthRedirectURI, OauthScope: oauthScope, Params: map[string]*string{}, } return cfg, nil } func parsePrivateKeyFromFile(path string) (*rsa.PrivateKey, error) { bytes, err := os.ReadFile(path) if err != nil { return nil, err } block, _ := pem.Decode(bytes) if block == nil { return nil, errors.New("failed to parse PEM block containing the private key") } privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return nil, err } pk, ok := privateKey.(*rsa.PrivateKey) if !ok { return nil, fmt.Errorf("interface convertion. expected type *rsa.PrivateKey, but got %T", privateKey) } return pk, nil } // ExtractAccountName extract an account name from a raw account. func ExtractAccountName(rawAccount string) string { posDot := strings.Index(rawAccount, ".") if posDot > 0 { return strings.ToUpper(rawAccount[:posDot]) } return strings.ToUpper(rawAccount) } func urlDecodeIfNeeded(param string) (decodedParam string) { unescaped, err := url.QueryUnescape(param) if err != nil { return param } return unescaped } // GetToken retrieves the token from the Config, reading from file if TokenFilePath is set. func GetToken(c *Config) (string, error) { if c.TokenFilePath != "" { return ReadToken(c.TokenFilePath) } return c.Token, nil } // DescribeIdentityAttributes returns a string describing the identity attributes of the Config. func DescribeIdentityAttributes(c *Config) string { return fmt.Sprintf("host: %v, account: %v, user: %v, password existed: %v, role: %v, database: %v, schema: %v, warehouse: %v, %v", c.Host, c.Account, c.User, (c.Password != ""), c.Role, c.Database, c.Schema, c.Warehouse, DescribeProxy(c)) } // DescribeProxy returns a string describing the proxy configuration. func DescribeProxy(c *Config) string { if c.ProxyHost != "" { return fmt.Sprintf("proxyHost: %v, proxyPort: %v proxyUser: %v, proxyPassword %v, proxyProtocol: %v, noProxy: %v", c.ProxyHost, c.ProxyPort, c.ProxyUser, c.ProxyPassword != "", c.ProxyProtocol, c.NoProxy) } return "proxy was not configured" } ================================================ FILE: internal/config/dsn_test.go ================================================ package config import ( "crypto/ecdsa" "crypto/elliptic" cr "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/base64" "encoding/pem" "errors" "fmt" "net/url" "os" "reflect" "strconv" "strings" "testing" "time" "github.com/aws/smithy-go/rand" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" ) type tcParseDSN struct { dsn string config *Config ocspMode string err error } func TestParseDSN(t *testing.T) { testPrivKey, _ := rsa.GenerateKey(cr.Reader, 2048) privKeyPKCS8 := generatePKCS8StringSupress(testPrivKey) privKeyPKCS1 := generatePKCS1String(testPrivKey) testcases := []tcParseDSN{ { dsn: "user:pass@ac-1-laksdnflaf.global/db/schema", config: &Config{ Account: "ac-1", User: "user", Password: "pass", Region: "global", Protocol: "https", Host: "ac-1-laksdnflaf.global.snowflakecomputing.com", Port: 443, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@ac-laksdnflaf.global/db/schema", config: &Config{ Account: "ac", User: "user", Password: "pass", Region: "global", Protocol: "https", Host: "ac-laksdnflaf.global.snowflakecomputing.com", Port: 443, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@asnowflakecomputing.com/db/pa?account=a&protocol=https&role=r&timezone=UTC&aehouse=w", config: &Config{Account: "a", User: "u", Password: "p", Database: "db", Schema: "pa", Protocol: "https", Role: "r", Host: "asnowflakecomputing.com.snowflakecomputing.com", Port: 443, Region: "com", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@/db?account=ac", config: &Config{ Account: "ac", User: "u", Password: "p", Database: "db", Protocol: "https", Host: "ac.snowflakecomputing.com", Port: 443, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@/db?account=ac&workloadIdentityEntraResource=https%3A%2F%2Fexample.com%2F.default&workloadIdentityProvider=azure&workloadIdentityImpersonationPath=%2Fdefault,%2Fdefault2", config: &Config{ Account: "ac", User: "u", Password: "p", Database: "db", Protocol: "https", Host: "ac.snowflakecomputing.com", Port: 443, WorkloadIdentityProvider: "azure", WorkloadIdentityEntraResource: "https://example.com/.default", WorkloadIdentityImpersonationPath: []string{"/default", "/default2"}, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@/db?account=ac®ion=cn-region", config: &Config{ Account: "ac", User: "u", Password: "p", Database: "db", Region: "cn-region", Protocol: "https", Host: "ac.cn-region.snowflakecomputing.cn", Port: 443, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account-hfdw89q748ew9gqf48w9qgf.global/db/s", config: &Config{ Account: "account", User: "user", Password: "pass", Region: "global", Protocol: "https", Host: "account-hfdw89q748ew9gqf48w9qgf.global.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account-hfdw89q748ew9gqf48w9qgf/db/s", config: &Config{ Account: "account-hfdw89q748ew9gqf48w9qgf", User: "user", Password: "pass", Region: "", Protocol: "https", Host: "account-hfdw89q748ew9gqf48w9qgf.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account", config: &Config{ Account: "account", User: "user", Password: "pass", Region: "", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account.cn-region", config: &Config{ Account: "account", User: "user", Password: "pass", Region: "cn-region", Protocol: "https", Host: "account.cn-region.snowflakecomputing.cn", Port: 443, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account.eu-faraway", config: &Config{ Account: "account", User: "user", Password: "pass", Region: "eu-faraway", Protocol: "https", Host: "account.eu-faraway.snowflakecomputing.com", Port: 443, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account?region=eu-faraway", config: &Config{ Account: "account", User: "user", Password: "pass", Region: "eu-faraway", Protocol: "https", Host: "account.eu-faraway.snowflakecomputing.com", Port: 443, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account/db", config: &Config{ Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account?oauthRedirectUri=http:%2F%2Flocalhost:8001%2Fsome-path&oauthClientId=testClientId&oauthClientSecret=testClientSecret&oauthAuthorizationUrl=http:%2F%2Fsomehost.com&oauthTokenRequestUrl=https:%2F%2Fsomehost2.com%2Fsomepath&oauthScope=test+scope", config: &Config{ Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", OauthAuthorizationURL: "http://somehost.com", OauthTokenRequestURL: "https://somehost2.com/somepath", OauthRedirectURI: "http://localhost:8001/some-path", OauthScope: "test scope", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account?oauthRedirectUri=http:%2F%2Flocalhost:8001%2Fsome-path&oauthClientId=testClientId&oauthClientSecret=testClientSecret&oauthAuthorizationUrl=http:%2F%2Fsomehost.com&oauthTokenRequestUrl=https:%2F%2Fsomehost2.com%2Fsomepath&oauthScope=test+scope&enableSingleUseRefreshTokens=true", config: &Config{ Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", OauthAuthorizationURL: "http://somehost.com", OauthTokenRequestURL: "https://somehost2.com/somepath", OauthRedirectURI: "http://localhost:8001/some-path", OauthScope: "test scope", EnableSingleUseRefreshTokens: true, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@host:123/db/schema?account=ac&protocol=http", config: &Config{ Account: "ac", User: "user", Password: "pass", Protocol: "http", Host: "host", Port: 123, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user@host:123/db/schema?account=ac&protocol=http", config: &Config{ Account: "ac", User: "user", Password: "pass", Protocol: "http", Host: "host", Port: 123, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: sferrors.ErrEmptyPassword(), }, { dsn: "@host:123/db/schema?account=ac&protocol=http", config: &Config{ Account: "ac", User: "user", Password: "pass", Protocol: "http", Host: "host", Port: 123, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: sferrors.ErrEmptyUsername(), }, { dsn: "@host:123/db/schema?account=ac&protocol=http&authenticator=oauth_authorization_code", config: &Config{ Account: "ac", User: "user", Password: "pass", Protocol: "http", Host: "host", Port: 123, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: sferrors.ErrEmptyOAuthParameters(), }, { dsn: "user:pass@host:123/db/schema?protocol=http", config: &Config{ Account: "ac", User: "user", Password: "pass", Protocol: "http", Host: "host", Port: 123, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: sferrors.ErrEmptyAccount(), }, { dsn: "user:@host:123/db/schema?protocol=http&authenticator=programmatic_access_token&account=ac", config: &Config{ Account: "ac", User: "user", Password: "pass", Protocol: "http", Host: "host", Port: 123, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: sferrors.ErrEmptyPasswordAndToken(), }, { dsn: "u:p@a.snowflakecomputing.com/db/pa?account=a&protocol=https&role=r&timezone=UTC&warehouse=w", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "db", Schema: "pa", Role: "r", Warehouse: "w", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflakecomputing.mil/db/pa?account=a", config: &Config{ Account: "a", User: "u", Password: "p", Region: "", Protocol: "https", Host: "a.snowflakecomputing.mil", Port: 443, Database: "db", Schema: "pa", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.eu-faraway.snowflakecomputing.mil/db/pa?account=a®ion=eu-faraway", config: &Config{ Account: "a", User: "u", Password: "p", Region: "eu-faraway", Protocol: "https", Host: "a.eu-faraway.snowflakecomputing.mil", Port: 443, Database: "db", Schema: "pa", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflakecomputing.gov.pl/db/pa?account=a", config: &Config{ Account: "a", User: "u", Password: "p", Region: "", Protocol: "https", Host: "a.snowflakecomputing.gov.pl", Port: 443, Database: "db", Schema: "pa", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflakecomputing.cn/db/pa?account=a", config: &Config{ Account: "a", User: "u", Password: "p", Region: "", Protocol: "https", Host: "a.snowflakecomputing.cn", Port: 443, Database: "db", Schema: "pa", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.cn-region.snowflakecomputing.mil/db/pa?account=a®ion=cn-region", config: &Config{ Account: "a", User: "u", Password: "p", Region: "cn-region", Protocol: "https", Host: "a.cn-region.snowflakecomputing.mil", Port: 443, Database: "db", Schema: "pa", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.cn-region.snowflakecomputing.cn/db/pa?account=a®ion=cn-region&protocol=https&role=r&timezone=UTC&warehouse=w", config: &Config{ Account: "a", User: "u", Password: "p", Region: "cn-region", Protocol: "https", Host: "a.cn-region.snowflakecomputing.cn", Port: 443, Database: "db", Schema: "pa", Role: "r", Warehouse: "w", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@snowflake.local:9876?account=a&protocol=http", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "http", Host: "snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "snowflake.local:9876?account=a&protocol=http&authenticator=OAUTH", config: &Config{ Account: "a", Authenticator: AuthTypeOAuth, Protocol: "http", Host: "snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "snowflake.local:9876?account=a&protocol=http&authenticator=OAUTH_AUTHORIZATION_CODE&oauthClientId=testClientId&oauthClientSecret=testClientSecret", config: &Config{ Account: "a", Authenticator: AuthTypeOAuthAuthorizationCode, Protocol: "http", Host: "snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "snowflake.local:9876?account=a&protocol=http&authenticator=OAUTH_CLIENT_CREDENTIALS", config: &Config{ Account: "a", Authenticator: AuthTypeOAuthClientCredentials, Protocol: "http", Host: "snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:@a.snowflake.local:9876?account=a&protocol=http&authenticator=SNOWFLAKE_JWT", config: &Config{ Account: "a", User: "u", Authenticator: AuthTypeJwt, Protocol: "http", Host: "a.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a?database=d&jwtTimeout=20", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "", JWTExpireTimeout: 20 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, }, { dsn: "u:p@a?database=d&externalBrowserTimeout=20&cloudStorageTimeout=7", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "", ExternalBrowserTimeout: 20 * time.Second, CloudStorageTimeout: 7 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), IncludeRetryReason: BoolTrue, MaxRetryCount: defaultMaxRetryCount, }, ocspMode: ocspModeFailOpen, }, { dsn: "u:p@a?database=d&maxRetryCount=20", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "", ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, MaxRetryCount: 20, }, ocspMode: ocspModeFailOpen, }, { dsn: "u:p@a?database=d", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "", JWTExpireTimeout: time.Duration(DefaultJWTTimeout), OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, }, { dsn: "u:p@snowflake.local:NNNN?account=a&protocol=http", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "http", Host: "snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: &sferrors.SnowflakeError{ Message: sferrors.ErrMsgFailedToParsePort, MessageArgs: []any{"NNNN"}, Number: sferrors.ErrCodeFailedToParsePort, }, }, { dsn: "u:p@a?database=d&schema=s&role=r&application=aa&authenticator=snowflake&disableOCSPChecks=true&passcode=pp&passcodeInPassword=true", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "s", Role: "r", Authenticator: AuthTypeSnowflake, Application: "aa", DisableOCSPChecks: true, Passcode: "pp", PasscodeInPassword: true, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeDisabled, err: nil, }, { dsn: "u:p@a?database=d&schema=s&role=r&application=aa&authenticator=snowflake&disableOCSPChecks=true&passcode=pp&passcodeInPassword=true", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "s", Role: "r", Authenticator: AuthTypeSnowflake, Application: "aa", DisableOCSPChecks: true, Passcode: "pp", PasscodeInPassword: true, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeDisabled, err: nil, }, { // schema should be ignored as no value is specified. dsn: "u:p@a?database=d&schema", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a?database= %Sd", config: &Config{}, err: url.EscapeError(`invalid URL escape`), }, { dsn: "u:p@a?schema= %Sd", config: &Config{}, err: url.EscapeError(`invalid URL escape`), }, { dsn: "u:p@a?warehouse= %Sd", config: &Config{}, err: url.EscapeError(`invalid URL escape`), }, { dsn: "u:p@a?role= %Sd", config: &Config{}, err: url.EscapeError(`invalid URL escape`), }, { dsn: ":/", config: &Config{}, err: &sferrors.SnowflakeError{ Number: sferrors.ErrCodeFailedToParsePort, }, }, { dsn: "u:u@/+/+?account=+&=0", config: &Config{}, err: sferrors.ErrEmptyAccount(), }, { dsn: "u:u@/+/+?account=+&=+&=+", config: &Config{}, err: sferrors.ErrEmptyAccount(), }, { dsn: "user%40%2F1:p%3A%40s@/db%2F?account=ac", config: &Config{ Account: "ac", User: "user@/1", Password: "p:@s", Database: "db/", Protocol: "https", Host: "ac.snowflakecomputing.com", Port: 443, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: fmt.Sprintf("u:p@ac.snowflake.local:9876?account=ac&protocol=http&authenticator=SNOWFLAKE_JWT&privateKey=%v", privKeyPKCS8), config: &Config{ Account: "ac", User: "u", Password: "p", Authenticator: AuthTypeJwt, PrivateKey: testPrivKey, Protocol: "http", Host: "ac.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: fmt.Sprintf("u:p@ac.snowflake.local:9876?account=ac&protocol=http&authenticator=%v", url.QueryEscape("https://ac.okta.com")), config: &Config{ Account: "ac", User: "u", Password: "p", Authenticator: AuthTypeOkta, OktaURL: &url.URL{ Scheme: "https", Host: "ac.okta.com", }, PrivateKey: testPrivKey, Protocol: "http", Host: "ac.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: fmt.Sprintf("u:p@ac.snowflake.local:9876?account=ac&protocol=http&authenticator=%v", url.QueryEscape("https://ac.some-host.com/custom-okta-url")), config: &Config{ Account: "ac", User: "u", Password: "p", Authenticator: AuthTypeOkta, OktaURL: &url.URL{ Scheme: "https", Host: "ac.some-host.com", Path: "/custom-okta-url", }, PrivateKey: testPrivKey, Protocol: "http", Host: "ac.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: fmt.Sprintf("u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=SNOWFLAKE_JWT&privateKey=%v", privKeyPKCS1), config: &Config{ Account: "a", User: "u", Password: "p", Authenticator: AuthTypeJwt, PrivateKey: testPrivKey, Protocol: "http", Host: "a.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: &sferrors.SnowflakeError{Number: sferrors.ErrCodePrivateKeyParseError}, }, { dsn: "user:pass@account/db/s?ocspFailOpen=true", config: &Config{ Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account/db/s?ocspFailOpen=false", config: &Config{ Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenFalse, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailClosed, err: nil, }, { dsn: "user:pass@account/db/s?validateDefaultParameters=true", config: &Config{ Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "user:pass@account/db/s?validateDefaultParameters=false", config: &Config{ Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolFalse, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&validateDefaultParameters=false", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolFalse, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&clientTimeout=300&jwtClientTimeout=45&includeRetryReason=false", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: 300 * time.Second, JWTClientTimeout: 45 * time.Second, ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, DisableQueryContextCache: false, IncludeRetryReason: BoolFalse, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&serverSessionKeepAlive=false", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&serverSessionKeepAlive=true", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ServerSessionKeepAlive: true, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&tmpDirPath=%2Ftmp", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, TmpDirPath: "/tmp", IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&disableQueryContextCache=true", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, DisableQueryContextCache: true, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&includeRetryReason=true", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&includeRetryReason=true&clientConfigFile=%2FUsers%2Fuser%2Fconfig.json", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, ClientConfigFile: "/Users/user/config.json", }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&includeRetryReason=true&clientConfigFile=c%3A%5CUsers%5Cuser%5Cconfig.json", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: BoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, ClientConfigFile: "c:\\Users\\user\\config.json", }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflakecomputing.com:443?authenticator=http%3A%2F%2Fsc.okta.com&ocspFailOpen=true&validateDefaultParameters=true", err: sferrors.ErrFailedToParseAuthenticator(), }, { dsn: "u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableConsoleLogin=true", config: &Config{ Account: "a", User: "u", Password: "p", Authenticator: AuthTypeExternalBrowser, Protocol: "http", Host: "a.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, DisableConsoleLogin: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableConsoleLogin=false", config: &Config{ Account: "a", User: "u", Password: "p", Authenticator: AuthTypeExternalBrowser, Protocol: "http", Host: "a.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, DisableConsoleLogin: BoolFalse, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableSamlURLCheck=true", config: &Config{ Account: "a", User: "u", Password: "p", Authenticator: AuthTypeExternalBrowser, Protocol: "http", Host: "a.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, DisableSamlURLCheck: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableSamlURLCheck=false", config: &Config{ Account: "a", User: "u", Password: "p", Authenticator: AuthTypeExternalBrowser, Protocol: "http", Host: "a.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, DisableSamlURLCheck: BoolFalse, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=PROGRAMMATIC_ACCESS_TOKEN&disableSamlURLCheck=false&token=t", config: &Config{ Account: "a", User: "u", Password: "p", Authenticator: AuthTypePat, Protocol: "http", Host: "a.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, DisableSamlURLCheck: BoolFalse, Token: "t", }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=PROGRAMMATIC_ACCESS_TOKEN&disableSamlURLCheck=false&tokenFilePath=..%2F..%2Ftest_data%2Fsnowflake%2Fsession%2Ftoken", config: &Config{ Account: "a", User: "u", Password: "p", Authenticator: AuthTypePat, Protocol: "http", Host: "a.snowflake.local", Port: 9876, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, DisableSamlURLCheck: BoolFalse, TokenFilePath: "../../test_data/snowflake/session/token", }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflake.local:9876?account=a&certRevocationCheckMode=enabled&crlAllowCertificatesWithoutCrlURL=true&crlInMemoryCacheDisabled=true&crlOnDiskCacheDisabled=true&crlDownloadMaxSize=10&crlHttpClientTimeout=10", config: &Config{ Account: "a", User: "u", Password: "p", Host: "a.snowflake.local", Port: 9876, Protocol: "https", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, CertRevocationCheckMode: CertRevocationCheckEnabled, CrlAllowCertificatesWithoutCrlURL: BoolTrue, CrlInMemoryCacheDisabled: true, CrlOnDiskCacheDisabled: true, CrlDownloadMaxSize: 10, CrlHTTPClientTimeout: 10 * time.Second, }, ocspMode: ocspModeFailOpen, }, { dsn: "user:pass@account/db?tlsConfigName=custom", err: &sferrors.SnowflakeError{ Number: sferrors.ErrCodeMissingTLSConfig, Message: fmt.Sprintf(sferrors.ErrMsgMissingTLSConfig, "custom"), }, }, { dsn: "u:p@a.snowflake.local:9876?account=a&&singleAuthenticationPrompt=true", config: &Config{ Account: "a", User: "u", Password: "p", Host: "a.snowflake.local", Port: 9876, Protocol: "https", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, SingleAuthenticationPrompt: BoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflake.local:9876?account=a&&singleAuthenticationPrompt=false", config: &Config{ Account: "a", User: "u", Password: "p", Host: "a.snowflake.local", Port: 9876, Protocol: "https", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, SingleAuthenticationPrompt: BoolFalse, }, ocspMode: ocspModeFailOpen, err: nil, }, { dsn: "u:p@a.snowflake.local:9876?account=a&tracing=debug&logQueryText=true&logQueryParameters=true", config: &Config{ Account: "a", User: "u", Password: "p", Host: "a.snowflake.local", Port: 9876, Protocol: "https", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, Tracing: "debug", IncludeRetryReason: BoolTrue, LogQueryText: true, LogQueryParameters: true, }, ocspMode: ocspModeFailOpen, }, } for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { testcases = append(testcases, tcParseDSN{ dsn: fmt.Sprintf("@host:777/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), config: &Config{ Account: "ac", User: "", Password: "", Protocol: "http", Host: "host", Port: 777, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, Authenticator: at, }, ocspMode: ocspModeFailOpen, err: nil, }) } for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA, AuthTypeJwt} { testcases = append(testcases, tcParseDSN{ dsn: fmt.Sprintf("@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), config: &Config{ Account: "ac", User: "", Password: "", Protocol: "http", Host: "host", Port: 888, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, Authenticator: at, }, ocspMode: ocspModeFailOpen, err: sferrors.ErrEmptyUsername(), }) } for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA} { testcases = append(testcases, tcParseDSN{ dsn: fmt.Sprintf("user@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), config: &Config{ Account: "ac", User: "user", Password: "", Protocol: "http", Host: "host", Port: 888, Database: "db", Schema: "schema", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: BoolTrue, ClientTimeout: time.Duration(DefaultClientTimeout), JWTClientTimeout: time.Duration(DefaultJWTClientTimeout), ExternalBrowserTimeout: time.Duration(DefaultExternalBrowserTimeout), CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: BoolTrue, Authenticator: at, }, ocspMode: ocspModeFailOpen, err: sferrors.ErrEmptyPassword(), }) } for i, test := range testcases { t.Run(maskSecrets(test.dsn), func(t *testing.T) { cfg, err := ParseDSN(test.dsn) switch { case test.err == nil: assertNilF(t, err, fmt.Sprintf("%d: Failed to parse the DSN. dsn: %v", i, test.dsn)) assertEqualE(t, cfg.Host, test.config.Host, fmt.Sprintf("Test %d: Host mismatch", i)) assertEqualE(t, cfg.Account, test.config.Account, fmt.Sprintf("Test %d: Account mismatch", i)) assertEqualE(t, cfg.User, test.config.User, fmt.Sprintf("Test %d: User mismatch", i)) assertEqualE(t, cfg.Password, test.config.Password, fmt.Sprintf("Test %d: Password mismatch", i)) assertEqualE(t, cfg.Database, test.config.Database, fmt.Sprintf("Test %d: Database mismatch", i)) assertEqualE(t, cfg.Schema, test.config.Schema, fmt.Sprintf("Test %d: Schema mismatch", i)) assertEqualE(t, cfg.Warehouse, test.config.Warehouse, fmt.Sprintf("Test %d: Warehouse mismatch", i)) assertEqualE(t, cfg.Role, test.config.Role, fmt.Sprintf("Test %d: Role mismatch", i)) assertEqualE(t, cfg.Region, test.config.Region, fmt.Sprintf("Test %d: Region mismatch", i)) assertEqualE(t, cfg.Protocol, test.config.Protocol, fmt.Sprintf("Test %d: Protocol mismatch", i)) assertEqualE(t, cfg.Passcode, test.config.Passcode, fmt.Sprintf("Test %d: Passcode mismatch", i)) assertEqualE(t, cfg.PasscodeInPassword, test.config.PasscodeInPassword, fmt.Sprintf("Test %d: PasscodeInPassword mismatch", i)) assertEqualE(t, cfg.Authenticator, test.config.Authenticator, fmt.Sprintf("Test %d: Authenticator mismatch", i)) assertEqualE(t, cfg.SingleAuthenticationPrompt, test.config.SingleAuthenticationPrompt, fmt.Sprintf("Test %d: SingleAuthenticationPrompt mismatch", i)) if test.config.Authenticator == AuthTypeOkta { assertEqualE(t, *cfg.OktaURL, *test.config.OktaURL, fmt.Sprintf("Test %d: OktaURL mismatch", i)) } assertEqualE(t, cfg.OCSPFailOpen, test.config.OCSPFailOpen, fmt.Sprintf("Test %d: OCSPFailOpen mismatch", i)) assertEqualE(t, OcspMode(cfg), test.ocspMode, fmt.Sprintf("Test %d: OCSPMode mismatch", i)) assertEqualE(t, cfg.ValidateDefaultParameters, test.config.ValidateDefaultParameters, fmt.Sprintf("Test %d: ValidateDefaultParameters mismatch", i)) assertEqualE(t, cfg.ClientTimeout, test.config.ClientTimeout, fmt.Sprintf("Test %d: ClientTimeout mismatch", i)) assertEqualE(t, cfg.JWTClientTimeout, test.config.JWTClientTimeout, fmt.Sprintf("Test %d: JWTClientTimeout mismatch", i)) assertEqualE(t, cfg.ExternalBrowserTimeout, test.config.ExternalBrowserTimeout, fmt.Sprintf("Test %d: ExternalBrowserTimeout mismatch", i)) assertEqualE(t, cfg.CloudStorageTimeout, test.config.CloudStorageTimeout, fmt.Sprintf("Test %d: CloudStorageTimeout mismatch", i)) assertEqualE(t, cfg.TmpDirPath, test.config.TmpDirPath, fmt.Sprintf("Test %d: TmpDirPath mismatch", i)) assertEqualE(t, cfg.DisableQueryContextCache, test.config.DisableQueryContextCache, fmt.Sprintf("Test %d: DisableQueryContextCache mismatch", i)) assertEqualE(t, cfg.IncludeRetryReason, test.config.IncludeRetryReason, fmt.Sprintf("Test %d: IncludeRetryReason mismatch", i)) assertEqualE(t, cfg.ServerSessionKeepAlive, test.config.ServerSessionKeepAlive, fmt.Sprintf("Test %d: ServerSessionKeepAlive mismatch", i)) assertEqualE(t, cfg.DisableConsoleLogin, test.config.DisableConsoleLogin, fmt.Sprintf("Test %d: DisableConsoleLogin mismatch", i)) assertEqualE(t, cfg.DisableSamlURLCheck, test.config.DisableSamlURLCheck, fmt.Sprintf("Test %d: DisableSamlURLCheck mismatch", i)) assertEqualE(t, cfg.OauthClientID, test.config.OauthClientID, fmt.Sprintf("Test %d: OauthClientID mismatch", i)) assertEqualE(t, cfg.OauthClientSecret, test.config.OauthClientSecret, fmt.Sprintf("Test %d: OauthClientSecret mismatch", i)) assertEqualE(t, cfg.OauthAuthorizationURL, test.config.OauthAuthorizationURL, fmt.Sprintf("Test %d: OauthAuthorizationURL mismatch", i)) assertEqualE(t, cfg.OauthTokenRequestURL, test.config.OauthTokenRequestURL, fmt.Sprintf("Test %d: OauthTokenRequestURL mismatch", i)) assertEqualE(t, cfg.OauthRedirectURI, test.config.OauthRedirectURI, fmt.Sprintf("Test %d: OauthRedirectURI mismatch", i)) assertEqualE(t, cfg.OauthScope, test.config.OauthScope, fmt.Sprintf("Test %d: OauthScope mismatch", i)) assertEqualE(t, cfg.EnableSingleUseRefreshTokens, test.config.EnableSingleUseRefreshTokens, fmt.Sprintf("Test %d: EnableSingleUseRefreshTokens mismatch", i)) assertEqualE(t, cfg.Token, test.config.Token, "token") assertEqualE(t, cfg.ClientConfigFile, test.config.ClientConfigFile, "client config file") assertEqualE(t, cfg.CertRevocationCheckMode, test.config.CertRevocationCheckMode, "cert revocation check mode") assertEqualE(t, cfg.CrlAllowCertificatesWithoutCrlURL, test.config.CrlAllowCertificatesWithoutCrlURL, "crl allow certificates without crl url") assertEqualE(t, cfg.CrlInMemoryCacheDisabled, test.config.CrlInMemoryCacheDisabled, "crl in memory cache disabled") assertEqualE(t, cfg.CrlOnDiskCacheDisabled, test.config.CrlOnDiskCacheDisabled, "crl on disk cache disabled") assertEqualE(t, cfg.CrlHTTPClientTimeout, test.config.CrlHTTPClientTimeout, "crl http client timeout") case test.err != nil: driverErrE, okE := test.err.(*sferrors.SnowflakeError) driverErrG, okG := err.(*sferrors.SnowflakeError) if okE && !okG || !okE && okG { t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) } if okE && okG { if driverErrE.Number != driverErrG.Number { t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) } } else { t1 := reflect.TypeOf(err) t2 := reflect.TypeOf(test.err) if t1 != t2 { t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) } } } }) } } type tcDSN struct { cfg *Config dsn string err error } func TestDSN(t *testing.T) { tmfmt := "MM-DD-YYYY" testcases := []tcDSN{ { cfg: &Config{ User: "u", Password: "p", Account: "a-aofnadsf.somewhere.azure", }, dsn: "u:p@a-aofnadsf.somewhere.azure.snowflakecomputing.com:443?ocspFailOpen=true®ion=somewhere.azure&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a-aofnadsf.global", }, dsn: "u:p@a-aofnadsf.global.snowflakecomputing.com:443?ocspFailOpen=true®ion=global&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a-aofnadsf.global", Region: "us-west-2", }, dsn: "u:p@a-aofnadsf.global.snowflakecomputing.com:443?ocspFailOpen=true®ion=global&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "account-name", Region: "cn-region", }, dsn: "u:p@account-name.cn-region.snowflakecomputing.cn:443?ocspFailOpen=true®ion=cn-region&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "account-name.cn-region", }, dsn: "u:p@account-name.cn-region.snowflakecomputing.cn:443?ocspFailOpen=true®ion=cn-region&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "account-name.cn-region", Host: "account-name.cn-region.snowflakecomputing.cn", }, dsn: "u:p@account-name.cn-region.snowflakecomputing.cn:443?account=account-name&ocspFailOpen=true®ion=cn-region&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "account.us-west-2", }, dsn: "u:p@account.snowflakecomputing.com:443?ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "account_us-west-2", }, dsn: "u:p@account_us-west-2.snowflakecomputing.com:443?ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "account-name", Host: "account-name.snowflakecomputing.mil", }, dsn: "u:p@account-name.snowflakecomputing.mil:443?account=account-name&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "account-name", Host: "account-name.snowflakecomputing.gov.pl", }, dsn: "u:p@account-name.snowflakecomputing.gov.pl:443?account=account-name&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a-aofnadsf.global", Region: "r", }, err: sferrors.ErrRegionConflict(), }, { cfg: &Config{ User: "u", Password: "p", Account: "a", }, dsn: "u:p@a.snowflakecomputing.com:443?ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Region: "us-west-2", }, dsn: "u:p@a.snowflakecomputing.com:443?ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Region: "r", }, dsn: "u:p@a.r.snowflakecomputing.com:443?ocspFailOpen=true®ion=r&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Region: "r", OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", OauthAuthorizationURL: "http://somehost.com", OauthTokenRequestURL: "https://somehost2.com/somepath", OauthRedirectURI: "http://localhost:8001/some-path", OauthScope: "test scope", }, dsn: "u:p@a.r.snowflakecomputing.com:443?oauthAuthorizationUrl=http%3A%2F%2Fsomehost.com&oauthClientId=testClientId&oauthClientSecret=testClientSecret&oauthRedirectUri=http%3A%2F%2Flocalhost%3A8001%2Fsome-path&oauthScope=test+scope&oauthTokenRequestUrl=https%3A%2F%2Fsomehost2.com%2Fsomepath&ocspFailOpen=true®ion=r&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Region: "r", OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", OauthAuthorizationURL: "http://somehost.com", OauthTokenRequestURL: "https://somehost2.com/somepath", OauthRedirectURI: "http://localhost:8001/some-path", OauthScope: "test scope", EnableSingleUseRefreshTokens: true, }, dsn: "u:p@a.r.snowflakecomputing.com:443?enableSingleUseRefreshTokens=true&oauthAuthorizationUrl=http%3A%2F%2Fsomehost.com&oauthClientId=testClientId&oauthClientSecret=testClientSecret&oauthRedirectUri=http%3A%2F%2Flocalhost%3A8001%2Fsome-path&oauthScope=test+scope&oauthTokenRequestUrl=https%3A%2F%2Fsomehost2.com%2Fsomepath&ocspFailOpen=true®ion=r&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Region: "r", ExternalBrowserTimeout: 20 * time.Second, CloudStorageTimeout: 7 * time.Second, }, dsn: "u:p@a.r.snowflakecomputing.com:443?cloudStorageTimeout=7&externalBrowserTimeout=20&ocspFailOpen=true®ion=r&validateDefaultParameters=true", }, { cfg: &Config{ User: "", Password: "p", Account: "a", }, err: sferrors.ErrEmptyUsername(), }, { cfg: &Config{ User: "u", Password: "", Account: "a", }, err: sferrors.ErrEmptyPassword(), }, { cfg: &Config{ User: "u", Password: "p", Account: "", }, err: sferrors.ErrEmptyAccount(), }, { cfg: &Config{ User: "u", Password: "p", Account: "ac", Authenticator: AuthTypeOAuthAuthorizationCode, }, err: sferrors.ErrEmptyOAuthParameters(), }, { cfg: &Config{ User: "u", Password: "p", Account: "a.e", }, dsn: "u:p@a.e.snowflakecomputing.com:443?ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.e", Region: "us-west-2", }, dsn: "u:p@a.e.snowflakecomputing.com:443?ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.e", Region: "r", }, err: sferrors.ErrRegionConflict(), }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Database: "db", Schema: "sc", Role: "ro", Region: "b", Authenticator: AuthTypeSnowflake, Passcode: "db", PasscodeInPassword: true, LoginTimeout: 10 * time.Second, RequestTimeout: 300 * time.Second, Application: "special go", }, dsn: "u:p@a.b.snowflakecomputing.com:443?application=special+go&database=db&loginTimeout=10&ocspFailOpen=true&passcode=db&passcodeInPassword=true®ion=b&requestTimeout=300&role=ro&schema=sc&validateDefaultParameters=true", }, { cfg: &Config{ Account: "ac", User: "u", Password: "p", Database: "db", Authenticator: AuthTypeWorkloadIdentityFederation, Host: "ac.snowflakecomputing.com", WorkloadIdentityProvider: "azure", WorkloadIdentityEntraResource: "https://example.com/default", WorkloadIdentityImpersonationPath: []string{"/default", "/default2"}, }, dsn: "u:p@ac.snowflakecomputing.com:443?account=ac&authenticator=workload_identity&database=db&ocspFailOpen=true&validateDefaultParameters=true&workloadIdentityEntraResource=https%3A%2F%2Fexample.com%2Fdefault&workloadIdentityImpersonationPath=%2Fdefault%2C%2Fdefault2&workloadIdentityProvider=azure", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Authenticator: AuthTypeExternalBrowser, ClientStoreTemporaryCredential: BoolTrue, }, dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&clientStoreTemporaryCredential=true&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Authenticator: AuthTypeExternalBrowser, ClientStoreTemporaryCredential: BoolFalse, }, dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&clientStoreTemporaryCredential=false&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Token: "t", Authenticator: AuthTypePat, ClientStoreTemporaryCredential: BoolFalse, }, dsn: "u:p@a.snowflakecomputing.com:443?authenticator=programmatic_access_token&clientStoreTemporaryCredential=false&ocspFailOpen=true&token=t&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", TokenFilePath: "../../test_data/snowflake/session/token", Authenticator: AuthTypePat, ClientStoreTemporaryCredential: BoolFalse, }, dsn: "u:p@a.snowflakecomputing.com:443?authenticator=programmatic_access_token&clientStoreTemporaryCredential=false&ocspFailOpen=true&tokenFilePath=..%2F..%2Ftest_data%2Fsnowflake%2Fsession%2Ftoken&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Authenticator: AuthTypeOAuthAuthorizationCode, OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", ClientStoreTemporaryCredential: BoolFalse, }, dsn: "u:p@a.snowflakecomputing.com:443?authenticator=oauth_authorization_code&clientStoreTemporaryCredential=false&oauthClientId=testClientId&oauthClientSecret=testClientSecret&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Authenticator: AuthTypeOAuthClientCredentials, ClientStoreTemporaryCredential: BoolFalse, }, dsn: "u:p@a.snowflakecomputing.com:443?authenticator=oauth_client_credentials&clientStoreTemporaryCredential=false&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", Authenticator: AuthTypeOkta, OktaURL: &url.URL{ Scheme: "https", Host: "sc.okta.com", }, }, dsn: "u:p@a.snowflakecomputing.com:443?authenticator=https%3A%2F%2Fsc.okta.com&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.e", Params: map[string]*string{ "TIMESTAMP_OUTPUT_FORMAT": &tmfmt, }, }, dsn: "u:p@a.e.snowflakecomputing.com:443?TIMESTAMP_OUTPUT_FORMAT=MM-DD-YYYY&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: ":@abc", Account: "a.e", Params: map[string]*string{ "TIMESTAMP_OUTPUT_FORMAT": &tmfmt, }, }, dsn: "u:%3A%40abc@a.e.snowflakecomputing.com:443?TIMESTAMP_OUTPUT_FORMAT=MM-DD-YYYY&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", OCSPFailOpen: OCSPFailOpenTrue, }, dsn: "u:p@a.snowflakecomputing.com:443?ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", OCSPFailOpen: OCSPFailOpenFalse, }, dsn: "u:p@a.snowflakecomputing.com:443?ocspFailOpen=false&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", ValidateDefaultParameters: BoolFalse, }, dsn: "u:p@a.snowflakecomputing.com:443?ocspFailOpen=true&validateDefaultParameters=false", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", ValidateDefaultParameters: BoolTrue, }, dsn: "u:p@a.snowflakecomputing.com:443?ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", DisableOCSPChecks: true, }, dsn: "u:p@a.snowflakecomputing.com:443?disableOCSPChecks=true&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", DisableOCSPChecks: true, }, dsn: "u:p@a.snowflakecomputing.com:443?disableOCSPChecks=true&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a", DisableOCSPChecks: true, ConnectionDiagnosticsEnabled: true, }, dsn: "u:p@a.snowflakecomputing.com:443?connectionDiagnosticsEnabled=true&disableOCSPChecks=true&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "account.snowflakecomputing.com", }, dsn: "u:p@account.snowflakecomputing.com.snowflakecomputing.com:443?ocspFailOpen=true®ion=snowflakecomputing.com&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Region: "us-west-2", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Region: "r", }, err: sferrors.ErrRegionConflict(), }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", ClientTimeout: 400 * time.Second, JWTClientTimeout: 60 * time.Second, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=400&jwtClientTimeout=60&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", ClientTimeout: 400 * time.Second, JWTExpireTimeout: 30 * time.Second, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=400&jwtTimeout=30&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Protocol: "http", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true&protocol=http®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Tracing: "debug", LogQueryText: true, LogQueryParameters: true, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?logQueryParameters=true&logQueryText=true&ocspFailOpen=true®ion=b.c&tracing=debug&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Authenticator: AuthTypeUsernamePasswordMFA, ClientRequestMfaToken: BoolTrue, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=username_password_mfa&clientRequestMfaToken=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Authenticator: AuthTypeUsernamePasswordMFA, ClientRequestMfaToken: BoolFalse, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=username_password_mfa&clientRequestMfaToken=false&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Warehouse: "wh", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&validateDefaultParameters=true&warehouse=wh", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Token: "t", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&token=t&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Authenticator: AuthTypeTokenAccessor, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=tokenaccessor&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", TmpDirPath: "/tmp", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&tmpDirPath=%2Ftmp&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", IncludeRetryReason: BoolFalse, MaxRetryCount: 30, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?includeRetryReason=false&maxRetryCount=30&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", ServerSessionKeepAlive: true, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&serverSessionKeepAlive=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", DisableQueryContextCache: true, IncludeRetryReason: BoolTrue, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?disableQueryContextCache=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", IncludeRetryReason: BoolFalse, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?includeRetryReason=false&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", IncludeRetryReason: BoolTrue, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", IncludeRetryReason: BoolTrue, ClientConfigFile: "/Users/user/config.json", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientConfigFile=%2FUsers%2Fuser%2Fconfig.json&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", IncludeRetryReason: BoolTrue, ClientConfigFile: "c:\\Users\\user\\config.json", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientConfigFile=c%3A%5CUsers%5Cuser%5Cconfig.json&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Authenticator: AuthTypeExternalBrowser, DisableConsoleLogin: BoolTrue, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=externalbrowser&disableConsoleLogin=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Authenticator: AuthTypeExternalBrowser, DisableConsoleLogin: BoolFalse, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=externalbrowser&disableConsoleLogin=false&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Authenticator: AuthTypeExternalBrowser, DisableSamlURLCheck: BoolTrue, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=externalbrowser&disableSamlURLCheck=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", Authenticator: AuthTypeExternalBrowser, DisableSamlURLCheck: BoolFalse, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=externalbrowser&disableSamlURLCheck=false&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", CertRevocationCheckMode: CertRevocationCheckEnabled, CrlAllowCertificatesWithoutCrlURL: BoolTrue, CrlInMemoryCacheDisabled: true, CrlOnDiskCacheDisabled: true, CrlDownloadMaxSize: 10, CrlHTTPClientTimeout: 5 * time.Second, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?certRevocationCheckMode=ENABLED&crlAllowCertificatesWithoutCrlURL=true&crlDownloadMaxSize=10&crlHttpClientTimeout=5&crlInMemoryCacheDisabled=true&crlOnDiskCacheDisabled=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", TLSConfigName: "custom", }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&tlsConfigName=custom&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", SingleAuthenticationPrompt: BoolTrue, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&singleAuthenticationPrompt=true&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", SingleAuthenticationPrompt: BoolFalse, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&singleAuthenticationPrompt=false&validateDefaultParameters=true", }, } for _, test := range testcases { t.Run(maskSecrets(test.dsn), func(t *testing.T) { if test.cfg.TLSConfigName != "" && test.err == nil { err := RegisterTLSConfig(test.cfg.TLSConfigName, &tls.Config{}) assertNilF(t, err, "Failed to register test TLS config") defer func() { _ = DeregisterTLSConfig(test.cfg.TLSConfigName) }() } dsn, err := DSN(test.cfg) if test.err == nil && err == nil { assertEqualF(t, dsn, test.dsn, fmt.Sprintf("failed to get DSN. expected: %v, got:\n %v", maskSecrets(test.dsn), maskSecrets(dsn))) _, err := ParseDSN(dsn) assertNilF(t, err, "failed to parse DSN. dsn:", dsn) } if test.err != nil { assertNotNilF(t, err, fmt.Sprintf("expected error. dsn: %v, expected err: %v", maskSecrets(test.dsn), maskSecrets(test.err.Error()))) } if test.err == nil { assertNilF(t, err, "failed to match") } }) } } func TestParsePrivateKeyFromFileMissingFile(t *testing.T) { _, err := parsePrivateKeyFromFile("nonexistent") if err == nil { t.Error("should report error for nonexistent file") } } func TestParsePrivateKeyFromFileIncorrectData(t *testing.T) { pemFile := createTmpFile(t, "exampleKey.pem", []byte("gibberish")) _, err := parsePrivateKeyFromFile(pemFile) if err == nil { t.Error("should report error for wrong data in file") } } func TestParsePrivateKeyFromFileNotRSAPrivateKey(t *testing.T) { // Generate an ECDSA private key for testing ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("failed to generate ECDSA private key: %v", err) } ecdsaPrivateKeyBytes, err := x509.MarshalECPrivateKey(ecdsaPrivateKey) if err != nil { t.Fatalf("failed to marshal ECDSA private key: %v", err) } pemBlock := &pem.Block{ Type: "EC PRIVATE KEY", Bytes: ecdsaPrivateKeyBytes, } pemData := pem.EncodeToMemory(pemBlock) // Write the PEM data to a temporary file pemFile := createTmpFile(t, "ecdsaKey.pem", pemData) // Attempt to parse the private key _, err = parsePrivateKeyFromFile(pemFile) if err == nil { t.Error("expected an error when trying to parse an ECDSA private key as RSA") } } func TestParsePrivateKeyFromFile(t *testing.T) { generatedKey, _ := rsa.GenerateKey(cr.Reader, 1024) pemKey, _ := x509.MarshalPKCS8PrivateKey(generatedKey) pemData := pem.EncodeToMemory( &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: pemKey, }, ) keyFile := createTmpFile(t, "exampleKey.pem", pemData) defer os.Remove(keyFile) parsedKey, err := parsePrivateKeyFromFile(keyFile) if err != nil { t.Errorf("unable to parse pam file from path: %v, err: %v", keyFile, err) } else if !parsedKey.Equal(generatedKey) { t.Errorf("generated key does not equal to parsed key from file\ngeneratedKey=%v\nparsedKey=%v", generatedKey, parsedKey) } } func createTmpFile(t *testing.T, fileName string, content []byte) string { tempFile, _ := os.CreateTemp("", fileName) _, err := tempFile.Write(content) assertNilF(t, err) absolutePath := tempFile.Name() return absolutePath } type configParamToValue struct { configParam string value string } func TestGetConfigFromEnv(t *testing.T) { envMap := map[string]configParamToValue{ "SF_TEST_ACCOUNT": {"Account", "account"}, "SF_TEST_USER": {"User", "user"}, "SF_TEST_PASSWORD": {"Password", "password"}, "SF_TEST_ROLE": {"Role", "role"}, "SF_TEST_HOST": {"Host", "host"}, "SF_TEST_PORT": {"Port", "8080"}, "SF_TEST_PROTOCOL": {"Protocol", "http"}, "SF_TEST_WAREHOUSE": {"Warehouse", "warehouse"}, "SF_TEST_DATABASE": {"Database", "database"}, "SF_TEST_REGION": {"Region", "region"}, "SF_TEST_PASSCODE": {"Passcode", "passcode"}, "SF_TEST_SCHEMA": {"Schema", "schema"}, "SF_TEST_APPLICATION": {"Application", "application"}, } var properties = make([]*Param, len(envMap)) i := 0 for key, ctv := range envMap { os.Setenv(key, ctv.value) cfgParam := Param{Name: ctv.configParam, EnvName: key, FailOnMissing: true} properties[i] = &cfgParam i++ } defer func() { for key := range envMap { os.Unsetenv(key) } }() cfg, err := GetConfigFromEnv(properties) if err != nil { t.Errorf("unable to parse env variables to Config, err: %v", err) } err = checkConfig(*cfg, envMap) if err != nil { t.Error(err) } } func checkConfig(cfg Config, envMap map[string]configParamToValue) error { appendError := func(errArray []string, envName string, expected string, received string) []string { errArray = append(errArray, fmt.Sprintf("field %v expected value: %v, received value: %v", envName, expected, received)) return errArray } value := reflect.ValueOf(cfg) typeOfCfg := value.Type() cfgValues := make(map[string]any, value.NumField()) for i := 0; i < value.NumField(); i++ { if value.Field(i).CanInterface() { cfgValues[typeOfCfg.Field(i).Name] = value.Field(i).Interface() } } var errArray []string for key, ctv := range envMap { if ctv.configParam == "Port" { if portStr := strconv.Itoa(cfgValues[ctv.configParam].(int)); portStr != ctv.value { errArray = appendError(errArray, key, ctv.value, cfgValues[ctv.configParam].(string)) } } else if cfgValues[ctv.configParam] != ctv.value { errArray = appendError(errArray, key, ctv.value, cfgValues[ctv.configParam].(string)) } } if errArray != nil { return errors.New(strings.Join(errArray, "\n")) } return nil } func TestConfigValidateTmpDirPath(t *testing.T) { cfg := &Config{ TmpDirPath: "/not/existing", } if err := cfg.Validate(); err == nil { t.Fatalf("Should fail on not existing TmpDirPath") } } func TestExtractAccountName(t *testing.T) { testcases := map[string]string{ "myaccount": "MYACCOUNT", "myaccount.eu-central-1": "MYACCOUNT", "myaccount.eu-central-1.privatelink": "MYACCOUNT", "myorg-myaccount": "MYORG-MYACCOUNT", "myorg-myaccount.privatelink": "MYORG-MYACCOUNT", "myorg-my-account": "MYORG-MY-ACCOUNT", "myorg-my-account.privatelink": "MYORG-MY-ACCOUNT", "myorg-my_account": "MYORG-MY_ACCOUNT", "myorg-my_account.privatelink": "MYORG-MY_ACCOUNT", } for account, expected := range testcases { t.Run(account, func(t *testing.T) { accountPart := ExtractAccountName(account) if accountPart != expected { t.Fatalf("extractAccountName returned unexpected response (%v), should be %v", accountPart, expected) } }) } } func TestUrlDecodeIfNeeded(t *testing.T) { testcases := map[string]string{ "query_tag": "query_tag", "%24my_custom_variable": "$my_custom_variable", } for param, expected := range testcases { t.Run(param, func(t *testing.T) { decodedParam := urlDecodeIfNeeded(param) assertEqualE(t, decodedParam, expected) }) } } func TestDSNParsingWithTLSConfig(t *testing.T) { // Clean up any existing registry ResetTLSConfigRegistry() // Register test TLS config testTLSConfig := tls.Config{ InsecureSkipVerify: true, ServerName: "custom.test.com", } err := RegisterTLSConfig("custom", &testTLSConfig) assertNilF(t, err, "Failed to register test TLS config") defer func() { err := DeregisterTLSConfig("custom") assertNilF(t, err, "Failed to deregister test TLS config") }() testCases := []struct { name string dsn string expected string err bool }{ { name: "Basic TLS config parameter", dsn: "user:pass@account/db?tlsConfigName=custom", expected: "custom", err: false, }, { name: "TLS config with other parameters", dsn: "user:pass@account/db?tlsConfigName=custom&warehouse=wh&role=admin", expected: "custom", err: false, }, { name: "No TLS config parameter", dsn: "user:pass@account/db?warehouse=wh", err: false, }, { name: "Nonexistent TLS config", dsn: "user:pass@account/db?tlsConfigName=nonexistent", err: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { cfg, err := ParseDSN(tc.dsn) if tc.err { assertNotNilF(t, err, "ParseDSN should have failed but did not") } else { assertNilF(t, err, "ParseDSN failed") // For DSN parsing, the TLS config should be resolved and set directly assertEqualF(t, cfg.TLSConfigName, tc.expected, "TLSConfigName mismatch") } }) } } func TestTokenAndTokenFilePathValidation(t *testing.T) { cfg := &Config{ Account: "a", User: "u", Password: "p", Token: "direct-token", TokenFilePath: "test_data/snowflake/session/token", } if err := cfg.Validate(); !errors.Is(err, errTokenConfigConflict) { t.Error("Expected validation error when both Token and TokenFilePath are set") } cfg.TokenFilePath = "" assertNilE(t, cfg.Validate(), "Should have accepted Token on its own") cfg.Token = "" cfg.TokenFilePath = "test_data/snowflake/session/token" assertNilE(t, cfg.Validate(), "Should have accepted TokenFilePath on its own") } func TestFillMissingConfigParametersDerivesAccountFromHost(t *testing.T) { cfg := &Config{ User: "u", Password: "p", Host: "myacct.us-east-1.snowflakecomputing.com", Port: 443, Account: "", Authenticator: AuthTypeSnowflake, } assertNilE(t, FillMissingConfigParameters(cfg), "FillMissingConfigParameters") if cfg.Account != "myacct" { t.Fatalf("Account: want myacct, got %q", cfg.Account) } } func TestFillMissingConfigParametersDerivesAccountFromCNHost(t *testing.T) { cfg := &Config{ User: "u", Password: "p", Host: "myacct.cn-north-1.snowflakecomputing.cn", Port: 443, Account: "", Authenticator: AuthTypeSnowflake, } assertNilE(t, FillMissingConfigParameters(cfg), "FillMissingConfigParameters") if cfg.Account != "myacct" { t.Fatalf("Account: want myacct, got %q", cfg.Account) } } func TestFillMissingConfigParametersNonSnowflakeHostRequiresAccount(t *testing.T) { cfg := &Config{ User: "u", Password: "p", Host: "snowflake.internal.example.com", Port: 443, Account: "", Authenticator: AuthTypeSnowflake, } err := FillMissingConfigParameters(cfg) assertNotNilF(t, err, "expected error for empty Account with non-Snowflake host") sfErr, ok := err.(*sferrors.SnowflakeError) assertTrueF(t, ok, "expected SnowflakeError") assertEqualE(t, sfErr.Number, sferrors.ErrCodeEmptyAccountCode, "error number") } // helper function to generate PKCS8 encoded base64 string of a private key func generatePKCS8StringSupress(key *rsa.PrivateKey) string { // Error would only be thrown when the private key type is not supported // We would be safe as long as we are using rsa.PrivateKey tmpBytes, _ := x509.MarshalPKCS8PrivateKey(key) privKeyPKCS8 := base64.URLEncoding.EncodeToString(tmpBytes) return privKeyPKCS8 } // helper function to generate PKCS1 encoded base64 string of a private key func generatePKCS1String(key *rsa.PrivateKey) string { tmpBytes := x509.MarshalPKCS1PrivateKey(key) privKeyPKCS1 := base64.URLEncoding.EncodeToString(tmpBytes) return privKeyPKCS1 } ================================================ FILE: internal/config/ocsp_mode.go ================================================ package config // OCSPFailOpenMode is OCSP fail open mode. OCSPFailOpenTrue by default and may // set to ocspModeFailClosed for fail closed mode type OCSPFailOpenMode uint32 const ( // OCSPFailOpenNotSet represents OCSP fail open mode is not set, which is the default value. OCSPFailOpenNotSet OCSPFailOpenMode = iota // OCSPFailOpenTrue represents OCSP fail open mode. OCSPFailOpenTrue // OCSPFailOpenFalse represents OCSP fail closed mode. OCSPFailOpenFalse ) const ( ocspModeFailOpen = "FAIL_OPEN" ocspModeFailClosed = "FAIL_CLOSED" ocspModeDisabled = "INSECURE" ) // OcspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED func OcspMode(c *Config) string { if c.DisableOCSPChecks { return ocspModeDisabled } else if c.OCSPFailOpen == OCSPFailOpenNotSet || c.OCSPFailOpen == OCSPFailOpenTrue { // by default or set to true return ocspModeFailOpen } return ocspModeFailClosed } ================================================ FILE: internal/config/priv_key.go ================================================ package config import ( "crypto/rsa" "crypto/x509" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" ) // ParsePKCS8PrivateKey parses a PKCS8 encoded private key. func ParsePKCS8PrivateKey(block []byte) (*rsa.PrivateKey, error) { privKey, err := x509.ParsePKCS8PrivateKey(block) if err != nil { return nil, &sferrors.SnowflakeError{ Number: sferrors.ErrCodePrivateKeyParseError, Message: "Error decoding private key using PKCS8.", } } return privKey.(*rsa.PrivateKey), nil } // MarshalPKCS8PrivateKey marshals a private key to PKCS8 format. func MarshalPKCS8PrivateKey(key *rsa.PrivateKey) ([]byte, error) { keyInBytes, err := x509.MarshalPKCS8PrivateKey(key) if err != nil { return nil, &sferrors.SnowflakeError{ Number: sferrors.ErrCodePrivateKeyParseError, Message: "Error encoding private key using PKCS8."} } return keyInBytes, nil } ================================================ FILE: internal/config/tls_config.go ================================================ package config import ( "crypto/tls" "sync" ) var ( tlsConfigLock sync.RWMutex tlsConfigRegistry = make(map[string]*tls.Config) ) // ResetTLSConfigRegistry clears the TLS config registry. Used in tests. func ResetTLSConfigRegistry() { tlsConfigLock.Lock() tlsConfigRegistry = make(map[string]*tls.Config) tlsConfigLock.Unlock() } // RegisterTLSConfig registers the tls.Config in configs registry. // Use the key as a value in the DSN where tlsConfigName=value. func RegisterTLSConfig(key string, config *tls.Config) error { tlsConfigLock.Lock() logger.Infof("Registering TLS config for key: %s", key) tlsConfigRegistry[key] = config.Clone() tlsConfigLock.Unlock() return nil } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) error { tlsConfigLock.Lock() logger.Infof("Deregistering TLS config for key: %s", key) delete(tlsConfigRegistry, key) tlsConfigLock.Unlock() return nil } // GetTLSConfig returns a TLS config from the registry. func GetTLSConfig(key string) (*tls.Config, bool) { tlsConfigLock.RLock() tlsConfig, ok := tlsConfigRegistry[key] tlsConfigLock.RUnlock() if !ok { return nil, false } return tlsConfig.Clone(), true } ================================================ FILE: internal/config/tls_config_test.go ================================================ package config import ( "crypto/tls" "crypto/x509" "testing" ) func TestRegisterTLSConfig(t *testing.T) { // Clean up any existing configs after testing defer ResetTLSConfigRegistry() testConfig := tls.Config{ InsecureSkipVerify: true, ServerName: "test-server", } // Test successful registration err := RegisterTLSConfig("test", &testConfig) assertNilE(t, err, "RegisterTLSConfig failed") // Verify config was registered retrieved, exists := GetTLSConfig("test") assertTrueE(t, exists, "TLS config was not registered") // Verify the retrieved config matches the original assertEqualE(t, retrieved.InsecureSkipVerify, testConfig.InsecureSkipVerify, "InsecureSkipVerify mismatch") assertEqualE(t, retrieved.ServerName, testConfig.ServerName, "ServerName mismatch") } func TestDeregisterTLSConfig(t *testing.T) { // Clean up any existing configs after testing defer ResetTLSConfigRegistry() testConfig := tls.Config{ InsecureSkipVerify: true, ServerName: "test-server", } // Register a config err := RegisterTLSConfig("test", &testConfig) assertNilE(t, err, "RegisterTLSConfig failed") // Verify it exists _, exists := GetTLSConfig("test") assertTrueE(t, exists, "TLS config should exist after registration") // Deregister it err = DeregisterTLSConfig("test") assertNilE(t, err, "DeregisterTLSConfig failed") // Verify it's gone _, exists = GetTLSConfig("test") assertFalseE(t, exists, "TLS config should not exist after deregistration") } func TestGetTLSConfigNonExistent(t *testing.T) { _, exists := GetTLSConfig("nonexistent") assertFalseE(t, exists, "getTLSConfig should return false for non-existent config") } func TestRegisterTLSConfigWithCustomRootCAs(t *testing.T) { // Clean up any existing configs after testing defer ResetTLSConfigRegistry() // Create a test cert pool certPool := x509.NewCertPool() testConfig := tls.Config{ RootCAs: certPool, InsecureSkipVerify: false, } err := RegisterTLSConfig("custom-ca", &testConfig) assertNilE(t, err, "RegisterTLSConfig failed") // Retrieve and verify retrieved, exists := GetTLSConfig("custom-ca") assertTrueE(t, exists, "TLS config should exist") // The retrieved should have the same certificates as the original assertTrueE(t, retrieved.RootCAs.Equal(testConfig.RootCAs), "RootCAs should match") } func TestMultipleTLSConfigs(t *testing.T) { // Clean up any existing configs after testing defer ResetTLSConfigRegistry() configs := map[string]*tls.Config{ "insecure": {InsecureSkipVerify: true}, "secure": {InsecureSkipVerify: false, ServerName: "secure.example.com"}, } // Register multiple configs for name, config := range configs { err := RegisterTLSConfig(name, config) assertNilE(t, err, "RegisterTLSConfig failed for "+name) } // Verify all can be retrieved for name, original := range configs { retrieved, exists := GetTLSConfig(name) assertTrueE(t, exists, "Config "+name+" should exist") assertEqualE(t, retrieved.InsecureSkipVerify, original.InsecureSkipVerify, name+" InsecureSkipVerify mismatch") assertEqualE(t, retrieved.ServerName, original.ServerName, name+" ServerName mismatch") } // Test overwriting newConfig := tls.Config{InsecureSkipVerify: false, ServerName: "new.example.com"} err := RegisterTLSConfig("insecure", &newConfig) assertNilE(t, err, "RegisterTLSConfig should allow overwriting") retrieved, _ := GetTLSConfig("insecure") assertEqualE(t, retrieved.ServerName, "new.example.com", "Config should have been overwritten") } ================================================ FILE: internal/config/token_accessor.go ================================================ package config // TokenAccessor manages the session token and master token type TokenAccessor interface { GetTokens() (token string, masterToken string, sessionID int64) SetTokens(token string, masterToken string, sessionID int64) Lock() error Unlock() } ================================================ FILE: internal/errors/errors.go ================================================ // Package errors defines error types and error codes for the Snowflake driver. // It includes both errors returned by the Snowflake server and errors generated by the driver itself. // The SnowflakeError type includes various fields to capture detailed information about an error, such as the error number, // SQL state, query ID, and a message with optional arguments for formatting. The package also defines a set of constants // for common error codes and message templates for consistent error reporting throughout the driver. package errors import "fmt" // SnowflakeError is an error type including various Snowflake specific information. type SnowflakeError struct { Number int SQLState string QueryID string Message string MessageArgs []any IncludeQueryID bool // TODO: populate this in connection } func (se *SnowflakeError) Error() string { message := se.Message if len(se.MessageArgs) > 0 { message = fmt.Sprintf(se.Message, se.MessageArgs...) } if se.SQLState != "" { if se.IncludeQueryID { return fmt.Sprintf("%06d (%s): %s: %s", se.Number, se.SQLState, se.QueryID, message) } return fmt.Sprintf("%06d (%s): %s", se.Number, se.SQLState, message) } if se.IncludeQueryID { return fmt.Sprintf("%06d: %s: %s", se.Number, se.QueryID, message) } return fmt.Sprintf("%06d: %s", se.Number, message) } // Snowflake Server Error code const ( QueryNotExecutingCode = "000605" QueryInProgressCode = "333333" QueryInProgressAsyncCode = "333334" SessionExpiredCode = "390112" InvalidOAuthAccessTokenCode = "390303" ExpiredOAuthAccessTokenCode = "390318" ) // Driver return errors const ( /* connection */ // ErrCodeEmptyAccountCode is an error code for the case where a DSN doesn't include account parameter ErrCodeEmptyAccountCode = 260000 // ErrCodeEmptyUsernameCode is an error code for the case where a DSN doesn't include user parameter ErrCodeEmptyUsernameCode = 260001 // ErrCodeEmptyPasswordCode is an error code for the case where a DSN doesn't include password parameter ErrCodeEmptyPasswordCode = 260002 // ErrCodeFailedToParseHost is an error code for the case where a DSN includes an invalid host name ErrCodeFailedToParseHost = 260003 // ErrCodeFailedToParsePort is an error code for the case where a DSN includes an invalid port number ErrCodeFailedToParsePort = 260004 // ErrCodeIdpConnectionError is an error code for the case where a IDP connection failed ErrCodeIdpConnectionError = 260005 // ErrCodeSSOURLNotMatch is an error code for the case where a SSO URL doesn't match ErrCodeSSOURLNotMatch = 260006 // ErrCodeServiceUnavailable is an error code for the case where service is unavailable. ErrCodeServiceUnavailable = 260007 // ErrCodeFailedToConnect is an error code for the case where a DB connection failed due to wrong account name ErrCodeFailedToConnect = 260008 // ErrCodeRegionOverlap is an error code for the case where a region is specified despite an account region present ErrCodeRegionOverlap = 260009 // ErrCodePrivateKeyParseError is an error code for the case where the private key is not parsed correctly ErrCodePrivateKeyParseError = 260010 // ErrCodeFailedToParseAuthenticator is an error code for the case where a DNS includes an invalid authenticator ErrCodeFailedToParseAuthenticator = 260011 // ErrCodeClientConfigFailed is an error code for the case where clientConfigFile is invalid or applying client configuration fails ErrCodeClientConfigFailed = 260012 // ErrCodeTomlFileParsingFailed is an error code for the case where parsing the toml file is failed because of invalid value. ErrCodeTomlFileParsingFailed = 260013 // ErrCodeFailedToFindDSNInToml is an error code for the case where the DSN does not exist in the toml file. ErrCodeFailedToFindDSNInToml = 260014 // ErrCodeInvalidFilePermission is an error code for the case where the user does not have 0600 permission to the toml file. ErrCodeInvalidFilePermission = 260015 // ErrCodeEmptyPasswordAndToken is an error code for the case where a DSN do includes neither password nor token ErrCodeEmptyPasswordAndToken = 260016 // ErrCodeEmptyOAuthParameters is an error code for the case where the client ID or client secret are not provided for OAuth flows. ErrCodeEmptyOAuthParameters = 260017 // ErrMissingAccessATokenButRefreshTokenPresent is an error code for the case when access token is not found in cache, but the refresh token is present. ErrMissingAccessATokenButRefreshTokenPresent = 260018 // ErrCodeMissingTLSConfig is an error code for the case where the TLS config is missing. ErrCodeMissingTLSConfig = 260019 /* network */ // ErrFailedToPostQuery is an error code for the case where HTTP POST failed. ErrFailedToPostQuery = 261000 // ErrFailedToRenewSession is an error code for the case where session renewal failed. ErrFailedToRenewSession = 261001 // ErrFailedToCancelQuery is an error code for the case where cancel query failed. ErrFailedToCancelQuery = 261002 // ErrFailedToCloseSession is an error code for the case where close session failed. ErrFailedToCloseSession = 261003 // ErrFailedToAuth is an error code for the case where authentication failed for unknown reason. ErrFailedToAuth = 261004 // ErrFailedToAuthSAML is an error code for the case where authentication via SAML failed for unknown reason. ErrFailedToAuthSAML = 261005 // ErrFailedToAuthOKTA is an error code for the case where authentication via OKTA failed for unknown reason. ErrFailedToAuthOKTA = 261006 // ErrFailedToGetSSO is an error code for the case where authentication via OKTA failed for unknown reason. ErrFailedToGetSSO = 261007 // ErrFailedToParseResponse is an error code for when we cannot parse an external browser response from Snowflake. ErrFailedToParseResponse = 261008 // ErrFailedToGetExternalBrowserResponse is an error code for when there's an error reading from the open socket. ErrFailedToGetExternalBrowserResponse = 261009 // ErrFailedToHeartbeat is an error code when a heartbeat fails. ErrFailedToHeartbeat = 261010 /* rows */ // ErrFailedToGetChunk is an error code for the case where it failed to get chunk of result set ErrFailedToGetChunk = 262000 // ErrNonArrowResponseInArrowBatches is an error code for case where ArrowBatches mode is enabled, but response is not Arrow-based ErrNonArrowResponseInArrowBatches = 262001 /* transaction*/ // ErrNoReadOnlyTransaction is an error code for the case where readonly mode is specified. ErrNoReadOnlyTransaction = 263000 // ErrNoDefaultTransactionIsolationLevel is an error code for the case where non default isolation level is specified. ErrNoDefaultTransactionIsolationLevel = 263001 /* file transfer */ // ErrInvalidStageFs is an error code denoting an invalid stage in the file system ErrInvalidStageFs = 264001 // ErrFailedToDownloadFromStage is an error code denoting the failure to download a file from the stage ErrFailedToDownloadFromStage = 264002 // ErrFailedToUploadToStage is an error code denoting the failure to upload a file to the stage ErrFailedToUploadToStage = 264003 // ErrInvalidStageLocation is an error code denoting an invalid stage location ErrInvalidStageLocation = 264004 // ErrLocalPathNotDirectory is an error code denoting a local path that is not a directory ErrLocalPathNotDirectory = 264005 // ErrFileNotExists is an error code denoting the file to be transferred does not exist ErrFileNotExists = 264006 // ErrCompressionNotSupported is an error code denoting the user specified compression type is not supported ErrCompressionNotSupported = 264007 // ErrInternalNotMatchEncryptMaterial is an error code denoting the encryption material specified does not match ErrInternalNotMatchEncryptMaterial = 264008 // ErrCommandNotRecognized is an error code denoting the PUT/GET command was not recognized ErrCommandNotRecognized = 264009 // ErrFailedToConvertToS3Client is an error code denoting the failure of an interface to s3.Client conversion ErrFailedToConvertToS3Client = 264010 // ErrNotImplemented is an error code denoting the file transfer feature is not implemented ErrNotImplemented = 264011 // ErrInvalidPadding is an error code denoting the invalid padding of decryption key ErrInvalidPadding = 264012 /* binding */ // ErrBindSerialization is an error code for a failed serialization of bind variables ErrBindSerialization = 265001 // ErrBindUpload is an error code for the uploading process of bind elements to the stage ErrBindUpload = 265002 /* async */ // ErrAsync is an error code for an unknown async error ErrAsync = 266001 /* multi-statement */ // ErrNoResultIDs is an error code for empty result IDs for multi statement queries ErrNoResultIDs = 267001 /* converter */ // ErrInvalidTimestampTz is an error code for the case where a returned TIMESTAMP_TZ internal value is invalid ErrInvalidTimestampTz = 268000 // ErrInvalidOffsetStr is an error code for the case where an offset string is invalid. The input string must // consist of sHHMI where one sign character '+'/'-' followed by zero filled hours and minutes ErrInvalidOffsetStr = 268001 // ErrInvalidBinaryHexForm is an error code for the case where a binary data in hex form is invalid. ErrInvalidBinaryHexForm = 268002 // ErrTooHighTimestampPrecision is an error code for the case where cannot convert Snowflake timestamp to arrow.Timestamp ErrTooHighTimestampPrecision = 268003 // ErrNullValueInArray is an error code for the case where there are null values in an array without arrayValuesNullable set to true ErrNullValueInArray = 268004 // ErrNullValueInMap is an error code for the case where there are null values in a map without mapValuesNullable set to true ErrNullValueInMap = 268005 /* OCSP */ // ErrOCSPStatusRevoked is an error code for the case where the certificate is revoked. ErrOCSPStatusRevoked = 269001 // ErrOCSPStatusUnknown is an error code for the case where the certificate revocation status is unknown. ErrOCSPStatusUnknown = 269002 // ErrOCSPInvalidValidity is an error code for the case where the OCSP response validity is invalid. ErrOCSPInvalidValidity = 269003 // ErrOCSPNoOCSPResponderURL is an error code for the case where the OCSP responder URL is not attached. ErrOCSPNoOCSPResponderURL = 269004 /* query Status*/ // ErrQueryStatus when check the status of a query, receive error or no status ErrQueryStatus = 279001 // ErrQueryIDFormat the query ID given to fetch its result is not valid ErrQueryIDFormat = 279101 // ErrQueryReportedError server side reports the query failed with error ErrQueryReportedError = 279201 // ErrQueryIsRunning the query is still running ErrQueryIsRunning = 279301 /* GS error code */ // ErrSessionGone is an GS error code for the case that session is already closed ErrSessionGone = 390111 // ErrRoleNotExist is a GS error code for the case that the role specified does not exist ErrRoleNotExist = 390189 // ErrObjectNotExistOrAuthorized is a GS error code for the case that the server-side object specified does not exist ErrObjectNotExistOrAuthorized = 390201 ) // Error message templates const ( ErrMsgFailedToParseHost = "failed to parse a host name. host: %v" ErrMsgFailedToParsePort = "failed to parse a port number. port: %v" ErrMsgFailedToParseAuthenticator = "failed to parse an authenticator: %v" ErrMsgInvalidOffsetStr = "offset must be a string consist of sHHMI where one sign character '+'/'-' followed by zero filled hours and minutes: %v" ErrMsgInvalidByteArray = "invalid byte array: %v" ErrMsgIdpConnectionError = "failed to verify URLs. authenticator: %v, token URL:%v, SSO URL:%v" ErrMsgSSOURLNotMatch = "SSO URL didn't match. expected: %v, got: %v" ErrMsgFailedToGetChunk = "failed to get a chunk of result sets. idx: %v" ErrMsgFailedToPostQuery = "failed to POST. HTTP: %v, URL: %v" ErrMsgFailedToRenew = "failed to renew session. HTTP: %v, URL: %v" ErrMsgFailedToCancelQuery = "failed to cancel query. HTTP: %v, URL: %v" ErrMsgFailedToCloseSession = "failed to close session. HTTP: %v, URL: %v" ErrMsgFailedToAuth = "failed to auth for unknown reason. HTTP: %v, URL: %v" ErrMsgFailedToAuthSAML = "failed to auth via SAML for unknown reason. HTTP: %v, URL: %v" ErrMsgFailedToAuthOKTA = "failed to auth via OKTA for unknown reason. HTTP: %v, URL: %v" ErrMsgFailedToGetSSO = "failed to auth via OKTA for unknown reason. HTTP: %v, URL: %v" ErrMsgFailedToParseResponse = "failed to parse a response from Snowflake. Response: %v" ErrMsgFailedToGetExternalBrowserResponse = "failed to get an external browser response from Snowflake, err: %s" ErrMsgNoReadOnlyTransaction = "no readonly mode is supported" ErrMsgNoDefaultTransactionIsolationLevel = "no default isolation transaction level is supported" ErrMsgServiceUnavailable = "service is unavailable. check your connectivity. you may need a proxy server. HTTP: %v, URL: %v" ErrMsgFailedToConnect = "failed to connect to db. verify account name is correct. HTTP: %v, URL: %v" ErrMsgOCSPStatusRevoked = "OCSP revoked: reason:%v, at:%v" ErrMsgOCSPStatusUnknown = "OCSP unknown" ErrMsgOCSPInvalidValidity = "invalid validity: producedAt: %v, thisUpdate: %v, nextUpdate: %v" ErrMsgOCSPNoOCSPResponderURL = "no OCSP server is attached to the certificate. %v" ErrMsgBindColumnMismatch = "column %v has a different number of binds (%v) than column 1 (%v)" ErrMsgNotImplemented = "not implemented" ErrMsgFeatureNotSupported = "feature is not supported: %v" ErrMsgCommandNotRecognized = "%v command not recognized" ErrMsgLocalPathNotDirectory = "the local path is not a directory: %v" ErrMsgFileNotExists = "file does not exist: %v" ErrMsgFailToReadDataFromBuffer = "failed to read data from buffer. err: %v" ErrMsgInvalidStageFs = "destination location type is not valid: %v" ErrMsgInternalNotMatchEncryptMaterial = "number of downloading files doesn't match the encryption materials. files=%v, encmat=%v" ErrMsgFailedToConvertToS3Client = "failed to convert interface to s3 client" ErrMsgNoResultIDs = "no result IDs returned with the multi-statement query" ErrMsgQueryStatus = "server ErrorCode=%s, ErrorMessage=%s" ErrMsgInvalidPadding = "invalid padding on input" ErrMsgClientConfigFailed = "client configuration failed: %v" ErrMsgNullValueInArray = "for handling null values in arrays use WithArrayValuesNullable(ctx)" ErrMsgNullValueInMap = "for handling null values in maps use WithMapValuesNullable(ctx)" ErrMsgFailedToParseTomlFile = "failed to parse toml file. the params %v occurred error with value %v" ErrMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file." ErrMsgInvalidWritablePermissionToFile = "file '%v' is writable by group or others — this poses a security risk because it allows unauthorized users to modify sensitive settings. Your Permission: %v" ErrMsgInvalidExecutablePermissionToFile = "file '%v' is executable — this poses a security risk because the file could be misused as a script or executed unintentionally. Your Permission: %v" ErrMsgNonArrowResponseInArrowBatches = "arrow batches enabled, but the response is not Arrow based" ErrMsgMissingTLSConfig = "TLS config not found: %v" ) // ErrEmptyAccount is returned if a DSN doesn't include account parameter. func ErrEmptyAccount() *SnowflakeError { return &SnowflakeError{ Number: ErrCodeEmptyAccountCode, Message: "account is empty", } } // ErrEmptyUsername is returned if a DSN doesn't include user parameter. func ErrEmptyUsername() *SnowflakeError { return &SnowflakeError{ Number: ErrCodeEmptyUsernameCode, Message: "user is empty", } } // ErrEmptyPassword is returned if a DSN doesn't include password parameter. func ErrEmptyPassword() *SnowflakeError { return &SnowflakeError{ Number: ErrCodeEmptyPasswordCode, Message: "password is empty", } } // ErrEmptyPasswordAndToken is returned if a DSN includes neither password nor token. func ErrEmptyPasswordAndToken() *SnowflakeError { return &SnowflakeError{ Number: ErrCodeEmptyPasswordAndToken, Message: "both password and token are empty", } } // ErrEmptyOAuthParameters is returned if OAuth is used but required fields are missing. func ErrEmptyOAuthParameters() *SnowflakeError { return &SnowflakeError{ Number: ErrCodeEmptyOAuthParameters, Message: "client ID or client secret are empty", } } // ErrRegionConflict is returned if a DSN's implicit and explicit region parameters conflict. func ErrRegionConflict() *SnowflakeError { return &SnowflakeError{ Number: ErrCodeRegionOverlap, Message: "two regions specified", } } // ErrFailedToParseAuthenticator is returned if a DSN includes an invalid authenticator. func ErrFailedToParseAuthenticator() *SnowflakeError { return &SnowflakeError{ Number: ErrCodeFailedToParseAuthenticator, Message: "failed to parse an authenticator", } } // ErrUnknownError is returned if the server side returns an error without meaningful message. func ErrUnknownError() *SnowflakeError { return &SnowflakeError{ Number: -1, SQLState: "-1", Message: "an unknown server side error occurred", QueryID: "-1", } } // ErrNullValueInArrayError is returned for null values in array without arrayValuesNullable. func ErrNullValueInArrayError() *SnowflakeError { return &SnowflakeError{ Number: ErrNullValueInArray, Message: ErrMsgNullValueInArray, } } // ErrNullValueInMapError is returned for null values in map without mapValuesNullable. func ErrNullValueInMapError() *SnowflakeError { return &SnowflakeError{ Number: ErrNullValueInMap, Message: ErrMsgNullValueInMap, } } // ErrNonArrowResponseForArrowBatches is returned when arrow batches mode is enabled but response is not Arrow-based. func ErrNonArrowResponseForArrowBatches(queryID string) *SnowflakeError { return &SnowflakeError{ QueryID: queryID, Number: ErrNonArrowResponseInArrowBatches, Message: ErrMsgNonArrowResponseInArrowBatches, } } ================================================ FILE: internal/logger/accessor.go ================================================ package logger import ( "errors" "log" "sync" "github.com/snowflakedb/gosnowflake/v2/sflog" ) // LoggerAccessor allows internal packages to access the global logger // without importing the main gosnowflake package (avoiding circular dependencies) var ( loggerAccessorMu sync.Mutex // globalLogger is the actual logger that provides all features (secret masking, level filtering, etc.) globalLogger sflog.SFLogger ) // GetLogger returns the global logger for use by internal packages func GetLogger() sflog.SFLogger { loggerAccessorMu.Lock() defer loggerAccessorMu.Unlock() return globalLogger } // SetLogger sets the raw (base) logger implementation and wraps it with the standard protection layers. // This function ALWAYS wraps the provided logger with: // 1. Secret masking (to protect sensitive data) // 2. Level filtering (for performance optimization) // // There is no way to bypass these protective layers. The globalLogger structure is: // // globalLogger = levelFilteringLogger → secretMaskingLogger → rawLogger // // If the provided logger is already wrapped (e.g., from CreateDefaultLogger), this function // automatically extracts the raw logger to prevent double-wrapping. // // Internal wrapper types that would cause issues are rejected: // - Proxy (would cause infinite recursion) func SetLogger(providedLogger SFLogger) error { loggerAccessorMu.Lock() defer loggerAccessorMu.Unlock() // Reject Proxy to prevent infinite recursion if _, isProxy := providedLogger.(*Proxy); isProxy { return errors.New("cannot set Proxy as raw logger - it would create infinite recursion") } // Unwrap if the logger is one of our own wrapper types // This allows SetLogger to accept both raw loggers and fully-wrapped loggers rawLogger := providedLogger // If it's a level filtering logger, unwrap to get the secret masking layer if levelFiltering, ok := rawLogger.(*levelFilteringLogger); ok { rawLogger = levelFiltering.inner } // If it's a secret masking logger, unwrap to get the raw logger if secretMasking, ok := rawLogger.(*secretMaskingLogger); ok { rawLogger = secretMasking.inner } // Build the standard protection chain: levelFiltering → secretMasking → rawLogger masked := newSecretMaskingLogger(rawLogger) filtered := newLevelFilteringLogger(masked) globalLogger = filtered return nil } func init() { rawLogger := newRawLogger() if err := SetLogger(rawLogger); err != nil { log.Panicf("cannot set default logger. %v", err) } } // CreateDefaultLogger function creates a new instance of the default logger with the standard protection layers. func CreateDefaultLogger() sflog.SFLogger { return newLevelFilteringLogger(newSecretMaskingLogger(newRawLogger())) } ================================================ FILE: internal/logger/accessor_test.go ================================================ package logger_test import ( "bytes" "context" "strings" "testing" "github.com/snowflakedb/gosnowflake/v2/internal/logger" ) // TestLoggerConfiguration verifies configuration methods work func TestLoggerConfiguration(t *testing.T) { log := logger.CreateDefaultLogger() // Get current level level := log.GetLogLevel() if level == "" { t.Error("Expected non-empty log level") } t.Logf("Current log level: %s", level) // Set log level err := log.SetLogLevel("debug") if err != nil { t.Errorf("SetLogLevel failed: %v", err) } // Verify it changed newLevel := log.GetLogLevel() if newLevel != "DEBUG" { t.Errorf("Expected 'debug', got '%s'", newLevel) } } // TestLoggerSecretMasking verifies secret masking works func TestLoggerSecretMasking(t *testing.T) { log := logger.CreateDefaultLogger() var buf bytes.Buffer log.SetOutput(&buf) // Reset log level to ensure info is logged _ = log.SetLogLevel("info") // Log a secret log.Infof("password=%s", "secret12345") output := buf.String() t.Logf("Output: %s", output) // Debug output // The output should have a masked secret if strings.Contains(output, "secret12345") { t.Errorf("Secret masking FAILED: secret leaked in: %s", output) } // Verify the message was logged (check for "password=") if !strings.Contains(output, "password=") { t.Errorf("Message not logged: %s", output) } t.Log("Secret masking works with GetLogger") } // TestLoggerAllMethods verifies all logging methods are available and produce output func TestLoggerAllMethods(t *testing.T) { log := logger.CreateDefaultLogger() var buf bytes.Buffer log.SetOutput(&buf) _ = log.SetLogLevel("trace") // Test all formatted methods log.Tracef("trace %s", "formatted") log.Debugf("debug %s", "formatted") log.Infof("info %s", "formatted") log.Warnf("warn %s", "formatted") log.Errorf("error %s", "formatted") // Fatalf would exit, so skip in test // Test all direct methods log.Trace("trace direct") log.Debug("debug direct") log.Info("info direct") log.Warn("warn direct") log.Error("error direct") // Fatal would exit, so skip in test output := buf.String() // Verify all messages appear in output expectedMessages := []string{ "trace formatted", "debug formatted", "info formatted", "warn formatted", "error formatted", "trace direct", "debug direct", "info direct", "warn direct", "error direct", } for _, msg := range expectedMessages { if !strings.Contains(output, msg) { t.Errorf("Expected output to contain '%s', got: %s", msg, output) } } } // TestLoggerLevelFiltering verifies log level filtering works correctly func TestLoggerLevelFiltering(t *testing.T) { log := logger.CreateDefaultLogger() var buf bytes.Buffer log.SetOutput(&buf) // Set to INFO level _ = log.SetLogLevel("info") // Log at different levels log.Debug("this should not appear") log.Info("this should appear") log.Warn("this should also appear") output := buf.String() // Debug should not appear if strings.Contains(output, "this should not appear") { t.Errorf("Debug message appeared when log level is INFO: %s", output) } // Info and Warn should appear if !strings.Contains(output, "this should appear") { t.Errorf("Info message did not appear: %s", output) } if !strings.Contains(output, "this should also appear") { t.Errorf("Warn message did not appear: %s", output) } t.Log("Log level filtering works correctly") } // TestLogEntry verifies log entry methods and field inclusion func TestLogEntry(t *testing.T) { log := logger.CreateDefaultLogger() var buf bytes.Buffer log.SetOutput(&buf) _ = log.SetLogLevel("info") // Get entry with field entry := log.WithField("module", "test") // Log with the entry entry.Infof("info with field %s", "formatted") entry.Info("info with field direct") output := buf.String() // Verify messages appear if !strings.Contains(output, "info with field formatted") { t.Errorf("Expected formatted message in output: %s", output) } if !strings.Contains(output, "info with field direct") { t.Errorf("Expected direct message in output: %s", output) } // Verify field appears in output if !strings.Contains(output, "module") || !strings.Contains(output, "test") { t.Errorf("Expected field 'module=test' in output: %s", output) } t.Log("LogEntry methods work correctly") } // TestLogEntryWithFields verifies WithFields works correctly func TestLogEntryWithFields(t *testing.T) { log := logger.CreateDefaultLogger() var buf bytes.Buffer log.SetOutput(&buf) _ = log.SetLogLevel("info") // Get entry with multiple fields entry := log.WithFields(map[string]any{ "requestId": "123-456", "userId": 42, }) entry.Info("processing request") output := buf.String() // Verify message appears if !strings.Contains(output, "processing request") { t.Errorf("Expected message in output: %s", output) } // Verify both fields appear if !strings.Contains(output, "requestId") { t.Errorf("Expected 'requestId' field in output: %s", output) } if !strings.Contains(output, "123-456") { t.Errorf("Expected '123-456' value in output: %s", output) } if !strings.Contains(output, "userId") { t.Errorf("Expected 'userId' field in output: %s", output) } t.Log("WithFields works correctly") } // TestSetOutput verifies output redirection works correctly func TestSetOutput(t *testing.T) { log := logger.CreateDefaultLogger() // Test with first buffer var buf1 bytes.Buffer log.SetOutput(&buf1) _ = log.SetLogLevel("info") log.Info("message to buffer 1") if !strings.Contains(buf1.String(), "message to buffer 1") { t.Errorf("Expected message in buffer 1: %s", buf1.String()) } // Switch to second buffer var buf2 bytes.Buffer log.SetOutput(&buf2) log.Info("message to buffer 2") // Should appear only in buf2 if !strings.Contains(buf2.String(), "message to buffer 2") { t.Errorf("Expected message in buffer 2: %s", buf2.String()) } // Should NOT appear in buf1 if strings.Contains(buf1.String(), "message to buffer 2") { t.Errorf("Message should not appear in buffer 1: %s", buf1.String()) } t.Log("SetOutput correctly redirects log output") } // TestLogEntryWithContext verifies WithContext works correctly func TestLogEntryWithContext(t *testing.T) { log := logger.CreateDefaultLogger() var buf bytes.Buffer log.SetOutput(&buf) _ = log.SetLogLevel("info") // Create type to avoid collisions type contextKey string // Create context with values ctx := context.WithValue(context.Background(), contextKey("traceId"), "trace-123") // Get entry with context entry := log.WithContext(ctx) entry.Info("message with context") output := buf.String() // Verify message appears if !strings.Contains(output, "message with context") { t.Errorf("Expected message in output: %s", output) } } ================================================ FILE: internal/logger/context.go ================================================ package logger import ( "context" "fmt" "log/slog" "maps" "sync" ) // Storage for log keys and hooks (single source of truth) var ( contextConfigMu sync.RWMutex logKeys []any clientLogContextHooks map[string]ClientLogContextHook ) // SetLogKeys sets the context keys to be extracted from context // This function is thread-safe and can be called at runtime. func SetLogKeys(keys []any) { contextConfigMu.Lock() defer contextConfigMu.Unlock() logKeys = make([]any, len(keys)) copy(logKeys, keys) } // GetLogKeys returns a copy of the current log keys func GetLogKeys() []any { contextConfigMu.RLock() defer contextConfigMu.RUnlock() keysCopy := make([]any, len(logKeys)) copy(keysCopy, logKeys) return keysCopy } // RegisterLogContextHook registers a hook for extracting context fields // This function is thread-safe and can be called at runtime. func RegisterLogContextHook(key string, hook ClientLogContextHook) { contextConfigMu.Lock() defer contextConfigMu.Unlock() if clientLogContextHooks == nil { clientLogContextHooks = make(map[string]ClientLogContextHook) } clientLogContextHooks[key] = hook } // GetClientLogContextHooks returns a copy of registered hooks func GetClientLogContextHooks() map[string]ClientLogContextHook { contextConfigMu.RLock() defer contextConfigMu.RUnlock() hooksCopy := make(map[string]ClientLogContextHook, len(clientLogContextHooks)) maps.Copy(hooksCopy, clientLogContextHooks) return hooksCopy } // extractContextFields extracts log fields from context using LogKeys and ClientLogContextHooks func extractContextFields(ctx context.Context) []slog.Attr { if ctx == nil { return nil } contextConfigMu.RLock() defer contextConfigMu.RUnlock() attrs := make([]slog.Attr, 0) // Built-in LogKeys for _, key := range logKeys { if val := ctx.Value(key); val != nil { keyStr := fmt.Sprint(key) if strVal, ok := val.(string); ok { attrs = append(attrs, slog.String(keyStr, MaskSecrets(strVal))) } else { masked := MaskSecrets(fmt.Sprint(val)) attrs = append(attrs, slog.String(keyStr, masked)) } } } // Custom hooks for key, hook := range clientLogContextHooks { if val := hook(ctx); val != "" { attrs = append(attrs, slog.String(key, MaskSecrets(val))) } } return attrs } ================================================ FILE: internal/logger/easy_logging_support.go ================================================ package logger import ( "fmt" "os" ) // CloseFileOnLoggerReplace closes a log file when the logger is replaced. // This is used by the easy logging feature to manage log file handles. func CloseFileOnLoggerReplace(sflog any, file *os.File) error { // Try to get the underlying default logger if ell, ok := unwrapToEasyLoggingLogger(sflog); ok { return ell.CloseFileOnLoggerReplace(file) } return fmt.Errorf("logger does not support closeFileOnLoggerReplace") } // IsEasyLoggingLogger checks if the given logger is based on the default logger implementation. // This is used by easy logging to determine if reconfiguration is allowed. func IsEasyLoggingLogger(sflog any) bool { _, ok := unwrapToEasyLoggingLogger(sflog) return ok } // unwrapToEasyLoggingLogger unwraps a logger to get to the underlying default logger if present func unwrapToEasyLoggingLogger(sflog any) (EasyLoggingSupport, bool) { current := sflog // Special case: if this is a Proxy, get the actual global logger if _, isProxy := current.(*Proxy); isProxy { current = GetLogger() } // Unwrap all layers for { if u, ok := current.(Unwrapper); ok { current = u.Unwrap() continue } break } // Check if it's a default logger by checking if it has EasyLoggingSupport if ell, ok := current.(EasyLoggingSupport); ok { return ell, true } return nil, false } ================================================ FILE: internal/logger/interfaces.go ================================================ package logger import ( "github.com/snowflakedb/gosnowflake/v2/sflog" ) // Re-export types from sflog package to avoid circular dependencies // while maintaining a clean internal API type ( // LogEntry reexports the LogEntry interface from sflog package. LogEntry = sflog.LogEntry // SFLogger reexports the SFLogger interface from sflog package. SFLogger = sflog.SFLogger // ClientLogContextHook reexports the ClientLogContextHook type from sflog package. ClientLogContextHook = sflog.ClientLogContextHook ) ================================================ FILE: internal/logger/level_filtering.go ================================================ package logger import ( "context" "errors" "github.com/snowflakedb/gosnowflake/v2/sflog" "io" "log/slog" ) // levelFilteringLogger wraps any logger and filters log messages based on log level. // This prevents expensive operations (like secret masking and formatting) from running // when the message wouldn't be logged anyway. type levelFilteringLogger struct { inner SFLogger } // Compile-time verification that levelFilteringLogger implements SFLogger var _ SFLogger = (*levelFilteringLogger)(nil) // Unwrap returns the inner logger (for introspection by easy_logging) func (l *levelFilteringLogger) Unwrap() any { return l.inner } // shouldLog determines if a message at messageLevel should be logged // given the current configured level func (l *levelFilteringLogger) shouldLog(messageLevel sflog.Level) bool { return messageLevel >= l.inner.GetLogLevelInt() } // newLevelFilteringLogger creates a new level filtering wrapper around the provided logger func newLevelFilteringLogger(inner SFLogger) SFLogger { if inner == nil { panic("inner logger cannot be nil") } return &levelFilteringLogger{inner: inner} } // Implement all formatted logging methods (*f variants) func (l *levelFilteringLogger) Tracef(format string, args ...any) { if !l.shouldLog(sflog.LevelTrace) { return } l.inner.Tracef(format, args...) } func (l *levelFilteringLogger) Debugf(format string, args ...any) { if !l.shouldLog(sflog.LevelDebug) { return } l.inner.Debugf(format, args...) } func (l *levelFilteringLogger) Infof(format string, args ...any) { if !l.shouldLog(sflog.LevelInfo) { return } l.inner.Infof(format, args...) } func (l *levelFilteringLogger) Warnf(format string, args ...any) { if !l.shouldLog(sflog.LevelWarn) { return } l.inner.Warnf(format, args...) } func (l *levelFilteringLogger) Errorf(format string, args ...any) { if !l.shouldLog(sflog.LevelError) { return } l.inner.Errorf(format, args...) } func (l *levelFilteringLogger) Fatalf(format string, args ...any) { l.inner.Fatalf(format, args...) } // Implement all direct logging methods func (l *levelFilteringLogger) Trace(msg string) { if !l.shouldLog(sflog.LevelTrace) { return } l.inner.Trace(msg) } func (l *levelFilteringLogger) Debug(msg string) { if !l.shouldLog(sflog.LevelDebug) { return } l.inner.Debug(msg) } func (l *levelFilteringLogger) Info(msg string) { if !l.shouldLog(sflog.LevelInfo) { return } l.inner.Info(msg) } func (l *levelFilteringLogger) Warn(msg string) { if !l.shouldLog(sflog.LevelWarn) { return } l.inner.Warn(msg) } func (l *levelFilteringLogger) Error(msg string) { if !l.shouldLog(sflog.LevelError) { return } l.inner.Error(msg) } func (l *levelFilteringLogger) Fatal(msg string) { l.inner.Fatal(msg) } // Implement structured logging methods - these return wrapped entries func (l *levelFilteringLogger) WithField(key string, value any) sflog.LogEntry { innerEntry := l.inner.WithField(key, value) return &levelFilteringEntry{ parent: l, inner: innerEntry, } } func (l *levelFilteringLogger) WithFields(fields map[string]any) sflog.LogEntry { innerEntry := l.inner.WithFields(fields) return &levelFilteringEntry{ parent: l, inner: innerEntry, } } func (l *levelFilteringLogger) WithContext(ctx context.Context) sflog.LogEntry { innerEntry := l.inner.WithContext(ctx) return &levelFilteringEntry{ parent: l, inner: innerEntry, } } // Delegate configuration methods to inner logger func (l *levelFilteringLogger) SetLogLevel(level string) error { return l.inner.SetLogLevel(level) } func (l *levelFilteringLogger) SetLogLevelInt(level sflog.Level) error { return l.inner.SetLogLevelInt(level) } func (l *levelFilteringLogger) GetLogLevel() string { return l.inner.GetLogLevel() } func (l *levelFilteringLogger) GetLogLevelInt() sflog.Level { return l.inner.GetLogLevelInt() } func (l *levelFilteringLogger) SetOutput(output io.Writer) { l.inner.SetOutput(output) } // SetHandler implements SFSlogLogger interface for advanced slog handler configuration func (l *levelFilteringLogger) SetHandler(handler slog.Handler) error { if sh, ok := l.inner.(sflog.SFSlogLogger); ok { return sh.SetHandler(handler) } return errors.New("underlying logger does not support SetHandler") } // levelFilteringEntry wraps a log entry and filters by level type levelFilteringEntry struct { parent *levelFilteringLogger inner sflog.LogEntry } // Implement all formatted logging methods for entry func (e *levelFilteringEntry) Tracef(format string, args ...any) { if !e.parent.shouldLog(sflog.LevelTrace) { return } e.inner.Tracef(format, args...) } func (e *levelFilteringEntry) Debugf(format string, args ...any) { if !e.parent.shouldLog(sflog.LevelDebug) { return } e.inner.Debugf(format, args...) } func (e *levelFilteringEntry) Infof(format string, args ...any) { if !e.parent.shouldLog(sflog.LevelInfo) { return } e.inner.Infof(format, args...) } func (e *levelFilteringEntry) Warnf(format string, args ...any) { if !e.parent.shouldLog(sflog.LevelWarn) { return } e.inner.Warnf(format, args...) } func (e *levelFilteringEntry) Errorf(format string, args ...any) { if !e.parent.shouldLog(sflog.LevelError) { return } e.inner.Errorf(format, args...) } func (e *levelFilteringEntry) Fatalf(format string, args ...any) { e.inner.Fatalf(format, args...) } // Implement all direct logging methods for entry func (e *levelFilteringEntry) Trace(msg string) { if !e.parent.shouldLog(sflog.LevelTrace) { return } e.inner.Trace(msg) } func (e *levelFilteringEntry) Debug(msg string) { if !e.parent.shouldLog(sflog.LevelDebug) { return } e.inner.Debug(msg) } func (e *levelFilteringEntry) Info(msg string) { if !e.parent.shouldLog(sflog.LevelInfo) { return } e.inner.Info(msg) } func (e *levelFilteringEntry) Warn(msg string) { if !e.parent.shouldLog(sflog.LevelWarn) { return } e.inner.Warn(msg) } func (e *levelFilteringEntry) Error(msg string) { if !e.parent.shouldLog(sflog.LevelError) { return } e.inner.Error(msg) } func (e *levelFilteringEntry) Fatal(msg string) { e.inner.Fatal(msg) } ================================================ FILE: internal/logger/optional_interfaces.go ================================================ package logger import "os" // EasyLoggingSupport is an optional interface for loggers that support easy_logging.go // functionality. This is used for file-based logging configuration. type EasyLoggingSupport interface { // CloseFileOnLoggerReplace closes the logger's file handle when logger is replaced CloseFileOnLoggerReplace(file *os.File) error } // Unwrapper is a common interface for unwrapping wrapped loggers type Unwrapper interface { Unwrap() any } ================================================ FILE: internal/logger/proxy.go ================================================ package logger import ( "context" "fmt" "io" "log/slog" "github.com/snowflakedb/gosnowflake/v2/sflog" ) // Proxy is a proxy that delegates all calls to the global logger. // This ensures a single source of truth for the current logger. type Proxy struct{} // Compile-time verification that Proxy implements SFLogger var _ sflog.SFLogger = (*Proxy)(nil) // Tracef implements the Tracef method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Tracef(format string, args ...any) { GetLogger().Tracef(format, args...) } // Debugf implements the Debugf method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Debugf(format string, args ...any) { GetLogger().Debugf(format, args...) } // Infof implements the Infof method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Infof(format string, args ...any) { GetLogger().Infof(format, args...) } // Warnf implements the Warnf method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Warnf(format string, args ...any) { GetLogger().Warnf(format, args...) } // Errorf implements the Errorf method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Errorf(format string, args ...any) { GetLogger().Errorf(format, args...) } // Fatalf implements the Fatalf method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Fatalf(format string, args ...any) { GetLogger().Fatalf(format, args...) } // Trace implements the Trace method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Trace(msg string) { GetLogger().Trace(msg) } // Debug implements the Debug method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Debug(msg string) { GetLogger().Debug(msg) } // Info implements the Info method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Info(msg string) { GetLogger().Info(msg) } // Warn implements the Warn method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Warn(msg string) { GetLogger().Warn(msg) } // Error implements the Error method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Error(msg string) { GetLogger().Error(msg) } // Fatal implements the Fatal method of the SFLogger interface by delegating to the global logger. func (p *Proxy) Fatal(msg string) { GetLogger().Fatal(msg) } // WithField implements the WithField method of the SFLogger interface by delegating to the global logger. func (p *Proxy) WithField(key string, value any) sflog.LogEntry { return GetLogger().WithField(key, value) } // WithFields implements the WithFields method of the SFLogger interface by delegating to the global logger. func (p *Proxy) WithFields(fields map[string]any) sflog.LogEntry { return GetLogger().WithFields(fields) } // WithContext implements the WithContext method of the SFLogger interface by delegating to the global logger. func (p *Proxy) WithContext(ctx context.Context) sflog.LogEntry { return GetLogger().WithContext(ctx) } // SetLogLevel implements the SetLogLevel method of the SFLogger interface by delegating to the global logger. func (p *Proxy) SetLogLevel(level string) error { return GetLogger().SetLogLevel(level) } // SetLogLevelInt implements the SetLogLevelInt method of the SFLogger interface by delegating to the global logger. func (p *Proxy) SetLogLevelInt(level sflog.Level) error { return GetLogger().SetLogLevelInt(level) } // GetLogLevel implements the GetLogLevel method of the SFLogger interface by delegating to the global logger. func (p *Proxy) GetLogLevel() string { return GetLogger().GetLogLevel() } // GetLogLevelInt implements the GetLogLevelInt method of the SFLogger interface by delegating to the global logger. func (p *Proxy) GetLogLevelInt() sflog.Level { return GetLogger().GetLogLevelInt() } // SetOutput implements the SetOutput method of the SFLogger interface by delegating to the global logger. func (p *Proxy) SetOutput(output io.Writer) { GetLogger().SetOutput(output) } // SetHandler implements SFSlogLogger interface for advanced slog handler configuration. // This delegates to the underlying logger if it supports SetHandler. func (p *Proxy) SetHandler(handler slog.Handler) error { logger := GetLogger() if sl, ok := logger.(sflog.SFSlogLogger); ok { return sl.SetHandler(handler) } return fmt.Errorf("underlying logger does not support SetHandler") } // NewLoggerProxy creates a new logger proxy that delegates all calls // to the global logger managed by the internal package. func NewLoggerProxy() sflog.SFLogger { return &Proxy{} } ================================================ FILE: internal/logger/secret_detector.go ================================================ package logger import ( "regexp" ) const ( awsKeyPattern = `(?i)(aws_key_id|aws_secret_key|access_key_id|secret_access_key)\s*=\s*'([^']+)'` awsTokenPattern = `(?i)(accessToken|tempToken|keySecret)"\s*:\s*"([a-z0-9/+]{32,}={0,2})"` sasTokenPattern = `(?i)(sig|signature|AWSAccessKeyId|password|passcode)=(?P[a-z0-9%/+]{16,})` privateKeyPattern = `(?im)-----BEGIN PRIVATE KEY-----\\n([a-z0-9/+=\\n]{32,})\\n-----END PRIVATE KEY-----` // pragma: allowlist secret privateKeyDataPattern = `(?i)"privateKeyData": "([a-z0-9/+=\\n]{10,})"` privateKeyParamPattern = `(?i)privateKey=([A-Za-z0-9/+=_%-]+)(&|$|\s)` connectionTokenPattern = `(?i)(token|assertion content)([\'\"\s:=]+)([a-z0-9=/_\-\+]{8,})` passwordPattern = `(?i)(password|pwd)([\'\"\s:=]+)([a-z0-9!\"#\$%&\\\'\(\)\*\+\,-\./:;<=>\?\@\[\]\^_\{\|\}~]{8,})` dsnPasswordPattern = `([^/:]+):([^@/:]{3,})@` // Matches user:password@host format in DSN strings clientSecretPattern = `(?i)(clientSecret)([\'\"\s:= ]+)([a-z0-9!\"#\$%&\\\'\(\)\*\+\,-\./:;<=>\?\@\[\]\^_\{\|\}~]+)` jwtTokenPattern = `(?i)(jwt|bearer)[\s:=]*([a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+)` // pragma: allowlist secret ) type patternAndReplace struct { regex *regexp.Regexp replacement string } var secretDetectorPatterns = []patternAndReplace{ {regexp.MustCompile(awsKeyPattern), "$1=****$2"}, {regexp.MustCompile(awsTokenPattern), "${1}XXXX$2"}, {regexp.MustCompile(sasTokenPattern), "${1}****$2"}, {regexp.MustCompile(privateKeyPattern), "-----BEGIN PRIVATE KEY-----\\\\\\\\nXXXX\\\\\\\\n-----END PRIVATE KEY-----"}, // pragma: allowlist secret {regexp.MustCompile(privateKeyDataPattern), `"privateKeyData": "XXXX"`}, {regexp.MustCompile(privateKeyParamPattern), "privateKey=****$2"}, {regexp.MustCompile(connectionTokenPattern), "$1${2}****"}, {regexp.MustCompile(passwordPattern), "$1${2}****"}, {regexp.MustCompile(dsnPasswordPattern), "$1:****@"}, {regexp.MustCompile(clientSecretPattern), "$1${2}****"}, {regexp.MustCompile(jwtTokenPattern), "$1 ****"}, } // MaskSecrets masks secrets in text (exported for use by main package and secret masking logger) func MaskSecrets(text string) (masked string) { res := text for _, pattern := range secretDetectorPatterns { res = pattern.regex.ReplaceAllString(res, pattern.replacement) } return res } ================================================ FILE: internal/logger/secret_detector_test.go ================================================ package logger import ( "fmt" "testing" "time" "github.com/golang-jwt/jwt/v5" ) const ( longToken = "_Y1ZNETTn5/qfUWj3Jedby7gipDzQs=UKyJH9DS=nFzzWnfZKGV+C7GopWC" + // pragma: allowlist secret "GD4LjOLLFZKOE26LXHDt3pTi4iI1qwKuSpf/FmClCMBSissVsU3Ei590FP0lPQQhcSG" + // pragma: allowlist secret "cDu69ZL_1X6e9h5z62t/iY7ZkII28n2qU=nrBJUgPRCIbtJQkVJXIuOHjX4G5yUEKjZ" + // pragma: allowlist secret "BAx4w6=_lqtt67bIA=o7D=oUSjfywsRFoloNIkBPXCwFTv+1RVUHgVA2g8A9Lw5XdJY" + // pragma: allowlist secret "uI8vhg=f0bKSq7AhQ2Bh" randomPassword = `Fh[+2J~AcqeqW%?` falsePositiveToken = "2020-04-30 23:06:04,069 - MainThread auth.py:397" + " - write_temporary_credential() - DEBUG - no ID token is given when " + "try to store temporary credential" ) // generateTestJWT creates a test JWT token for masking tests using the JWT library func generateTestJWT(t *testing.T) string { // Create claims for the test JWT claims := jwt.MapClaims{ "sub": "test123", "name": "Test User", "exp": time.Now().Add(time.Hour).Unix(), "iat": time.Now().Unix(), } // Create the token with HS256 signing method token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) // Sign the token with a test secret testSecret := []byte("test-secret-for-masking-validation") tokenString, err := token.SignedString(testSecret) if err != nil { // Fallback to a simple test JWT if signing fails t.Fatalf("Failed to generate test JWT: %s", err) } return tokenString } func TestSecretsDetector(t *testing.T) { testCases := []struct { name string input string expected string }{ // Token masking tests {"Token with equals", fmt.Sprintf("Token =%s", longToken), "Token =****"}, {"idToken with colon space", fmt.Sprintf("idToken : %s", longToken), "idToken : ****"}, {"sessionToken with colon space", fmt.Sprintf("sessionToken : %s", longToken), "sessionToken : ****"}, {"masterToken with colon space", fmt.Sprintf("masterToken : %s", longToken), "masterToken : ****"}, {"accessToken with colon space", fmt.Sprintf("accessToken : %s", longToken), "accessToken : ****"}, {"refreshToken with colon space", fmt.Sprintf("refreshToken : %s", longToken), "refreshToken : ****"}, {"programmaticAccessToken with colon space", fmt.Sprintf("programmaticAccessToken : %s", longToken), "programmaticAccessToken : ****"}, {"programmatic_access_token with colon space", fmt.Sprintf("programmatic_access_token : %s", longToken), "programmatic_access_token : ****"}, {"JWT - with Bearer prefix", fmt.Sprintf("Bearer %s", generateTestJWT(t)), "Bearer ****"}, {"JWT - with JWT prefix", fmt.Sprintf("JWT %s", generateTestJWT(t)), "JWT ****"}, // Password masking tests {"password with colon", fmt.Sprintf("password:%s", randomPassword), "password:****"}, {"PASSWORD uppercase with colon", fmt.Sprintf("PASSWORD:%s", randomPassword), "PASSWORD:****"}, {"PaSsWoRd mixed case with colon", fmt.Sprintf("PaSsWoRd:%s", randomPassword), "PaSsWoRd:****"}, {"password with equals and spaces", fmt.Sprintf("password = %s", randomPassword), "password = ****"}, {"pwd with colon", fmt.Sprintf("pwd:%s", randomPassword), "pwd:****"}, // Mixed token and password tests { "token and password mixed", fmt.Sprintf("token=%s foo bar baz password:%s", longToken, randomPassword), "token=**** foo bar baz password:****", }, { "PWD and TOKEN mixed", fmt.Sprintf("PWD = %s blah blah blah TOKEN:%s", randomPassword, longToken), "PWD = **** blah blah blah TOKEN:****", }, // Client secret tests {"clientSecret with values", "clientSecret abc oauthClientSECRET=def", "clientSecret **** oauthClientSECRET=****"}, // False positive test {"false positive should not be masked", falsePositiveToken, falsePositiveToken}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { result := MaskSecrets(tc.input) if result != tc.expected { t.Errorf("expected %q to be equal to %q but was not", result, tc.expected) } }) } } ================================================ FILE: internal/logger/secret_masking.go ================================================ package logger import ( "context" "fmt" "github.com/snowflakedb/gosnowflake/v2/sflog" "io" "log/slog" ) // secretMaskingLogger wraps any logger implementation and ensures // all log messages have secrets masked before being passed to the inner logger. type secretMaskingLogger struct { inner SFLogger } // Compile-time verification that secretMaskingLogger implements SFLogger var _ SFLogger = (*secretMaskingLogger)(nil) // Unwrap returns the inner logger (for introspection by easy_logging) func (l *secretMaskingLogger) Unwrap() any { return l.inner } // newSecretMaskingLogger creates a new secret masking wrapper around the provided logger. func newSecretMaskingLogger(inner SFLogger) *secretMaskingLogger { if inner == nil { panic("inner logger cannot be nil") } return &secretMaskingLogger{inner: inner} } // Helper methods for masking func (l *secretMaskingLogger) maskValue(value any) any { if str, ok := value.(string); ok { return l.maskString(str) } // For other types, convert to string, mask, but return original type if no secrets strVal := fmt.Sprint(value) masked := l.maskString(strVal) if masked != strVal { return masked // Secrets found and masked } return value // No secrets, return original } func (l *secretMaskingLogger) maskString(value string) string { return MaskSecrets(value) } // Implement all formatted logging methods (*f variants) func (l *secretMaskingLogger) Tracef(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := l.maskString(message) l.inner.Trace(maskedMessage) } func (l *secretMaskingLogger) Debugf(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := l.maskString(message) l.inner.Debug(maskedMessage) } func (l *secretMaskingLogger) Infof(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := l.maskString(message) l.inner.Info(maskedMessage) } func (l *secretMaskingLogger) Warnf(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := l.maskString(message) l.inner.Warn(maskedMessage) } func (l *secretMaskingLogger) Errorf(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := l.maskString(message) l.inner.Error(maskedMessage) } func (l *secretMaskingLogger) Fatalf(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := l.maskString(message) l.inner.Fatal(maskedMessage) } // Implement all direct logging methods func (l *secretMaskingLogger) Trace(msg string) { l.inner.Trace(l.maskString(msg)) } func (l *secretMaskingLogger) Debug(msg string) { l.inner.Debug(l.maskString(msg)) } func (l *secretMaskingLogger) Info(msg string) { l.inner.Info(l.maskString(msg)) } func (l *secretMaskingLogger) Warn(msg string) { l.inner.Warn(l.maskString(msg)) } func (l *secretMaskingLogger) Error(msg string) { l.inner.Error(l.maskString(msg)) } func (l *secretMaskingLogger) Fatal(msg string) { l.inner.Fatal(l.maskString(msg)) } // Implement structured logging methods // Note: These return LogEntry to maintain compatibility with the adapter layer func (l *secretMaskingLogger) WithField(key string, value any) LogEntry { maskedValue := l.maskValue(value) result := l.inner.WithField(key, maskedValue) return &secretMaskingEntry{ inner: result, parent: l, } } func (l *secretMaskingLogger) WithFields(fields map[string]any) LogEntry { maskedFields := make(map[string]any, len(fields)) for k, v := range fields { maskedFields[k] = l.maskValue(v) } result := l.inner.WithFields(maskedFields) return &secretMaskingEntry{ inner: result, parent: l, } } func (l *secretMaskingLogger) WithContext(ctx context.Context) LogEntry { result := l.inner.WithContext(ctx) return &secretMaskingEntry{ inner: result, parent: l, } } // Delegate configuration methods func (l *secretMaskingLogger) SetLogLevel(level string) error { return l.inner.SetLogLevel(level) } func (l *secretMaskingLogger) SetLogLevelInt(level sflog.Level) error { return l.inner.SetLogLevelInt(level) } func (l *secretMaskingLogger) GetLogLevel() string { return l.inner.GetLogLevel() } func (l *secretMaskingLogger) GetLogLevelInt() sflog.Level { return l.inner.GetLogLevelInt() } func (l *secretMaskingLogger) SetOutput(output io.Writer) { l.inner.SetOutput(output) } // SetHandler delegates to inner logger's SetHandler (for slog handler configuration) func (l *secretMaskingLogger) SetHandler(handler slog.Handler) error { if logger, ok := l.inner.(sflog.SFSlogLogger); ok { return logger.SetHandler(handler) } return fmt.Errorf("inner logger does not support SetHandler") } // secretMaskingEntry wraps a log entry and masks all secrets. type secretMaskingEntry struct { inner LogEntry parent *secretMaskingLogger } // Compile-time verification that secretMaskingEntry implements LogEntry var _ LogEntry = (*secretMaskingEntry)(nil) // Implement all formatted logging methods (*f variants) func (e *secretMaskingEntry) Tracef(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := MaskSecrets(message) e.inner.Trace(maskedMessage) } func (e *secretMaskingEntry) Debugf(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := MaskSecrets(message) e.inner.Debug(maskedMessage) } func (e *secretMaskingEntry) Infof(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := MaskSecrets(message) e.inner.Info(maskedMessage) } func (e *secretMaskingEntry) Warnf(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := MaskSecrets(message) e.inner.Warn(maskedMessage) } func (e *secretMaskingEntry) Errorf(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := MaskSecrets(message) e.inner.Error(maskedMessage) } func (e *secretMaskingEntry) Fatalf(format string, args ...any) { message := fmt.Sprintf(format, args...) maskedMessage := MaskSecrets(message) e.inner.Fatal(maskedMessage) } // Implement all direct logging methods func (e *secretMaskingEntry) Trace(msg string) { e.inner.Trace(e.parent.maskString(msg)) } func (e *secretMaskingEntry) Debug(msg string) { e.inner.Debug(e.parent.maskString(msg)) } func (e *secretMaskingEntry) Info(msg string) { e.inner.Info(e.parent.maskString(msg)) } func (e *secretMaskingEntry) Warn(msg string) { e.inner.Warn(e.parent.maskString(msg)) } func (e *secretMaskingEntry) Error(msg string) { e.inner.Error(e.parent.maskString(msg)) } func (e *secretMaskingEntry) Fatal(msg string) { e.inner.Fatal(e.parent.maskString(msg)) } ================================================ FILE: internal/logger/secret_masking_test.go ================================================ package logger import ( "context" "github.com/snowflakedb/gosnowflake/v2/sflog" "io" "testing" ) // mockLogger is a simple logger implementation for testing type mockLogger struct { lastMessage string } func (m *mockLogger) Tracef(format string, args ...any) {} func (m *mockLogger) Debugf(format string, args ...any) {} func (m *mockLogger) Infof(format string, args ...any) {} func (m *mockLogger) Warnf(format string, args ...any) {} func (m *mockLogger) Errorf(format string, args ...any) {} func (m *mockLogger) Fatalf(format string, args ...any) {} func (m *mockLogger) Trace(msg string) {} func (m *mockLogger) Debug(msg string) {} func (m *mockLogger) Info(msg string) { m.lastMessage = msg } func (m *mockLogger) Warn(msg string) {} func (m *mockLogger) Error(msg string) {} func (m *mockLogger) Fatal(msg string) {} func (m *mockLogger) WithField(key string, value any) LogEntry { return m } func (m *mockLogger) WithFields(fields map[string]any) LogEntry { return m } func (m *mockLogger) WithContext(ctx context.Context) LogEntry { return m } func (m *mockLogger) SetLogLevel(level string) error { return nil } func (m *mockLogger) SetLogLevelInt(level sflog.Level) error { return nil } func (m *mockLogger) GetLogLevel() string { return "info" } func (m *mockLogger) GetLogLevelInt() sflog.Level { return sflog.LevelInfo } func (m *mockLogger) SetOutput(output io.Writer) {} // Compile-time verification that mockLogger implements SFLogger var _ SFLogger = (*mockLogger)(nil) func TestSecretMaskingLogger(t *testing.T) { mock := &mockLogger{} logger := newSecretMaskingLogger(mock) // Use a real password pattern that will be masked logger.Infof("test message with %s", "password:secret123") // Secret masking logger formats the message, masks it, then passes with "%s" format if mock.lastMessage != "test message with password:****" { t.Errorf("Expected format string to be '%%s', got %s", mock.lastMessage) } // The masked message should have been passed as the first arg // (We can't check this with the current mock, but we verified it works in other tests) } ================================================ FILE: internal/logger/slog_handler.go ================================================ package logger import ( "context" "github.com/snowflakedb/gosnowflake/v2/sflog" "log/slog" ) // snowflakeHandler wraps slog.Handler and adds context field extraction type snowflakeHandler struct { inner slog.Handler levelVar *slog.LevelVar } func newSnowflakeHandler(inner slog.Handler, level sflog.Level) *snowflakeHandler { levelVar := &slog.LevelVar{} levelVar.Set(slog.Level(level)) return &snowflakeHandler{ inner: inner, levelVar: levelVar, } } // Enabled checks if the handler is enabled for the given level func (h *snowflakeHandler) Enabled(ctx context.Context, level slog.Level) bool { return h.inner.Enabled(ctx, level) } // Handle processes a log record func (h *snowflakeHandler) Handle(ctx context.Context, r slog.Record) error { // NOTE: Context field extraction is NOT done here because: // - If WithContext() was used, fields are already added to the logger via .With() // - If WithContext() was not used, the context passed here is typically context.Background() // and wouldn't have any fields anyway // Secret masking is already done in secretMaskingLogger wrapper return h.inner.Handle(ctx, r) } // WithAttrs creates a new handler with additional attributes func (h *snowflakeHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return &snowflakeHandler{ inner: h.inner.WithAttrs(attrs), levelVar: h.levelVar, } } // WithGroup creates a new handler with a group func (h *snowflakeHandler) WithGroup(name string) slog.Handler { return &snowflakeHandler{ inner: h.inner.WithGroup(name), levelVar: h.levelVar, } } ================================================ FILE: internal/logger/slog_logger.go ================================================ package logger import ( "context" "fmt" "github.com/snowflakedb/gosnowflake/v2/sflog" "io" "log/slog" "os" "path" "runtime" "strings" "sync" "time" ) // formatSource formats caller information for logging func formatSource(frame *runtime.Frame) (string, string) { return path.Base(frame.Function), fmt.Sprintf("%s:%d", path.Base(frame.File), frame.Line) } // rawLogger implements SFLogger using slog type rawLogger struct { inner *slog.Logger handler *snowflakeHandler level sflog.Level enabled bool // For OFF level support file *os.File output io.Writer mu sync.Mutex } // Compile-time verification that rawLogger implements SFLogger var _ SFLogger = (*rawLogger)(nil) // newRawLogger creates the internal default logger using slog func newRawLogger() SFLogger { level := sflog.LevelInfo opts := createOpts(slog.Level(level)) textHandler := slog.NewTextHandler(os.Stderr, opts) handler := newSnowflakeHandler(textHandler, level) slogLogger := slog.New(handler) return &rawLogger{ inner: slogLogger, handler: handler, level: level, enabled: true, output: os.Stderr, } } // isEnabled checks if logging is enabled (for OFF level) func (log *rawLogger) isEnabled() bool { log.mu.Lock() defer log.mu.Unlock() return log.enabled } // SetLogLevel sets the log level func (log *rawLogger) SetLogLevel(level string) error { upperLevel, err := sflog.ParseLevel(strings.ToUpper(level)) if err != nil { return fmt.Errorf("error while setting log level. %v", err) } if upperLevel == sflog.LevelOff { log.mu.Lock() log.level = sflog.LevelOff log.enabled = false log.mu.Unlock() return nil } log.mu.Lock() log.enabled = true log.level = upperLevel log.mu.Unlock() return nil } func (log *rawLogger) SetLogLevelInt(level sflog.Level) error { log.mu.Lock() defer log.mu.Unlock() _, err := sflog.LevelToString(level) if err != nil { return fmt.Errorf("invalid log level: %d", level) } log.level = level return nil } // GetLogLevel returns the current log level func (log *rawLogger) GetLogLevel() string { if levelStr, err := sflog.LevelToString(log.level); err == nil { return levelStr } return "unknown" } func (log *rawLogger) GetLogLevelInt() sflog.Level { log.mu.Lock() defer log.mu.Unlock() return log.level } // SetOutput sets the output writer func (log *rawLogger) SetOutput(output io.Writer) { log.mu.Lock() defer log.mu.Unlock() log.output = output // Create new handler with new output opts := createOpts(slog.Level(log.level)) textHandler := slog.NewTextHandler(output, opts) log.handler = newSnowflakeHandler(textHandler, log.level) log.inner = slog.New(log.handler) } func createOpts(level slog.Level) *slog.HandlerOptions { opts := &slog.HandlerOptions{ Level: level, AddSource: true, ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { if a.Key == slog.TimeKey { if t, ok := a.Value.Any().(time.Time); ok { return slog.String(slog.TimeKey, t.Format(time.RFC3339Nano)) } } if a.Key == slog.SourceKey { if src, ok := a.Value.Any().(*slog.Source); ok { frame := &runtime.Frame{ File: src.File, Line: src.Line, Function: src.Function, } _, location := formatSource(frame) return slog.String(slog.SourceKey, location) } } return a }, } return opts } // SetHandler sets a custom slog handler (implements SFSlogLogger interface) // The provided handler will be wrapped with snowflakeHandler to preserve context extraction. // Secret masking is handled at a higher level (secretMaskingLogger wrapper). func (log *rawLogger) SetHandler(handler slog.Handler) error { log.mu.Lock() defer log.mu.Unlock() // Wrap user's handler with snowflakeHandler to preserve context extraction log.handler = newSnowflakeHandler(handler, log.level) log.inner = slog.New(log.handler) return nil } // logWithSkip logs a message at the given level, skipping 'skip' frames when determining source location. // This is used internally to skip wrapper frames (levelFilteringLogger -> secretMaskingLogger -> rawLogger) // and report the actual caller's location. func (log *rawLogger) logWithSkip(skip int, level sflog.Level, msg string) { if !log.isEnabled() { return } var pcs [1]uintptr // Skip: runtime.Callers itself + logWithSkip + specified skip runtime.Callers(skip+2, pcs[:]) r := slog.NewRecord(time.Now(), slog.Level(level), msg, pcs[0]) _ = log.handler.Handle(context.Background(), r) } // Implement all formatted logging methods (*f variants) // Skip depth = 3 assumes standard wrapper chain: levelFilteringLogger -> secretMaskingLogger -> rawLogger // If wrapper chain changes, update this value. See TestSkipDepthWarning test. func (log *rawLogger) Tracef(format string, args ...any) { log.logWithSkip(3, sflog.LevelTrace, fmt.Sprintf(format, args...)) } func (log *rawLogger) Debugf(format string, args ...any) { log.logWithSkip(3, sflog.LevelDebug, fmt.Sprintf(format, args...)) } func (log *rawLogger) Infof(format string, args ...any) { log.logWithSkip(3, sflog.LevelInfo, fmt.Sprintf(format, args...)) } func (log *rawLogger) Warnf(format string, args ...any) { log.logWithSkip(3, sflog.LevelWarn, fmt.Sprintf(format, args...)) } func (log *rawLogger) Errorf(format string, args ...any) { log.logWithSkip(3, sflog.LevelError, fmt.Sprintf(format, args...)) } func (log *rawLogger) Fatalf(format string, args ...any) { log.logWithSkip(3, sflog.LevelFatal, fmt.Sprintf(format, args...)) os.Exit(1) } // Implement all direct logging methods // Skip depth = 3 assumes standard wrapper chain: levelFilteringLogger -> secretMaskingLogger -> rawLogger // If wrapper chain changes, update this value. See TestSkipDepthWarning test. func (log *rawLogger) Trace(msg string) { log.logWithSkip(3, sflog.LevelTrace, msg) } func (log *rawLogger) Debug(msg string) { log.logWithSkip(3, sflog.LevelDebug, msg) } func (log *rawLogger) Info(msg string) { log.logWithSkip(3, sflog.LevelInfo, msg) } func (log *rawLogger) Warn(msg string) { log.logWithSkip(3, sflog.LevelWarn, msg) } func (log *rawLogger) Error(msg string) { log.logWithSkip(3, sflog.LevelError, msg) } func (log *rawLogger) Fatal(msg string) { log.logWithSkip(3, sflog.LevelFatal, msg) os.Exit(1) } // Structured logging methods func (log *rawLogger) WithField(key string, value any) LogEntry { return &slogEntry{ logger: log.inner.With(slog.Any(key, value)), enabled: &log.enabled, mu: &log.mu, } } func (log *rawLogger) WithFields(fields map[string]any) LogEntry { attrs := make([]any, 0, len(fields)*2) for k, v := range fields { attrs = append(attrs, k, v) } return &slogEntry{ logger: log.inner.With(attrs...), enabled: &log.enabled, mu: &log.mu, } } func (log *rawLogger) WithContext(ctx context.Context) LogEntry { if ctx == nil { return log } // Extract fields from context attrs := extractContextFields(ctx) if len(attrs) == 0 { return log } // Convert []slog.Attr to []any for With() // slog.Logger.With() can accept slog.Attr directly args := make([]any, len(attrs)) for i, attr := range attrs { args[i] = attr } newLogger := log.inner.With(args...) return &slogEntry{ logger: newLogger, enabled: &log.enabled, mu: &log.mu, } } // slogEntry implements LogEntry type slogEntry struct { logger *slog.Logger enabled *bool mu *sync.Mutex } // Compile-time verification that slogEntry implements LogEntry var _ LogEntry = (*slogEntry)(nil) func (e *slogEntry) isEnabled() bool { e.mu.Lock() defer e.mu.Unlock() return *e.enabled } // logWithSkip logs a message at the given level, skipping 'skip' frames when determining source location. func (e *slogEntry) logWithSkip(skip int, level sflog.Level, msg string) { if !e.isEnabled() { return } var pcs [1]uintptr runtime.Callers(skip+2, pcs[:]) // +2: runtime.Callers itself + logWithSkip r := slog.NewRecord(time.Now(), slog.Level(level), msg, pcs[0]) _ = e.logger.Handler().Handle(context.Background(), r) } // Implement all formatted logging methods (*f variants) // Skip depth = 3 assumes standard wrapper chain: levelFilteringEntry -> secretMaskingEntry -> slogEntry // If wrapper chain changes, update this value. See TestSkipDepthWarning test. func (e *slogEntry) Tracef(format string, args ...any) { e.logWithSkip(3, sflog.LevelTrace, fmt.Sprintf(format, args...)) } func (e *slogEntry) Debugf(format string, args ...any) { e.logWithSkip(3, sflog.LevelDebug, fmt.Sprintf(format, args...)) } func (e *slogEntry) Infof(format string, args ...any) { e.logWithSkip(3, sflog.LevelInfo, fmt.Sprintf(format, args...)) } func (e *slogEntry) Warnf(format string, args ...any) { e.logWithSkip(3, sflog.LevelWarn, fmt.Sprintf(format, args...)) } func (e *slogEntry) Errorf(format string, args ...any) { e.logWithSkip(3, sflog.LevelError, fmt.Sprintf(format, args...)) } func (e *slogEntry) Fatalf(format string, args ...any) { e.logWithSkip(3, sflog.LevelFatal, fmt.Sprintf(format, args...)) os.Exit(1) } // Implement all direct logging methods // Skip depth = 3 assumes standard wrapper chain: levelFilteringEntry -> secretMaskingEntry -> slogEntry // If wrapper chain changes, update this value. See TestSkipDepthWarning test. func (e *slogEntry) Trace(msg string) { e.logWithSkip(3, sflog.LevelTrace, msg) } func (e *slogEntry) Debug(msg string) { e.logWithSkip(3, sflog.LevelDebug, msg) } func (e *slogEntry) Info(msg string) { e.logWithSkip(3, sflog.LevelInfo, msg) } func (e *slogEntry) Warn(msg string) { e.logWithSkip(3, sflog.LevelWarn, msg) } func (e *slogEntry) Error(msg string) { e.logWithSkip(3, sflog.LevelError, msg) } func (e *slogEntry) Fatal(msg string) { e.logWithSkip(3, sflog.LevelFatal, msg) os.Exit(1) } // Helper methods for internal use and easy_logging support func (log *rawLogger) closeFileOnLoggerReplace(file *os.File) error { log.mu.Lock() defer log.mu.Unlock() if log.file != nil && log.file != file { return fmt.Errorf("could not set a file to close on logger reset because there were already set one") } log.file = file return nil } // CloseFileOnLoggerReplace is exported for easy_logging support func (log *rawLogger) CloseFileOnLoggerReplace(file *os.File) error { return log.closeFileOnLoggerReplace(file) } // ReplaceGlobalLogger closes the current logger's file (for easy_logging support) // The actual global logger replacement is handled by the main package func (log *rawLogger) ReplaceGlobalLogger(newLogger any) { if log.file != nil { _ = log.file.Close() } } // Ensure rawLogger implements SFLogger var _ SFLogger = (*rawLogger)(nil) ================================================ FILE: internal/logger/source_location_test.go ================================================ package logger import ( "bytes" "strings" "testing" ) // IMPORTANT: The skip depth values in rawLogger and slogEntry assume the standard wrapper chain: // For logger methods: levelFilteringLogger -> secretMaskingLogger -> rawLogger (skip=3) // For entry methods: levelFilteringEntry -> secretMaskingEntry -> slogEntry (skip=3) // // These tests verify the standard configuration. If you add or remove wrapper layers, you MUST update: // - internal/logger/slog_logger.go: rawLogger methods (currently skip=3) // - internal/logger/slog_logger.go: slogEntry methods (currently skip=3) // TestSourceLocationWithLevelFiltering verifies that source location is correct // with the standard wrapper chain: levelFilteringLogger -> secretMaskingLogger -> rawLogger func TestSourceLocationWithLevelFiltering(t *testing.T) { innerLogger := newRawLogger() var buf bytes.Buffer innerLogger.SetOutput(&buf) _ = innerLogger.SetLogLevel("debug") // Build the standard wrapper chain masked := newSecretMaskingLogger(innerLogger) filtered := newLevelFilteringLogger(masked) filtered.Debug("test message") // Line 31 - This line should appear in source location output := buf.String() // Check that the source location points to this test file, not the wrappers if !strings.Contains(output, "source_location_test.go") { t.Errorf("Expected source location to contain 'source_location_test.go', got: %s", output) } if strings.Contains(output, "level_filtering.go") { t.Errorf("Source location should not contain 'level_filtering.go', got: %s", output) } if strings.Contains(output, "secret_masking.go") { t.Errorf("Source location should not contain 'secret_masking.go', got: %s", output) } } // TestSourceLocationWithDebugf verifies formatted logging also reports correct source func TestSourceLocationWithDebugf(t *testing.T) { innerLogger := newRawLogger() var buf bytes.Buffer innerLogger.SetOutput(&buf) _ = innerLogger.SetLogLevel("debug") // Build the standard wrapper chain masked := newSecretMaskingLogger(innerLogger) filtered := newLevelFilteringLogger(masked) filtered.Debugf("formatted message: %s", "test") // Line 58 - This line should appear output := buf.String() if !strings.Contains(output, "source_location_test.go") { t.Errorf("Expected source location to contain 'source_location_test.go', got: %s", output) } if strings.Contains(output, "level_filtering.go") || strings.Contains(output, "secret_masking.go") { t.Errorf("Source location should not contain wrapper files, got: %s", output) } } // TestSourceLocationWithEntry verifies that structured logging (WithField) also works correctly func TestSourceLocationWithEntry(t *testing.T) { innerLogger := newRawLogger() var buf bytes.Buffer innerLogger.SetOutput(&buf) _ = innerLogger.SetLogLevel("debug") // Build the standard wrapper chain masked := newSecretMaskingLogger(innerLogger) filtered := newLevelFilteringLogger(masked) filtered.WithField("key", "value").Debug("entry message") // Line 82 - This line should appear output := buf.String() if !strings.Contains(output, "source_location_test.go") { t.Errorf("Expected source location to contain 'source_location_test.go', got: %s", output) } // Also verify the field is present if !strings.Contains(output, "key=value") { t.Errorf("Expected output to contain 'key=value', got: %s", output) } } // TestSkipDepthWarning documents the skip depth assumption and fails if wrappers change // This test intentionally checks implementation details to warn developers when skip depths need updating. func TestSkipDepthWarning(t *testing.T) { innerLogger := newRawLogger() var buf bytes.Buffer innerLogger.SetOutput(&buf) _ = innerLogger.SetLogLevel("debug") // Build the expected standard wrapper chain masked := newSecretMaskingLogger(innerLogger) filtered := newLevelFilteringLogger(masked) // Log from this test filtered.Debug("skip depth test") // Line 102 - This line should appear in source location output := buf.String() if !strings.Contains(output, "source_location_test.go:102") { t.Errorf(` Skip depth appears incorrect! Expected source location: source_location_test.go:102 Got: %s If you added/removed a wrapper layer, update the skip values in: - internal/logger/slog_logger.go: rawLogger methods (currently skip=3) - internal/logger/slog_logger.go: slogEntry methods (currently skip=3) Current wrapper chain for logger methods: Driver code -> levelFilteringLogger -> secretMaskingLogger -> rawLogger Current wrapper chain for entry methods: Driver code -> levelFilteringEntry -> secretMaskingEntry -> slogEntry `, output) } } ================================================ FILE: internal/os/libc_info.go ================================================ package os import ( "bufio" "debug/elf" "io" "os" "regexp" "strconv" "strings" "sync" ) var ( libcInfo LibcInfo libcInfoOnce sync.Once ) // LibcInfo contains information about the C standard library in use. type LibcInfo struct { Family string // "glibc", "musl", or "" if not detected Version string // e.g., "2.31", "1.2.4", or "" if not determined } // parseProcMapsForLibc scans the contents of /proc/self/maps and returns // the libc family ("glibc" or "musl") and the filesystem path to the mapped library. func parseProcMapsForLibc(r io.Reader) (family string, libcPath string) { scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() // /proc/self/maps format: addr perms offset dev inode pathname fields := strings.Fields(line) if len(fields) < 6 { continue } path := fields[len(fields)-1] if strings.Contains(path, "musl") { return "musl", path } if strings.Contains(path, "libc.so.6") { return "glibc", path } } return "", "" } var glibcVersionPattern = regexp.MustCompile(`^GLIBC_(\d+\.\d+(?:\.\d+)?)$`) // glibcVersionFromELF opens the given ELF file (libc.so.6) and extracts the // glibc version from its SHT_GNU_verdef section via DynamicVersions(). // It returns the highest GLIBC_x.y[.z] version found. func glibcVersionFromELF(path string) string { f, err := elf.Open(path) if err != nil { return "" } defer func() { _ = f.Close() }() versions, err := f.DynamicVersions() if err != nil { return "" } var best string for _, v := range versions { m := glibcVersionPattern.FindStringSubmatch(v.Name) if m != nil { if best == "" || compareVersions(m[1], best) > 0 { best = m[1] } } } return best } var muslVersionPattern = regexp.MustCompile(`Version (\d+\.\d+\.\d+)`) // muslVersionFromBinary reads the musl library binary and searches for the // embedded version string pattern "Version X.Y.Z". func muslVersionFromBinary(path string) string { f, err := os.Open(path) if err != nil { return "" } defer func() { _ = f.Close() }() buf := make([]byte, 1<<20) // 1MB limit n, _ := io.ReadFull(f, buf) content := string(buf[:n]) m := muslVersionPattern.FindStringSubmatch(content) if m != nil { return m[1] } return "" } // compareVersions compares two dotted version strings numerically. // Returns -1 if a < b, 0 if a == b, 1 if a > b. func compareVersions(a, b string) int { partsA := strings.Split(a, ".") partsB := strings.Split(b, ".") maxLen := max(len(partsB), len(partsA)) for i := range maxLen { var va, vb int if i < len(partsA) { va, _ = strconv.Atoi(partsA[i]) } if i < len(partsB) { vb, _ = strconv.Atoi(partsB[i]) } if va < vb { return -1 } if va > vb { return 1 } } return 0 } ================================================ FILE: internal/os/libc_info_linux.go ================================================ //go:build linux package os import "os" // GetLibcInfo returns the libc family and version on Linux. // The result is cached so the detection only runs once. func GetLibcInfo() LibcInfo { libcInfoOnce.Do(func() { libcInfo = detectLibcInfo() }) return libcInfo } func detectLibcInfo() LibcInfo { fd, err := os.Open("/proc/self/maps") if err != nil { return LibcInfo{} } defer func() { _ = fd.Close() }() family, libcPath := parseProcMapsForLibc(fd) if family == "" { return LibcInfo{} } var version string switch family { case "glibc": version = glibcVersionFromELF(libcPath) case "musl": version = muslVersionFromBinary(libcPath) } return LibcInfo{Family: family, Version: version} } ================================================ FILE: internal/os/libc_info_notlinux.go ================================================ //go:build !linux package os // GetLibcInfo returns an empty LibcInfo on non-Linux platforms. func GetLibcInfo() LibcInfo { return LibcInfo{} } ================================================ FILE: internal/os/libc_info_test.go ================================================ package os import ( "runtime" "strings" "testing" ) func TestParseProcMapsGlibc(t *testing.T) { maps := `7f1234560000-7f1234580000 r-xp 00000000 08:01 12345 /usr/lib/x86_64-linux-gnu/libc.so.6 7f1234580000-7f1234590000 r--p 00020000 08:01 12345 /usr/lib/x86_64-linux-gnu/libc.so.6` family, path := parseProcMapsForLibc(strings.NewReader(maps)) if family != "glibc" { t.Errorf("expected glibc, got %q", family) } if path != "/usr/lib/x86_64-linux-gnu/libc.so.6" { t.Errorf("unexpected path: %q", path) } } func TestParseProcMapsMusl(t *testing.T) { maps := `7f1234560000-7f1234580000 r-xp 00000000 08:01 12345 /lib/ld-musl-x86_64.so.1` family, path := parseProcMapsForLibc(strings.NewReader(maps)) if family != "musl" { t.Errorf("expected musl, got %q", family) } if path != "/lib/ld-musl-x86_64.so.1" { t.Errorf("unexpected path: %q", path) } } func TestParseProcMapsMuslLibc(t *testing.T) { maps := `7f1234560000-7f1234580000 r-xp 00000000 08:01 12345 /lib/libc.musl-x86_64.so.1` family, path := parseProcMapsForLibc(strings.NewReader(maps)) if family != "musl" { t.Errorf("expected musl, got %q", family) } if path != "/lib/libc.musl-x86_64.so.1" { t.Errorf("unexpected path: %q", path) } } func TestParseProcMapsEmpty(t *testing.T) { family, path := parseProcMapsForLibc(strings.NewReader("")) if family != "" || path != "" { t.Errorf("expected empty, got family=%q path=%q", family, path) } } func TestParseProcMapsNoLibc(t *testing.T) { maps := `7f1234560000-7f1234580000 r-xp 00000000 08:01 12345 /usr/lib/libpthread.so.0 7fff12340000-7fff12360000 rw-p 00000000 00:00 0 [stack]` family, path := parseProcMapsForLibc(strings.NewReader(maps)) if family != "" || path != "" { t.Errorf("expected empty, got family=%q path=%q", family, path) } } func TestParseProcMapsShortLines(t *testing.T) { maps := `7f1234560000-7f1234580000 r-xp 00000000 08:01 12345 7fff12340000-7fff12360000 rw-p 00000000 00:00 0` family, path := parseProcMapsForLibc(strings.NewReader(maps)) if family != "" || path != "" { t.Errorf("expected empty for short lines, got family=%q path=%q", family, path) } } func TestCompareVersions(t *testing.T) { cases := []struct { a, b string want int }{ {"2.31", "2.17", 1}, {"2.17", "2.31", -1}, {"2.31", "2.31", 0}, {"2.31.1", "2.31", 1}, {"2.31", "2.31.1", -1}, {"1.2.3", "1.2.3", 0}, {"10.0", "9.99", 1}, } for _, c := range cases { got := compareVersions(c.a, c.b) if got != c.want { t.Errorf("compareVersions(%q, %q) = %d, want %d", c.a, c.b, got, c.want) } } } func TestGetLibcInfoNonLinux(t *testing.T) { if runtime.GOOS == "linux" { t.Skip("this test is for non-Linux platforms") } info := GetLibcInfo() if info.Family != "" || info.Version != "" { t.Errorf("expected empty LibcInfo on non-Linux, got %+v", info) } } ================================================ FILE: internal/os/os_details.go ================================================ package os import ( "bufio" "os" "strings" "sync" ) var ( osDetails map[string]string osDetailsOnce sync.Once ) // allowedOsReleaseKeys defines the keys we want to extract from /etc/os-release var allowedOsReleaseKeys = map[string]bool{ "NAME": true, "PRETTY_NAME": true, "ID": true, "IMAGE_ID": true, "IMAGE_VERSION": true, "BUILD_ID": true, "VERSION": true, "VERSION_ID": true, } // readOsRelease reads and parses an os-release file from the given path. // Returns nil on any error. func readOsRelease(filename string) map[string]string { file, err := os.Open(filename) if err != nil { return nil } defer func() { _ = file.Close() }() result := make(map[string]string) scanner := bufio.NewScanner(file) for scanner.Scan() { line := scanner.Text() line = strings.TrimSpace(line) // Skip empty lines if line == "" { continue } // Parse KEY=VALUE format parts := strings.SplitN(line, "=", 2) if len(parts) != 2 { continue } key := strings.TrimSpace(parts[0]) value := strings.TrimSpace(parts[1]) // Only include allowed keys if !allowedOsReleaseKeys[key] { continue } value = unquoteOsReleaseValue(value) result[key] = value } if len(result) == 0 { return nil } return result } // unquoteOsReleaseValue extracts the value from a possibly quoted string. // If the value is wrapped in matching single or double quotes, the content // between the quotes is returned (ignoring anything after the closing quote). // Otherwise the raw value is returned. func unquoteOsReleaseValue(s string) string { if len(s) >= 2 && (s[0] == '"' || s[0] == '\'') { quote := s[0] if end := strings.IndexByte(s[1:], quote); end >= 0 { return s[1 : 1+end] } } return s } ================================================ FILE: internal/os/os_details_linux.go ================================================ //go:build linux package os // GetOsDetails returns OS details from /etc/os-release on Linux. // The result is cached so it's only read once. func GetOsDetails() map[string]string { osDetailsOnce.Do(func() { osDetails = readOsRelease("/etc/os-release") }) return osDetails } ================================================ FILE: internal/os/os_details_notlinux.go ================================================ //go:build !linux package os // GetOsDetails returns nil on non-Linux platforms. func GetOsDetails() map[string]string { return nil } ================================================ FILE: internal/os/os_details_test.go ================================================ package os import ( "testing" ) func TestReadOsRelease(t *testing.T) { result := readOsRelease("test_data/sample_os_release") if result == nil { t.Fatal("expected non-nil result from sample_os_release") } // Verify only allowed keys are parsed (8 keys expected) // Note: test file also contains lines with spaces only, spaces+tabs, // and comments - all should be ignored expectedEntries := map[string]string{ "NAME": "Ubuntu", "PRETTY_NAME": "Ubuntu 22.04.3 LTS", "ID": "ubuntu", "VERSION_ID": "22.04", "VERSION": "22.04.3 LTS (Jammy Jellyfish)", "BUILD_ID": "20231115", "IMAGE_ID": "ubuntu-jammy", "IMAGE_VERSION": "1.0.0", } // Check correct number of entries (no extra keys parsed) if len(result) != len(expectedEntries) { t.Errorf("expected %d entries, got %d. Result: %v", len(expectedEntries), len(result), result) } // Verify each expected entry for key, expectedValue := range expectedEntries { actualValue, exists := result[key] if !exists { t.Errorf("expected key %q not found in result", key) continue } if actualValue != expectedValue { t.Errorf("key %q: expected %q, got %q", key, expectedValue, actualValue) } } // Verify all keys are expected for key := range result { _, exists := expectedEntries[key] if !exists { t.Errorf("expected to not contain key %v", key) } } } ================================================ FILE: internal/os/test_data/sample_os_release ================================================ # This is a comment and should be ignored NAME="Ubuntu" PRETTY_NAME='Ubuntu 22.04.3 LTS' #this is pretty name ID=ubuntu VERSION_ID="22.04" VERSION="22.04.3 LTS (Jammy Jellyfish)" BUILD_ID=20231115 IMAGE_ID=ubuntu-jammy IMAGE_VERSION=1.0.0 # These keys should be ignored (not in allowed list) HOME_URL="https://www.ubuntu.com/" SUPPORT_URL=https://help.ubuntu.com/ BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/" PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy" UBUNTU_CODENAME=jammy VARIANT="Server" # Empty lines should be ignored # Line with spaces only should be ignored # Line with spaces and tabs should be ignored # Lines without = should be ignored INVALID_LINE_WITHOUT_EQUALS ================================================ FILE: internal/query/response_types.go ================================================ package query // ExecResponseRowType describes column metadata from a query response. type ExecResponseRowType struct { Name string `json:"name"` Fields []FieldMetadata `json:"fields"` ByteLength int64 `json:"byteLength"` Length int64 `json:"length"` Type string `json:"type"` Precision int64 `json:"precision"` Scale int64 `json:"scale"` Nullable bool `json:"nullable"` } // FieldMetadata describes metadata for a field, including nested fields for complex types. type FieldMetadata struct { Name string `json:"name,omitempty"` Type string `json:"type"` Nullable bool `json:"nullable"` Length int `json:"length"` Scale int `json:"scale"` Precision int `json:"precision"` Fields []FieldMetadata `json:"fields,omitempty"` } // ExecResponseChunk describes metadata for a chunk of query results, including URL and size information. type ExecResponseChunk struct { URL string `json:"url"` RowCount int `json:"rowCount"` UncompressedSize int64 `json:"uncompressedSize"` CompressedSize int64 `json:"compressedSize"` } ================================================ FILE: internal/query/transform.go ================================================ package query // ToFieldMetadata transforms ExecResponseRowType to FieldMetadata. func (ex *ExecResponseRowType) ToFieldMetadata() FieldMetadata { return FieldMetadata{ ex.Name, ex.Type, ex.Nullable, int(ex.Length), int(ex.Scale), int(ex.Precision), ex.Fields, } } ================================================ FILE: internal/types/types.go ================================================ package types import ( "strings" ) // SnowflakeType represents the various data types supported by Snowflake, including both standard and internal types used by the driver. type SnowflakeType int const ( // FixedType represents the FIXED data type in Snowflake, which is a numeric type with a specified precision and scale. FixedType SnowflakeType = iota // RealType represents the REAL data type in Snowflake, which is a floating-point numeric type. RealType // DecfloatType represents the DECFLOAT data type in Snowflake, which is a decimal floating-point numeric type with high precision. DecfloatType // TextType represents the TEXT data type in Snowflake, which is a variable-length string type. TextType // DateType represents the DATE data type in Snowflake, which is used to store calendar dates (year, month, day). DateType // VariantType represents the VARIANT data type in Snowflake, which is a semi-structured data type that can store values of various types. VariantType // TimestampLtzType represents the TIMESTAMP_LTZ data type in Snowflake, which is a timestamp with local time zone information. TimestampLtzType // TimestampNtzType represents the TIMESTAMP_NTZ data type in Snowflake, which is a timestamp without time zone information. TimestampNtzType // TimestampTzType represents the TIMESTAMP_TZ data type in Snowflake, which is a timestamp with time zone information. TimestampTzType // ObjectType represents the OBJECT data type in Snowflake, which is a semi-structured data type that can store key-value pairs. ObjectType // ArrayType represents the ARRAY data type in Snowflake, which is a semi-structured data type that can store ordered lists of values. ArrayType // MapType represents the MAP data type in Snowflake, which is a semi-structured data type that can store key-value pairs with unique keys. MapType // BinaryType represents the BINARY data type in Snowflake, which is used to store binary data (byte arrays). BinaryType // TimeType represents the TIME data type in Snowflake, which is used to store time values (hour, minute, second). TimeType // BooleanType represents the BOOLEAN data type in Snowflake, which is used to store boolean values (true/false). BooleanType // NullType represents a null value type, used internally to represent null values in Snowflake. NullType // SliceType represents a slice type, used internally to represent slices of data in Snowflake. SliceType // ChangeType represents a change type, used internally to represent changes in data in Snowflake. ChangeType // UnSupportedType represents an unsupported type, used internally to represent types that are not supported by the driver. UnSupportedType // NilObjectType represents a nil object type, used internally to represent null objects in Snowflake. NilObjectType // NilArrayType represents a nil array type, used internally to represent null arrays in Snowflake. NilArrayType // NilMapType represents a nil map type, used internally to represent null maps in Snowflake. NilMapType ) // SnowflakeToDriverType maps Snowflake data type names (as strings) to their corresponding SnowflakeType constants used internally by the driver. // This mapping allows for easy conversion between the string representation of Snowflake types and the internal enumeration used by the driver for type handling. var SnowflakeToDriverType = map[string]SnowflakeType{ "FIXED": FixedType, "REAL": RealType, "DECFLOAT": DecfloatType, "TEXT": TextType, "DATE": DateType, "VARIANT": VariantType, "TIMESTAMP_LTZ": TimestampLtzType, "TIMESTAMP_NTZ": TimestampNtzType, "TIMESTAMP_TZ": TimestampTzType, "OBJECT": ObjectType, "ARRAY": ArrayType, "MAP": MapType, "BINARY": BinaryType, "TIME": TimeType, "BOOLEAN": BooleanType, "NULL": NullType, "SLICE": SliceType, "CHANGE_TYPE": ChangeType, "NOT_SUPPORTED": UnSupportedType} // DriverTypeToSnowflake is the inverse mapping of SnowflakeToDriverType, allowing for conversion from SnowflakeType constants back to their string representations. var DriverTypeToSnowflake = invertMap(SnowflakeToDriverType) func invertMap(m map[string]SnowflakeType) map[SnowflakeType]string { inv := make(map[SnowflakeType]string) for k, v := range m { if _, ok := inv[v]; ok { panic("failed to create DriverTypeToSnowflake map due to duplicated values") } inv[v] = k } return inv } // Byte returns the byte representation of the SnowflakeType, which can be used for efficient type handling and comparisons within the driver. func (st SnowflakeType) Byte() byte { return byte(st) } func (st SnowflakeType) String() string { return DriverTypeToSnowflake[st] } // GetSnowflakeType takes a string representation of a Snowflake data type and returns the corresponding SnowflakeType constant used internally by the driver. func GetSnowflakeType(typ string) SnowflakeType { return SnowflakeToDriverType[strings.ToUpper(typ)] } ================================================ FILE: local_storage_client.go ================================================ package gosnowflake import ( "bufio" "cmp" "context" "fmt" "io" "os" "path" "path/filepath" "strings" ) type localUtil struct { } func (util *localUtil) createClient(_ *execResponseStageInfo, _ bool, _ *Config, _ *snowflakeTelemetry) (cloudClient, error) { return nil, nil } func (util *localUtil) uploadOneFileWithRetry(_ context.Context, meta *fileMetadata) error { var frd *bufio.Reader if meta.srcStream != nil { b := cmp.Or(meta.realSrcStream, meta.srcStream) frd = bufio.NewReader(b) } else { f, err := os.Open(meta.realSrcFileName) if err != nil { return err } defer func() { if err = f.Close(); err != nil { logger.Warnf("failed to close the file %v: %v", meta.realSrcFileName, err) } }() frd = bufio.NewReader(f) } user, err := expandUser(meta.stageInfo.Location) if err != nil { return err } if !meta.overwrite { if _, err := os.Stat(filepath.Join(user, meta.dstFileName)); err == nil { meta.dstFileSize = 0 meta.resStatus = skipped return nil } } output, err := os.OpenFile(filepath.Join(user, meta.dstFileName), os.O_CREATE|os.O_WRONLY, readWriteFileMode) if err != nil { return err } defer func() { if err = output.Close(); err != nil { logger.Warnf("failed to close the file %v: %v", meta.dstFileName, err) } }() data := make([]byte, meta.uploadSize) for { n, err := frd.Read(data) if err != nil && err != io.EOF { return err } if n == 0 { break } if _, err = output.Write(data); err != nil { return err } } meta.dstFileSize = meta.uploadSize meta.resStatus = uploaded return nil } func (util *localUtil) downloadOneFile(_ context.Context, meta *fileMetadata) error { srcFileName := meta.srcFileName if strings.HasPrefix(meta.srcFileName, fmt.Sprintf("%b", os.PathSeparator)) { srcFileName = srcFileName[1:] } user, err := expandUser(meta.stageInfo.Location) if err != nil { return err } fullSrcFileName := path.Join(user, srcFileName) user, err = expandUser(meta.localLocation) if err != nil { return err } fullDstFileName := path.Join(user, baseName(meta.dstFileName)) baseDir, err := getDirectory() if err != nil { return err } if _, err = os.Stat(baseDir); os.IsNotExist(err) { if err = os.MkdirAll(baseDir, os.ModePerm); err != nil { return err } } data, err := os.ReadFile(fullSrcFileName) if err != nil { return err } if err = os.WriteFile(fullDstFileName, data, readWriteFileMode); err != nil { return err } fi, err := os.Stat(fullDstFileName) if err != nil { return err } meta.dstFileSize = fi.Size() meta.resStatus = downloaded return nil } ================================================ FILE: local_storage_client_test.go ================================================ package gosnowflake import ( "bytes" "compress/gzip" "context" "os" "path" "path/filepath" "testing" ) func TestLocalUpload(t *testing.T) { tmpDir, err := os.MkdirTemp("", "local_put") if err != nil { t.Error(err) } defer os.RemoveAll(tmpDir) fname := filepath.Join(tmpDir, "test_put_get.txt.gz") originalContents := "123,test1\n456,test2\n" var b bytes.Buffer gzw := gzip.NewWriter(&b) _, err = gzw.Write([]byte(originalContents)) assertNilF(t, err) assertNilF(t, gzw.Close()) if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } putDir, err := os.MkdirTemp("", "put") if err != nil { t.Error(err) } info := execResponseStageInfo{ Location: putDir, LocationType: "LOCAL_FS", } localUtil := new(localUtil) localCli, err := localUtil.createClient(&info, false, nil, nil) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "LOCAL_FS", noSleepingTime: true, parallel: 4, client: localCli, stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(tmpDir, "/test_put_get.txt.gz"), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName err = localUtil.uploadOneFileWithRetry(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != uploaded { t.Fatalf("failed to upload file") } uploadMeta.overwrite = false err = localUtil.uploadOneFileWithRetry(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != skipped { t.Fatal("overwrite is false. should have skipped") } fileStream, _ := os.Open(fname) ctx := WithFilePutStream(context.Background(), fileStream) uploadMeta.fileStream, err = getFileStream(ctx) assertNilF(t, err) err = localUtil.uploadOneFileWithRetry(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != skipped { t.Fatalf("overwrite is false. should have skipped") } uploadMeta.overwrite = true err = localUtil.uploadOneFileWithRetry(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != uploaded { t.Fatalf("failed to upload file") } uploadMeta.realSrcStream = uploadMeta.srcStream err = localUtil.uploadOneFileWithRetry(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != uploaded { t.Fatalf("failed to upload file") } } func TestDownloadLocalFile(t *testing.T) { tmpDir, err := os.MkdirTemp("", "local_put") if err != nil { t.Error(err) } defer func() { assertNilF(t, os.RemoveAll(tmpDir)) }() fname := filepath.Join(tmpDir, "test_put_get.txt.gz") originalContents := "123,test1\n456,test2\n" var b bytes.Buffer gzw := gzip.NewWriter(&b) _, err = gzw.Write([]byte(originalContents)) assertNilF(t, err) assertNilF(t, gzw.Close()) if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } putDir, err := os.MkdirTemp("", "put") if err != nil { t.Error(err) } info := execResponseStageInfo{ Location: tmpDir, LocationType: "LOCAL_FS", } localUtil := new(localUtil) localCli, err := localUtil.createClient(&info, false, nil, nil) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "test_put_get.txt.gz", stageLocationType: "LOCAL_FS", noSleepingTime: true, client: localCli, stageInfo: &info, dstFileName: "test_put_get.txt.gz", overwrite: true, srcFileName: "test_put_get.txt.gz", localLocation: putDir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, } err = localUtil.downloadOneFile(context.Background(), &downloadMeta) if err != nil { t.Error(err) } if downloadMeta.resStatus != downloaded { t.Fatalf("failed to get file in local storage") } downloadMeta.srcFileName = "test_put_get.txt.gz" err = localUtil.downloadOneFile(context.Background(), &downloadMeta) if err != nil { t.Error(err) } if downloadMeta.resStatus != downloaded { t.Fatalf("failed to get file in local storage") } downloadMeta.srcFileName = "local://test_put_get.txt.gz" err = localUtil.downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("file name is invalid. should have returned an error") } } ================================================ FILE: location.go ================================================ package gosnowflake import ( "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "strconv" "sync" "time" ) var ( timezones map[int]*time.Location updateTimezoneMutex *sync.Mutex ) // Location returns an offset (minutes) based Location object for Snowflake database. func Location(offset int) *time.Location { updateTimezoneMutex.Lock() defer updateTimezoneMutex.Unlock() loc := timezones[offset] if loc != nil { return loc } loc = genTimezone(offset) timezones[offset] = loc return loc } // LocationWithOffsetString returns an offset based Location object. The offset string must consist of sHHMI where one sign // character '+'/'-' followed by zero filled hours and minutes. func LocationWithOffsetString(offsets string) (loc *time.Location, err error) { if len(offsets) != 5 { return nil, &SnowflakeError{ Number: ErrInvalidOffsetStr, SQLState: SQLStateInvalidDataTimeFormat, Message: errors.ErrMsgInvalidOffsetStr, MessageArgs: []any{offsets}, } } if offsets[0] != '-' && offsets[0] != '+' { return nil, &SnowflakeError{ Number: ErrInvalidOffsetStr, SQLState: SQLStateInvalidDataTimeFormat, Message: errors.ErrMsgInvalidOffsetStr, MessageArgs: []any{offsets}, } } s := 1 if offsets[0] == '-' { s = -1 } var h, m int64 h, err = strconv.ParseInt(offsets[1:3], 10, 64) if err != nil { return } m, err = strconv.ParseInt(offsets[3:], 10, 64) if err != nil { return } offset := s * (int(h)*60 + int(m)) loc = Location(offset) return } func genTimezone(offset int) *time.Location { var offsetSign string var toffset int if offset < 0 { offsetSign = "-" toffset = -offset } else { offsetSign = "+" toffset = offset } logger.Debugf("offset: %v", offset) return time.FixedZone( fmt.Sprintf("%v%02d%02d", offsetSign, toffset/60, toffset%60), int(offset)*60) } func init() { updateTimezoneMutex = &sync.Mutex{} timezones = make(map[int]*time.Location, 48) // pre-generate all common timezones for i := -720; i <= 720; i += 30 { logger.Debugf("offset: %v", i) timezones[i] = genTimezone(i) } } // retrieve current location based on connection func getCurrentLocation(sp *syncParams) *time.Location { loc := time.Now().Location() if sp == nil { return loc } var err error if tz, ok := sp.get("timezone"); ok && tz != nil { loc, err = time.LoadLocation(*tz) if err != nil { loc = time.Now().Location() } } return loc } ================================================ FILE: location_test.go ================================================ package gosnowflake import ( "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "reflect" "testing" "time" ) type tcLocation struct { ss string tt string err error } func TestWithOffsetString(t *testing.T) { testcases := []tcLocation{ { ss: "+0700", tt: "+0700", err: nil, }, { ss: "-1200", tt: "-1200", err: nil, }, { ss: "+0710", tt: "+0710", err: nil, }, { ss: "1200", tt: "", err: &SnowflakeError{ Number: ErrInvalidOffsetStr, Message: errors2.ErrMsgInvalidOffsetStr, MessageArgs: []any{"1200"}, }, }, { ss: "x1200", tt: "", err: &SnowflakeError{ Number: ErrInvalidOffsetStr, Message: errors2.ErrMsgInvalidOffsetStr, MessageArgs: []any{"x1200"}, }, }, { ss: "+12001", tt: "", err: &SnowflakeError{ Number: ErrInvalidOffsetStr, Message: errors2.ErrMsgInvalidOffsetStr, MessageArgs: []any{"+12001"}, }, }, { ss: "x12001", tt: "", err: &SnowflakeError{ Number: ErrInvalidOffsetStr, Message: errors2.ErrMsgInvalidOffsetStr, MessageArgs: []any{"x12001"}, }, }, { ss: "-12CD", tt: "", err: errors.New("parse int error"), // can this be more specific? }, { ss: "+ABCD", tt: "", err: errors.New("parse int error"), // can this be more specific? }, } for _, t0 := range testcases { t.Run(t0.ss, func(t *testing.T) { loc, err := LocationWithOffsetString(t0.ss) if t0.err != nil { if t0.err != err { driverError1, ok1 := t0.err.(*SnowflakeError) driverError2, ok2 := err.(*SnowflakeError) if ok1 && ok2 && driverError1.Number != driverError2.Number { t.Fatalf("error expected: %v, got: %v", t0.err, err) } } } else { if err != nil { t.Fatalf("%v", err) } if t0.tt != loc.String() { t.Fatalf("location string didn't match. expected: %v, got: %v", t0.tt, loc) } } }) } } func TestGetCurrentLocation(t *testing.T) { specificTz := "Pacific/Honolulu" specificLoc, err := time.LoadLocation(specificTz) if err != nil { t.Fatalf("Cannot initialize specific timezone location") } incorrectTz := "Not/exists" testcases := []struct { params syncParams loc *time.Location }{ { params: newSyncParams(map[string]*string{}), loc: time.Now().Location(), }, { params: newSyncParams(map[string]*string{ "timezone": nil, }), loc: time.Now().Location(), }, { params: newSyncParams(map[string]*string{ "timezone": &specificTz, }), loc: specificLoc, }, { params: newSyncParams(map[string]*string{ "timezone": &incorrectTz, }), loc: time.Now().Location(), }, } for i := range testcases { tc := &testcases[i] t.Run(fmt.Sprintf("%v", tc.loc), func(t *testing.T) { loc := getCurrentLocation(&tc.params) if !reflect.DeepEqual(*loc, *tc.loc) { t.Fatalf("location mismatch. expected: %v, got: %v", tc.loc, loc) } }) } } ================================================ FILE: locker.go ================================================ package gosnowflake import "sync" // ---------- API ---------- type lockKeyType interface { lockID() string } type locker interface { lock(lockKey lockKeyType) unlocker } type unlocker interface { Unlock() } func getValueWithLock[T any](locker locker, lockKey lockKeyType, f func() (T, error)) (T, error) { unlock := locker.lock(lockKey) defer unlock.Unlock() return f() } // ---------- Locking implementation ---------- type exclusiveLockerType struct { m sync.Map } var exclusiveLocker = newExclusiveLocker() func (e *exclusiveLockerType) lock(lockKey lockKeyType) unlocker { logger.Debugf("Acquiring lock for %s", lockKey.lockID()) // We can ignore clearing up the map because the number of unique lockID is very limited, and they will be probably reused during the lifetime of the app. mu, _ := e.m.LoadOrStore(lockKey.lockID(), &sync.Mutex{}) mu.(*sync.Mutex).Lock() return mu.(*sync.Mutex) } func newExclusiveLocker() *exclusiveLockerType { return &exclusiveLockerType{} } // ---------- No locking implementation ---------- type noopLockerType struct{} var noopLocker = &noopLockerType{} type noopUnlocker struct{} func (n noopUnlocker) Unlock() { } func (n *noopLockerType) lock(_ lockKeyType) unlocker { logger.Debug("No lock is acquired") return noopUnlocker{} } ================================================ FILE: log.go ================================================ package gosnowflake import ( loggerinternal "github.com/snowflakedb/gosnowflake/v2/internal/logger" "github.com/snowflakedb/gosnowflake/v2/sflog" ) // SFSessionIDKey is context key of session id const SFSessionIDKey ContextKey = "LOG_SESSION_ID" // SFSessionUserKey is context key of user id of a session const SFSessionUserKey ContextKey = "LOG_USER" func init() { // Set default log keys in internal package SetLogKeys(SFSessionIDKey, SFSessionUserKey) } // Re-export types from sflog package for backward compatibility type ( // ClientLogContextHook is a client-defined hook that can be used to insert log // fields based on the Context. ClientLogContextHook = sflog.ClientLogContextHook // LogEntry allows for logging using a snapshot of field values. // No implementation-specific logging details should be placed into this interface. LogEntry = sflog.LogEntry // SFLogger Snowflake logger interface which abstracts away the underlying logging mechanism. // No implementation-specific logging details should be placed into this interface. SFLogger = sflog.SFLogger // SFSlogLogger is an optional interface for advanced slog handler configuration. // This interface is separate from SFLogger to maintain framework-agnostic design. // Users can type-assert the logger to check if slog handler configuration is supported. SFSlogLogger = sflog.SFSlogLogger // Level is the log level. Info is set to 0. For more details, see sflog.Level. Level = sflog.Level ) // SetLogKeys sets the context keys to be written to logs when logger.WithContext is used. // This function is thread-safe and can be called at runtime. func SetLogKeys(keys ...ContextKey) { // Convert ContextKey to []any for internal package ikeys := make([]any, len(keys)) for i, k := range keys { ikeys[i] = k } loggerinternal.SetLogKeys(ikeys) } // GetLogKeys returns the currently configured context keys. func GetLogKeys() []ContextKey { ikeys := loggerinternal.GetLogKeys() // Convert []any back to []ContextKey keys := make([]ContextKey, 0, len(ikeys)) for _, k := range ikeys { if ck, ok := k.(ContextKey); ok { keys = append(keys, ck) } } return keys } // RegisterLogContextHook registers a hook that can be used to extract fields // from the Context and associated with log messages using the provided key. // This function is thread-safe and can be called at runtime. func RegisterLogContextHook(contextKey string, ctxExtractor ClientLogContextHook) { // Delegate directly to internal package loggerinternal.RegisterLogContextHook(contextKey, ctxExtractor) } // GetClientLogContextHooks returns the registered log context hooks. func GetClientLogContextHooks() map[string]ClientLogContextHook { return loggerinternal.GetClientLogContextHooks() } // logger is a proxy that delegates all calls to the internal global logger. // This ensures a single source of truth for the current logger. // This variable is private and should only be used internally within the main package. var logger SFLogger = loggerinternal.NewLoggerProxy() // SetLogger sets a custom logger implementation for gosnowflake. // The provided logger will be used as the base logger and automatically wrapped with: // - Secret masking (to protect sensitive data like passwords and tokens) // - Level filtering (for performance optimization) // // You cannot bypass these protective layers. If you need to configure them, use the // returned logger's methods (SetLogLevel, etc.). // // Example: // // customLogger := mylogger.New() // gosnowflake.SetLogger(customLogger) func SetLogger(logger SFLogger) error { return loggerinternal.SetLogger(logger) } // GetLogger returns the current global logger with all protective layers applied // (secret masking and level filtering). This is the actual wrapped logger instance, // not a proxy. // // Example: // // logger := gosnowflake.GetLogger() // logger.Info("message") func GetLogger() SFLogger { return loggerinternal.GetLogger() } // CreateDefaultLogger creates and returns a new instance of SFLogger with default config. // The returned logger is automatically wrapped with secret masking and level filtering. // This is a pure factory function and does NOT modify global state. // If you want to set it as the global logger, call SetLogger(newLogger). // // The wrapping chain is: levelFilteringLogger → secretMaskingLogger → rawLogger func CreateDefaultLogger() SFLogger { return loggerinternal.CreateDefaultLogger() } ================================================ FILE: log_client_test.go ================================================ package gosnowflake_test import ( "bytes" "context" "encoding/json" "fmt" "github.com/snowflakedb/gosnowflake/v2/sflog" "io" "log/slog" "maps" "strings" "sync" "testing" "github.com/snowflakedb/gosnowflake/v2" ) // customLogger is a simple implementation of gosnowflake.SFLogger for testing type customLogger struct { buf *bytes.Buffer level string fields map[string]any mu sync.Mutex } func newCustomLogger() *customLogger { return &customLogger{ buf: &bytes.Buffer{}, level: "info", fields: make(map[string]any), } } func (l *customLogger) formatMessage(level, format string, args ...any) { l.mu.Lock() defer l.mu.Unlock() msg := fmt.Sprintf(format, args...) // Include fields if any fieldStr := "" if len(l.fields) > 0 { parts := []string{} for k, v := range l.fields { parts = append(parts, fmt.Sprintf("%s=%v", k, v)) } fieldStr = " " + strings.Join(parts, " ") } fmt.Fprintf(l.buf, "%s: %s%s\n", level, msg, fieldStr) } func (l *customLogger) Tracef(format string, args ...any) { l.formatMessage("TRACE", format, args...) } func (l *customLogger) Debugf(format string, args ...any) { l.formatMessage("DEBUG", format, args...) } func (l *customLogger) Infof(format string, args ...any) { l.formatMessage("INFO", format, args...) } func (l *customLogger) Warnf(format string, args ...any) { l.formatMessage("WARN", format, args...) } func (l *customLogger) Errorf(format string, args ...any) { l.formatMessage("ERROR", format, args...) } func (l *customLogger) Fatalf(format string, args ...any) { l.formatMessage("FATAL", format, args...) } func (l *customLogger) Trace(msg string) { l.formatMessage("TRACE", "%s", fmt.Sprint(msg)) } func (l *customLogger) Debug(msg string) { l.formatMessage("DEBUG", "%s", fmt.Sprint(msg)) } func (l *customLogger) Info(msg string) { l.formatMessage("INFO", "%s", fmt.Sprint(msg)) } func (l *customLogger) Warn(msg string) { l.formatMessage("WARN", "%s", fmt.Sprint(msg)) } func (l *customLogger) Error(msg string) { l.formatMessage("ERROR", "%s", fmt.Sprint(msg)) } func (l *customLogger) Fatal(msg string) { l.formatMessage("FATAL", "%s", fmt.Sprint(msg)) } func (l *customLogger) WithField(key string, value any) gosnowflake.LogEntry { newFields := make(map[string]any) maps.Copy(newFields, l.fields) newFields[key] = value return &customLogEntry{ logger: l, fields: newFields, } } func (l *customLogger) WithFields(fields map[string]any) gosnowflake.LogEntry { newFields := make(map[string]any) maps.Copy(newFields, l.fields) maps.Copy(newFields, fields) return &customLogEntry{ logger: l, fields: newFields, } } func (l *customLogger) WithContext(ctx context.Context) gosnowflake.LogEntry { newFields := make(map[string]any) maps.Copy(newFields, l.fields) // Extract context fields if sessionID := ctx.Value(gosnowflake.SFSessionIDKey); sessionID != nil { newFields["LOG_SESSION_ID"] = sessionID } if user := ctx.Value(gosnowflake.SFSessionUserKey); user != nil { newFields["LOG_USER"] = user } return &customLogEntry{ logger: l, fields: newFields, } } func (l *customLogger) SetLogLevel(level string) error { l.mu.Lock() defer l.mu.Unlock() l.level = strings.ToLower(level) return nil } func (l *customLogger) SetLogLevelInt(level gosnowflake.Level) error { l.mu.Lock() defer l.mu.Unlock() levelStr, err := sflog.LevelToString(level) if err != nil { return err } l.level = levelStr return nil } func (l *customLogger) GetLogLevel() string { l.mu.Lock() defer l.mu.Unlock() return l.level } func (l *customLogger) GetLogLevelInt() gosnowflake.Level { l.mu.Lock() defer l.mu.Unlock() level, _ := sflog.ParseLevel(l.level) return level } func (l *customLogger) SetOutput(output io.Writer) { // For this test logger, we keep using our internal buffer } func (l *customLogger) GetOutput() string { l.mu.Lock() defer l.mu.Unlock() return l.buf.String() } func (l *customLogger) Reset() { l.mu.Lock() defer l.mu.Unlock() l.buf.Reset() } // customLogEntry implements gosnowflake.LogEntry type customLogEntry struct { logger *customLogger fields map[string]any } func (e *customLogEntry) formatMessage(level, format string, args ...any) { e.logger.mu.Lock() defer e.logger.mu.Unlock() msg := fmt.Sprintf(format, args...) // Include fields fieldStr := "" if len(e.fields) > 0 { parts := []string{} for k, v := range e.fields { parts = append(parts, fmt.Sprintf("%s=%v", k, v)) } fieldStr = " " + strings.Join(parts, " ") } fmt.Fprintf(e.logger.buf, "%s: %s%s\n", level, msg, fieldStr) } func (e *customLogEntry) Tracef(format string, args ...any) { e.formatMessage("TRACE", format, args...) } func (e *customLogEntry) Debugf(format string, args ...any) { e.formatMessage("DEBUG", format, args...) } func (e *customLogEntry) Infof(format string, args ...any) { e.formatMessage("INFO", format, args...) } func (e *customLogEntry) Warnf(format string, args ...any) { e.formatMessage("WARN", format, args...) } func (e *customLogEntry) Errorf(format string, args ...any) { e.formatMessage("ERROR", format, args...) } func (e *customLogEntry) Fatalf(format string, args ...any) { e.formatMessage("FATAL", format, args...) } func (e *customLogEntry) Trace(msg string) { e.formatMessage("TRACE", "%s", fmt.Sprint(msg)) } func (e *customLogEntry) Debug(msg string) { e.formatMessage("DEBUG", "%s", fmt.Sprint(msg)) } func (e *customLogEntry) Info(msg string) { e.formatMessage("INFO", "%s", fmt.Sprint(msg)) } func (e *customLogEntry) Warn(msg string) { e.formatMessage("WARN", "%s", fmt.Sprint(msg)) } func (e *customLogEntry) Error(msg string) { e.formatMessage("ERROR", "%s", fmt.Sprint(msg)) } func (e *customLogEntry) Fatal(msg string) { e.formatMessage("FATAL", "%s", fmt.Sprint(msg)) } // Helper functions func assertContains(t *testing.T, output, expected string) { t.Helper() if !strings.Contains(output, expected) { t.Errorf("Expected output to contain %q, got:\n%s", expected, output) } } func assertNotContains(t *testing.T, output, unexpected string) { t.Helper() if strings.Contains(output, unexpected) { t.Errorf("Expected output to NOT contain %q, got:\n%s", unexpected, output) } } func assertJSONFormat(t *testing.T, output string) { t.Helper() lines := strings.SplitSeq(strings.TrimSpace(output), "\n") for line := range lines { if line == "" { continue } var js map[string]any if err := json.Unmarshal([]byte(line), &js); err != nil { t.Errorf("Expected valid JSON, got error: %v, line: %s", err, line) } } } func TestCustomSlogHandler(t *testing.T) { // Save original logger originalLogger := gosnowflake.GetLogger() defer func() { gosnowflake.SetLogger(originalLogger) }() // Create a new default logger logger := gosnowflake.CreateDefaultLogger() // Set it as global logger first gosnowflake.SetLogger(logger) // Get the logger and try to set custom handler currentLogger := gosnowflake.GetLogger() // Type assert to SFSlogLogger slogLogger, ok := currentLogger.(gosnowflake.SFSlogLogger) if !ok { t.Fatal("Logger does not implement SFSlogLogger interface") } // Create custom JSON handler with buffer buf := &bytes.Buffer{} jsonHandler := slog.NewJSONHandler(buf, &slog.HandlerOptions{ Level: slog.LevelInfo, }) // Set the custom handler err := slogLogger.SetHandler(jsonHandler) if err != nil { t.Fatalf("Failed to set custom handler: %v", err) } // Log some messages _ = currentLogger.SetLogLevel("info") currentLogger.Info("Test message from custom JSON handler") currentLogger.Infof("Formatted message: %d", 42) // Verify output is in JSON format output := buf.String() assertJSONFormat(t, output) assertContains(t, output, "Test message from custom JSON handler") assertContains(t, output, "Formatted message: 42") } func TestCustomLoggerImplementation(t *testing.T) { // Save original logger originalLogger := gosnowflake.GetLogger() defer func() { gosnowflake.SetLogger(originalLogger) }() // Create custom logger customLog := newCustomLogger() var sfLogger gosnowflake.SFLogger = customLog // Set as global logger gosnowflake.SetLogger(sfLogger) // Get logger (should be proxied) logger := gosnowflake.GetLogger() // Log various messages logger.Info("Test info message") logger.Infof("Formatted: %s", "value") logger.Warn("Warning message") // Verify output output := customLog.GetOutput() assertContains(t, output, "INFO: Test info message") assertContains(t, output, "INFO: Formatted: value") assertContains(t, output, "WARN: Warning message") } func TestCustomLoggerSecretMasking(t *testing.T) { // Save original logger originalLogger := gosnowflake.GetLogger() defer func() { gosnowflake.SetLogger(originalLogger) }() // Create custom logger customLog := newCustomLogger() var sfLogger gosnowflake.SFLogger = customLog // Set as global logger gosnowflake.SetLogger(sfLogger) // Get logger logger := gosnowflake.GetLogger() // Log messages with secrets (use 8+ char secrets for detection) logger.Infof("Connection string: password='secret123'") logger.Info("Token: idToken:abc12345678") logger.Infof("Auth: token=def12345678") // Verify secrets are masked output := customLog.GetOutput() assertContains(t, output, "****") assertNotContains(t, output, "secret123") assertNotContains(t, output, "abc12345678") // pragma: allowlist secret assertNotContains(t, output, "def12345678") // pragma: allowlist secret } func TestCustomHandlerWithContext(t *testing.T) { // Save original logger originalLogger := gosnowflake.GetLogger() defer func() { gosnowflake.SetLogger(originalLogger) }() // Create a new default logger with JSON handler logger := gosnowflake.CreateDefaultLogger() gosnowflake.SetLogger(logger) currentLogger := gosnowflake.GetLogger() // Set custom JSON handler buf := &bytes.Buffer{} jsonHandler := slog.NewJSONHandler(buf, &slog.HandlerOptions{ Level: slog.LevelInfo, }) if slogLogger, ok := currentLogger.(gosnowflake.SFSlogLogger); ok { _ = slogLogger.SetHandler(jsonHandler) } // Create context with session info ctx := context.Background() ctx = context.WithValue(ctx, gosnowflake.SFSessionIDKey, "session-123") ctx = context.WithValue(ctx, gosnowflake.SFSessionUserKey, "test-user") // Log with context _ = currentLogger.SetLogLevel("info") currentLogger.WithContext(ctx).Info("Message with context") // Verify context fields in JSON output output := buf.String() assertJSONFormat(t, output) assertContains(t, output, "session-123") assertContains(t, output, "test-user") } func TestCustomLoggerWithFields(t *testing.T) { // Save original logger originalLogger := gosnowflake.GetLogger() defer func() { gosnowflake.SetLogger(originalLogger) }() // Create custom logger customLog := newCustomLogger() var sfLogger gosnowflake.SFLogger = customLog // Set as global logger gosnowflake.SetLogger(sfLogger) // Get logger logger := gosnowflake.GetLogger() // Use WithField logger.WithField("key1", "value1").Info("Message with field") // Use WithFields logger.WithFields(map[string]any{ "key2": "value2", "key3": 123, }).Info("Message with multiple fields") // Verify fields in output output := customLog.GetOutput() assertContains(t, output, "key1=value1") assertContains(t, output, "key2=value2") assertContains(t, output, "key3=123") } func TestCustomLoggerLevelConfiguration(t *testing.T) { // Save original logger originalLogger := gosnowflake.GetLogger() defer func() { gosnowflake.SetLogger(originalLogger) }() // Create custom logger customLog := newCustomLogger() var sfLogger gosnowflake.SFLogger = customLog // Set as global logger gosnowflake.SetLogger(sfLogger) // Get logger logger := gosnowflake.GetLogger() // Set level to info err := logger.SetLogLevel("info") if err != nil { t.Fatalf("Failed to set log level: %v", err) } // Verify level if level := logger.GetLogLevel(); level != "info" { t.Errorf("Expected level 'info', got %q", level) } // Log at different levels logger.Debug("Debug message - should not appear at info level") logger.Info("Info message - should appear") // Check output output := customLog.GetOutput() // Note: Our custom logger doesn't implement level filtering // This test validates that the API works, actual filtering // would be implemented in a production custom logger assertContains(t, output, "INFO: Info message") } func TestCustomHandlerRestore(t *testing.T) { // Save original logger originalLogger := gosnowflake.GetLogger() defer func() { gosnowflake.SetLogger(originalLogger) }() // Create logger with JSON handler logger1 := gosnowflake.CreateDefaultLogger() gosnowflake.SetLogger(logger1) buf1 := &bytes.Buffer{} if slogLogger, ok := gosnowflake.GetLogger().(gosnowflake.SFSlogLogger); ok { jsonHandler := slog.NewJSONHandler(buf1, &slog.HandlerOptions{ Level: slog.LevelInfo, }) _ = slogLogger.SetHandler(jsonHandler) } // Log with JSON handler _ = gosnowflake.GetLogger().SetLogLevel("info") gosnowflake.GetLogger().Info("JSON format message") // Verify JSON format output1 := buf1.String() assertJSONFormat(t, output1) assertContains(t, output1, "JSON format message") // Create new default logger (text format) logger2 := gosnowflake.CreateDefaultLogger() buf2 := &bytes.Buffer{} logger2.SetOutput(buf2) gosnowflake.SetLogger(logger2) // Log with default text handler _ = gosnowflake.GetLogger().SetLogLevel("info") gosnowflake.GetLogger().Info("Text format message") // Verify text format (not JSON) output2 := buf2.String() assertContains(t, output2, "Text format message") // Text format should have "level=" in it assertContains(t, output2, "level=") } ================================================ FILE: log_test.go ================================================ package gosnowflake import ( "bytes" "context" "fmt" "strings" "testing" ) func TestLogLevelEnabled(t *testing.T) { log := CreateDefaultLogger() // via the SFLogger interface. err := log.SetLogLevel("info") if err != nil { t.Fatalf("log level could not be set %v", err) } if log.GetLogLevel() != "INFO" { t.Fatalf("log level should be info but is %v", log.GetLogLevel()) } } func TestSetLogLevelError(t *testing.T) { logger := CreateDefaultLogger() err := logger.SetLogLevel("unknown") if err == nil { t.Fatal("should have thrown an error") } } func TestDefaultLogLevel(t *testing.T) { logger := CreateDefaultLogger() buf := &bytes.Buffer{} logger.SetOutput(buf) // default logger level is info logger.Info("info") logger.Infof("info%v", "f") // debug and trace won't write to log since they are higher than info level logger.Debug("debug") logger.Debugf("debug%v", "f") logger.Trace("trace") logger.Tracef("trace%v", "f") logger.Warn("warn") logger.Warnf("warn%v", "f") logger.Error("error") logger.Errorf("error%v", "f") // verify output var strbuf = buf.String() if !strings.Contains(strbuf, "info") || !strings.Contains(strbuf, "warn") || !strings.Contains(strbuf, "error") { t.Fatalf("unexpected output in log: %v", strbuf) } if strings.Contains(strbuf, "debug") || strings.Contains(strbuf, "trace") { t.Fatalf("debug/trace should not be in log: %v", strbuf) } } func TestOffLogLevel(t *testing.T) { logger := CreateDefaultLogger() buf := &bytes.Buffer{} logger.SetOutput(buf) err := logger.SetLogLevel("OFF") assertNilF(t, err) logger.Info("info") logger.Infof("info%v", "f") logger.Debug("debug") logger.Debugf("debug%v", "f") logger.Trace("trace") logger.Tracef("trace%v", "f") logger.Warn("warn") logger.Warnf("warn%v", "f") logger.Error("error") logger.Errorf("error%v", "f") assertEqualE(t, buf.Len(), 0, "log messages count") assertEqualE(t, logger.GetLogLevel(), "OFF", "log level") } func TestLogSetLevel(t *testing.T) { logger := CreateDefaultLogger() buf := &bytes.Buffer{} logger.SetOutput(buf) _ = logger.SetLogLevel("trace") logger.Trace("should print at trace level") logger.Debug("should print at debug level") var strbuf = buf.String() if !strings.Contains(strbuf, "trace level") || !strings.Contains(strbuf, "debug level") { t.Fatalf("unexpected output in log: %v", strbuf) } } func TestLowerLevelsAreSuppressed(t *testing.T) { logger := CreateDefaultLogger() buf := &bytes.Buffer{} logger.SetOutput(buf) _ = logger.SetLogLevel("info") logger.Trace("should print at trace level") logger.Debug("should print at debug level") logger.Info("should print at info level") logger.Warn("should print at warn level") logger.Error("should print at error level") var strbuf = buf.String() if strings.Contains(strbuf, "trace level") || strings.Contains(strbuf, "debug level") { t.Fatalf("unexpected debug and trace are not present in log: %v", strbuf) } if !strings.Contains(strbuf, "info level") || !strings.Contains(strbuf, "warn level") || !strings.Contains(strbuf, "error level") { t.Fatalf("expected info, warn, error output in log: %v", strbuf) } } func TestLogWithField(t *testing.T) { logger := CreateDefaultLogger() buf := &bytes.Buffer{} logger.SetOutput(buf) logger.WithField("field", "test").Info("hello") var strbuf = buf.String() if !strings.Contains(strbuf, "field") || !strings.Contains(strbuf, "test") { t.Fatalf("expected field and test in output: %v", strbuf) } } type testRequestIDCtxKey struct{} func TestLogKeysDefault(t *testing.T) { logger := CreateDefaultLogger() buf := &bytes.Buffer{} logger.SetOutput(buf) ctx := context.Background() // set the sessionID on the context to see if we have it in the logs sessionIDContextValue := "sessionID" ctx = context.WithValue(ctx, SFSessionIDKey, sessionIDContextValue) userContextValue := "madison" ctx = context.WithValue(ctx, SFSessionUserKey, userContextValue) // base case (not using RegisterContextVariableToLog to add additional types ) logger.WithContext(ctx).Info("test") var strbuf = buf.String() if !strings.Contains(strbuf, string(SFSessionIDKey)) || !strings.Contains(strbuf, sessionIDContextValue) { t.Fatalf("expected that sfSessionIdKey would be in logs if logger.WithContext was used, but got: %v", strbuf) } if !strings.Contains(strbuf, string(SFSessionUserKey)) || !strings.Contains(strbuf, userContextValue) { t.Fatalf("expected that SFSessionUserKey would be in logs if logger.WithContext was used, but got: %v", strbuf) } } func TestLogKeysWithRegisterContextVariableToLog(t *testing.T) { logger := CreateDefaultLogger() buf := &bytes.Buffer{} logger.SetOutput(buf) ctx := context.Background() // set the sessionID on the context to see if we have it in the logs sessionIDContextValue := "sessionID" ctx = context.WithValue(ctx, SFSessionIDKey, sessionIDContextValue) userContextValue := "testUser" ctx = context.WithValue(ctx, SFSessionUserKey, userContextValue) // test that RegisterContextVariableToLog works with non string keys logKey := "REQUEST_ID" contextIntVal := 123 ctx = context.WithValue(ctx, testRequestIDCtxKey{}, contextIntVal) getRequestKeyFunc := func(ctx context.Context) string { if requestContext, ok := ctx.Value(testRequestIDCtxKey{}).(int); ok { return fmt.Sprint(requestContext) } return "" } RegisterLogContextHook(logKey, getRequestKeyFunc) // base case (not using RegisterContextVariableToLog to add additional types ) logger.WithContext(ctx).Info("test") var strbuf = buf.String() if !strings.Contains(strbuf, string(SFSessionIDKey)) || !strings.Contains(strbuf, sessionIDContextValue) { t.Fatalf("expected that sfSessionIdKey would be in logs if logger.WithContext and RegisterContextVariableToLog was used, but got: %v", strbuf) } if !strings.Contains(strbuf, string(SFSessionUserKey)) || !strings.Contains(strbuf, userContextValue) { t.Fatalf("expected that SFSessionUserKey would be in logs if logger.WithContext and RegisterContextVariableToLog was used, but got: %v", strbuf) } if !strings.Contains(strbuf, logKey) || !strings.Contains(strbuf, fmt.Sprint(contextIntVal)) { t.Fatalf("expected that REQUEST_ID would be in logs if logger.WithContext and RegisterContextVariableToLog was used, but got: %v", strbuf) } } func TestLogMaskSecrets(t *testing.T) { logger := CreateDefaultLogger() buf := &bytes.Buffer{} logger.SetOutput(buf) ctx := context.Background() query := "create user testuser password='testpassword'" logger.WithContext(ctx).Infof("Query: %#v", query) // verify output expected := "create user testuser password='****" var strbuf = buf.String() if !strings.Contains(strbuf, expected) { t.Fatalf("expected that password would be masked. WithContext was used, but got: %v", strbuf) } } ================================================ FILE: minicore.go ================================================ package gosnowflake import ( "fmt" "os" "path/filepath" "runtime" "strings" "sync" "time" "github.com/snowflakedb/gosnowflake/v2/internal/compilation" internalos "github.com/snowflakedb/gosnowflake/v2/internal/os" ) const disableMinicoreEnv = "SF_DISABLE_MINICORE" var miniCoreOnce sync.Once var miniCoreMutex sync.RWMutex var miniCoreInstance miniCore var minicoreLoadLogs = struct { mu sync.Mutex logs []string startTime time.Time }{} type minicoreDirCandidate struct { dirType string path string preUseFunc func() error } func newMinicoreDirCandidate(dirType, path string) minicoreDirCandidate { return minicoreDirCandidate{ dirType: dirType, path: path, } } func (m minicoreDirCandidate) String() string { return m.dirType } // getMiniCoreFileName returns the filename of the loaded minicore library func getMiniCoreFileName() string { miniCoreMutex.RLock() defer miniCoreMutex.RUnlock() return corePlatformConfig.coreLibFileName } // miniCoreErrorType represents the category of minicore error that occurred. type miniCoreErrorType int // Error type constants for categorizing minicore failures. const ( miniCoreErrorTypeLoad miniCoreErrorType = iota // Library loading failed miniCoreErrorTypeSymbol // Symbol lookup failed miniCoreErrorTypeCall // Function call failed miniCoreErrorTypeInit // Initialization failed miniCoreErrorTypeWrite // File write failed ) // String returns a human-readable string representation of the error type. func (et miniCoreErrorType) String() string { switch et { case miniCoreErrorTypeLoad: return "load" case miniCoreErrorTypeSymbol: return "symbol" case miniCoreErrorTypeCall: return "call" case miniCoreErrorTypeInit: return "init" case miniCoreErrorTypeWrite: return "write" default: return "unknown" } } // miniCoreError represents a structured error from minicore operations. // It provides detailed context about what went wrong, where, and why. type miniCoreError struct { errorType miniCoreErrorType // errorType categorizes the kind of error platform string // platform identifies the OS where error occurred path string // path to the library file, if applicable err error // err wraps the underlying error cause } // Error returns a formatted error message with context about the failure. func (e *miniCoreError) Error() string { if e.path != "" { return fmt.Sprintf("minicore %s on %s (path: %s): %v", e.errorType, e.platform, e.path, e.err) } return fmt.Sprintf("minicore %s on %s: %v", e.errorType, e.platform, e.err) } // Unwrap returns the underlying error for error chain inspection. func (e *miniCoreError) Unwrap() error { return e.err } // newMiniCoreError creates a new structured minicore error with full context. func newMiniCoreError(errType miniCoreErrorType, platform, path string, err error) *miniCoreError { return &miniCoreError{ errorType: errType, platform: platform, path: path, err: err, } } // corePlatformConfigType holds platform-specific minicore configuration. type corePlatformConfigType struct { initialized bool // initialized indicates if the platform is supported coreLib []byte // coreLib contains the embedded native library coreLibFileName string // coreLibFileName is the filename from the go:embed directive } // corePlatformConfig holds platform-specific configuration. If not initialized, minicore is unsupported. var corePlatformConfig = corePlatformConfigType{} type miniCore interface { // FullVersion returns the version string from the native library. FullVersion() (string, error) } // erroredMiniCore implements miniCore but always returns an error. // It's used when minicore initialization fails. type erroredMiniCore struct { err error } // newErroredMiniCore creates a miniCore implementation that always returns the given error. func newErroredMiniCore(err error) *erroredMiniCore { minicoreDebugf("minicore error: %v", err) return &erroredMiniCore{err: err} } // FullVersion always returns an empty string and the stored error. func (emc erroredMiniCore) FullVersion() (string, error) { return "", emc.err } // miniCoreLoaderType manages the loading and initialization of the minicore native library. type miniCoreLoaderType struct { searchDirs []minicoreDirCandidate // searchDirs contains directories to search for the library } // newMiniCoreLoader creates a new minicore miniCoreLoaderType with platform-appropriate search directories. func newMiniCoreLoader() *miniCoreLoaderType { return &miniCoreLoaderType{ searchDirs: buildMiniCoreSearchDirs(), } } // buildMiniCoreSearchDirs constructs the list of directories to search for the minicore library. func buildMiniCoreSearchDirs() []minicoreDirCandidate { var dirs []minicoreDirCandidate // Add temp directory if tempDir, err := os.MkdirTemp("", "gosnowflake-cgo"); err == nil && tempDir != "" { minicoreDebugf("created temp directory for minicore loading") switch runtime.GOOS { case "linux", "darwin": if err = os.Chmod(tempDir, 0700); err == nil { minicoreDebugf("configured permissions to temp as 0700") dirs = append(dirs, newMinicoreDirCandidate("temp", tempDir)) } else { minicoreDebugf("cannot change minicore directory permissions to 0700") } default: dirs = append(dirs, newMinicoreDirCandidate("temp", tempDir)) } } else { minicoreDebugf("cannot create temp directory for gosnowflakecore: %v", err) } // Add platform-specific cache directory if cacheDir := getMiniCoreCacheDirInHome(); cacheDir != "" { dirCandidate := newMinicoreDirCandidate("home", cacheDir) dirCandidate.preUseFunc = func() error { minicoreDebugf("using cache directory: %v", cacheDir) if err := os.MkdirAll(cacheDir, 0700); err != nil { minicoreDebugf("cannot create %v: %v", cacheDir, err) return err } minicoreDebugf("created cache directory: %v, configured permissions to 0700", cacheDir) if runtime.GOOS == "linux" || runtime.GOOS == "darwin" { if err := os.Chmod(cacheDir, 0700); err != nil { minicoreDebugf("cannot change minicore cache directory permissions to 0700. %v", err) return err } } minicoreDebugf("configured permissions to cache directory as 0700") return nil } dirs = append(dirs, dirCandidate) } // Add current working directory if cwd, err := os.Getwd(); err == nil { dirs = append(dirs, newMinicoreDirCandidate("cwd", cwd)) } else { minicoreDebugf("cannot get current working directory: %v", err) } minicoreDebugf("candidate directories for minicore loading: %v", dirs) return dirs } // getMiniCoreCacheDirInHome returns the platform-specific cache directory for storing the minicore library. func getMiniCoreCacheDirInHome() string { homeDir, err := os.UserHomeDir() if err != nil { minicoreDebugf("cannot get user home directory: %v", err) return "" } switch runtime.GOOS { case "windows": return filepath.Join(homeDir, "AppData", "Local", "Snowflake", "Caches", "minicore") case "darwin": return filepath.Join(homeDir, "Library", "Caches", "Snowflake", "minicore") default: return filepath.Join(homeDir, ".cache", "snowflake", "minicore") } } // loadCore loads and initializes the minicore native library. func (l *miniCoreLoaderType) loadCore() miniCore { if !corePlatformConfig.initialized { return newErroredMiniCore(newMiniCoreError(miniCoreErrorTypeInit, runtime.GOOS, "", fmt.Errorf("minicore is not supported on %v/%v platform", runtime.GOOS, runtime.GOARCH))) } if linkingMode, err := compilation.CheckDynamicLinking(); err != nil || linkingMode == compilation.UnknownLinking { minicoreDebugf("cannot determine linking mode: %v, proceeding anyway", err) } else if linkingMode == compilation.StaticLinking { return newErroredMiniCore(newMiniCoreError(miniCoreErrorTypeLoad, runtime.GOOS, "", fmt.Errorf("binary is statically linked (no dynamic linker); dlopen is unavailable"))) } libDir, libPath, err := l.writeLibraryToFile() if err != nil { return newErroredMiniCore(err) } defer func(libDir minicoreDirCandidate, libPath string) { if err = os.Remove(libPath); err != nil { minicoreDebugf("cannot remove library. %v", err) } if libDir.dirType == "temp" { if err = os.Remove(libDir.path); err != nil { minicoreDebugf("cannot remove temp directory. %v", err) } } }(libDir, libPath) minicoreDebugf("Loading minicore library from: %s", libDir) return osSpecificLoadFromPath(libPath) } var osSpecificLoadFromPath = func(libPath string) miniCore { return newErroredMiniCore(fmt.Errorf("minicore loader is not available on %v/%v", runtime.GOOS, runtime.GOARCH)) } // writeLibraryToFile writes the embedded library to the first available directory func (l *miniCoreLoaderType) writeLibraryToFile() (minicoreDirCandidate, string, error) { var errs []error for _, dir := range l.searchDirs { if dir.preUseFunc != nil { if err := dir.preUseFunc(); err != nil { minicoreDebugf("Failed to prepare directory %q: %v", dir.path, err) errs = append(errs, fmt.Errorf("failed to prepare directory %q: %v", dir.path, err)) continue } } libPath := filepath.Join(dir.path, corePlatformConfig.coreLibFileName) if err := os.WriteFile(libPath, corePlatformConfig.coreLib, 0600); err != nil { minicoreDebugf("Failed to write embedded library to %q: %v", libPath, err) errs = append(errs, fmt.Errorf("failed to write to %q: %v", libPath, err)) continue } minicoreDebugf("Successfully wrote embedded library to %s", dir) return dir, libPath, nil } return minicoreDirCandidate{}, "", newMiniCoreError(miniCoreErrorTypeWrite, runtime.GOOS, "", fmt.Errorf("failed to write embedded library to any directory (errors: %v)", errs)) } // getMiniCore returns the minicore instance, loading it asynchronously if needed. func getMiniCore() miniCore { miniCoreOnce.Do(func() { minicoreDebugf("minicore enabled at compile time: %v", compilation.MinicoreEnabled) minicoreDebugf("cgo enabled: %v", compilation.CgoEnabled) if !compilation.MinicoreEnabled { logger.Debugf("minicore disabled at compile time (built with -tags minicore_disabled)") return } if strings.EqualFold(os.Getenv(disableMinicoreEnv), "true") { logger.Debugf("minicore loading disabled") return } go func() { minicoreLoadLogs.mu.Lock() minicoreLoadLogs.startTime = time.Now() minicoreLoadLogs.mu.Unlock() minicoreDebugf("Starting asynchronous minicore loading") miniCoreLoader := newMiniCoreLoader() core := miniCoreLoader.loadCore() miniCoreMutex.Lock() miniCoreInstance = core miniCoreMutex.Unlock() if v, err := core.FullVersion(); err != nil { minicoreDebugf("Minicore version not available: %v", err) } else { minicoreDebugf("Minicore loading completed, version: %s", v) } }() }) // Return current instance (may be nil initially) miniCoreMutex.RLock() defer miniCoreMutex.RUnlock() return miniCoreInstance } func init() { // Start async minicore loading but don't block initialization. // This allows the application to start quickly while minicore loads in the background. getMiniCore() } func minicoreDebugf(format string, args ...any) { minicoreLoadLogs.mu.Lock() defer minicoreLoadLogs.mu.Unlock() var finalArgs []any finalArgs = append(finalArgs, time.Since(minicoreLoadLogs.startTime)) finalArgs = append(finalArgs, args...) finalFormat := "[%v] " + format logger.Debugf(finalFormat, finalArgs...) minicoreLoadLogs.logs = append(minicoreLoadLogs.logs, maskSecrets(fmt.Sprintf(finalFormat, finalArgs...))) } // libcType represents the type of C library in use type libcType string const ( libcTypeGlibc libcType = "glibc" libcTypeMusl libcType = "musl" libcTypeIgnored libcType = "" ) // detectLibc detects whether glibc or musl is in use func detectLibc() libcType { if runtime.GOOS != "linux" { return libcTypeIgnored } info := internalos.GetLibcInfo() switch info.Family { case "glibc": minicoreDebugf("detected glibc environment") if info.Version != "" { minicoreDebugf("glibc version: %s", info.Version) } return libcTypeGlibc case "musl": minicoreDebugf("detected musl environment") if info.Version != "" { minicoreDebugf("musl version: %s", info.Version) } return libcTypeMusl default: minicoreDebugf("Could not detect libc type, assuming glibc") return libcTypeGlibc } } ================================================ FILE: minicore_disabled_test.go ================================================ //go:build minicore_disabled package gosnowflake import ( "database/sql" "testing" "github.com/snowflakedb/gosnowflake/v2/internal/compilation" ) func TestMiniCoreDisabledAtCompileTime(t *testing.T) { assertFalseF(t, compilation.MinicoreEnabled, "MinicoreEnabled should be false when built with -tags minicore_disabled") } func TestMiniCoreDisabledE2E(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("minicore/auth/disabled_flow.json"), newWiremockMapping("select1.json")) cfg := wiremock.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) } ================================================ FILE: minicore_posix.go ================================================ //go:build !windows && !minicore_disabled package gosnowflake /* #cgo LDFLAGS: -ldl #include #include #include static void* dlOpen(const char* path) { return dlopen(path, RTLD_LAZY); } static void* dlSym(void* handle, const char* name) { return dlsym(handle, name); } static int dlClose(void* handle) { return dlclose(handle); } static char* dlError() { return dlerror(); } typedef const char* (*coreFullVersion)(); static const char* callCoreFullVersion(coreFullVersion f) { return f(); } */ import "C" import ( "errors" "fmt" "unsafe" ) type posixMiniCore struct { // fullVersion holds the version string returned from Rust, just to not invoke it multiple times. fullVersion string // coreInitError holds any error that occurred during initialization. coreInitError error } func newPosixMiniCore(fullVersion string) *posixMiniCore { return &posixMiniCore{ fullVersion: fullVersion, } } func (pmc *posixMiniCore) FullVersion() (string, error) { return pmc.fullVersion, pmc.coreInitError } var _ = func() any { osSpecificLoadFromPath = loadFromPath return nil }() func loadFromPath(libPath string) miniCore { cLibPath := C.CString(libPath) defer C.free(unsafe.Pointer(cLibPath)) // Loading library minicoreDebugf("Calling dlOpen") handle := C.dlOpen(cLibPath) minicoreDebugf("Calling dlOpen finished") if handle == nil { err := C.dlError() mcErr := newMiniCoreError(miniCoreErrorTypeLoad, "posix", libPath, fmt.Errorf("failed to load shared library: %v", C.GoString(err))) return newErroredMiniCore(mcErr) } // Unloading library at the end defer func() { minicoreDebugf("Calling dlClose") defer minicoreDebugf("Calling dlClose finished") if ret := C.dlClose(handle); ret != 0 { err := C.dlError() minicoreDebugf("Error when closing dynamic library: %v", C.GoString(err)) } }() // Loading symbol symbolName := C.CString("sf_core_full_version") defer C.free(unsafe.Pointer(symbolName)) minicoreDebugf("Loading sf_core_full_version symbol") coreFullVersionSymbol := C.dlSym(handle, symbolName) minicoreDebugf("Loading sf_core_full_version symbol finished") if coreFullVersionSymbol == nil { err := C.dlError() mcErr := newMiniCoreError(miniCoreErrorTypeSymbol, "posix", libPath, fmt.Errorf("symbol 'sf_core_full_version' not found: %v", C.GoString(err))) return newErroredMiniCore(mcErr) } // Calling minicore var coreFullVersionFunc C.coreFullVersion = (C.coreFullVersion)(coreFullVersionSymbol) minicoreDebugf("Calling sf_core_full_version") fullVersion := C.GoString(C.callCoreFullVersion(coreFullVersionFunc)) minicoreDebugf("Calling sf_core_full_version finished") if fullVersion == "" { return newErroredMiniCore(newMiniCoreError(miniCoreErrorTypeCall, "posix", libPath, errors.New("failed to get version from core library function"))) } return newPosixMiniCore(fullVersion) } ================================================ FILE: minicore_provider_darwin_amd64.go ================================================ //go:build !minicore_disabled package gosnowflake import ( // embed is used only to initialize go:embed directive _ "embed" ) //go:embed libsf_mini_core_darwin_amd64.dylib var coreLibDarwinAmd64 []byte var _ = initMinicoreProvider() func initMinicoreProvider() any { corePlatformConfig.coreLib = coreLibDarwinAmd64 corePlatformConfig.coreLibFileName = "libsf_mini_core_darwin_amd64.dylib" corePlatformConfig.initialized = true return nil } ================================================ FILE: minicore_provider_darwin_arm64.go ================================================ //go:build !minicore_disabled package gosnowflake import ( // embed is used only to initialize go:embed directive _ "embed" ) //go:embed libsf_mini_core_darwin_arm64.dylib var coreLibDarwinArm64 []byte var _ = initMinicoreProvider() func initMinicoreProvider() any { corePlatformConfig.coreLib = coreLibDarwinArm64 corePlatformConfig.coreLibFileName = "libsf_mini_core_darwin_arm64.dylib" corePlatformConfig.initialized = true return nil } ================================================ FILE: minicore_provider_linux_amd64.go ================================================ //go:build !minicore_disabled package gosnowflake import ( // embed is used only to initialize go:embed directive _ "embed" ) //go:embed libsf_mini_core_linux_amd64_glibc.so var coreLibLinuxAmd64Glibc []byte //go:embed libsf_mini_core_linux_amd64_musl.so var coreLibLinuxAmd64Musl []byte var _ = initMinicoreProvider() func initMinicoreProvider() any { switch detectLibc() { case libcTypeGlibc: corePlatformConfig.coreLib = coreLibLinuxAmd64Glibc corePlatformConfig.coreLibFileName = "libsf_mini_core_linux_amd64_glibc.so" case libcTypeMusl: corePlatformConfig.coreLib = coreLibLinuxAmd64Musl corePlatformConfig.coreLibFileName = "libsf_mini_core_linux_amd64_musl.so" default: minicoreDebugf("unknown libc") return nil } corePlatformConfig.initialized = true return nil } ================================================ FILE: minicore_provider_linux_arm64.go ================================================ //go:build !minicore_disabled package gosnowflake import ( // embed is used only to initialize go:embed directive _ "embed" ) //go:embed libsf_mini_core_linux_arm64_glibc.so var coreLibLinuxArm64Glibc []byte //go:embed libsf_mini_core_linux_arm64_musl.so var coreLibLinuxArm64Musl []byte var _ = initMinicoreProvider() func initMinicoreProvider() any { switch detectLibc() { case libcTypeGlibc: corePlatformConfig.coreLib = coreLibLinuxArm64Glibc corePlatformConfig.coreLibFileName = "libsf_mini_core_linux_arm64_glibc.so" case libcTypeMusl: corePlatformConfig.coreLib = coreLibLinuxArm64Musl corePlatformConfig.coreLibFileName = "libsf_mini_core_linux_arm64_musl.so" default: minicoreDebugf("unknown libc") return nil } corePlatformConfig.initialized = true return nil } ================================================ FILE: minicore_provider_windows_amd64.go ================================================ //go:build !minicore_disabled package gosnowflake import ( // embed is used only to initialize go:embed directive _ "embed" ) //go:embed libsf_mini_core_windows_amd64.dll var coreLibWindowsAmd64Glibc []byte var _ = initMinicoreProvider() func initMinicoreProvider() any { corePlatformConfig.coreLib = coreLibWindowsAmd64Glibc corePlatformConfig.coreLibFileName = "libsf_mini_core_windows_amd64.dll" corePlatformConfig.initialized = true return nil } ================================================ FILE: minicore_provider_windows_arm64.go ================================================ //go:build !minicore_disabled package gosnowflake import ( // embed is used only to initialize go:embed directive _ "embed" ) //go:embed libsf_mini_core_windows_arm64.dll var coreLibWindowsArm64Glibc []byte var _ = initMinicoreProvider() func initMinicoreProvider() any { corePlatformConfig.coreLib = coreLibWindowsArm64Glibc corePlatformConfig.coreLibFileName = "libsf_mini_core_windows_arm64.dll" corePlatformConfig.initialized = true return nil } ================================================ FILE: minicore_test.go ================================================ //go:build !minicore_disabled package gosnowflake import ( "database/sql" "os" "runtime" "strings" "testing" "time" "github.com/snowflakedb/gosnowflake/v2/internal/compilation" ) func TestMiniCoreLoadSuccess(t *testing.T) { mcl := newMiniCoreLoader() checkLoadCore(t, mcl) } func checkLoadCore(t *testing.T, mcl *miniCoreLoaderType) { core := mcl.loadCore() assertNotNilF(t, core) fullVersion, err := core.FullVersion() assertNilF(t, err) assertEqualE(t, fullVersion, "0.0.1") } func TestMiniCoreLoaderChoosesCorrectCandidates(t *testing.T) { skipOnMissingHome(t) assertNilF(t, os.RemoveAll(getMiniCoreCacheDirInHome())) mcl := newMiniCoreLoader() checkAllLoadDirsAvailable(t, mcl) } func TestMiniCoreLoaderChoosesCorrectCandidatesWhenHomeCacheDirAlreadyExists(t *testing.T) { skipOnMissingHome(t) mcl := newMiniCoreLoader() checkAllLoadDirsAvailable(t, mcl) mcl = newMiniCoreLoader() checkAllLoadDirsAvailable(t, mcl) } func checkAllLoadDirsAvailable(t *testing.T, mcl *miniCoreLoaderType) { assertEqualF(t, len(mcl.searchDirs), 3) assertEqualE(t, mcl.searchDirs[0].dirType, "temp") assertEqualE(t, mcl.searchDirs[1].dirType, "home") assertEqualE(t, mcl.searchDirs[2].dirType, "cwd") } func TestMiniCoreNoFolderCandidate(t *testing.T) { mcl := newMiniCoreLoader() mcl.searchDirs = []minicoreDirCandidate{} core := mcl.loadCore() version, err := core.FullVersion() assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "failed to write embedded library to any directory") assertEqualE(t, version, "") } func TestMiniCoreNoWritableFolder(t *testing.T) { skipOnWindows(t, "permission system is different") tempDir := t.TempDir() err := os.Chmod(tempDir, 0000) assertNilF(t, err) defer os.Chmod(tempDir, 0700) mcl := newMiniCoreLoader() mcl.searchDirs = []minicoreDirCandidate{newMinicoreDirCandidate("test", tempDir)} core := mcl.loadCore() assertNotNilF(t, core) _, err = core.FullVersion() assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "failed to write embedded library to any directory") } func TestMiniCoreNoWritableFirstFolder(t *testing.T) { tempDir := t.TempDir() err := os.Chmod(tempDir, 0000) defer os.Chmod(tempDir, 0700) tempDir2 := t.TempDir() assertNilF(t, err) mcl := newMiniCoreLoader() mcl.searchDirs = []minicoreDirCandidate{newMinicoreDirCandidate("test", tempDir), newMinicoreDirCandidate("test", tempDir2)} checkLoadCore(t, mcl) } func TestMiniCoreInvalidDynamicLibrary(t *testing.T) { origCoreLib := corePlatformConfig.coreLib defer func() { corePlatformConfig.coreLib = origCoreLib }() corePlatformConfig.coreLib = []byte("invalid content") mcl := newMiniCoreLoader() core := mcl.loadCore() assertNotNilF(t, core) _, err := core.FullVersion() assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "failed to load shared library") } func TestMiniCoreNotInitialized(t *testing.T) { defer func() { corePlatformConfig.initialized = true }() corePlatformConfig.initialized = false mcl := newMiniCoreLoader() core := mcl.loadCore() assertNotNilF(t, core) _, err := core.FullVersion() assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "minicore is not supported on") } func TestMiniCoreLoadLogsVersion(t *testing.T) { minicoreLoadLogs.mu.Lock() minicoreLoadLogs.logs = nil minicoreLoadLogs.startTime = time.Now() minicoreLoadLogs.mu.Unlock() mcl := newMiniCoreLoader() core := mcl.loadCore() assertNotNilF(t, core) v, err := core.FullVersion() assertNilF(t, err) minicoreDebugf("Minicore loading completed, version: %s", v) minicoreLoadLogs.mu.Lock() joined := strings.Join(minicoreLoadLogs.logs, "\n") minicoreLoadLogs.mu.Unlock() assertStringContainsE(t, joined, "Minicore loading completed, version: 0.0.1") } func TestIsDynamicallyLinked(t *testing.T) { linkingMode, err := compilation.CheckDynamicLinking() if runtime.GOOS == "linux" { assertNilF(t, err, "should be able to read /proc/self/exe") assertEqualE(t, linkingMode, compilation.DynamicLinking, "go test binaries should be dynamically linked") } else { assertEqualE(t, linkingMode, compilation.UnknownLinking, "linking mode should be unknown on non-linux OS") } } func TestMiniCoreLoadedE2E(t *testing.T) { logger.SetLogLevel("debug") mappingFile := "minicore/auth/successful_flow.json" if runtime.GOOS == "linux" { mappingFile = "minicore/auth/successful_flow_linux.json" } wiremock.registerMappings(t, newWiremockMapping(mappingFile), newWiremockMapping("select1.json")) cfg := wiremock.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) } ================================================ FILE: minicore_windows.go ================================================ //go:build windows && !minicore_disabled package gosnowflake import ( _ "embed" "fmt" "golang.org/x/sys/windows" "syscall" "unsafe" ) type windowsMiniCore struct { // fullVersion holds the version string returned from the library fullVersion string // coreInitError holds any error that occurred during initialization coreInitError error } var _ = func() any { osSpecificLoadFromPath = loadFromPath return nil }() func (wmc *windowsMiniCore) FullVersion() (string, error) { return wmc.fullVersion, wmc.coreInitError } func loadFromPath(libPath string) miniCore { minicoreDebugf("Calling LoadLibrary") dllHandle, err := windows.LoadLibrary(libPath) minicoreDebugf("Calling LoadLibrary finished") if err != nil { mcErr := newMiniCoreError(miniCoreErrorTypeLoad, "windows", libPath, fmt.Errorf("failed to load shared library: %v", err)) return newErroredMiniCore(mcErr) } // Release the DLL handle, because we cache minicore fullVersion result. defer windows.FreeLibrary(dllHandle) // Get the address of the function minicoreDebugf("getting procedure address") procAddr, err := windows.GetProcAddress(dllHandle, "sf_core_full_version") if err != nil { mcErr := newMiniCoreError(miniCoreErrorTypeSymbol, "windows", libPath, fmt.Errorf("procedure sf_core_full_version not found: %v", err)) return newErroredMiniCore(mcErr) } minicoreDebugf("Invoking system call") // Second return value - omitted, required for syscalls that returns more values ret, _, callErr := syscall.Syscall( procAddr, 0, // nargs: Number of arguments is ZERO 0, // a1: Argument 1 (unused) 0, // a2: Argument 2 (unused) 0, // a3: Argument 3 (unused) ) minicoreDebugf("Invoking system call finished") if callErr != 0 { mcErr := newMiniCoreError(miniCoreErrorTypeCall, "windows", libPath, fmt.Errorf("system call failed with error code: %v", callErr)) return newErroredMiniCore(mcErr) } cStrPtr := (*byte)(unsafe.Pointer(ret)) if cStrPtr == nil { mcErr := newMiniCoreError(miniCoreErrorTypeCall, "windows", libPath, fmt.Errorf("native function returned null pointer (error code: %v)", callErr)) return newErroredMiniCore(mcErr) } goStr := windows.BytePtrToString(cStrPtr) return &windowsMiniCore{ fullVersion: goStr, } } ================================================ FILE: monitoring.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "encoding/json" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "net/url" "strconv" "time" ) const urlQueriesResultFmt = "/queries/%s/result" // queryResultStatus is status returned from server type queryResultStatus int // Query Status defined at server side const ( // Deprecated: will be unexported in the future releases. SFQueryRunning queryResultStatus = iota // Deprecated: will be unexported in the future releases. SFQueryAborting // Deprecated: will be unexported in the future releases. SFQuerySuccess // Deprecated: will be unexported in the future releases. SFQueryFailedWithError // Deprecated: will be unexported in the future releases. SFQueryAborted // Deprecated: will be unexported in the future releases. SFQueryQueued // Deprecated: will be unexported in the future releases. SFQueryFailedWithIncident // Deprecated: will be unexported in the future releases. SFQueryDisconnected // Deprecated: will be unexported in the future releases. SFQueryResumingWarehouse // SFQueryQueueRepairingWarehouse present in QueryDTO.java. // Deprecated: will be unexported in the future releases. SFQueryQueueRepairingWarehouse // Deprecated: will be unexported in the future releases. SFQueryRestarted // SFQueryBlocked is when a statement is waiting on a lock on resource held // by another statement. // Deprecated: will be unexported in the future releases. SFQueryBlocked // Deprecated: will be unexported in the future releases. SFQueryNoData ) func (qs queryResultStatus) String() string { return [...]string{"RUNNING", "ABORTING", "SUCCESS", "FAILED_WITH_ERROR", "ABORTED", "QUEUED", "FAILED_WITH_INCIDENT", "DISCONNECTED", "RESUMING_WAREHOUSE", "QUEUED_REPAIRING_WAREHOUSE", "RESTARTED", "BLOCKED", "NO_DATA"}[qs] } func (qs queryResultStatus) isRunning() bool { switch qs { case SFQueryRunning, SFQueryResumingWarehouse, SFQueryQueued, SFQueryQueueRepairingWarehouse, SFQueryNoData: return true default: return false } } func (qs queryResultStatus) isError() bool { switch qs { case SFQueryAborting, SFQueryFailedWithError, SFQueryAborted, SFQueryFailedWithIncident, SFQueryDisconnected, SFQueryBlocked: return true default: return false } } var strQueryStatusMap = map[string]queryResultStatus{"RUNNING": SFQueryRunning, "ABORTING": SFQueryAborting, "SUCCESS": SFQuerySuccess, "FAILED_WITH_ERROR": SFQueryFailedWithError, "ABORTED": SFQueryAborted, "QUEUED": SFQueryQueued, "FAILED_WITH_INCIDENT": SFQueryFailedWithIncident, "DISCONNECTED": SFQueryDisconnected, "RESUMING_WAREHOUSE": SFQueryResumingWarehouse, "QUEUED_REPAIRING_WAREHOUSE": SFQueryQueueRepairingWarehouse, "RESTARTED": SFQueryRestarted, "BLOCKED": SFQueryBlocked, "NO_DATA": SFQueryNoData} type retStatus struct { Status string `json:"status"` SQLText string `json:"sqlText"` StartTime int64 `json:"startTime"` EndTime int64 `json:"endTime"` ErrorCode string `json:"errorCode"` ErrorMessage string `json:"errorMessage"` Stats retStats `json:"stats"` } type retStats struct { ScanBytes int64 `json:"scanBytes"` ProducedRows int64 `json:"producedRows"` } type statusResponse struct { Data struct { Queries []retStatus `json:"queries"` } `json:"data"` Message string `json:"message"` Code string `json:"code"` Success bool `json:"success"` } func strToQueryStatus(in string) queryResultStatus { return strQueryStatusMap[in] } // SnowflakeQueryStatus is the query status metadata of a snowflake query type SnowflakeQueryStatus struct { SQLText string StartTime int64 EndTime int64 ErrorCode string ErrorMessage string ScanBytes int64 ProducedRows int64 } // SnowflakeConnection is a wrapper to snowflakeConn that exposes API functions type SnowflakeConnection interface { GetQueryStatus(ctx context.Context, queryID string) (*SnowflakeQueryStatus, error) AddTelemetryData(ctx context.Context, eventDate time.Time, data map[string]string) error } // checkQueryStatus returns the status given the query ID. If successful, // the error will be nil, indicating there is a complete query result to fetch. // Other than nil, there are three error types that can be returned: // 1. ErrQueryStatus, if GS cannot return any kind of status due to any reason, // i.e. connection, permission, if a query was just submitted, etc. // 2, ErrQueryReportedError, if the requested query was terminated or aborted // and GS returned an error status included in query. SFQueryFailedWithError // 3, ErrQueryIsRunning, if the requested query is still running and might have // a complete result later, these statuses were listed in query. SFQueryRunning func (sc *snowflakeConn) checkQueryStatus( ctx context.Context, qid string) ( *retStatus, error) { headers := make(map[string]string) param := make(url.Values) param.Set(requestGUIDKey, NewUUID().String()) if tok, _, _ := sc.rest.TokenAccessor.GetTokens(); tok != "" { headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, tok) } resultPath := fmt.Sprintf("%s/%s", monitoringQueriesPath, qid) url := sc.rest.getFullURL(resultPath, ¶m) res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout) if err != nil { logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) return nil, err } defer func() { if err = res.Body.Close(); err != nil { logger.WithContext(ctx).Warnf("failed to close response body. err: %v", err) } }() var statusResp = statusResponse{} if err = json.NewDecoder(res.Body).Decode(&statusResp); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } if !statusResp.Success || len(statusResp.Data.Queries) == 0 { logger.WithContext(ctx).Errorf("status query returned not-success or no status returned.") return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrQueryStatus, Message: "status query returned not-success or no status returned. Please retry", }, sc) } queryRet := statusResp.Data.Queries[0] if queryRet.ErrorCode != "" { return &queryRet, exceptionTelemetry(&SnowflakeError{ Number: ErrQueryStatus, Message: errors.ErrMsgQueryStatus, MessageArgs: []any{queryRet.ErrorCode, queryRet.ErrorMessage}, IncludeQueryID: true, QueryID: qid, }, sc) } // returned errorCode is 0. Now check what is the returned status of the query. qStatus := strToQueryStatus(queryRet.Status) if qStatus.isError() { return &queryRet, exceptionTelemetry(&SnowflakeError{ Number: ErrQueryReportedError, Message: fmt.Sprintf("%s: status from server: [%s]", queryRet.ErrorMessage, queryRet.Status), IncludeQueryID: true, QueryID: qid, }, sc) } if qStatus.isRunning() { return &queryRet, exceptionTelemetry(&SnowflakeError{ Number: ErrQueryIsRunning, Message: fmt.Sprintf("%s: status from server: [%s]", queryRet.ErrorMessage, queryRet.Status), IncludeQueryID: true, QueryID: qid, }, sc) } //success return &queryRet, nil } func (sc *snowflakeConn) getQueryResultResp( ctx context.Context, resultPath string) ( *execResponse, error) { headers := getHeaders() if sn, ok := sc.syncParams.get(serviceName); ok { headers[httpHeaderServiceName] = *sn } param := make(url.Values) param.Set(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) param.Set("clientStartTime", strconv.FormatInt(sc.currentTimeProvider.currentTime(), 10)) param.Set(requestGUIDKey, NewUUID().String()) token, _, _ := sc.rest.TokenAccessor.GetTokens() if token != "" { headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) } url := sc.rest.getFullURL(resultPath, ¶m) respd, err := getQueryResultWithRetriesForAsyncMode(ctx, sc.rest, url, headers, sc.rest.RequestTimeout) if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) return nil, err } return respd, nil } // Fetch query result for a query id from /queries//result endpoint. func (sc *snowflakeConn) rowsForRunningQuery( ctx context.Context, qid string, rows *snowflakeRows) error { resultPath := fmt.Sprintf(urlQueriesResultFmt, qid) resp, err := sc.getQueryResultResp(ctx, resultPath) if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) return err } if !resp.Success { code, err := strconv.Atoi(resp.Code) if err != nil { return err } return exceptionTelemetry(&SnowflakeError{ Number: code, SQLState: resp.Data.SQLState, Message: resp.Message, QueryID: resp.Data.QueryID, }, sc) } rows.addDownloader(populateChunkDownloader(ctx, sc, resp.Data)) return nil } // prepare a Rows object to return for query of 'qid' func (sc *snowflakeConn) buildRowsForRunningQuery( ctx context.Context, qid string) ( driver.Rows, error) { rows := new(snowflakeRows) rows.sc = sc rows.queryID = qid rows.ctx = ctx if err := sc.rowsForRunningQuery(ctx, qid, rows); err != nil { return nil, err } err := rows.ChunkDownloader.start() return rows, err } ================================================ FILE: multistatement.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "strconv" "strings" ) type childResult struct { id string typ string } func getChildResults(IDs string, types string) []childResult { if IDs == "" { return nil } queryIDs := strings.Split(IDs, ",") resultTypes := strings.Split(types, ",") res := make([]childResult, len(queryIDs)) for i, id := range queryIDs { res[i] = childResult{id, resultTypes[i]} } return res } func (sc *snowflakeConn) handleMultiExec( ctx context.Context, data execResponseData) ( driver.Result, error) { if data.ResultIDs == "" { return nil, exceptionTelemetry(&SnowflakeError{ Number: ErrNoResultIDs, SQLState: data.SQLState, Message: errors.ErrMsgNoResultIDs, QueryID: data.QueryID, }, sc) } var updatedRows int64 childResults := getChildResults(data.ResultIDs, data.ResultTypes) for _, child := range childResults { resultPath := fmt.Sprintf(urlQueriesResultFmt, child.id) childResultType, err := strconv.ParseInt(child.typ, 10, 64) if err != nil { return nil, err } if isDml(childResultType) { childData, err := sc.getQueryResultResp(ctx, resultPath) if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) return nil, err } if childData != nil && !childData.Success { code, err := strconv.Atoi(childData.Code) if err != nil { return nil, err } return nil, exceptionTelemetry(&SnowflakeError{ Number: code, SQLState: childData.Data.SQLState, Message: childData.Message, QueryID: childData.Data.QueryID, }, sc) } count, err := updateRows(childData.Data) if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) return nil, err } updatedRows += count } } logger.WithContext(ctx).Infof("number of updated rows: %#v", updatedRows) return &snowflakeResult{ affectedRows: updatedRows, insertID: -1, queryID: data.QueryID, }, nil } // Fill the corresponding rows and add chunk downloader into the rows when // iterating across the childResults func (sc *snowflakeConn) handleMultiQuery( ctx context.Context, data execResponseData, rows *snowflakeRows) error { if data.ResultIDs == "" { return exceptionTelemetry(&SnowflakeError{ Number: ErrNoResultIDs, SQLState: data.SQLState, Message: errors.ErrMsgNoResultIDs, QueryID: data.QueryID, }, sc) } childResults := getChildResults(data.ResultIDs, data.ResultTypes) for _, child := range childResults { if err := sc.rowsForRunningQuery(ctx, child.id, rows); err != nil { return err } } return nil } ================================================ FILE: multistatement_test.go ================================================ package gosnowflake import ( "context" "encoding/json" "errors" "io" "net/http" "net/url" "os" "reflect" "testing" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "time" ) func TestMultiStatementExecuteNoResultSet(t *testing.T) { ctx := WithMultiStatement(context.Background(), 4) multiStmtQuery := "begin;\n" + "delete from test_multi_statement_txn;\n" + "insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" + "commit;" runDBTest(t, func(dbt *DBTest) { dbt.mustExec(`create or replace table test_multi_statement_txn(c1 number, c2 string) as select 10, 'z'`) res := dbt.mustExecContext(ctx, multiStmtQuery) count, err := res.RowsAffected() if err != nil { t.Fatalf("res.RowsAffected() returned error: %v", err) } if count != 3 { t.Fatalf("expected 3 affected rows, got %d", count) } }) } func TestMultiStatementQueryResultSet(t *testing.T) { ctx := WithMultiStatement(context.Background(), 4) multiStmtQuery := "select 123;\n" + "select 456;\n" + "select 789;\n" + "select '000';" var v1, v2, v3 int64 var v4 string runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, multiStmtQuery) defer rows.Close() // first statement if rows.Next() { if err := rows.Scan(&v1); err != nil { t.Errorf("failed to scan: %#v", err) } if v1 != 123 { t.Fatalf("failed to fetch. value: %v", v1) } } else { t.Error("failed to query") } // second statement if !rows.NextResultSet() { t.Error("failed to retrieve next result set") } if rows.Next() { if err := rows.Scan(&v2); err != nil { t.Errorf("failed to scan: %#v", err) } if v2 != 456 { t.Fatalf("failed to fetch. value: %v", v2) } } else { t.Error("failed to query") } // third statement if !rows.NextResultSet() { t.Error("failed to retrieve next result set") } if rows.Next() { if err := rows.Scan(&v3); err != nil { t.Errorf("failed to scan: %#v", err) } if v3 != 789 { t.Fatalf("failed to fetch. value: %v", v3) } } else { t.Error("failed to query") } // fourth statement if !rows.NextResultSet() { t.Error("failed to retrieve next result set") } if rows.Next() { if err := rows.Scan(&v4); err != nil { t.Errorf("failed to scan: %#v", err) } if v4 != "000" { t.Fatalf("failed to fetch. value: %v", v4) } } else { t.Error("failed to query") } }) } // TestMultistatementQueryLargeResultSet validates multi-statement queries with // chunked results. The 1,000,000 row count per statement is required to trigger // Snowflake's chunked result delivery. A bug in HasNextResultSet/NextResultSet // (SNOW-1646792) only manifested with large, multi-chunk result sets. Do not // reduce the row count — smaller values may fit in a single chunk and miss the // bug class this test guards against. func TestMultistatementQueryLargeResultSet(t *testing.T) { ctx := WithMultiStatement(context.Background(), 2) runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContextT(ctx, t, "SELECT 'abc' FROM TABLE(GENERATOR(ROWCOUNT => 1000000)); SELECT 'abc' FROM TABLE(GENERATOR(ROWCOUNT => 1000000))") totalRows := 0 for hasNextResultSet := true; hasNextResultSet; hasNextResultSet = rows.NextResultSet() { for rows.Next() { var s string rows.mustScan(&s) assertEqualE(t, s, "abc") totalRows++ } } assertEqualE(t, totalRows, 2000000) }) } func TestMultiStatementExecuteResultSet(t *testing.T) { ctx := WithMultiStatement(context.Background(), 6) multiStmtQuery := "begin;\n" + "delete from test_multi_statement_txn_rb;\n" + "insert into test_multi_statement_txn_rb values (1, 'a'), (2, 'b');\n" + "select 1;\n" + "select 2;\n" + "rollback;" runDBTest(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_txn_rb") dbt.mustExec(`create or replace table test_multi_statement_txn_rb( c1 number, c2 string) as select 10, 'z'`) defer dbt.mustExec("drop table if exists test_multi_statement_txn_rb") res := dbt.mustExecContext(ctx, multiStmtQuery) count, err := res.RowsAffected() if err != nil { t.Fatalf("res.RowsAffected() returned error: %v", err) } if count != 3 { t.Fatalf("expected 3 affected rows, got %d", count) } }) } func TestMultiStatementQueryNoResultSet(t *testing.T) { ctx := WithMultiStatement(context.Background(), 4) multiStmtQuery := "begin;\n" + "delete from test_multi_statement_txn;\n" + "insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" + "commit;" runDBTest(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_txn") dbt.mustExec(`create or replace table test_multi_statement_txn( c1 number, c2 string) as select 10, 'z'`) defer dbt.mustExec("drop table if exists tfmuest_multi_statement_txn") rows := dbt.mustQueryContext(ctx, multiStmtQuery) defer rows.Close() }) } func TestMultiStatementExecuteMix(t *testing.T) { ctx := WithMultiStatement(context.Background(), 3) multiStmtQuery := "create or replace temporary table test_multi (cola int);\n" + "insert into test_multi values (1), (2);\n" + "select cola from test_multi order by cola asc;" runDBTest(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_txn") dbt.mustExec(`create or replace table test_multi_statement_txn( c1 number, c2 string) as select 10, 'z'`) defer dbt.mustExec("drop table if exists test_multi_statement_txn") res := dbt.mustExecContext(ctx, multiStmtQuery) count, err := res.RowsAffected() if err != nil { t.Fatalf("res.RowsAffected() returned error: %v", err) } if count != 2 { t.Fatalf("expected 2 affected rows, got %d", count) } }) } func TestMultiStatementQueryMix(t *testing.T) { ctx := WithMultiStatement(context.Background(), 3) multiStmtQuery := "create or replace temporary table test_multi (cola int);\n" + "insert into test_multi values (1), (2);\n" + "select cola from test_multi order by cola asc;" var count, v int runDBTest(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_txn") dbt.mustExec(`create or replace table test_multi_statement_txn( c1 number, c2 string) as select 10, 'z'`) defer dbt.mustExec("drop table if exists test_multi_statement_txn") rows := dbt.mustQueryContext(ctx, multiStmtQuery) defer rows.Close() // first statement if !rows.Next() { t.Error("failed to query") } // second statement rows.NextResultSet() if rows.Next() { if err := rows.Scan(&count); err != nil { t.Errorf("failed to scan: %#v", err) } if count != 2 { t.Fatalf("expected 2 affected rows, got %d", count) } } expected := 1 // third statement rows.NextResultSet() for rows.Next() { if err := rows.Scan(&v); err != nil { t.Errorf("failed to scan: %#v", err) } if v != expected { t.Fatalf("failed to fetch. value: %v", v) } expected++ } }) } func TestMultiStatementCountZero(t *testing.T) { ctx := WithMultiStatement(context.Background(), 0) var v1 int var v2 string var v3 float64 var v4 bool runDBTest(t, func(dbt *DBTest) { // first query multiStmtQuery1 := "select 123;\n" + "select '456';" rows1 := dbt.mustQueryContext(ctx, multiStmtQuery1) defer rows1.Close() // first statement if rows1.Next() { if err := rows1.Scan(&v1); err != nil { t.Errorf("failed to scan: %#v", err) } if v1 != 123 { t.Fatalf("failed to fetch. value: %v", v1) } } else { t.Error("failed to query") } // second statement if !rows1.NextResultSet() { t.Error("failed to retrieve next result set") } if rows1.Next() { if err := rows1.Scan(&v2); err != nil { t.Errorf("failed to scan: %#v", err) } if v2 != "456" { t.Fatalf("failed to fetch. value: %v", v2) } } else { t.Error("failed to query") } // second query multiStmtQuery2 := "select 789;\n" + "select 'foo';\n" + "select 0.123;\n" + "select true;" rows2 := dbt.mustQueryContext(ctx, multiStmtQuery2) defer rows2.Close() // first statement if rows2.Next() { if err := rows2.Scan(&v1); err != nil { t.Errorf("failed to scan: %#v", err) } if v1 != 789 { t.Fatalf("failed to fetch. value: %v", v1) } } else { t.Error("failed to query") } // second statement if !rows2.NextResultSet() { t.Error("failed to retrieve next result set") } if rows2.Next() { if err := rows2.Scan(&v2); err != nil { t.Errorf("failed to scan: %#v", err) } if v2 != "foo" { t.Fatalf("failed to fetch. value: %v", v2) } } else { t.Error("failed to query") } // third statement if !rows2.NextResultSet() { t.Error("failed to retrieve next result set") } if rows2.Next() { if err := rows2.Scan(&v3); err != nil { t.Errorf("failed to scan: %#v", err) } if v3 != 0.123 { t.Fatalf("failed to fetch. value: %v", v3) } } else { t.Error("failed to query") } // fourth statement if !rows2.NextResultSet() { t.Error("failed to retrieve next result set") } if rows2.Next() { if err := rows2.Scan(&v4); err != nil { t.Errorf("failed to scan: %#v", err) } if v4 != true { t.Fatalf("failed to fetch. value: %v", v4) } } else { t.Error("failed to query") } }) } func TestMultiStatementCountMismatch(t *testing.T) { runDBTest(t, func(dbt *DBTest) { multiStmtQuery := "select 123;\n" + "select 456;\n" + "select 789;\n" + "select '000';" ctx := WithMultiStatement(context.Background(), 3) if _, err := dbt.conn.QueryContext(ctx, multiStmtQuery); err == nil { t.Fatal("should have failed to query multiple statements") } }) } func TestMultiStatementVaryingColumnCount(t *testing.T) { multiStmtQuery := "select c1 from test_tbl;\n" + "select c1,c2 from test_tbl;" ctx := WithMultiStatement(context.Background(), 0) var v1, v2 int runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table test_tbl(c1 int, c2 int)") dbt.mustExec("insert into test_tbl values(1, 0)") defer dbt.mustExec("drop table if exists test_tbl") rows := dbt.mustQueryContext(ctx, multiStmtQuery) defer rows.Close() if rows.Next() { if err := rows.Scan(&v1); err != nil { t.Errorf("failed to scan: %#v", err) } if v1 != 1 { t.Fatalf("failed to fetch. value: %v", v1) } } else { t.Error("failed to query") } if !rows.NextResultSet() { t.Error("failed to retrieve next result set") } if rows.Next() { if err := rows.Scan(&v1, &v2); err != nil { t.Errorf("failed to scan: %#v", err) } if v1 != 1 || v2 != 0 { t.Fatalf("failed to fetch. value: %v, %v", v1, v2) } } else { t.Error("failed to query") } }) } // The total completion time should be similar to the duration of the query on Snowflake UI. func TestMultiStatementExecutePerformance(t *testing.T) { ctx := WithMultiStatement(context.Background(), 100) runDBTest(t, func(dbt *DBTest) { file, err := os.Open("test_data/multistatements.sql") if err != nil { t.Fatalf("failed opening file: %s", err) } defer file.Close() statements, err := io.ReadAll(file) if err != nil { t.Fatalf("failed reading file: %s", err) } sql := string(statements) start := time.Now() res := dbt.mustExecContext(ctx, sql) duration := time.Since(start) count, err := res.RowsAffected() if err != nil { t.Fatalf("res.RowsAffected() returned error: %v", err) } if count != 0 { t.Fatalf("expected 0 affected rows, got %d", count) } t.Logf("The total completion time was %v", duration) file, err = os.Open("test_data/multistatements_drop.sql") if err != nil { t.Fatalf("failed opening file: %s", err) } defer file.Close() statements, err = io.ReadAll(file) if err != nil { t.Fatalf("failed reading file: %s", err) } sql = string(statements) dbt.mustExecContext(ctx, sql) }) } func TestUnitGetChildResults(t *testing.T) { testcases := []struct { ids string types string out []childResult }{ {"", "", nil}, {"", "4096", nil}, {"01aa3265-0405-ab7c-0000-53b106343aba,02aa3265-0405-ab7c-0000-53b106343aba", "12544,12544", []childResult{ {"01aa3265-0405-ab7c-0000-53b106343aba", "12544"}, {"02aa3265-0405-ab7c-0000-53b106343aba", "12544"}}}, {"01aa3265-0405-ab7c-0000-53b106343aba,02aa3265-0405-ab7c-0000-53b106343aba,03aa3265-0405-ab7c-0000-53b106343aba", "25344,4096,12544", []childResult{ {"01aa3265-0405-ab7c-0000-53b106343aba", "25344"}, {"02aa3265-0405-ab7c-0000-53b106343aba", "4096"}, {"03aa3265-0405-ab7c-0000-53b106343aba", "12544"}}}, } for _, test := range testcases { t.Run(test.ids, func(t *testing.T) { res := getChildResults(test.ids, test.types) if !reflect.DeepEqual(res, test.out) { t.Fatalf("Child result should be equal, expected %v, actual %v", res, test.out) } }) } } func funcGetQueryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { return nil, errors.New("failed to get query response") } func funcGetQueryRespError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, Message: "query failed", Code: "261000", Success: false, } ba, err := json.Marshal(er) if err != nil { panic(err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } func TestUnitHandleMultiExec(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { data := execResponseData{ ResultIDs: "", ResultTypes: "", } _, err := sct.sc.handleMultiExec(context.Background(), data) if err == nil { t.Fatalf("should have failed") } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } if driverErr.Number != ErrNoResultIDs { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) } data = execResponseData{ ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", ResultTypes: "12544,12544", } sct.sc.rest = &snowflakeRestful{ FuncGet: funcGetQueryRespFail, FuncCloseSession: closeSessionMock, TokenAccessor: getSimpleTokenAccessor(), } _, err = sct.sc.handleMultiExec(context.Background(), data) if err == nil { t.Fatalf("should have failed") } sct.sc.rest.FuncGet = funcGetQueryRespError data.SQLState = "01112" _, err = sct.sc.handleMultiExec(context.Background(), data) if err == nil { t.Fatalf("should have failed") } driverErr, ok = err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } if driverErr.Number != ErrFailedToPostQuery { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) } }) } func TestUnitHandleMultiQuery(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { data := execResponseData{ ResultIDs: "", ResultTypes: "", } rows := new(snowflakeRows) err := sct.sc.handleMultiQuery(context.Background(), data, rows) if err == nil { t.Fatalf("should have failed") } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } if driverErr.Number != ErrNoResultIDs { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) } data = execResponseData{ ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", ResultTypes: "12544,12544", } sct.sc.rest = &snowflakeRestful{ FuncGet: funcGetQueryRespFail, FuncCloseSession: closeSessionMock, TokenAccessor: getSimpleTokenAccessor(), } err = sct.sc.handleMultiQuery(context.Background(), data, rows) if err == nil { t.Fatalf("should have failed") } sct.sc.rest.FuncGet = funcGetQueryRespError data.SQLState = "01112" err = sct.sc.handleMultiQuery(context.Background(), data, rows) if err == nil { t.Fatalf("should have failed") } driverErr, ok = err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } if driverErr.Number != ErrFailedToPostQuery { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) } }) } func TestMultiStatementArrowFormat(t *testing.T) { ctx := WithMultiStatement(context.Background(), 4) multiStmtQuery := "select 123;\n" + "select 456;\n" + "select 789;\n" + "select '000';" runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET ENABLE_FIX_1758055_ADD_ARROW_SUPPORT_FOR_MULTI_STMTS = TRUE") testCases := []struct { name string formatType string forceQuery string }{ {name: "forceJSON", formatType: "json", forceQuery: forceJSON}, {name: "forceArrow", formatType: "arrow", forceQuery: forceARROW}, } rowTypes := []string{"123", "456", "789", "'000'"} for _, testCase := range testCases { t.Run("with "+testCase.name, func(t *testing.T) { dbt.mustExec(testCase.forceQuery) buffer, cleanup := setupTestLogger() defer cleanup() rows := dbt.mustQueryContext(ia.EnableArrowBatches(ctx), multiStmtQuery) defer rows.Close() logOutput := buffer.String() for _, rowType := range rowTypes { assertStringContainsE(t, logOutput, "[Server Response Validation]: RowType: "+rowType+", QueryResultFormat: "+testCase.formatType) } }) } }) } ================================================ FILE: ocsp.go ================================================ package gosnowflake import ( "bufio" "context" "crypto" "crypto/fips140" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/base64" "encoding/json" "errors" "fmt" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors" "io" "math/big" "net/http" "net/url" "os" "path/filepath" "runtime" "strconv" "strings" "sync" "time" "golang.org/x/crypto/ocsp" ) var ( ocspModuleInitialized = false ocspModuleMu sync.Mutex ocspCacheClearer = &ocspCacheClearerType{} ocspCacheServerEnabled = true ) var ( // cacheDir is the location of OCSP response cache file cacheDir = "" // cacheFileName is the file name of OCSP response cache file cacheFileName = "" // cacheUpdated is true if the memory cache is updated cacheUpdated = true ) // OCSPFailOpenMode is OCSP fail open mode. OCSPFailOpenTrue by default and may // set to ocspModeFailClosed for fail closed mode // Deprecated: will be moved to Config/DSN in the future releases. type OCSPFailOpenMode = sfconfig.OCSPFailOpenMode const ( // OCSPFailOpenTrue represents OCSP fail open mode. OCSPFailOpenTrue = sfconfig.OCSPFailOpenTrue // OCSPFailOpenFalse represents OCSP fail closed mode. OCSPFailOpenFalse = sfconfig.OCSPFailOpenFalse ) const ( // defaultOCSPCacheServerTimeout is the total timeout for OCSP cache server. defaultOCSPCacheServerTimeout = 5 * time.Second // defaultOCSPResponderTimeout is the total timeout for OCSP responder. defaultOCSPResponderTimeout = 10 * time.Second // defaultOCSPMaxRetryCount specifies maximum numbere of subsequent retries to OCSP (cache and server) defaultOCSPMaxRetryCount = 2 // defaultOCSPResponseCacheClearingInterval is the default value for clearing OCSP response cache defaultOCSPResponseCacheClearingInterval = 15 * time.Minute ) var ( // OcspCacheServerTimeout is a timeout for OCSP cache server. // Deprecated: will be moved to Config/DSN in the future releases. OcspCacheServerTimeout = defaultOCSPCacheServerTimeout // OcspResponderTimeout is a timeout for OCSP responders. // Deprecated: will be moved to Config/DSN in the future releases. OcspResponderTimeout = defaultOCSPResponderTimeout // OcspMaxRetryCount is a number of retires to OCSP (cache server and responders). // Deprecated: will be moved to Config/DSN in the future releases. OcspMaxRetryCount = defaultOCSPMaxRetryCount ) const ( cacheFileBaseName = "ocsp_response_cache.json" // cacheExpire specifies cache data expiration time in seconds. cacheExpire = float64(24 * 60 * 60) defaultCacheServerHost = "http://ocsp.snowflakecomputing.com" cacheServerEnabledEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED" cacheServerURLEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_URL" cacheDirEnv = "SF_OCSP_RESPONSE_CACHE_DIR" ocspResponseCacheClearingIntervalInSecondsEnv = "SF_OCSP_RESPONSE_CACHE_CLEARING_INTERVAL_IN_SECONDS" ) const ( ocspTestResponderURLEnv = "SF_OCSP_TEST_RESPONDER_URL" ocspTestNoOCSPURLEnv = "SF_OCSP_TEST_NO_OCSP_RESPONDER_URL" ) const ( tolerableValidityRatio = 100 // buffer for certificate revocation update time maxClockSkew = 900 * time.Second // buffer for clock skew ) type ocspStatusCode int type ocspStatus struct { code ocspStatusCode err error } const ( ocspSuccess ocspStatusCode = 0 ocspStatusGood ocspStatusCode = -1 ocspStatusRevoked ocspStatusCode = -2 ocspStatusUnknown ocspStatusCode = -3 ocspStatusOthers ocspStatusCode = -4 ocspNoServer ocspStatusCode = -5 ocspFailedParseOCSPHost ocspStatusCode = -6 ocspFailedComposeRequest ocspStatusCode = -7 ocspFailedDecomposeRequest ocspStatusCode = -8 ocspFailedSubmit ocspStatusCode = -9 ocspFailedResponse ocspStatusCode = -10 ocspFailedExtractResponse ocspStatusCode = -11 ocspFailedParseResponse ocspStatusCode = -12 ocspInvalidValidity ocspStatusCode = -13 ocspMissedCache ocspStatusCode = -14 ocspCacheExpired ocspStatusCode = -15 ocspFailedDecodeResponse ocspStatusCode = -16 ) // copied from crypto/ocsp.go type certID struct { HashAlgorithm pkix.AlgorithmIdentifier NameHash []byte IssuerKeyHash []byte SerialNumber *big.Int } // cache key type certIDKey struct { HashAlgorithm crypto.Hash NameHash string IssuerKeyHash string SerialNumber string } type certCacheValue struct { ts float64 ocspRespBase64 string } type parsedOcspRespKey struct { ocspRespBase64 string certIDBase64 string } var ( ocspResponseCache map[certIDKey]*certCacheValue ocspParsedRespCache map[parsedOcspRespKey]*ocspStatus ocspResponseCacheLock = &sync.RWMutex{} ocspParsedRespCacheLock = &sync.Mutex{} ) type ocspValidator struct { mode OCSPFailOpenMode cacheServerURL string isPrivateLink bool retryURL string cfg *Config } func newOcspValidator(cfg *Config) *ocspValidator { isPrivateLink := checkIsPrivateLink(cfg.Host) var cacheServerURL, retryURL string var ok bool logger.Debug("initializing OCSP module") if cacheServerURL, ok = os.LookupEnv(cacheServerURLEnv); ok { logger.Debugf("OCSP Cache Server already set by user for %v: %v", cfg.Host, cacheServerURL) } else if isPrivateLink { cacheServerURL = fmt.Sprintf("http://ocsp.%v/%v", cfg.Host, cacheFileBaseName) logger.Debugf("Using PrivateLink host (%v), setting up OCSP cache server to %v", cfg.Host, cacheServerURL) retryURL = fmt.Sprintf("http://ocsp.%v/retry/", cfg.Host) + "%v/%v" logger.Debugf("Using PrivateLink retry proxy %v", retryURL) } else if !strings.HasSuffix(cfg.Host, sfconfig.DefaultDomain) { cacheServerURL = fmt.Sprintf("http://ocsp.%v/%v", cfg.Host, cacheFileBaseName) logger.Debugf("Using not global host (%v), setting up OCSP cache server to %v", cfg.Host, cacheServerURL) } else { cacheServerURL = fmt.Sprintf("%v/%v", defaultCacheServerHost, cacheFileBaseName) logger.Debugf("OCSP Cache Server not set by user for %v, setting it up to %v", cfg.Host, cacheServerURL) } return &ocspValidator{ mode: cfg.OCSPFailOpen, cacheServerURL: strings.ToLower(cacheServerURL), isPrivateLink: isPrivateLink, retryURL: strings.ToLower(retryURL), cfg: cfg, } } // copied from crypto/ocsp var hashOIDs = map[crypto.Hash]asn1.ObjectIdentifier{ crypto.SHA1: asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26}), crypto.SHA256: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 1}), crypto.SHA384: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 2}), crypto.SHA512: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 3}), } // copied from crypto/ocsp func getOIDFromHashAlgorithm(target crypto.Hash) asn1.ObjectIdentifier { for hash, oid := range hashOIDs { if hash == target { return oid } } logger.Errorf("no valid OID is found for the hash algorithm. %#v", target) return nil } func getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto.Hash { for hash, oid := range hashOIDs { if oid.Equal(target.Algorithm) { return hash } } logger.Errorf("no valid hash algorithm is found for the oid. Falling back to SHA1: %#v", target) return crypto.SHA1 } // calcTolerableValidity returns the maximum validity buffer func calcTolerableValidity(thisUpdate, nextUpdate time.Time) time.Duration { return durationMax(time.Duration(nextUpdate.Sub(thisUpdate)/tolerableValidityRatio), maxClockSkew) } // isInValidityRange checks the validity func isInValidityRange(currTime, thisUpdate, nextUpdate time.Time) bool { if currTime.Sub(thisUpdate.Add(-maxClockSkew)) < 0 { return false } if nextUpdate.Add(calcTolerableValidity(thisUpdate, nextUpdate)).Sub(currTime) < 0 { return false } return true } func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { r, err := ocsp.ParseRequest(ocspReq) if err != nil { return nil, &ocspStatus{ code: ocspFailedDecomposeRequest, err: err, } } // encode CertID, used as a key in the cache encodedCertID := &certIDKey{ r.HashAlgorithm, base64.StdEncoding.EncodeToString(r.IssuerNameHash), base64.StdEncoding.EncodeToString(r.IssuerKeyHash), r.SerialNumber.String(), } return encodedCertID, &ocspStatus{ code: ocspSuccess, } } func decodeCertIDKey(certIDKeyBase64 string) *certIDKey { r, err := base64.StdEncoding.DecodeString(certIDKeyBase64) if err != nil { return nil } var c certID rest, err := asn1.Unmarshal(r, &c) if err != nil { // error in parsing return nil } if len(rest) > 0 { // extra bytes to the end return nil } return &certIDKey{ getHashAlgorithmFromOID(c.HashAlgorithm), base64.StdEncoding.EncodeToString(c.NameHash), base64.StdEncoding.EncodeToString(c.IssuerKeyHash), c.SerialNumber.String(), } } func encodeCertIDKey(k *certIDKey) string { serialNumber := new(big.Int) serialNumber.SetString(k.SerialNumber, 10) nameHash, err := base64.StdEncoding.DecodeString(k.NameHash) if err != nil { return "" } issuerKeyHash, err := base64.StdEncoding.DecodeString(k.IssuerKeyHash) if err != nil { return "" } encodedCertID, err := asn1.Marshal(certID{ pkix.AlgorithmIdentifier{ Algorithm: getOIDFromHashAlgorithm(k.HashAlgorithm), Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, }, nameHash, issuerKeyHash, serialNumber, }) if err != nil { return "" } return base64.StdEncoding.EncodeToString(encodedCertID) } func (ov *ocspValidator) checkOCSPResponseCache(certIDKey *certIDKey, subject, issuer *x509.Certificate) *ocspStatus { if !ocspCacheServerEnabled { return &ocspStatus{code: ocspNoServer} } gotValueFromCache, ok := func() (*certCacheValue, bool) { ocspResponseCacheLock.RLock() defer ocspResponseCacheLock.RUnlock() valueFromCache, ok := ocspResponseCache[*certIDKey] return valueFromCache, ok }() if !ok { return &ocspStatus{ code: ocspMissedCache, err: fmt.Errorf("miss cache data. subject: %v", subject), } } status := extractOCSPCacheResponseValue(certIDKey, gotValueFromCache, subject, issuer) if !isValidOCSPStatus(status.code) { deleteOCSPCache(certIDKey) } return status } func deleteOCSPCache(encodedCertID *certIDKey) { ocspResponseCacheLock.Lock() defer ocspResponseCacheLock.Unlock() delete(ocspResponseCache, *encodedCertID) cacheUpdated = true } func validateOCSP(ocspRes *ocsp.Response) *ocspStatus { curTime := time.Now() if ocspRes == nil { return &ocspStatus{ code: ocspFailedDecomposeRequest, err: errors.New("OCSP Response is nil"), } } if !isInValidityRange(curTime, ocspRes.ThisUpdate, ocspRes.NextUpdate) { return &ocspStatus{ code: ocspInvalidValidity, err: &SnowflakeError{ Number: ErrOCSPInvalidValidity, Message: sferrors.ErrMsgOCSPInvalidValidity, MessageArgs: []any{ocspRes.ProducedAt, ocspRes.ThisUpdate, ocspRes.NextUpdate}, }, } } return returnOCSPStatus(ocspRes) } func returnOCSPStatus(ocspRes *ocsp.Response) *ocspStatus { switch ocspRes.Status { case ocsp.Good: return &ocspStatus{ code: ocspStatusGood, err: nil, } case ocsp.Revoked: return &ocspStatus{ code: ocspStatusRevoked, err: &SnowflakeError{ Number: ErrOCSPStatusRevoked, Message: sferrors.ErrMsgOCSPStatusRevoked, MessageArgs: []any{ocspRes.RevocationReason, ocspRes.RevokedAt}, }, } case ocsp.Unknown: return &ocspStatus{ code: ocspStatusUnknown, err: &SnowflakeError{ Number: ErrOCSPStatusUnknown, Message: sferrors.ErrMsgOCSPStatusUnknown, }, } default: return &ocspStatus{ code: ocspStatusOthers, err: fmt.Errorf("OCSP others. %v", ocspRes.Status), } } } func checkOCSPCacheServer( ctx context.Context, client clientInterface, req requestFunc, ocspServerHost *url.URL, totalTimeout time.Duration) ( cacheContent *map[string]*certCacheValue, ocspS *ocspStatus) { var respd map[string][]any headers := make(map[string]string) res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, OcspMaxRetryCount, defaultTimeProvider, nil).execute() if err != nil { logger.WithContext(ctx).Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err) return nil, &ocspStatus{ code: ocspFailedSubmit, err: err, } } defer func() { if err = res.Body.Close(); err != nil { logger.Warnf("failed to close response body: %v", err) } }() logger.WithContext(ctx).Debugf("StatusCode from OCSP Cache Server: %v", res.StatusCode) if res.StatusCode != http.StatusOK { return nil, &ocspStatus{ code: ocspFailedResponse, err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), } } logger.WithContext(ctx).Debugf("reading contents") dec := json.NewDecoder(res.Body) for { if err := dec.Decode(&respd); err == io.EOF { break } else if err != nil { logger.WithContext(ctx).Errorf("failed to decode OCSP cache. %v", err) return nil, &ocspStatus{ code: ocspFailedExtractResponse, err: err, } } } buf := make(map[string]*certCacheValue) for key, value := range respd { ok, ts, ocspRespBase64 := extractTsAndOcspRespBase64(value) if !ok { continue } buf[key] = &certCacheValue{ts, ocspRespBase64} } return &buf, &ocspStatus{ code: ocspSuccess, } } // retryOCSP is the second level of retry method if the returned contents are corrupted. It often happens with OCSP // serer and retry helps. func (ov *ocspValidator) retryOCSP( ctx context.Context, client clientInterface, req requestFunc, ocspHost *url.URL, headers map[string]string, reqBody []byte, issuer *x509.Certificate, totalTimeout time.Duration) ( ocspRes *ocsp.Response, ocspResBytes []byte, ocspS *ocspStatus) { multiplier := 1 if ov.mode == OCSPFailOpenFalse { multiplier = 3 } res, err := newRetryHTTP( ctx, client, req, ocspHost, headers, totalTimeout*time.Duration(multiplier), OcspMaxRetryCount, defaultTimeProvider, nil).doPost().setBody(reqBody).execute() if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedSubmit, err: err, } } defer func() { if err = res.Body.Close(); err != nil { logger.WithContext(ctx).Warnf("failed to close response body: %v", err) } }() logger.WithContext(ctx).Debugf("StatusCode from OCSP Server: %v\n", res.StatusCode) if res.StatusCode != http.StatusOK { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedResponse, err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), } } ocspResBytes, err = io.ReadAll(res.Body) if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedExtractResponse, err: err, } } ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) if err != nil { _, ok1 := err.(asn1.StructuralError) _, ok2 := err.(asn1.SyntaxError) if ok1 || ok2 { logger.WithContext(ctx).Warnf("error when parsing ocsp response: %v", err) logger.WithContext(ctx).Warnf("performing GET fallback request to OCSP") return ov.fallbackRetryOCSPToGETRequest(ctx, client, req, ocspHost, headers, issuer, totalTimeout) } logger.Warnf("Unknown response status from OCSP responder: %v", err) return nil, nil, &ocspStatus{ code: ocspStatusUnknown, err: err, } } logger.WithContext(ctx).Debugf("OCSP Status from server: %v", printStatus(ocspRes)) return ocspRes, ocspResBytes, &ocspStatus{ code: ocspSuccess, } } // fallbackRetryOCSPToGETRequest is the third level of retry method. Some OCSP responders do not support POST requests // and will return with a "malformed" request error. In that case we also try to perform a GET request func (ov *ocspValidator) fallbackRetryOCSPToGETRequest( ctx context.Context, client clientInterface, req requestFunc, ocspHost *url.URL, headers map[string]string, issuer *x509.Certificate, totalTimeout time.Duration) ( ocspRes *ocsp.Response, ocspResBytes []byte, ocspS *ocspStatus) { multiplier := 1 if ov.mode == OCSPFailOpenFalse { multiplier = 3 } res, err := newRetryHTTP(ctx, client, req, ocspHost, headers, totalTimeout*time.Duration(multiplier), OcspMaxRetryCount, defaultTimeProvider, nil).execute() if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedSubmit, err: err, } } defer func() { if err = res.Body.Close(); err != nil { logger.Warnf("failed to close response body: %v", err) } }() logger.WithContext(ctx).Debugf("GET fallback StatusCode from OCSP Server: %v", res.StatusCode) if res.StatusCode != http.StatusOK { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedResponse, err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), } } ocspResBytes, err = io.ReadAll(res.Body) if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedExtractResponse, err: err, } } ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedParseResponse, err: err, } } logger.WithContext(ctx).Debugf("GET fallback OCSP Status from server: %v", printStatus(ocspRes)) return ocspRes, ocspResBytes, &ocspStatus{ code: ocspSuccess, } } func printStatus(response *ocsp.Response) string { switch response.Status { case ocsp.Good: return "Good" case ocsp.Revoked: return "Revoked" case ocsp.Unknown: return "Unknown" default: return fmt.Sprintf("%d", response.Status) } } func fullOCSPURL(url *url.URL) string { fullURL := url.Hostname() if url.Path != "" { if !strings.HasPrefix(url.Path, "/") { fullURL += "/" } fullURL += url.Path } return fullURL } // getRevocationStatus checks the certificate revocation status for subject using issuer certificate. func (ov *ocspValidator) getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) *ocspStatus { logger.WithContext(ctx).Tracef("Subject: %v, Issuer: %v", subject.Subject, issuer.Subject) status, ocspReq, encodedCertID := ov.validateWithCache(subject, issuer) if isValidOCSPStatus(status.code) { return status } if ocspReq == nil || encodedCertID == nil { return status } logger.WithContext(ctx).Infof("cache missed") logger.WithContext(ctx).Infof("OCSP Server: %v", subject.OCSPServer) testResponderURL := os.Getenv(ocspTestResponderURLEnv) if (len(subject.OCSPServer) == 0 || isTestNoOCSPURL()) && testResponderURL == "" { return &ocspStatus{ code: ocspNoServer, err: &SnowflakeError{ Number: ErrOCSPNoOCSPResponderURL, Message: sferrors.ErrMsgOCSPNoOCSPResponderURL, MessageArgs: []any{subject.Subject}, }, } } ocspHost := testResponderURL if ocspHost == "" && len(subject.OCSPServer) > 0 { ocspHost = subject.OCSPServer[0] } u, err := url.Parse(ocspHost) if err != nil { return &ocspStatus{ code: ocspFailedParseOCSPHost, err: fmt.Errorf("failed to parse OCSP server host. %v", ocspHost), } } var hostname string if retryURL := ov.retryURL; retryURL != "" { hostname = fmt.Sprintf(retryURL, fullOCSPURL(u), base64.StdEncoding.EncodeToString(ocspReq)) u0, err := url.Parse(hostname) if err == nil { hostname = u0.Hostname() u = u0 } } else { hostname = fullOCSPURL(u) } logger.WithContext(ctx).Debugf("Fetching OCSP response from server: %v", u) logger.WithContext(ctx).Debugf("Host in headers: %v", hostname) headers := make(map[string]string) headers[httpHeaderContentType] = "application/ocsp-request" headers[httpHeaderAccept] = "application/ocsp-response" headers[httpHeaderContentLength] = strconv.Itoa(len(ocspReq)) headers[httpHeaderHost] = hostname timeout := OcspResponderTimeout ocspClient := &http.Client{ Timeout: timeout, Transport: newTransportFactory(ov.cfg, nil).createNoRevocationTransport(defaultTransportConfigs.forTransportType(transportTypeOCSP)), } ocspRes, ocspResBytes, ocspS := ov.retryOCSP( ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer, timeout) if ocspS.code != ocspSuccess { return ocspS } ret := validateOCSP(ocspRes) if !isValidOCSPStatus(ret.code) { return ret // return invalid } v := &certCacheValue{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)} ocspResponseCacheLock.Lock() ocspResponseCache[*encodedCertID] = v cacheUpdated = true ocspResponseCacheLock.Unlock() return ret } func isTestNoOCSPURL() bool { return strings.EqualFold(os.Getenv(ocspTestNoOCSPURLEnv), "true") } func isValidOCSPStatus(status ocspStatusCode) bool { return status == ocspStatusGood || status == ocspStatusRevoked || status == ocspStatusUnknown } // verifyPeerCertificate verifies all of certificate revocation status func (ov *ocspValidator) verifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate) (err error) { for _, chain := range verifiedChains { results := ov.getAllRevocationStatus(ctx, chain) if r := ov.canEarlyExitForOCSP(results, chain); r != nil { return r.err } } ocspResponseCacheLock.Lock() if cacheUpdated { ov.writeOCSPCacheFile() } cacheUpdated = false ocspResponseCacheLock.Unlock() return nil } func (ov *ocspValidator) canEarlyExitForOCSP(results []*ocspStatus, verifiedChain []*x509.Certificate) *ocspStatus { var msg strings.Builder if ov.mode == OCSPFailOpenFalse { // Fail closed. any error is returned to stop connection for _, r := range results { if r.err != nil { return r } } } else { // Fail open and all results are valid. allValid := len(results) == len(verifiedChain)-1 // root certificate is not checked for _, r := range results { if !isValidOCSPStatus(r.code) { allValid = false break } } for _, r := range results { if allValid && r.code == ocspStatusRevoked { return r } if r != nil && r.code != ocspStatusGood && r.err != nil { msg.WriteString("\n" + r.err.Error()) } } } if len(msg.String()) > 0 { logger.Debugf("OCSP responder didn't respond correctly. Assuming certificate is not revoked. Detail: %v", msg.String()[1:]) } return nil } func (ov *ocspValidator) validateWithCacheForAllCertificates(verifiedChains []*x509.Certificate) bool { n := len(verifiedChains) - 1 for j := range n { subject := verifiedChains[j] issuer := verifiedChains[j+1] status, _, _ := ov.validateWithCache(subject, issuer) if !isValidOCSPStatus(status.code) { return false } } return true } func (ov *ocspValidator) validateWithCache(subject, issuer *x509.Certificate) (*ocspStatus, []byte, *certIDKey) { reqOpts := &ocsp.RequestOptions{} if fips140.Enabled() { logger.Debug("FIPS 140 mode is enabled. Using SHA256 for OCSP request.") reqOpts.Hash = crypto.SHA256 } ocspReq, err := ocsp.CreateRequest(subject, issuer, reqOpts) if err != nil { logger.Errorf("failed to create OCSP request from the certificates.\n") return &ocspStatus{ code: ocspFailedComposeRequest, err: errors.New("failed to create a OCSP request"), }, nil, nil } encodedCertID, ocspS := extractCertIDKeyFromRequest(ocspReq) if ocspS.code != ocspSuccess { logger.Errorf("failed to extract CertID from OCSP Request.\n") return &ocspStatus{ code: ocspFailedComposeRequest, err: errors.New("failed to extract cert ID Key"), }, ocspReq, nil } status := ov.checkOCSPResponseCache(encodedCertID, subject, issuer) return status, ocspReq, encodedCertID } func (ov *ocspValidator) downloadOCSPCacheServer() { // TODO if !ocspCacheServerEnabled { logger.Debugf("OCSP Cache Server is disabled by user. Skipping download.") return } ocspCacheServerURL := ov.cacheServerURL u, err := url.Parse(ocspCacheServerURL) if err != nil { return } logger.Infof("downloading OCSP Cache from server %v", ocspCacheServerURL) timeout := OcspCacheServerTimeout ocspClient := &http.Client{ Timeout: timeout, Transport: newTransportFactory(ov.cfg, nil).createNoRevocationTransport(defaultTransportConfigs.forTransportType(transportTypeOCSP)), } ret, ocspStatus := checkOCSPCacheServer(context.Background(), ocspClient, http.NewRequest, u, timeout) if ocspStatus.code != ocspSuccess { return } ocspResponseCacheLock.Lock() for k, cacheValue := range *ret { cacheKey := decodeCertIDKey(k) status := extractOCSPCacheResponseValueWithoutSubject(cacheKey, cacheValue) if !isValidOCSPStatus(status.code) { continue } ocspResponseCache[*cacheKey] = cacheValue } cacheUpdated = true ocspResponseCacheLock.Unlock() } func (ov *ocspValidator) getAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certificate) []*ocspStatus { cached := ov.validateWithCacheForAllCertificates(verifiedChains) if !cached { ov.downloadOCSPCacheServer() } n := len(verifiedChains) - 1 results := make([]*ocspStatus, n) for j := range n { results[j] = ov.getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1]) if !isValidOCSPStatus(results[j].code) { return results } } return results } // verifyPeerCertificateSerial verifies the certificate revocation status in serial. func (ov *ocspValidator) verifyPeerCertificateSerial(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { func() { ocspModuleMu.Lock() defer ocspModuleMu.Unlock() if !ocspModuleInitialized { initOcspModule() } }() overrideCacheDir() return ov.verifyPeerCertificate(context.Background(), verifiedChains) } func overrideCacheDir() { if os.Getenv(cacheDirEnv) != "" { ocspResponseCacheLock.Lock() defer ocspResponseCacheLock.Unlock() createOCSPCacheDir() } } // initOCSPCache initializes OCSP Response cache file. func initOCSPCache() { if !ocspCacheServerEnabled { return } func() { ocspResponseCacheLock.Lock() defer ocspResponseCacheLock.Unlock() ocspResponseCache = make(map[certIDKey]*certCacheValue) }() func() { ocspParsedRespCacheLock.Lock() defer ocspParsedRespCacheLock.Unlock() ocspParsedRespCache = make(map[parsedOcspRespKey]*ocspStatus) }() logger.Infof("reading OCSP Response cache file. %v\n", cacheFileName) f, err := os.OpenFile(cacheFileName, os.O_CREATE|os.O_RDONLY, readWriteFileMode) if err != nil { logger.Debugf("failed to open. Ignored. %v\n", err) return } defer func() { if err = f.Close(); err != nil { logger.Warnf("failed to close file: %v. ignored.\n", err) } }() buf := make(map[string][]any) r := bufio.NewReader(f) dec := json.NewDecoder(r) for { if err = dec.Decode(&buf); err == io.EOF { break } else if err != nil { logger.Debugf("failed to read. Ignored. %v\n", err) return } } for k, cacheValue := range buf { ok, ts, ocspRespBase64 := extractTsAndOcspRespBase64(cacheValue) if !ok { continue } certValue := &certCacheValue{ts, ocspRespBase64} cacheKey := decodeCertIDKey(k) status := extractOCSPCacheResponseValueWithoutSubject(cacheKey, certValue) if !isValidOCSPStatus(status.code) { continue } ocspResponseCache[*cacheKey] = certValue } cacheUpdated = false } func extractTsAndOcspRespBase64(value []any) (bool, float64, string) { ts, ok := value[0].(float64) if !ok { logger.Warnf("cannot cast %v as float64", value[0]) return false, -1, "" } ocspRespBase64, ok := value[1].(string) if !ok { logger.Warnf("cannot cast %v as string", value[1]) return false, -1, "" } return true, ts, ocspRespBase64 } func extractOCSPCacheResponseValueWithoutSubject(cacheKey *certIDKey, cacheValue *certCacheValue) *ocspStatus { return extractOCSPCacheResponseValue(cacheKey, cacheValue, nil, nil) } func extractOCSPCacheResponseValue(certIDKey *certIDKey, certCacheValue *certCacheValue, subject, issuer *x509.Certificate) *ocspStatus { subjectName := "Unknown" if subject != nil { subjectName = subject.Subject.CommonName } curTime := time.Now() currentTime := float64(curTime.UTC().Unix()) if currentTime-certCacheValue.ts >= cacheExpire { return &ocspStatus{ code: ocspCacheExpired, err: fmt.Errorf("cache expired. current: %v, cache: %v", time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(certCacheValue.ts), 0).UTC()), } } ocspParsedRespCacheLock.Lock() defer ocspParsedRespCacheLock.Unlock() var cacheKey parsedOcspRespKey if certIDKey != nil { cacheKey = parsedOcspRespKey{certCacheValue.ocspRespBase64, encodeCertIDKey(certIDKey)} } else { cacheKey = parsedOcspRespKey{certCacheValue.ocspRespBase64, ""} } status, ok := ocspParsedRespCache[cacheKey] if !ok { logger.Tracef("OCSP status not found in cache; certIdKey: %v", certIDKey) var err error var b []byte b, err = base64.StdEncoding.DecodeString(certCacheValue.ocspRespBase64) if err != nil { return &ocspStatus{ code: ocspFailedDecodeResponse, err: fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err), } } // check the revocation status here ocspResponse, err := ocsp.ParseResponse(b, issuer) if err != nil { logger.Warnf("the second cache element is not a valid OCSP Response. Ignored. subject: %v\n", subjectName) return &ocspStatus{ code: ocspFailedParseResponse, err: fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err), } } status = validateOCSP(ocspResponse) ocspParsedRespCache[cacheKey] = status } logger.Tracef("OCSP status found in cache: %v; certIdKey: %v", status, certIDKey) return status } // writeOCSPCacheFile writes a OCSP Response cache file. This is called if all revocation status is success. // lock file is used to mitigate race condition with other process. func (ov *ocspValidator) writeOCSPCacheFile() { if !ocspCacheServerEnabled { return } logger.Infof("writing OCSP Response cache file. %v\n", cacheFileName) cacheLockFileName := cacheFileName + ".lck" err := os.Mkdir(cacheLockFileName, 0600) switch { case os.IsExist(err): statinfo, err := os.Stat(cacheLockFileName) if err != nil { logger.Debugf("failed to get file info for cache lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) return } if time.Since(statinfo.ModTime()) < 15*time.Minute { logger.Debugf("other process locks the cache file. %v. ignored.\n", cacheLockFileName) return } if err = os.Remove(cacheLockFileName); err != nil { logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) return } if err = os.Mkdir(cacheLockFileName, 0600); err != nil { logger.Debugf("failed to create lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) return } } // if mkdir fails for any other reason: permission denied, operation not permitted, I/O error, too many open files, etc. if err != nil { logger.Debugf("failed to create lock file. file %v, err: %v. ignored.\n", cacheLockFileName, err) return } defer func() { if err = os.RemoveAll(cacheLockFileName); err != nil { logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) } }() buf := make(map[string][]any) for k, v := range ocspResponseCache { cacheKeyInBase64 := encodeCertIDKey(&k) buf[cacheKeyInBase64] = []any{v.ts, v.ocspRespBase64} } j, err := json.Marshal(buf) if err != nil { logger.Debugf("failed to convert OCSP Response cache to JSON. ignored.") return } if err = os.WriteFile(cacheFileName, j, 0644); err != nil { logger.Debugf("failed to write OCSP Response cache. err: %v. ignored.\n", err) } } // createOCSPCacheDir creates OCSP response cache directory and set the cache file name. func createOCSPCacheDir() { if !ocspCacheServerEnabled { logger.Info(`OCSP Cache Server disabled. All further access and use of OCSP Cache will be disabled for this OCSP Status Query`) return } cacheDir = os.Getenv(cacheDirEnv) if cacheDir == "" { cacheDir = os.Getenv("SNOWFLAKE_TEST_WORKSPACE") } if cacheDir == "" { switch runtime.GOOS { case "windows": cacheDir = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local", "Snowflake", "Caches") case "darwin": home := os.Getenv("HOME") if home == "" { logger.Info("HOME is blank.") } cacheDir = filepath.Join(home, "Library", "Caches", "Snowflake") default: home := os.Getenv("HOME") if home == "" { logger.Info("HOME is blank") } cacheDir = filepath.Join(home, ".cache", "snowflake") } } if _, err := os.Stat(cacheDir); os.IsNotExist(err) { if err = os.MkdirAll(cacheDir, os.ModePerm); err != nil { logger.Debugf("failed to create cache directory. %v, err: %v. ignored\n", cacheDir, err) } } cacheFileName = filepath.Join(cacheDir, cacheFileBaseName) logger.Infof("reset OCSP cache file. %v", cacheFileName) } // StartOCSPCacheClearer starts the job that clears OCSP caches func StartOCSPCacheClearer() { ocspCacheClearer.start() } // StopOCSPCacheClearer stops the job that clears OCSP caches. func StopOCSPCacheClearer() { ocspCacheClearer.stop() } func clearOCSPCaches() { logger.Debugf("clearing OCSP caches") func() { ocspResponseCacheLock.Lock() defer ocspResponseCacheLock.Unlock() ocspResponseCache = make(map[certIDKey]*certCacheValue) }() func() { ocspParsedRespCacheLock.Lock() defer ocspParsedRespCacheLock.Unlock() ocspParsedRespCache = make(map[parsedOcspRespKey]*ocspStatus) }() } func initOcspModule() { createOCSPCacheDir() initOCSPCache() if cacheServerEnabledStr, ok := os.LookupEnv(cacheServerEnabledEnv); ok { logger.Debugf("OCSP Cache Server enabled by user: %v", cacheServerEnabledStr) ocspCacheServerEnabled = strings.EqualFold(cacheServerEnabledStr, "true") } ocspModuleInitialized = true } type ocspCacheClearerType struct { mu sync.Mutex running bool cancel context.CancelFunc } func (occ *ocspCacheClearerType) start() { occ.mu.Lock() defer occ.mu.Unlock() if occ.running { return } ctx, cancel := context.WithCancel(context.Background()) occ.cancel = cancel interval := defaultOCSPResponseCacheClearingInterval if intervalFromEnv := os.Getenv(ocspResponseCacheClearingIntervalInSecondsEnv); intervalFromEnv != "" { intervalAsSeconds, err := strconv.Atoi(intervalFromEnv) if err != nil { logger.Warnf("unparsable %v value: %v", ocspResponseCacheClearingIntervalInSecondsEnv, intervalFromEnv) } else { interval = time.Duration(intervalAsSeconds) * time.Second } } logger.Debugf("initializing OCSP cache clearer to %v", interval) go GoroutineWrapper(context.Background(), func() { ticker := time.NewTicker(interval) for { select { case <-ticker.C: clearOCSPCaches() case <-ctx.Done(): occ.mu.Lock() defer occ.mu.Unlock() logger.Debug("stopped clearing OCSP cache") ticker.Stop() occ.running = false return } } }) occ.running = true } func (occ *ocspCacheClearerType) stop() { occ.mu.Lock() defer occ.mu.Unlock() if occ.running { occ.cancel() } } ================================================ FILE: ocsp_test.go ================================================ package gosnowflake import ( "bytes" "context" "crypto" "crypto/tls" "crypto/x509" "encoding/base64" "errors" "fmt" "io" "net" "net/http" "net/url" "os" "testing" "time" "golang.org/x/crypto/ocsp" ) func TestOCSP(t *testing.T) { cacheServerEnabled := []string{ "true", "false", } targetURL := []string{ "https://sfctest0.snowflakecomputing.com/", "https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64", "https://sfcdev2.blob.core.windows.net/", } ocspTransport, err := newTransportFactory(&Config{}, nil).createOCSPTransport(defaultTransportConfigs.forTransportType(transportTypeSnowflake)) assertNilF(t, err) transports := []http.RoundTripper{ createTestNoRevocationTransport(), ocspTransport, } for _, enabled := range cacheServerEnabled { for _, tgt := range targetURL { _ = os.Setenv(cacheServerEnabledEnv, enabled) _ = os.Remove(cacheFileName) // clear cache file syncUpdateOcspResponseCache(func() { ocspResponseCache = make(map[certIDKey]*certCacheValue) }) for _, tr := range transports { t.Run(fmt.Sprintf("%v_%v", tgt, enabled), func(t *testing.T) { c := &http.Client{ Transport: tr, Timeout: 30 * time.Second, } req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) if err != nil { t.Fatalf("fail to create a request. err: %v", err) } res, err := c.Do(req) if err != nil { t.Fatalf("failed to GET contents. err: %v", err) } defer res.Body.Close() _, err = io.ReadAll(res.Body) if err != nil { t.Fatalf("failed to read content body for %v", tgt) } }) } } } _ = os.Unsetenv(cacheServerEnabledEnv) } type tcValidityRange struct { thisTime time.Time nextTime time.Time ret bool } func TestUnitIsInValidityRange(t *testing.T) { currentTime := time.Now() testcases := []tcValidityRange{ { // basic tests thisTime: currentTime.Add(-100 * time.Second), nextTime: currentTime.Add(maxClockSkew), ret: true, }, { // on the border thisTime: currentTime.Add(maxClockSkew), nextTime: currentTime.Add(maxClockSkew), ret: true, }, { // 1 earlier late thisTime: currentTime.Add(maxClockSkew + 1*time.Second), nextTime: currentTime.Add(maxClockSkew), ret: false, }, { // on the border thisTime: currentTime.Add(-maxClockSkew), nextTime: currentTime.Add(-maxClockSkew), ret: true, }, { // around the border thisTime: currentTime.Add(-24*time.Hour - 40*time.Second), nextTime: currentTime.Add(-24*time.Hour/time.Duration(100) - 40*time.Second), ret: false, }, { // on the border thisTime: currentTime.Add(-48*time.Hour - 29*time.Minute), nextTime: currentTime.Add(-48 * time.Hour / time.Duration(100)), ret: true, }, } for _, tc := range testcases { t.Run(fmt.Sprintf("%v_%v", tc.thisTime, tc.nextTime), func(t *testing.T) { if tc.ret != isInValidityRange(currentTime, tc.thisTime, tc.nextTime) { t.Fatalf("failed to check validity. should be: %v, currentTime: %v, thisTime: %v, nextTime: %v", tc.ret, currentTime, tc.thisTime, tc.nextTime) } }) } } func TestUnitEncodeCertIDGood(t *testing.T) { targetURLs := []string{ "faketestaccount.snowflakecomputing.com:443", "s3-us-west-2.amazonaws.com:443", "sfcdev2.blob.core.windows.net:443", } for _, tt := range targetURLs { t.Run(tt, func(t *testing.T) { chainedCerts := getCert(tt) for i := 0; i < len(chainedCerts)-1; i++ { subject := chainedCerts[i] issuer := chainedCerts[i+1] ocspServers := subject.OCSPServer if len(ocspServers) == 0 { t.Fatalf("no OCSP server is found. cert: %v", subject.Subject) } ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) if err != nil { t.Fatalf("failed to create OCSP request. err: %v", err) } var ost *ocspStatus _, ost = extractCertIDKeyFromRequest(ocspReq) if ost.err != nil { t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) } // better hash. Not sure if the actual OCSP server accepts this, though. ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) if err != nil { t.Fatalf("failed to create OCSP request. err: %v", err) } _, ost = extractCertIDKeyFromRequest(ocspReq) if ost.err != nil { t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) } // tweaked request binary ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) if err != nil { t.Fatalf("failed to create OCSP request. err: %v", err) } ocspReq[10] = 0 // random change _, ost = extractCertIDKeyFromRequest(ocspReq) if ost.err == nil { t.Fatal("should have failed") } } }) } } func TestUnitCheckOCSPResponseCache(t *testing.T) { ocspCacheServerEnabled = true ov := newOcspValidator(&Config{OCSPFailOpen: OCSPFailOpenTrue}) dummyKey0 := certIDKey{ HashAlgorithm: crypto.SHA1, NameHash: "dummy0", IssuerKeyHash: "dummy0", SerialNumber: "dummy0", } dummyKey := certIDKey{ HashAlgorithm: crypto.SHA1, NameHash: "dummy1", IssuerKeyHash: "dummy1", SerialNumber: "dummy1", } b64Key := base64.StdEncoding.EncodeToString([]byte("DUMMY_VALUE")) currentTime := float64(time.Now().UTC().Unix()) syncUpdateOcspResponseCache(func() { ocspResponseCache[dummyKey0] = &certCacheValue{currentTime, b64Key} }) subject := &x509.Certificate{} issuer := &x509.Certificate{} ost := ov.checkOCSPResponseCache(&dummyKey, subject, issuer) if ost.code != ocspMissedCache { t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code) } // old timestamp syncUpdateOcspResponseCache(func() { ocspResponseCache[dummyKey] = &certCacheValue{float64(1395054952), b64Key} }) ost = ov.checkOCSPResponseCache(&dummyKey, subject, issuer) if ost.code != ocspCacheExpired { t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code) } // future timestamp syncUpdateOcspResponseCache(func() { ocspResponseCache[dummyKey] = &certCacheValue{float64(1805054952), b64Key} }) ost = ov.checkOCSPResponseCache(&dummyKey, subject, issuer) if ost.code != ocspFailedParseResponse { t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code) } // actual OCSP but it fails to parse, because an invalid issuer certificate is given. actualOcspResponse := "MIIB0woBAKCCAcwwggHIBgkrBgEFBQcwAQEEggG5MIIBtTCBnqIWBBSxPsNpA/i/RwHUmCYaCALvY2QrwxgPMjAxNz" + // pragma: allowlist secret "A1MTYyMjAwMDBaMHMwcTBJMAkGBSsOAwIaBQAEFN+qEuMosQlBk+KfQoLOR0BClVijBBSxPsNpA/i/RwHUmCYaCALvY2QrwwIQBOHnp" + // pragma: allowlist secret "Nxc8vNtwCtCuF0Vn4AAGA8yMDE3MDUxNjIyMDAwMFqgERgPMjAxNzA1MjMyMjAwMDBaMA0GCSqGSIb3DQEBCwUAA4IBAQCuRGwqQsKy" + // pragma: allowlist secret "IAAGHgezTfG0PzMYgGD/XRDhU+2i08WTJ4Zs40Lu88cBeRXWF3iiJSpiX3/OLgfI7iXmHX9/sm2SmeNWc0Kb39bk5Lw1jwezf8hcI9+" + // pragma: allowlist secret "mZHt60vhUgtgZk21SsRlTZ+S4VXwtDqB1Nhv6cnSnfrL2A9qJDZS2ltPNOwebWJnznDAs2dg+KxmT2yBXpHM1kb0EOolWvNgORbgIgB" + // pragma: allowlist secret "koRzw/UU7zKsqiTB0ZN/rgJp+MocTdqQSGKvbZyR8d4u8eNQqi1x4Pk3yO/pftANFaJKGB+JPgKS3PQAqJaXcipNcEfqtl7y4PO6kqA" + // pragma: allowlist secret "Jb4xI/OTXIrRA5TsT4cCioE" // issuer is not a true issuer certificate syncUpdateOcspResponseCache(func() { ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse} }) ost = ov.checkOCSPResponseCache(&dummyKey, subject, issuer) if ost.code != ocspFailedParseResponse { t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedParseResponse, ost.code) } // invalid validity syncUpdateOcspResponseCache(func() { ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse} }) ost = ov.checkOCSPResponseCache(&dummyKey, subject, nil) if ost.code != ocspInvalidValidity { t.Fatalf("should have failed. expected: %v, got: %v", ocspInvalidValidity, ost.code) } } func TestOcspCacheClearer(t *testing.T) { initOCSPCache() origValue := os.Getenv(ocspResponseCacheClearingIntervalInSecondsEnv) defer func() { StopOCSPCacheClearer() os.Setenv(ocspResponseCacheClearingIntervalInSecondsEnv, origValue) initOCSPCache() StartOCSPCacheClearer() }() syncUpdateOcspResponseCache(func() { ocspResponseCache[certIDKey{}] = nil }) func() { ocspParsedRespCacheLock.Lock() defer ocspParsedRespCacheLock.Unlock() ocspParsedRespCache[parsedOcspRespKey{}] = nil }() StopOCSPCacheClearer() os.Setenv(ocspResponseCacheClearingIntervalInSecondsEnv, "1") StartOCSPCacheClearer() time.Sleep(2 * time.Second) syncUpdateOcspResponseCache(func() { assertEqualE(t, len(ocspResponseCache), 0) }) func() { ocspParsedRespCacheLock.Lock() defer ocspParsedRespCacheLock.Unlock() assertEqualE(t, len(ocspParsedRespCache), 0) }() } func TestUnitValidateOCSP(t *testing.T) { ocspRes := &ocsp.Response{ ThisUpdate: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), NextUpdate: time.Date(2020, 1, 5, 0, 0, 0, 0, time.UTC), } ost := validateOCSP(ocspRes) if ost.code != ocspInvalidValidity { t.Fatalf("should have failed. expected: %v, got: %v", ocspInvalidValidity, ost.code) } currentTime := time.Now() ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour) ocspRes.NextUpdate = currentTime.Add(2 * time.Hour) ocspRes.Status = ocsp.Revoked ost = validateOCSP(ocspRes) if ost.code != ocspStatusRevoked { t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code) } ocspRes.Status = ocsp.Good ost = validateOCSP(ocspRes) if ost.code != ocspStatusGood { t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code) } ocspRes.Status = ocsp.Unknown ost = validateOCSP(ocspRes) if ost.code != ocspStatusUnknown { t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code) } ocspRes.Status = ocsp.ServerFailed ost = validateOCSP(ocspRes) if ost.code != ocspStatusOthers { t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code) } } func TestUnitEncodeCertID(t *testing.T) { var st *ocspStatus _, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2}) if st.code != ocspFailedDecomposeRequest { t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code) } } func getCert(addr string) []*x509.Certificate { tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second) if err != nil { panic(err) } defer tcpConn.Close() err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second)) if err != nil { panic(err) } config := tls.Config{InsecureSkipVerify: true, ServerName: addr} conn := tls.Client(tcpConn, &config) defer conn.Close() err = conn.Handshake() if err != nil { panic(err) } state := conn.ConnectionState() return state.PeerCertificates } func TestOCSPRetry(t *testing.T) { ov := newOcspValidator(&Config{OCSPFailOpen: OCSPFailOpenTrue}) certs := getCert("s3-us-west-2.amazonaws.com:443") dummyOCSPHost := &url.URL{ Scheme: "https", Host: "dummyOCSPHost", } client := &fakeHTTPClient{ cnt: 3, success: true, body: []byte{1, 2, 3}, t: t, } res, b, st := ov.retryOCSP( context.Background(), client, emptyRequest, dummyOCSPHost, make(map[string]string), []byte{0}, certs[len(certs)-1], 10*time.Second) if st.err == nil { fmt.Printf("should fail: %v, %v, %v\n", res, b, st) } client = &fakeHTTPClient{ cnt: 30, success: true, body: []byte{1, 2, 3}, t: t, } res, b, st = ov.retryOCSP( context.Background(), client, fakeRequestFunc, dummyOCSPHost, make(map[string]string), []byte{0}, certs[len(certs)-1], 5*time.Second) if st.err == nil { fmt.Printf("should fail: %v, %v, %v\n", res, b, st) } } func TestFullOCSPURL(t *testing.T) { testcases := []tcFullOCSPURL{ { url: &url.URL{Host: "some-ocsp-url.com"}, expectedURLString: "some-ocsp-url.com", }, { url: &url.URL{ Host: "some-ocsp-url.com", Path: "/some-path", }, expectedURLString: "some-ocsp-url.com/some-path", }, { url: &url.URL{ Host: "some-ocsp-url.com", Path: "some-path", }, expectedURLString: "some-ocsp-url.com/some-path", }, } for _, testcase := range testcases { t.Run("", func(t *testing.T) { returnedStringURL := fullOCSPURL(testcase.url) if returnedStringURL != testcase.expectedURLString { t.Fatalf("failed to match returned OCSP url string; expected: %v, got: %v", testcase.expectedURLString, returnedStringURL) } }) } } type tcFullOCSPURL struct { url *url.URL expectedURLString string } func TestOCSPCacheServerRetry(t *testing.T) { dummyOCSPHost := &url.URL{ Scheme: "https", Host: "dummyOCSPHost", } client := &fakeHTTPClient{ cnt: 3, success: true, body: []byte{1, 2, 3}, t: t, } res, st := checkOCSPCacheServer( context.Background(), client, fakeRequestFunc, dummyOCSPHost, 20*time.Second) if st.err == nil { t.Errorf("should fail: %v", res) } client = &fakeHTTPClient{ cnt: 30, success: true, body: []byte{1, 2, 3}, t: t, } res, st = checkOCSPCacheServer( context.Background(), client, fakeRequestFunc, dummyOCSPHost, 10*time.Second) if st.err == nil { t.Errorf("should fail: %v", res) } } type tcCanEarlyExit struct { results []*ocspStatus resultLen int retFailOpen *ocspStatus retFailClosed *ocspStatus } func TestCanEarlyExitForOCSP(t *testing.T) { testcases := []tcCanEarlyExit{ { // 0 results: []*ocspStatus{ { code: ocspStatusGood, }, { code: ocspStatusGood, }, { code: ocspStatusGood, }, }, retFailOpen: nil, retFailClosed: nil, }, { // 1 results: []*ocspStatus{ { code: ocspStatusRevoked, err: errors.New("revoked"), }, { code: ocspStatusGood, }, { code: ocspStatusGood, }, }, retFailOpen: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, }, { // 2 results: []*ocspStatus{ { code: ocspStatusUnknown, err: errors.New("unknown"), }, { code: ocspStatusGood, }, { code: ocspStatusGood, }, }, retFailOpen: nil, retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")}, }, { // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included. results: []*ocspStatus{ { code: ocspStatusRevoked, err: errors.New("revoked"), }, { code: ocspInvalidValidity, }, { code: ocspStatusGood, }, }, retFailOpen: nil, retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, }, { // 4: not taken as revoked if the number of results don't match the expected results. results: []*ocspStatus{ { code: ocspStatusRevoked, err: errors.New("revoked"), }, { code: ocspStatusGood, }, }, resultLen: 3, retFailOpen: nil, retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, }, } for idx, tt := range testcases { t.Run("", func(t *testing.T) { ovOpen := newOcspValidator(&Config{OCSPFailOpen: OCSPFailOpenTrue}) expectedLen := len(tt.results) if tt.resultLen > 0 { expectedLen = tt.resultLen } expectedLen++ // add one because normally there is a root certificate that is not included in the results. mockVerifiedChain := make([]*x509.Certificate, expectedLen) r := ovOpen.canEarlyExitForOCSP(tt.results, mockVerifiedChain) if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) { t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r) } ovClosed := newOcspValidator(&Config{OCSPFailOpen: OCSPFailOpenFalse}) r = ovClosed.canEarlyExitForOCSP(tt.results, mockVerifiedChain) if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) { t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r) } }) } } func TestInitOCSPCacheFileCreation(t *testing.T) { if runningOnGithubAction() { t.Skip("cannot write to github file system") } dirName, err := os.UserHomeDir() if err != nil { t.Error(err) } srcFileName := dirName + "/.cache/snowflake/ocsp_response_cache.json" tmpFileName := srcFileName + "_tmp" dst, err := os.Create(tmpFileName) if err != nil { t.Error(err) } defer dst.Close() var src *os.File if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) { // file does not exist if err = os.MkdirAll(dirName+"/.cache/snowflake/", os.ModePerm); err != nil { t.Error(err) } if _, err = os.Create(srcFileName); err != nil { t.Error(err) } } else if err != nil { t.Error(err) } else { // file exists src, err = os.Open(srcFileName) if err != nil { t.Error(err) } defer src.Close() // copy original contents to temporary file if _, err = io.Copy(dst, src); err != nil { t.Error(err) } if err = os.Remove(srcFileName); err != nil { t.Error(err) } } // cleanup defer func() { src, _ = os.Open(tmpFileName) defer src.Close() dst, _ = os.OpenFile(srcFileName, os.O_WRONLY, readWriteFileMode) defer dst.Close() // copy temporary file contents back to original file if _, err = io.Copy(dst, src); err != nil { t.Fatal(err) } if err = os.Remove(tmpFileName); err != nil { t.Error(err) } }() initOCSPCache() if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) { t.Error(err) } else if err != nil { t.Error(err) } } func syncUpdateOcspResponseCache(f func()) { ocspResponseCacheLock.Lock() defer ocspResponseCacheLock.Unlock() f() } ================================================ FILE: old_driver_test.go ================================================ package gosnowflake import ( "bytes" "reflect" "testing" ) const ( forceARROW = "ALTER SESSION SET GO_QUERY_RESULT_FORMAT = ARROW" forceJSON = "ALTER SESSION SET GO_QUERY_RESULT_FORMAT = JSON" ) func TestJSONInt(t *testing.T) { testInt(t, true) } func TestJSONFloat32(t *testing.T) { testFloat32(t, true) } func TestJSONFloat64(t *testing.T) { testFloat64(t, true) } func TestJSONVariousTypes(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec(forceJSON) rows := dbt.mustQuery(selectVariousTypes) defer rows.Close() if !rows.Next() { dbt.Error("failed to query") } cc, err := rows.Columns() if err != nil { dbt.Errorf("columns: %v", cc) } ct, err := rows.ColumnTypes() if err != nil { dbt.Errorf("column types: %v", ct) } var v1 float32 var v2, v2a int var v3 string var v4 float64 var v5 []byte var v6 bool err = rows.Scan(&v1, &v2, &v2a, &v3, &v4, &v5, &v6) if err != nil { dbt.Errorf("failed to scan: %#v", err) } if v1 != 1.0 { dbt.Errorf("failed to scan. %#v", v1) } if ct[0].Name() != "C1" || ct[1].Name() != "C2" || ct[2].Name() != "C2A" || ct[3].Name() != "C3" || ct[4].Name() != "C4" || ct[5].Name() != "C5" || ct[6].Name() != "C6" { dbt.Errorf("failed to get column names: %#v", ct) } if ct[0].ScanType() != reflect.TypeFor[float64]() { dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeFor[float64](), ct[0].ScanType()) } if ct[1].ScanType() != reflect.TypeFor[int64]() { dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeFor[int64](), ct[1].ScanType()) } assertEqualE(t, ct[2].ScanType(), reflect.TypeFor[string]()) var pr, sc int64 var cLen int64 var canNull bool pr, sc = dbt.mustDecimalSize(ct[0]) if pr != 30 || sc != 2 { dbt.Errorf("failed to get precision and scale. %#v", ct[0]) } dbt.mustFailLength(ct[0]) canNull = dbt.mustNullable(ct[0]) if canNull { dbt.Errorf("failed to get nullable. %#v", ct[0]) } if cLen != 0 { dbt.Errorf("failed to get length. %#v", ct[0]) } if v2 != 2 { dbt.Errorf("failed to scan. %#v", v2) } pr, sc = dbt.mustDecimalSize(ct[1]) if pr != 18 || sc != 0 { dbt.Errorf("failed to get precision and scale. %#v", ct[1]) } dbt.mustFailLength(ct[1]) canNull = dbt.mustNullable(ct[1]) if canNull { dbt.Errorf("failed to get nullable. %#v", ct[1]) } if v2a != 22 { dbt.Errorf("failed to scan. %#v", v2a) } pr, sc = dbt.mustDecimalSize(ct[2]) if pr != 38 || sc != 0 { dbt.Errorf("failed to get precision and scale. %#v", ct[2]) } if v3 != "t3" { dbt.Errorf("failed to scan. %#v", v3) } dbt.mustFailDecimalSize(ct[3]) cLen = dbt.mustLength(ct[3]) if cLen != 2 { dbt.Errorf("failed to get length. %#v", ct[3]) } canNull = dbt.mustNullable(ct[3]) if canNull { dbt.Errorf("failed to get nullable. %#v", ct[3]) } if v4 != 4.2 { dbt.Errorf("failed to scan. %#v", v4) } dbt.mustFailDecimalSize(ct[4]) dbt.mustFailLength(ct[4]) canNull = dbt.mustNullable(ct[4]) if canNull { dbt.Errorf("failed to get nullable. %#v", ct[4]) } if !bytes.Equal(v5, []byte{0xab, 0xcd}) { dbt.Errorf("failed to scan. %#v", v5) } dbt.mustFailDecimalSize(ct[5]) cLen = dbt.mustLength(ct[5]) // BINARY if cLen != 8388608 { dbt.Errorf("failed to get length. %#v", ct[5]) } canNull = dbt.mustNullable(ct[5]) if canNull { dbt.Errorf("failed to get nullable. %#v", ct[5]) } if !v6 { dbt.Errorf("failed to scan. %#v", v6) } dbt.mustFailDecimalSize(ct[6]) dbt.mustFailLength(ct[6]) }) } func TestJSONString(t *testing.T) { testString(t, true) } func TestJSONSimpleDateTimeTimestampFetch(t *testing.T) { testSimpleDateTimeTimestampFetch(t, true) } func TestJSONDateTime(t *testing.T) { testDateTime(t, true) } func TestJSONTimestampLTZ(t *testing.T) { testTimestampLTZ(t, true) } func TestJSONTimestampTZ(t *testing.T) { testTimestampTZ(t, true) } func TestJSONNULL(t *testing.T) { testNULL(t, true) } func TestJSONVariant(t *testing.T) { testVariant(t, true) } func TestJSONArray(t *testing.T) { testArray(t, true) } // TestLargeSetJSONResultWithDecoder and TestLargeSetResultWithCustomJSONDecoder // validate JSON result decoding with row counts large enough to trigger chunked // result delivery from Snowflake. The row counts (10,000 and 20,000) are // calibrated to exercise the chunk download pipeline while staying within CI // timeout limits. func TestLargeSetJSONResultWithDecoder(t *testing.T) { testLargeSetResult(t, 10000, true) } // TestLargeSetResultWithCustomJSONDecoder validates chunked JSON decoding using // the custom decoder. Same row count constraints as TestLargeSetJSONResultWithDecoder // apply here — the count must be large enough to trigger chunked delivery. func TestLargeSetResultWithCustomJSONDecoder(t *testing.T) { customJSONDecoderEnabled = true // less number of rows to avoid CI timeout testLargeSetResult(t, 20000, true) } func TestBindingJSONInterface(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec(forceJSON) rows := dbt.mustQuery(selectVariousTypes) defer rows.Close() if !rows.Next() { dbt.Error("failed to query") } var v1, v2, v2a, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v2a, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } if s, ok := v1.(string); !ok || s != "1.00" { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v1) } if s, ok := v2.(string); !ok || s != "2" { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v2) } if s, ok := v3.(string); !ok || s != "t3" { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v3) } if s, ok := v4.(string); !ok || s != "4.2" { dbt.Fatalf("failed to fetch. ok: %v, value: %v", ok, v4) } }) } ================================================ FILE: os_specific_posix.go ================================================ //go:build !windows package gosnowflake import ( "fmt" "golang.org/x/sys/unix" "io" "os" "syscall" ) var osVersion = getOSVersion() func getOSVersion() string { var uts unix.Utsname if err := unix.Uname(&uts); err != nil { panic(err) } sysname := unix.ByteSliceToString(uts.Sysname[:]) release := unix.ByteSliceToString(uts.Release[:]) return sysname + "-" + release } func provideFileOwner(file *os.File) (uint32, error) { info, err := file.Stat() if err != nil { return 0, err } return provideOwnerFromStat(info, file.Name()) } func provideOwnerFromStat(info os.FileInfo, filepath string) (uint32, error) { nativeStat, ok := info.Sys().(*syscall.Stat_t) if !ok { return 0, fmt.Errorf("cannot cast file info for %v to *syscall.Stat_t", filepath) } return nativeStat.Uid, nil } func getFileContents(filePath string, expectedPerm os.FileMode) ([]byte, error) { // open the file with read only and no symlink flags file, err := os.OpenFile(filePath, syscall.O_RDONLY|syscall.O_NOFOLLOW, 0) if err != nil { return nil, err } defer func() { if err = file.Close(); err != nil { logger.Warnf("failed to close the file: %v", err) } }() // validate file permissions and owner if err = validateFilePermissionBits(file, expectedPerm); err != nil { return nil, err } if err = ensureFileOwner(file); err != nil { return nil, err } // read the file fileContents, err := io.ReadAll(file) if err != nil { return nil, err } return fileContents, nil } func validateFilePermissionBits(f *os.File, expectedPerm os.FileMode) error { fileInfo, err := f.Stat() if err != nil { return err } filePerm := fileInfo.Mode() if filePerm&expectedPerm != 0 { return fmt.Errorf("incorrect permissions of %s", f.Name()) } return nil } ================================================ FILE: os_specific_windows.go ================================================ package gosnowflake import ( "errors" "fmt" "os" "golang.org/x/sys/windows/registry" ) var osVersion = getWindowsOSVersion() func getWindowsOSVersion() string { k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) if err != nil { errString := fmt.Sprintf("cannot open Windows registry key: %v", err) logger.Debugf(errString) return errString } defer k.Close() cv, _, err := k.GetStringValue("CurrentVersion") if err != nil { logger.Debugf("cannot find Windows current version: %v", err) cv = "CurrentVersion=unknown" } pn, _, err := k.GetStringValue("ProductName") if err != nil { logger.Debugf("cannot find Windows product name: %v", err) pn = "ProductName=unknown" } maj, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") if err != nil { logger.Debugf("cannot find Windows major version number: %v", err) } min, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") if err != nil { logger.Debugf("cannot find Windows minor version number: %v", err) } cb, _, err := k.GetStringValue("CurrentBuild") if err != nil { logger.Debugf("cannot find Windows current build: %v", err) cb = "CurrentBuild=unknown" } return fmt.Sprintf("CurrentVersion=%s; ProductName=%s; MajorVersion=%d; MinorVersion=%d; CurrentBuild=%s", cv, pn, maj, min, cb) } func provideFileOwner(file *os.File) (uint32, error) { return 0, errors.New("provideFileOwner is unsupported on windows") } func getFileContents(filePath string, expectedPerm os.FileMode) ([]byte, error) { fileContents, err := os.ReadFile(filePath) if err != nil { return nil, err } return fileContents, nil } ================================================ FILE: parameters.json.local ================================================ { "testconnection": { "SNOWFLAKE_TEST_HOST": "snowflake.reg.local", "SNOWFLAKE_TEST_PROTOCOL": "http", "SNOWFLAKE_TEST_PORT": "8082", "SNOWFLAKE_TEST_USER": "snowman", "SNOWFLAKE_TEST_PASSWORD": "test", "SNOWFLAKE_TEST_ACCOUNT": "s3testaccount", "SNOWFLAKE_TEST_WAREHOUSE": "regress", "SNOWFLAKE_TEST_DATABASE": "testdb", "SNOWFLAKE_TEST_SCHEMA": "testschema", "SNOWFLAKE_TEST_ROLE": "sysadmin", "SNOWFLAKE_TEST_DEBUG": "false" } } ================================================ FILE: parameters.json.tmpl ================================================ { "testconnection": { "SNOWFLAKE_TEST_USER": "testuser", "SNOWFLAKE_TEST_PASSWORD": "testpass", "SNOWFLAKE_TEST_ACCOUNT": "testaccount", "SNOWFLAKE_TEST_WAREHOUSE": "testwarehouse", "SNOWFLAKE_TEST_DATABASE": "testdatabase", "SNOWFLAKE_TEST_SCHEMA": "testschema", "SNOWFLAKE_TEST_ROLE": "testrole", "SNOWFLAKE_TEST_DEBUG": "false", } } ================================================ FILE: permissions_test.go ================================================ //go:build !windows package gosnowflake import ( "fmt" "os" "path" "testing" "golang.org/x/sys/unix" ) func TestConfigPermissions(t *testing.T) { testCases := []struct { filePerm int isValid bool }{ {filePerm: 0700, isValid: true}, {filePerm: 0600, isValid: true}, {filePerm: 0500, isValid: true}, {filePerm: 0400, isValid: true}, {filePerm: 0707, isValid: false}, {filePerm: 0706, isValid: false}, {filePerm: 0705, isValid: true}, {filePerm: 0704, isValid: true}, {filePerm: 0703, isValid: false}, {filePerm: 0702, isValid: false}, {filePerm: 0701, isValid: true}, {filePerm: 0770, isValid: false}, {filePerm: 0760, isValid: false}, {filePerm: 0750, isValid: true}, {filePerm: 0740, isValid: true}, {filePerm: 0730, isValid: false}, {filePerm: 0720, isValid: false}, {filePerm: 0710, isValid: true}, } oldMask := unix.Umask(0000) defer unix.Umask(oldMask) for _, tc := range testCases { t.Run(fmt.Sprintf("0%o", tc.filePerm), func(t *testing.T) { tempFile := path.Join(t.TempDir(), fmt.Sprintf("filePerm_%o", tc.filePerm)) err := os.WriteFile(tempFile, nil, os.FileMode(tc.filePerm)) assertNilE(t, err) defer os.Remove(tempFile) f, err := os.Open(tempFile) assertNilE(t, err) defer f.Close() expectedPerm := os.FileMode(1<<4 | 1<<1) err = validateFilePermissionBits(f, expectedPerm) if err != nil && tc.isValid { t.Error(err) } }) } } func TestLogDirectoryPermissions(t *testing.T) { testCases := []struct { dirPerm int limitedToUser bool }{ {dirPerm: 0700, limitedToUser: true}, {dirPerm: 0600, limitedToUser: false}, {dirPerm: 0500, limitedToUser: false}, {dirPerm: 0400, limitedToUser: false}, {dirPerm: 0300, limitedToUser: false}, {dirPerm: 0200, limitedToUser: false}, {dirPerm: 0100, limitedToUser: false}, {dirPerm: 0707, limitedToUser: false}, {dirPerm: 0706, limitedToUser: false}, {dirPerm: 0705, limitedToUser: false}, {dirPerm: 0704, limitedToUser: false}, {dirPerm: 0703, limitedToUser: false}, {dirPerm: 0702, limitedToUser: false}, {dirPerm: 0701, limitedToUser: false}, {dirPerm: 0770, limitedToUser: false}, {dirPerm: 0760, limitedToUser: false}, {dirPerm: 0750, limitedToUser: false}, {dirPerm: 0740, limitedToUser: false}, {dirPerm: 0730, limitedToUser: false}, {dirPerm: 0720, limitedToUser: false}, {dirPerm: 0710, limitedToUser: false}, } oldMask := unix.Umask(0000) defer unix.Umask(oldMask) for _, tc := range testCases { t.Run(fmt.Sprintf("0%o", tc.dirPerm), func(t *testing.T) { tempDir := path.Join(t.TempDir(), fmt.Sprintf("filePerm_%o", tc.dirPerm)) err := os.Mkdir(tempDir, os.FileMode(tc.dirPerm)) assertNilE(t, err) defer os.Remove(tempDir) result, _, err := isDirAccessCorrect(tempDir) if err != nil && tc.limitedToUser { t.Error(err) } assertEqualE(t, result, tc.limitedToUser) }) } } ================================================ FILE: platform_detection.go ================================================ package gosnowflake import ( "context" "errors" "io" "net/http" "net/url" "os" "regexp" "strings" "sync" "time" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go/logging" ) type platformDetectionState string const ( platformDetected platformDetectionState = "detected" platformNotDetected platformDetectionState = "not_detected" platformDetectionTimeout platformDetectionState = "timeout" ) const disablePlatformDetectionEnv = "SNOWFLAKE_DISABLE_PLATFORM_DETECTION" var ( azureMetadataBaseURL = "http://169.254.169.254" gceMetadataRootURL = "http://metadata.google.internal" gcpMetadataBaseURL = "http://metadata.google.internal/computeMetadata/v1" ) var ( detectedPlatformsCache []string initPlatformDetectionOnce sync.Once platformDetectionDone = make(chan struct{}) ) func initPlatformDetection() { initPlatformDetectionOnce.Do(func() { go func() { detectedPlatformsCache = detectPlatforms(context.Background(), 200*time.Millisecond) defer close(platformDetectionDone) }() }) } func getDetectedPlatforms() []string { logger.Debugf("getDetectedPlatforms: waiting for platform detection to complete") <-platformDetectionDone logger.Debugf("getDetectedPlatforms: returning cached detected platforms: %v", detectedPlatformsCache) return detectedPlatformsCache } func metadataServerHTTPClient(timeout time.Duration) *http.Client { return &http.Client{ Timeout: timeout, Transport: &http.Transport{ Proxy: nil, DisableKeepAlives: true, }, } } type detectorFunc struct { name string fn func(ctx context.Context, timeout time.Duration) platformDetectionState } func detectPlatforms(ctx context.Context, timeout time.Duration) []string { if strings.EqualFold(os.Getenv(disablePlatformDetectionEnv), "true") { return []string{"disabled"} } detectors := []detectorFunc{ {name: "is_aws_lambda", fn: detectAwsLambdaEnv}, {name: "is_azure_function", fn: detectAzureFunctionEnv}, {name: "is_gce_cloud_run_service", fn: detectGceCloudRunServiceEnv}, {name: "is_gce_cloud_run_job", fn: detectGceCloudRunJobEnv}, {name: "is_github_action", fn: detectGithubActionsEnv}, {name: "is_ec2_instance", fn: detectEc2Instance}, {name: "has_aws_identity", fn: detectAwsIdentity}, {name: "is_azure_vm", fn: detectAzureVM}, {name: "has_azure_managed_identity", fn: detectAzureManagedIdentity}, {name: "is_gce_vm", fn: detectGceVM}, {name: "has_gcp_identity", fn: detectGcpIdentity}, } detectionStates := make(map[string]platformDetectionState, len(detectors)) var waitGroup sync.WaitGroup var mutex sync.Mutex waitGroup.Add(len(detectors)) for _, detector := range detectors { go func(detector detectorFunc) { defer waitGroup.Done() detectionState := detector.fn(ctx, timeout) mutex.Lock() detectionStates[detector.name] = detectionState mutex.Unlock() }(detector) } waitGroup.Wait() detectedPlatformNames := []string{} for _, detector := range detectors { if detectionStates[detector.name] == platformDetected { detectedPlatformNames = append(detectedPlatformNames, detector.name) } } logger.Debugf("detectPlatforms: completed. Detection states: %v", detectionStates) return detectedPlatformNames } func detectAwsLambdaEnv(_ context.Context, _ time.Duration) platformDetectionState { if os.Getenv("LAMBDA_TASK_ROOT") != "" { return platformDetected } return platformNotDetected } func detectGithubActionsEnv(_ context.Context, _ time.Duration) platformDetectionState { if os.Getenv("GITHUB_ACTIONS") != "" { return platformDetected } return platformNotDetected } func detectAzureFunctionEnv(_ context.Context, _ time.Duration) platformDetectionState { if os.Getenv("FUNCTIONS_WORKER_RUNTIME") != "" && os.Getenv("FUNCTIONS_EXTENSION_VERSION") != "" && os.Getenv("AzureWebJobsStorage") != "" { return platformDetected } return platformNotDetected } func detectGceCloudRunServiceEnv(_ context.Context, _ time.Duration) platformDetectionState { if os.Getenv("K_SERVICE") != "" && os.Getenv("K_REVISION") != "" && os.Getenv("K_CONFIGURATION") != "" { return platformDetected } return platformNotDetected } func detectGceCloudRunJobEnv(_ context.Context, _ time.Duration) platformDetectionState { if os.Getenv("CLOUD_RUN_JOB") != "" && os.Getenv("CLOUD_RUN_EXECUTION") != "" { return platformDetected } return platformNotDetected } func detectEc2Instance(ctx context.Context, timeout time.Duration) platformDetectionState { timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() cfg, err := config.LoadDefaultConfig(timeoutCtx, config.WithLogger(logging.NewStandardLogger(io.Discard))) if err != nil { return platformNotDetected } client := imds.NewFromConfig(cfg) result, err := client.GetInstanceIdentityDocument(timeoutCtx, &imds.GetInstanceIdentityDocumentInput{}) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return platformDetectionTimeout } return platformNotDetected } if result != nil && result.InstanceID != "" { return platformDetected } return platformNotDetected } func detectAwsIdentity(ctx context.Context, timeout time.Duration) platformDetectionState { timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() cfg, err := config.LoadDefaultConfig(timeoutCtx, config.WithLogger(logging.NewStandardLogger(io.Discard))) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return platformDetectionTimeout } return platformNotDetected } client := sts.NewFromConfig(cfg) out, err := client.GetCallerIdentity(timeoutCtx, &sts.GetCallerIdentityInput{}) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return platformDetectionTimeout } return platformNotDetected } if out == nil || out.Arn == nil || *out.Arn == "" { return platformNotDetected } if isValidArnForWif(*out.Arn) { return platformDetected } return platformNotDetected } func detectAzureVM(ctx context.Context, timeout time.Duration) platformDetectionState { client := metadataServerHTTPClient(timeout) req, err := http.NewRequestWithContext(ctx, http.MethodGet, azureMetadataBaseURL+"/metadata/instance?api-version=2019-03-11", nil) if err != nil { return platformNotDetected } req.Header.Set("Metadata", "true") resp, err := client.Do(req) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return platformDetectionTimeout } return platformNotDetected } defer func() { _ = resp.Body.Close() }() if resp.StatusCode == http.StatusOK { return platformDetected } return platformNotDetected } func detectAzureManagedIdentity(ctx context.Context, timeout time.Duration) platformDetectionState { if detectAzureFunctionEnv(ctx, timeout) == platformDetected && os.Getenv("IDENTITY_HEADER") != "" { return platformDetected } client := metadataServerHTTPClient(timeout) values := url.Values{} values.Set("api-version", "2018-02-01") values.Set("resource", "https://management.azure.com") req, err := http.NewRequestWithContext(ctx, http.MethodGet, azureMetadataBaseURL+"/metadata/identity/oauth2/token?"+values.Encode(), nil) if err != nil { return platformNotDetected } req.Header.Set("Metadata", "true") resp, err := client.Do(req) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return platformDetectionTimeout } return platformNotDetected } defer func() { _ = resp.Body.Close() }() if resp.StatusCode == http.StatusOK { return platformDetected } return platformNotDetected } func detectGceVM(ctx context.Context, timeout time.Duration) platformDetectionState { client := metadataServerHTTPClient(timeout) req, err := http.NewRequestWithContext(ctx, http.MethodGet, gceMetadataRootURL, nil) if err != nil { return platformNotDetected } resp, err := client.Do(req) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return platformDetectionTimeout } return platformNotDetected } defer func() { _ = resp.Body.Close() }() if resp.Header.Get(gcpMetadataFlavorHeaderName) == gcpMetadataFlavor { return platformDetected } return platformNotDetected } func detectGcpIdentity(ctx context.Context, timeout time.Duration) platformDetectionState { client := metadataServerHTTPClient(timeout) url := gcpMetadataBaseURL + "/instance/service-accounts/default/email" req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return platformNotDetected } req.Header.Set(gcpMetadataFlavorHeaderName, gcpMetadataFlavor) resp, err := client.Do(req) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return platformDetectionTimeout } return platformNotDetected } defer func() { _ = resp.Body.Close() }() if resp.StatusCode == http.StatusOK { return platformDetected } return platformNotDetected } func isValidArnForWif(arn string) bool { patterns := []string{ `^arn:[^:]+:iam::[^:]+:user/.+$`, `^arn:[^:]+:sts::[^:]+:assumed-role/.+$`, } for _, pattern := range patterns { matched, err := regexp.MatchString(pattern, arn) if err == nil && matched { return true } } return false } ================================================ FILE: platform_detection_test.go ================================================ package gosnowflake import ( "context" "fmt" "os" "slices" "sync" "testing" "time" ) type platformDetectionTestCase struct { name string envVars map[string]string wiremockMappings []wiremockMapping expectedResult []string } type envSnapshot map[string]string func setupCleanPlatformEnv() func() { platformEnvVars := []string{ "LAMBDA_TASK_ROOT", "GITHUB_ACTIONS", "FUNCTIONS_WORKER_RUNTIME", "FUNCTIONS_EXTENSION_VERSION", "AzureWebJobsStorage", "K_SERVICE", "K_REVISION", "K_CONFIGURATION", "CLOUD_RUN_JOB", "CLOUD_RUN_EXECUTION", "IDENTITY_HEADER", disablePlatformDetectionEnv, } snapshot := make(envSnapshot) for _, env := range platformEnvVars { snapshot[env] = os.Getenv(env) } for _, env := range platformEnvVars { os.Unsetenv(env) } return func() { for env, value := range snapshot { os.Setenv(env, value) } } } func setupWiremockMetadataEndpoints() func() { originalAzureURL := azureMetadataBaseURL originalGceRootURL := gceMetadataRootURL originalGcpBaseURL := gcpMetadataBaseURL wiremockURL := wiremock.baseURL() azureMetadataBaseURL = wiremockURL gceMetadataRootURL = wiremockURL gcpMetadataBaseURL = wiremockURL + "/computeMetadata/v1" os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", wiremockURL) os.Setenv("AWS_ENDPOINT_URL_STS", wiremockURL) return func() { azureMetadataBaseURL = originalAzureURL gceMetadataRootURL = originalGceRootURL gcpMetadataBaseURL = originalGcpBaseURL os.Unsetenv("AWS_EC2_METADATA_SERVICE_ENDPOINT") os.Unsetenv("AWS_ENDPOINT_URL_STS") } } func TestPlatformDetectionCachingAndSyncOnce(t *testing.T) { cleanup := setupCleanPlatformEnv() defer cleanup() originalDone, originalCache := platformDetectionDone, detectedPlatformsCache initPlatformDetectionOnce, platformDetectionDone, detectedPlatformsCache = sync.Once{}, make(chan struct{}), nil defer func() { platformDetectionDone, detectedPlatformsCache = originalDone, originalCache }() os.Setenv("LAMBDA_TASK_ROOT", "/var/task") initPlatformDetection() platforms1 := getDetectedPlatforms() // Verify caching works and AWS Lambda detected assertDeepEqualE(t, platforms1, detectedPlatformsCache) assertTrueE(t, slices.Contains(platforms1, "is_aws_lambda"), "Should detect AWS Lambda") // Change environment and test sync.Once behavior cleanup() os.Setenv("GITHUB_ACTIONS", "true") initPlatformDetection() platforms2 := getDetectedPlatforms() assertDeepEqualE(t, platforms1, platforms2) assertTrueE(t, slices.Contains(platforms2, "is_aws_lambda"), "Should still show cached AWS Lambda result") assertFalseE(t, slices.Contains(platforms2, "is_github_action"), "Should NOT detect GitHub Actions due to caching") } func TestDetectPlatforms(t *testing.T) { testCases := []platformDetectionTestCase{ { name: "returns disabled when SNOWFLAKE_DISABLE_PLATFORM_DETECTION is set", envVars: map[string]string{ "SNOWFLAKE_DISABLE_PLATFORM_DETECTION": "true", }, expectedResult: []string{"disabled"}, }, { name: "returns empty when no platforms detected", expectedResult: []string{}, }, { name: "detects AWS Lambda", envVars: map[string]string{ "LAMBDA_TASK_ROOT": "/var/task", }, expectedResult: []string{"is_aws_lambda"}, }, { name: "detects GitHub Actions", envVars: map[string]string{ "GITHUB_ACTIONS": "true", }, expectedResult: []string{"is_github_action"}, }, { name: "detects Azure Function", envVars: map[string]string{ "FUNCTIONS_WORKER_RUNTIME": "node", "FUNCTIONS_EXTENSION_VERSION": "~4", "AzureWebJobsStorage": "DefaultEndpointsProtocol=https;AccountName=test", }, expectedResult: []string{"is_azure_function"}, }, { name: "detects GCE Cloud Run Service", envVars: map[string]string{ "K_SERVICE": "my-service", "K_REVISION": "my-service-00001", "K_CONFIGURATION": "my-service", }, expectedResult: []string{"is_gce_cloud_run_service"}, }, { name: "detects GCE Cloud Run Job", envVars: map[string]string{ "CLOUD_RUN_JOB": "my-job", "CLOUD_RUN_EXECUTION": "my-job-execution-1", }, expectedResult: []string{"is_gce_cloud_run_job"}, }, { name: "detects EC2 instance", wiremockMappings: []wiremockMapping{ newWiremockMapping("platform_detection/aws_ec2_instance_success.json"), }, expectedResult: []string{"is_ec2_instance"}, }, { name: "detects AWS identity", wiremockMappings: []wiremockMapping{ newWiremockMapping("platform_detection/aws_identity_success.json"), }, expectedResult: []string{"has_aws_identity"}, }, { name: "detects Azure VM", wiremockMappings: []wiremockMapping{ newWiremockMapping("platform_detection/azure_vm_success.json"), }, expectedResult: []string{"is_azure_vm"}, }, { name: "detects Azure Managed Identity using IDENTITY_HEADER", envVars: map[string]string{ "FUNCTIONS_WORKER_RUNTIME": "node", "FUNCTIONS_EXTENSION_VERSION": "~4", "AzureWebJobsStorage": "DefaultEndpointsProtocol=https;AccountName=test", "IDENTITY_HEADER": "test-header", }, expectedResult: []string{"is_azure_function", "has_azure_managed_identity"}, }, { name: "detects Azure Manage Identity using metadata service", wiremockMappings: []wiremockMapping{ newWiremockMapping("platform_detection/azure_managed_identity_success.json"), }, expectedResult: []string{"has_azure_managed_identity"}, }, { name: "detects GCE VM", wiremockMappings: []wiremockMapping{ newWiremockMapping("platform_detection/gce_vm_success.json"), }, expectedResult: []string{"is_gce_vm"}, }, { name: "detects GCP identity", wiremockMappings: []wiremockMapping{ newWiremockMapping("platform_detection/gce_identity_success.json"), }, expectedResult: []string{"has_gcp_identity"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { cleanup := setupCleanPlatformEnv() defer cleanup() for key, value := range tc.envVars { os.Setenv(key, value) } wiremock.registerMappings(t, tc.wiremockMappings) wiremockCleanup := setupWiremockMetadataEndpoints() defer wiremockCleanup() platforms := detectPlatforms(context.Background(), 200*time.Millisecond) assertDeepEqualE(t, platforms, tc.expectedResult) }) } } func TestDetectPlatformsTimeout(t *testing.T) { cleanup := setupCleanPlatformEnv() defer cleanup() wiremock.registerMappings(t, newWiremockMapping("platform_detection/timeout_response.json")) wiremockCleanup := setupWiremockMetadataEndpoints() defer wiremockCleanup() start := time.Now() platforms := detectPlatforms(context.Background(), 200*time.Millisecond) executionTime := time.Since(start) assertEqualE(t, len(platforms), 0, fmt.Sprintf("Expected empty platforms, got: %v", platforms)) assertTrueE(t, executionTime >= 200*time.Millisecond && executionTime < 250*time.Millisecond, fmt.Sprintf("Expected execution time around 200ms, got: %v", executionTime)) } func TestIsValidArnForWif(t *testing.T) { testCases := []struct { arn string expected bool }{ {"arn:aws:iam::123456789012:user/JohnDoe", true}, {"arn:aws:sts::123456789012:assumed-role/RoleName/SessionName", true}, {"invalid-arn-format", false}, {"arn:aws:iam::account:root", false}, {"arn:aws:iam::123456789012:group/Developers", false}, {"arn:aws:iam::123456789012:role/S3Access", false}, {"arn:aws:iam::123456789012:policy/UsersManageOwnCredentials", false}, {"arn:aws:iam::123456789012:instance-profile/Webserver", false}, {"arn:aws:sts::123456789012:federated-user/John", false}, {"arn:aws:sts::account:self", false}, {"arn:aws:iam::123456789012:mfa/JaneMFA", false}, {"arn:aws:iam::123456789012:u2f/user/John/default", false}, {"arn:aws:iam::123456789012:server-certificate/ProdServerCert", false}, {"arn:aws:iam::123456789012:saml-provider/ADFSProvider", false}, {"arn:aws:iam::123456789012:oidc-provider/GoogleProvider", false}, {"arn:aws:iam::aws:contextProvider/IdentityCenter", false}, } for _, tc := range testCases { t.Run(tc.arn, func(t *testing.T) { result := isValidArnForWif(tc.arn) assertEqualE(t, result, tc.expected, fmt.Sprintf("ARN validation failed for: %s", tc.arn)) }) } } ================================================ FILE: prepared_statement_test.go ================================================ package gosnowflake import ( "testing" ) // TestPreparedStatement creates a basic prepared statement, inserting values // after the statement has been prepared func TestPreparedStatement(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table test_prep_statement(c1 INTEGER, c2 FLOAT, c3 BOOLEAN, c4 STRING)") defer dbt.mustExec(deleteTableSQL) intArray := []int{1, 2, 3} fltArray := []float64{0.1, 2.34, 5.678} boolArray := []bool{true, false, true} strArray := []string{"test1", "test2", "test3"} stmt := dbt.mustPrepare("insert into TEST_PREP_STATEMENT values(?, ?, ?, ?)") if _, err := stmt.Exec(mustArray(&intArray), mustArray(&fltArray), mustArray(&boolArray), mustArray(&strArray)); err != nil { t.Fatal(err) } rows := dbt.mustQuery(selectAllSQL) defer rows.Close() var v1 int var v2 float64 var v3 bool var v4 string if rows.Next() { err := rows.Scan(&v1, &v2, &v3, &v4) if err != nil { t.Fatal(err) } if v1 != 1 && v2 != 0.1 && v3 != true && v4 != "test1" { t.Fatalf("failed to fetch. expected: 1, 0.1, true, test1. got: %v, %v, %v, %v", v1, v2, v3, v4) } } else { t.Error("failed to query") } if rows.Next() { err := rows.Scan(&v1, &v2, &v3, &v4) if err != nil { t.Fatal(err) } if v1 != 2 && v2 != 2.34 && v3 != false && v4 != "test2" { t.Fatalf("failed to fetch. expected: 2, 2.34, false, test2. got: %v, %v, %v, %v", v1, v2, v3, v4) } } else { t.Error("failed to query") } if rows.Next() { err := rows.Scan(&v1, &v2, &v3, &v4) if err != nil { t.Fatal(err) } if v1 != 3 && v2 != 5.678 && v3 != true && v4 != "test3" { t.Fatalf("failed to fetch. expected: 3, test3. got: %v, %v, %v, %v", v1, v2, v3, v4) } } else { t.Error("failed to query") } }) } ================================================ FILE: priv_key_test.go ================================================ package gosnowflake // For compile concern, should any newly added variables or functions here must also be added with same // name or signature but with default or empty content in the priv_key_test.go(See addParseDSNTest) import ( "context" "crypto/rand" "crypto/rsa" "crypto/x509" "database/sql" "encoding/pem" "fmt" "net/http" "os" "testing" "time" ) // helper function to set up private key for testing func setupPrivateKey() { env := func(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } privKeyPath := env("SNOWFLAKE_TEST_PRIVATE_KEY", "") if privKeyPath == "" { customPrivateKey = false testPrivKey, _ = rsa.GenerateKey(rand.Reader, 2048) } else { // path to the DER file customPrivateKey = true data, _ := os.ReadFile(privKeyPath) block, _ := pem.Decode(data) if block == nil || block.Type != "PRIVATE KEY" { panic(fmt.Sprintf("%v is not a public key in PEM format.", privKeyPath)) } privKey, _ := x509.ParsePKCS8PrivateKey(block.Bytes) testPrivKey = privKey.(*rsa.PrivateKey) } } func TestJWTTokenTimeout(t *testing.T) { brt := newBlockingRoundTripper(http.DefaultTransport, 2000*time.Millisecond) localTestKey, err := rsa.GenerateKey(rand.Reader, 2048) assertNilF(t, err, "Failed to generate test private key") cfg := &Config{ User: "user", Host: "localhost", Port: wiremock.port, Account: "jwtAuthTokenTimeout", JWTClientTimeout: 10 * time.Millisecond, PrivateKey: localTestKey, Authenticator: AuthTypeJwt, MaxRetryCount: 1, Transporter: brt, } db := sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfg)) defer db.Close() ctx := context.Background() _, err = db.Conn(ctx) assertNotNilF(t, err) assertErrIsE(t, err, context.DeadlineExceeded) } ================================================ FILE: put_get_test.go ================================================ package gosnowflake import ( "bufio" "bytes" "compress/gzip" "context" "crypto/sha256" "database/sql" "fmt" "io" "math/rand" "os" "os/user" "path/filepath" "runtime" "strconv" "strings" "testing" "time" ) const createStageStmt = "CREATE OR REPLACE STAGE %v URL = '%v' CREDENTIALS = (%v)" func TestPutError(t *testing.T) { if isWindows { t.Skip("permission model is different") } tmpDir := t.TempDir() file1 := filepath.Join(tmpDir, "file1") remoteLocation := filepath.Join(tmpDir, "remote_loc") f, err := os.Create(file1) if err != nil { t.Error(err) } defer func() { assertNilF(t, f.Close()) }() _, err = f.WriteString("test1") assertNilF(t, err) assertNilF(t, os.Chmod(file1, 0000)) defer func() { assertNilF(t, os.Chmod(file1, 0644)) }() data := &execResponseData{ Command: string(uploadCommand), AutoCompress: false, SrcLocations: []string{file1}, SourceCompression: "none", StageInfo: execResponseStageInfo{ Location: remoteLocation, LocationType: string(local), Path: "remote_loc", }, } fta := &snowflakeFileTransferAgent{ ctx: context.Background(), data: data, sc: &snowflakeConn{ cfg: &Config{}, }, } if err = fta.execute(); err != nil { t.Fatal(err) } if _, err = fta.result(); err == nil { t.Fatalf("should raise permission error") } } func TestPercentage(t *testing.T) { testcases := []struct { seen int64 size float64 expected float64 }{ {0, 0, 1.0}, {20, 0, 1.0}, {40, 20, 1.0}, {14, 28, 0.5}, } for _, test := range testcases { t.Run(fmt.Sprintf("%v_%v_%v", test.seen, test.size, test.expected), func(t *testing.T) { spp := snowflakeProgressPercentage{} if spp.percent(test.seen, test.size) != test.expected { t.Fatalf("percentage conversion failed. %v/%v, expected: %v, got: %v", test.seen, test.size, test.expected, spp.percent(test.seen, test.size)) } }) } } type tcPutGetData struct { dir string awsAccessKeyID string awsSecretAccessKey string stage string warehouse string database string userBucket string } func cleanupPut(dbt *DBTest, td *tcPutGetData) { dbt.mustExec("drop database " + td.database) dbt.mustExec("drop warehouse " + td.warehouse) } func getAWSCredentials() (string, string, string, error) { keyID, ok := os.LookupEnv("AWS_ACCESS_KEY_ID") if !ok { return "", "", "", fmt.Errorf("key id invalid") } secretKey, ok := os.LookupEnv("AWS_SECRET_ACCESS_KEY") if !ok { return keyID, "", "", fmt.Errorf("secret key invalid") } bucket, present := os.LookupEnv("SF_AWS_USER_BUCKET") if !present { user, err := user.Current() if err != nil { return keyID, secretKey, "", err } bucket = fmt.Sprintf("sfc-eng-regression/%v/reg", user.Username) } return keyID, secretKey, bucket, nil } func createTestData(dbt *DBTest) (*tcPutGetData, error) { keyID, secretKey, bucket, err := getAWSCredentials() if err != nil { return nil, err } uniqueName := randomString(10) database := fmt.Sprintf("%v_db", uniqueName) wh := fmt.Sprintf("%v_wh", uniqueName) dir, err := os.Getwd() if err != nil { return nil, err } ret := tcPutGetData{ dir, keyID, secretKey, fmt.Sprintf("%v_stage", uniqueName), wh, database, bucket, } if _, err = dbt.exec("use role sysadmin"); err != nil { return nil, err } dbt.mustExec(fmt.Sprintf( "create or replace warehouse %v warehouse_size='small' "+ "warehouse_type='standard' auto_suspend=1800", wh)) dbt.mustExec("create or replace database " + database) dbt.mustExec("create or replace schema gotesting_schema") dbt.mustExec("create or replace file format VSV type = 'CSV' " + "field_delimiter='|' error_on_column_count_mismatch=false") return &ret, nil } func TestPutLocalFile(t *testing.T) { if runningOnGithubAction() && !runningOnAWS() { t.Skip("skipping non aws environment") } runDBTest(t, func(dbt *DBTest) { data, err := createTestData(dbt) if err != nil { t.Skip("snowflake admin account not accessible") } defer cleanupPut(dbt, data) dbt.mustExec("use warehouse " + data.warehouse) dbt.mustExec("alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false") dbt.mustExec("use schema " + data.database + ".gotesting_schema") execQuery := fmt.Sprintf( `create or replace table gotest_putget_t1 (c1 STRING, c2 STRING, c3 STRING, c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING, c9 STRING) stage_file_format = ( field_delimiter = '|' error_on_column_count_mismatch=false) stage_copy_options = (purge=false) stage_location = (url = 's3://%v/%v' credentials = (AWS_KEY_ID='%v' AWS_SECRET_KEY='%v'))`, data.userBucket, data.stage, data.awsAccessKeyID, data.awsSecretAccessKey) dbt.mustExec(execQuery) defer dbt.mustExec("drop table if exists gotest_putget_t1") execQuery = fmt.Sprintf(`put file://%v/test_data/orders_10*.csv @%%gotest_putget_t1`, data.dir) dbt.mustExec(execQuery) dbt.mustQueryAssertCount("ls @%gotest_putget_t1", 2) var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 sql.NullString rows := dbt.mustQuery("copy into gotest_putget_t1") defer func() { assertNilF(t, rows.Close()) }() for rows.Next() { assertNilF(t, rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)) if !s1.Valid || s1.String != "LOADED" { t.Fatal("not loaded") } } rows2 := dbt.mustQuery("select count(*) from gotest_putget_t1") defer func() { assertNilF(t, rows2.Close()) }() var i int if rows2.Next() { assertNilF(t, rows2.Scan(&i)) if i != 75 { t.Fatalf("expected 75 rows, got %v", i) } } rows3 := dbt.mustQuery(`select STATUS from information_schema .load_history where table_name='gotest_putget_t1'`) defer func() { assertNilF(t, rows3.Close()) }() if rows3.Next() { assertNilF(t, rows3.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)) if !s1.Valid || s1.String != "LOADED" { t.Fatal("not loaded") } } }) } func TestPutGetWithAutoCompressFalse(t *testing.T) { tmpDir := t.TempDir() testData := filepath.Join(tmpDir, "data.txt") f, err := os.Create(testData) if err != nil { t.Error(err) } originalContents := "test1,test2\ntest3,test4" _, err = f.WriteString(originalContents) assertNilF(t, err) assertNilF(t, f.Sync()) defer func() { assertNilF(t, f.Close()) }() runDBTest(t, func(dbt *DBTest) { stageDir := "test_put_uncompress_file_" + randomString(10) dbt.mustExec("rm @~/" + stageDir) // PUT test sqlText := fmt.Sprintf("put 'file://%v' @~/%v auto_compress=FALSE", testData, stageDir) sqlText = strings.ReplaceAll(sqlText, "\\", "\\\\") dbt.mustExec(sqlText) defer dbt.mustExec("rm @~/" + stageDir) rows := dbt.mustQuery("ls @~/" + stageDir) defer func() { assertNilF(t, rows.Close()) }() var file, s1, s2, s3 string if rows.Next() { err = rows.Scan(&file, &s1, &s2, &s3) assertNilE(t, err) } assertTrueF(t, strings.Contains(file, stageDir+"/data.txt"), fmt.Sprintf("should contain file. got: %v", file)) assertFalseF(t, strings.Contains(file, "data.txt.gz"), fmt.Sprintf("should not contain file. got: %v", file)) // GET test var streamBuf bytes.Buffer ctx := WithFileGetStream(context.Background(), &streamBuf) sql := fmt.Sprintf("get @~/%v/data.txt 'file://%v'", stageDir, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQueryContext(ctx, sqlText) defer func() { assertNilF(t, rows2.Close()) }() for rows2.Next() { err = rows2.Scan(&file, &s1, &s2, &s3) assertNilE(t, err) assertTrueE(t, strings.HasPrefix(file, "data.txt"), "a file was not downloaded by GET") v, err := strconv.Atoi(s1) assertNilE(t, err) assertEqualE(t, v, 23, "did not return the right file size") assertEqualE(t, s2, "DOWNLOADED", "did not return DOWNLOADED status") assertEqualE(t, s3, "") } var contents string r := bytes.NewReader(streamBuf.Bytes()) for { c := make([]byte, defaultChunkBufferSize) if n, err := r.Read(c); err != nil { if err == io.EOF { contents = contents + string(c[:n]) break } t.Error(err) } else { contents = contents + string(c[:n]) } } assertEqualE(t, contents, originalContents) }) } func TestPutOverwrite(t *testing.T) { tmpDir := t.TempDir() testData := filepath.Join(tmpDir, "data.txt") f, err := os.Create(testData) if err != nil { t.Error(err) } _, err = f.WriteString("test1,test2\ntest3,test4\n") assertNilF(t, err) assertNilF(t, f.Close()) stageName := "test_put_overwrite_stage_" + randomString(10) runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE OR REPLACE STAGE " + stageName) defer dbt.mustExec("DROP STAGE " + stageName) f, _ = os.Open(testData) rows := dbt.mustQueryContext( WithFilePutStream(context.Background(), f), fmt.Sprintf("put 'file://%v' @"+stageName+"/test_put_overwrite", strings.ReplaceAll(testData, "\\", "/"))) defer rows.Close() f.Close() var s0, s1, s2, s3, s4, s5, s6, s7 string if rows.Next() { if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { t.Fatal(err) } } if s6 != uploaded.String() { t.Fatalf("expected UPLOADED, got %v", s6) } rows = dbt.mustQuery("ls @" + stageName + "/test_put_overwrite") defer rows.Close() assertTrueF(t, rows.Next(), "expected new rows") if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { t.Fatal(err) } md5Column := s2 f, _ = os.Open(testData) rows = dbt.mustQueryContext( WithFilePutStream(context.Background(), f), fmt.Sprintf("put 'file://%v' @"+stageName+"/test_put_overwrite", strings.ReplaceAll(testData, "\\", "/"))) defer rows.Close() f.Close() assertTrueF(t, rows.Next(), "expected new rows") if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { t.Fatal(err) } if s6 != skipped.String() { t.Fatalf("expected SKIPPED, got %v", s6) } rows = dbt.mustQuery("ls @" + stageName + "/test_put_overwrite") defer rows.Close() assertTrueF(t, rows.Next(), "expected new rows") if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { t.Fatal(err) } if s2 != md5Column { t.Fatal("The MD5 column should have stayed the same") } f, _ = os.Open(testData) rows = dbt.mustQueryContext( WithFilePutStream(context.Background(), f), fmt.Sprintf("put 'file://%v' @"+stageName+"/test_put_overwrite overwrite=true", strings.ReplaceAll(testData, "\\", "/"))) defer rows.Close() f.Close() assertTrueF(t, rows.Next(), "expected new rows") if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { t.Fatal(err) } if s6 != uploaded.String() { t.Fatalf("expected UPLOADED, got %v", s6) } rows = dbt.mustQuery("ls @" + stageName + "/test_put_overwrite") defer rows.Close() assertTrueF(t, rows.Next(), "expected new rows") if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { t.Fatal(err) } assertEqualE(t, s0, stageName+"/test_put_overwrite/"+baseName(testData)+".gz") assertNotEqualE(t, s2, md5Column) }) } func TestPutGetFile(t *testing.T) { testPutGet(t, false) } func TestPutGetStream(t *testing.T) { testPutGet(t, true) } func testPutGet(t *testing.T, isStream bool) { tmpDir := t.TempDir() fname := filepath.Join(tmpDir, "test_put_get.txt.gz") originalContents := "123,test1\n456,test2\n" tableName := randomString(5) var b bytes.Buffer gzw := gzip.NewWriter(&b) _, err := gzw.Write([]byte(originalContents)) assertNilF(t, err) assertNilF(t, gzw.Close()) if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table " + tableName + " (a int, b string)") defer dbt.mustExec("drop table " + tableName) fileStream, err := os.Open(fname) assertNilF(t, err) defer func() { assertNilF(t, fileStream.Close()) }() var sqlText string var rows *RowsExtended sql := "put 'file://%v' @%%%v auto_compress=true parallel=30" ctx := context.Background() if isStream { sqlText = fmt.Sprintf( sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQueryContextT(WithFilePutStream(ctx, fileStream), t, sqlText) } else { sqlText = fmt.Sprintf( sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQueryT(t, sqlText) } defer func() { assertNilF(t, rows.Close()) }() var s0, s1, s2, s3, s4, s5, s6, s7 string assertTrueF(t, rows.Next(), "expected new rows") rows.mustScan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7) assertEqualF(t, s6, uploaded.String()) // check file is PUT dbt.mustQueryAssertCount("ls @%"+tableName, 1) dbt.mustExecT(t, "copy into "+tableName) dbt.mustExecT(t, "rm @%"+tableName) dbt.mustQueryAssertCount("ls @%"+tableName, 0) dbt.mustExecT(t, fmt.Sprintf(`copy into @%%%v from %v file_format=(type=csv compression='gzip')`, tableName, tableName)) var streamBuf bytes.Buffer if isStream { ctx = WithFileGetStream(ctx, &streamBuf) } sql = fmt.Sprintf("get @%%%v 'file://%v' parallel=10", tableName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQueryContextT(ctx, t, sqlText) defer func() { assertNilF(t, rows2.Close()) }() for rows2.Next() { rows2.mustScan(&s0, &s1, &s2, &s3) assertHasPrefixF(t, s0, "data_") v, err := strconv.Atoi(s1) assertNilF(t, err) assertEqualE(t, v, 36) assertEqualE(t, s2, "DOWNLOADED") assertEqualE(t, s3, "") } var contents string if isStream { gz, err := gzip.NewReader(&streamBuf) assertNilF(t, err) defer func() { assertNilF(t, gz.Close()) }() for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { if err == io.EOF { contents = contents + string(c[:n]) break } t.Fatal(err) } else { contents = contents + string(c[:n]) } } } else { files, err := filepath.Glob(filepath.Join(tmpDir, "data_*")) assertNilF(t, err) fileName := files[0] f, err := os.Open(fileName) assertNilF(t, err) defer func() { assertNilF(t, f.Close()) }() gz, err := gzip.NewReader(f) assertNilF(t, err) defer func() { assertNilF(t, gz.Close()) }() for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { if err == io.EOF { contents = contents + string(c[:n]) break } t.Fatal(err) } else { contents = contents + string(c[:n]) } } } assertEqualE(t, contents, originalContents, "output is different from the original contents") }) } func TestPutGetWithSnowflakeSSE(t *testing.T) { tmpDir := t.TempDir() cwd, err := os.Getwd() assertNilF(t, err) sourceFilePath := filepath.Join(cwd, "test_data", "orders_100.csv") originalContents, err := os.ReadFile(sourceFilePath) assertNilF(t, err) runDBTest(t, func(dbt *DBTest) { for _, useStream := range []bool{true, false} { t.Run(fmt.Sprintf("useStream=%v", useStream), func(t *testing.T) { stageName := "test_stage_sse_" + randomString(10) dbt.mustExec(fmt.Sprintf("CREATE STAGE %s ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')", stageName)) defer dbt.mustExec("DROP STAGE " + stageName) uploadCtx := context.Background() if useStream { fileStream, err := os.Open(sourceFilePath) assertNilF(t, err) defer fileStream.Close() uploadCtx = WithFilePutStream(uploadCtx, fileStream) } rows := dbt.mustQueryContextT(uploadCtx, t, fmt.Sprintf("PUT 'file://%s' @%s", strings.ReplaceAll(sourceFilePath, "\\", "\\\\"), stageName)) defer rows.Close() var s0, s1, s2, s3, s4, s5, s6, s7 string assertTrueF(t, rows.Next(), "expected new rows") rows.mustScan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7) assertEqualF(t, s6, uploaded.String()) downloadCtx := context.Background() var downloadBuf bytes.Buffer if useStream { downloadCtx = WithFileGetStream(downloadCtx, &downloadBuf) } rows2 := dbt.mustQueryContextT(downloadCtx, t, fmt.Sprintf("GET @%s 'file://%s'", stageName, strings.ReplaceAll(tmpDir, "\\", "\\\\"))) defer rows2.Close() assertTrueF(t, rows2.Next(), "expected new rows") rows2.mustScan(&s0, &s1, &s2, &s3) assertEqualF(t, s2, "DOWNLOADED") var compressedData []byte if useStream { compressedData, err = io.ReadAll(&downloadBuf) assertNilF(t, err) } else { downloadedFilePath := filepath.Join(tmpDir, "orders_100.csv.gz") compressedData, err = os.ReadFile(downloadedFilePath) assertNilF(t, err) } gzReader, err := gzip.NewReader(bytes.NewReader(compressedData)) assertNilF(t, err) defer gzReader.Close() decompressedData, err := io.ReadAll(gzReader) assertNilF(t, err) assertEqualE(t, string(decompressedData), string(originalContents), "downloaded file content does not match original") }) } }) } func TestPutGetWithSpacesInDirectoryName(t *testing.T) { tmpDir := t.TempDir() cwd, err := os.Getwd() assertNilF(t, err) sourceFilePath := filepath.Join(cwd, "test_data", "orders_100.csv") originalContents, err := os.ReadFile(sourceFilePath) assertNilF(t, err) runDBTest(t, func(dbt *DBTest) { for _, useStream := range []bool{true, false} { t.Run(fmt.Sprintf("useStream=%v", useStream), func(t *testing.T) { stageName := "test_stage_sse_" + randomString(10) dbt.mustExec(fmt.Sprintf("CREATE STAGE %s", stageName)) defer dbt.mustExec("DROP STAGE " + stageName) uploadCtx := context.Background() if useStream { fileStream, err := os.Open(sourceFilePath) assertNilF(t, err) defer fileStream.Close() uploadCtx = WithFilePutStream(uploadCtx, fileStream) } rows := dbt.mustQueryContextT(uploadCtx, t, fmt.Sprintf("PUT 'file://%s' '@%s/dir with spaces'", strings.ReplaceAll(sourceFilePath, "\\", "\\\\"), stageName)) defer rows.Close() var s0, s1, s2, s3, s4, s5, s6, s7 string assertTrueF(t, rows.Next(), "expected new rows") rows.mustScan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7) assertEqualF(t, s6, uploaded.String()) downloadCtx := context.Background() var downloadBuf bytes.Buffer if useStream { downloadCtx = WithFileGetStream(downloadCtx, &downloadBuf) } rows2 := dbt.mustQueryContextT(downloadCtx, t, fmt.Sprintf("GET '@%s/dir with spaces' 'file://%s'", stageName, strings.ReplaceAll(tmpDir, "\\", "\\\\"))) defer rows2.Close() assertTrueF(t, rows2.Next(), "expected new rows") rows2.mustScan(&s0, &s1, &s2, &s3) assertEqualF(t, s2, "DOWNLOADED") var compressedData []byte if useStream { compressedData, err = io.ReadAll(&downloadBuf) assertNilF(t, err) } else { downloadedFilePath := filepath.Join(tmpDir, "orders_100.csv.gz") compressedData, err = os.ReadFile(downloadedFilePath) assertNilF(t, err) } gzReader, err := gzip.NewReader(bytes.NewReader(compressedData)) assertNilF(t, err) defer gzReader.Close() decompressedData, err := io.ReadAll(gzReader) assertNilF(t, err) assertEqualE(t, string(decompressedData), string(originalContents), "downloaded file content does not match original") }) } }) } func TestPutWithNonWritableTemp(t *testing.T) { if isWindows { t.Skip("permission system is different") } tempDir := t.TempDir() assertNilF(t, os.Chmod(tempDir, 0000)) customDsn := dsn + "&tmpDirPath=" + strings.ReplaceAll(tempDir, "/", "%2F") runDBTestWithConfig(t, &testConfig{dsn: customDsn}, func(dbt *DBTest) { for _, isStream := range []bool{false, true} { t.Run(fmt.Sprintf("isStream=%v", isStream), func(t *testing.T) { stageName := "test_stage_" + randomString(10) cwd, err := os.Getwd() assertNilF(t, err) filePath := fmt.Sprintf("%v/test_data/orders_100.csv", cwd) dbt.mustExecT(t, "CREATE STAGE "+stageName) defer dbt.mustExecT(t, "DROP STAGE "+stageName) ctx := context.Background() if isStream { fd, err := os.Open(filePath) assertNilF(t, err) ctx = WithFilePutStream(ctx, fd) } _, err = dbt.conn.ExecContext(ctx, fmt.Sprintf("PUT 'file://%v' @%v", filePath, stageName)) if !isStream { assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "mkdir") assertStringContainsE(t, err.Error(), "permission denied") } else { assertNilF(t, os.Chmod(tempDir, 0755)) _ = dbt.mustExecContextT(ctx, t, fmt.Sprintf("GET @%v 'file://%v'", stageName, tempDir)) resultBytesCompressed, err := os.ReadFile(filepath.Join(tempDir, "orders_100.csv.gz")) assertNilF(t, err) resultBytesReader, err := gzip.NewReader(bytes.NewReader(resultBytesCompressed)) assertNilF(t, err) resultBytes, err := io.ReadAll(resultBytesReader) assertNilF(t, err) inputBytes, err := os.ReadFile(filePath) assertNilF(t, err) assertEqualE(t, string(resultBytes), string(inputBytes)) } }) } }) } func TestGetWithNonWritableTemp(t *testing.T) { if isWindows { t.Skip("permission system is different") } tempDir := t.TempDir() customDsn := dsn + "&tmpDirPath=" + strings.ReplaceAll(tempDir, "/", "%2F") runDBTestWithConfig(t, &testConfig{dsn: customDsn}, func(dbt *DBTest) { stageName := "test_stage_" + randomString(10) cwd, err := os.Getwd() assertNilF(t, err) filePath := fmt.Sprintf("%v/test_data/orders_100.csv", cwd) dbt.mustExecT(t, "CREATE STAGE "+stageName) defer dbt.mustExecT(t, "DROP STAGE "+stageName) dbt.mustExecT(t, fmt.Sprintf("PUT 'file://%v' @%v", filePath, stageName)) assertNilF(t, os.Chmod(tempDir, 0000)) for _, isStream := range []bool{false, true} { t.Run(fmt.Sprintf("isStream=%v", isStream), func(t *testing.T) { ctx := context.Background() var resultBuf bytes.Buffer if isStream { ctx = WithFileGetStream(ctx, &resultBuf) } _, err = dbt.conn.ExecContext(ctx, fmt.Sprintf("GET @%v 'file://%v'", stageName, tempDir)) if !isStream { assertNotNilF(t, err) assertStringContainsE(t, err.Error(), "mkdir") assertStringContainsE(t, err.Error(), "permission denied") } else { assertNilF(t, err) resultBytesReader, err := gzip.NewReader(&resultBuf) assertNilF(t, err) resultBytes, err := io.ReadAll(resultBytesReader) assertNilF(t, err) inputBytes, err := os.ReadFile(filePath) assertNilF(t, err) assertEqualE(t, string(resultBytes), string(inputBytes)) } }) } }) } func TestPutGetGcsDownscopedCredential(t *testing.T) { if runningOnGithubAction() && !runningOnGCP() { t.Skip("skipping non GCP environment") } tmpDir, err := os.MkdirTemp("", "put_get") if err != nil { t.Error(err) } defer func() { assertNilF(t, os.RemoveAll(tmpDir)) }() fname := filepath.Join(tmpDir, "test_put_get.txt.gz") originalContents := "123,test1\n456,test2\n" tableName := randomString(5) var b bytes.Buffer gzw := gzip.NewWriter(&b) _, err = gzw.Write([]byte(originalContents)) assertNilF(t, err) assertNilF(t, gzw.Close()) if err = os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } customDsn := dsn + "&GCS_USE_DOWNSCOPED_CREDENTIAL=true" runDBTestWithConfig(t, &testConfig{dsn: customDsn}, func(dbt *DBTest) { dbt.mustExec("create or replace table " + tableName + " (a int, b string)") fileStream, err := os.Open(fname) if err != nil { t.Error(err) } defer func() { defer dbt.mustExec("drop table " + tableName) if fileStream != nil { assertNilF(t, fileStream.Close()) } }() var sqlText string var rows *RowsExtended sql := "put 'file://%v' @%%%v auto_compress=true parallel=30" sqlText = fmt.Sprintf( sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQuery(sqlText) defer func() { assertNilF(t, rows.Close()) }() var s0, s1, s2, s3, s4, s5, s6, s7 string if rows.Next() { if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { t.Fatal(err) } } if s6 != uploaded.String() { t.Fatalf("expected %v, got: %v", uploaded, s6) } // check file is PUT dbt.mustQueryAssertCount("ls @%"+tableName, 1) dbt.mustExec("copy into " + tableName) dbt.mustExec("rm @%" + tableName) dbt.mustQueryAssertCount("ls @%"+tableName, 0) dbt.mustExec(fmt.Sprintf(`copy into @%%%v from %v file_format=(type=csv compression='gzip')`, tableName, tableName)) sql = fmt.Sprintf("get @%%%v 'file://%v' parallel=10", tableName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQuery(sqlText) defer func() { assertNilF(t, rows2.Close()) }() for rows2.Next() { if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil { t.Error(err) } if !strings.HasPrefix(s0, "data_") { t.Error("a file was not downloaded by GET") } if v, err := strconv.Atoi(s1); err != nil || v != 36 { t.Error("did not return the right file size") } if s2 != "DOWNLOADED" { t.Error("did not return DOWNLOADED status") } if s3 != "" { t.Errorf("returned %v", s3) } } files, err := filepath.Glob(filepath.Join(tmpDir, "data_*")) if err != nil { t.Fatal(err) } fileName := files[0] f, err := os.Open(fileName) if err != nil { t.Error(err) } defer f.Close() gz, err := gzip.NewReader(f) if err != nil { t.Error(err) } var contents string for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { if err == io.EOF { contents = contents + string(c[:n]) break } t.Error(err) } else { contents = contents + string(c[:n]) } } if contents != originalContents { t.Error("output is different from the original file") } }) } func TestPutGetLargeFile(t *testing.T) { testData := createTempLargeFile(t, 5*1024*1024) baseName := filepath.Base(testData) fnameStage := baseName + ".gz" runDBTest(t, func(dbt *DBTest) { stageDir := "test_put_largefile_" + randomString(10) dbt.mustExec("rm @~/" + stageDir) // PUT test putQuery := fmt.Sprintf("put 'file://%v' @~/%v", strings.ReplaceAll(testData, "\\", "/"), stageDir) sqlText := strings.ReplaceAll(putQuery, "\\", "\\\\") dbt.mustExec(sqlText) defer dbt.mustExec("rm @~/" + stageDir) rows := dbt.mustQuery("ls @~/" + stageDir) defer func() { assertNilF(t, rows.Close()) }() var file, s1, s2, s3 string if rows.Next() { err := rows.Scan(&file, &s1, &s2, &s3) assertNilF(t, err) } if !strings.Contains(file, fnameStage) { t.Fatalf("should contain file. got: %v", file) } // GET test with stream var streamBuf bytes.Buffer ctx := WithFileGetStream(context.Background(), &streamBuf) sql := fmt.Sprintf("get @~/%v/%v 'file://%v'", stageDir, fnameStage, t.TempDir()) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQueryContext(ctx, sqlText) defer func() { assertNilF(t, rows2.Close()) }() for rows2.Next() { err := rows2.Scan(&file, &s1, &s2, &s3) assertNilE(t, err) assertTrueE(t, strings.HasPrefix(file, fnameStage), "a file was not downloaded by GET") assertEqualE(t, s2, "DOWNLOADED", "did not return DOWNLOADED status") assertEqualE(t, s3, "") } // convert the compressed stream to string var contents string gz, err := gzip.NewReader(&streamBuf) assertNilE(t, err) defer func() { assertNilF(t, gz.Close()) }() for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { if err == io.EOF { contents = contents + string(c[:n]) break } t.Error(err) } else { contents = contents + string(c[:n]) } } // verify the downloaded stream matches the original file originalContents, err := os.ReadFile(testData) assertNilE(t, err) assertEqualF(t, contents, string(originalContents), "data did not match content") }) } func TestPutGetMaxLOBSize(t *testing.T) { t.Skip("fails on CI because of backend testing in progress") testCases := [2]int{smallSize, largeSize} runDBTest(t, func(dbt *DBTest) { dbt.mustExec("alter session set ALLOW_LARGE_LOBS_IN_EXTERNAL_SCAN = false") defer dbt.mustExec("alter session unset ALLOW_LARGE_LOBS_IN_EXTERNAL_SCAN") for _, tc := range testCases { t.Run(strconv.Itoa(tc), func(t *testing.T) { // create the data file tmpDir := t.TempDir() fname := filepath.Join(tmpDir, "test_put_get.txt.gz") tableName := randomString(5) originalContents := fmt.Sprintf("%v,%s,%v\n", randomString(tc), randomString(tc), rand.Intn(100000)) var b bytes.Buffer gzw := gzip.NewWriter(&b) _, err := gzw.Write([]byte(originalContents)) assertNilF(t, err) assertNilF(t, gzw.Close()) err = os.WriteFile(fname, b.Bytes(), readWriteFileMode) assertNilF(t, err, "could not write to gzip file") dbt.mustExec(fmt.Sprintf("create or replace table %s (c1 varchar, c2 varchar(%v), c3 int)", tableName, tc)) defer dbt.mustExec("drop table " + tableName) fileStream, err := os.Open(fname) assertNilF(t, err) defer func() { assertNilF(t, fileStream.Close()) }() // test PUT command var sqlText string var rows *RowsExtended sql := "put 'file://%v' @%%%v auto_compress=true parallel=30" sqlText = fmt.Sprintf( sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQuery(sqlText) defer func() { assertNilF(t, rows.Close()) }() var s0, s1, s2, s3, s4, s5, s6, s7 string assertTrueF(t, rows.Next(), "expected new rows") err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7) assertNilF(t, err) assertEqualF(t, s6, uploaded.String(), fmt.Sprintf("expected %v, got: %v", uploaded, s6)) assertNilF(t, err) // check file is PUT dbt.mustQueryAssertCount("ls @%"+tableName, 1) dbt.mustExec("copy into " + tableName) dbt.mustExec("rm @%" + tableName) dbt.mustQueryAssertCount("ls @%"+tableName, 0) dbt.mustExec(fmt.Sprintf(`copy into @%%%v from %v file_format=(type=csv compression='gzip')`, tableName, tableName)) // test GET command sql = fmt.Sprintf("get @%%%v 'file://%v' parallel=10", tableName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQuery(sqlText) defer func() { assertNilF(t, rows2.Close()) }() for rows2.Next() { err = rows2.Scan(&s0, &s1, &s2, &s3) assertNilE(t, err) assertTrueF(t, strings.HasPrefix(s0, "data_"), "a file was not downloaded by GET") assertEqualE(t, s2, "DOWNLOADED", "did not return DOWNLOADED status") assertEqualE(t, s3, "", fmt.Sprintf("returned %v", s3)) } // verify the content in the file files, err := filepath.Glob(filepath.Join(tmpDir, "data_*")) assertNilF(t, err) fileName := files[0] f, err := os.Open(fileName) assertNilE(t, err) defer func() { assertNilF(t, f.Close()) }() gz, err := gzip.NewReader(f) assertNilE(t, err) defer func() { assertNilF(t, gz.Close()) }() var contents string for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { if err == io.EOF { contents = contents + string(c[:n]) break } t.Error(err) } else { contents = contents + string(c[:n]) } } assertEqualE(t, contents, originalContents, "output is different from the original file") }) } }) } func TestPutCancel(t *testing.T) { testData := createTempLargeFile(t, 128*1024*1024) stageDir := "test_put_cancel_" + randomString(10) runDBTest(t, func(dbt *DBTest) { c := make(chan error, 1) ctx, cancel := context.WithCancel(context.Background()) go func() { // Use a larger, non-compressed single-part upload so cancellation // wins reliably even on faster runners. _, err := dbt.conn.ExecContext( ctx, fmt.Sprintf("put 'file://%v' @~/%v overwrite=true auto_compress=false parallel=1", strings.ReplaceAll(testData, "\\", "/"), stageDir)) c <- err close(c) }() time.Sleep(200 * time.Millisecond) cancel() ret := <-c assertNotNilF(t, ret) assertErrIsE(t, ret, context.Canceled) }) } func TestPutGetLargeFileNonStream(t *testing.T) { testPutGetLargeFile(t, false, true) } func TestPutGetLargeFileNonStreamAutoCompressFalse(t *testing.T) { testPutGetLargeFile(t, false, false) } func TestPutGetLargeFileStream(t *testing.T) { testPutGetLargeFile(t, true, true) } func TestPutGetLargeFileStreamAutoCompressFalse(t *testing.T) { testPutGetLargeFile(t, true, false) } func testPutGetLargeFile(t *testing.T, isStream bool, autoCompress bool) { var err error fname := createTempLargeFile(t, 5*1024*1024) baseName := filepath.Base(fname) fnameGet := baseName + ".gz" if !autoCompress { fnameGet = baseName } runDBTest(t, func(dbt *DBTest) { stageDir := "test_put_largefile_" + randomString(10) dbt.mustExec("rm @~/" + stageDir) ctx := context.Background() if isStream { f, err := os.Open(fname) assertNilF(t, err) defer func() { assertNilF(t, f.Close()) }() ctx = WithFilePutStream(ctx, f) } // PUT test escapedFname := strings.ReplaceAll(fname, "\\", "\\\\") putQuery := fmt.Sprintf("put 'file://%v' @~/%v auto_compress=true overwrite=true", escapedFname, stageDir) if !autoCompress { putQuery = fmt.Sprintf("put 'file://%v' @~/%v auto_compress=false overwrite=true", escapedFname, stageDir) } // Record initial memory stats before PUT var startMemStats, endMemStats runtime.MemStats runtime.ReadMemStats(&startMemStats) // Execute PUT command _ = dbt.mustExecContext(ctx, putQuery) // Record memory stats after PUT runtime.ReadMemStats(&endMemStats) fmt.Printf("Memory used for PUT command: %d MB\n", (endMemStats.Alloc-startMemStats.Alloc)/1024/1024) defer dbt.mustExec("rm @~/" + stageDir) rows := dbt.mustQuery("ls @~/" + stageDir) defer func() { assertNilF(t, rows.Close()) }() var file, s1, s2, s3 sql.NullString if rows.Next() { err = rows.Scan(&file, &s1, &s2, &s3) assertNilF(t, err) } if !strings.Contains(file.String, fnameGet) { t.Fatalf("should contain file. got: %v", file.String) } // GET test var streamBuf bytes.Buffer ctx = context.Background() if isStream { ctx = WithFileGetStream(ctx, &streamBuf) } tmpDir := t.TempDir() tmpDirURL := strings.ReplaceAll(tmpDir, "\\", "/") sql := fmt.Sprintf("get @~/%v/%v 'file://%v'", stageDir, fnameGet, tmpDirURL) sqlText := strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQueryContext(ctx, sqlText) defer func() { assertNilF(t, rows2.Close()) }() for rows2.Next() { err = rows2.Scan(&file, &s1, &s2, &s3) assertNilE(t, err) assertTrueE(t, strings.HasPrefix(file.String, fnameGet), "a file was not downloaded by GET") assertEqualE(t, s2.String, "DOWNLOADED", "did not return DOWNLOADED status") assertEqualE(t, s3.String, "") } var r io.Reader if autoCompress { // convert the compressed contents to string if isStream { r, err = gzip.NewReader(&streamBuf) assertNilE(t, err) } else { downloadedFile := filepath.Join(tmpDir, fnameGet) f, err := os.Open(downloadedFile) assertNilE(t, err) defer func() { assertNilF(t, f.Close()) }() r, err = gzip.NewReader(f) assertNilE(t, err) } } else { if isStream { r = bytes.NewReader(streamBuf.Bytes()) } else { downloadedFile := filepath.Join(tmpDir, fnameGet) f, err := os.Open(downloadedFile) assertNilE(t, err) defer func() { assertNilF(t, f.Close()) }() r = bufio.NewReader(f) } } hash := sha256.New() _, err = io.Copy(hash, r) assertNilE(t, err) downloadedChecksum := fmt.Sprintf("%x", hash.Sum(nil)) originalFile, err := os.Open(fname) assertNilF(t, err) defer func() { assertNilF(t, originalFile.Close()) }() originalHash := sha256.New() _, err = io.Copy(originalHash, originalFile) assertNilE(t, err) originalChecksum := fmt.Sprintf("%x", originalHash.Sum(nil)) assertEqualF(t, downloadedChecksum, originalChecksum, "file integrity check failed - checksums don't match") }) } // createTempLargeFile creates a sparse file of sizeBytes in t.TempDir(). // The file is grown with Truncate, so no I/O is needed; sparse-file-capable // filesystems allocate no real disk space. The extended region reads back as // zero bytes, which is sufficient for PUT/GET round-trip tests. func createTempLargeFile(t *testing.T, sizeBytes int64) string { t.Helper() tmpFile, err := os.CreateTemp(t.TempDir(), "large_test_*.bin") assertNilF(t, err, "creating temp large file") assertNilF(t, tmpFile.Truncate(sizeBytes), fmt.Sprintf("truncating temp file to %d bytes", sizeBytes)) assertNilF(t, tmpFile.Close(), "closing temp large file") return tmpFile.Name() } ================================================ FILE: put_get_user_stage_test.go ================================================ package gosnowflake import ( "context" "fmt" "os" "path/filepath" "strconv" "strings" "testing" ) func TestPutGetFileSmallDataViaUserStage(t *testing.T) { if os.Getenv("AWS_ACCESS_KEY_ID") == "" { t.Skip("this test requires to change the internal parameter") } putGetUserStage(t, 5, 1, false) } func TestPutGetStreamSmallDataViaUserStage(t *testing.T) { if os.Getenv("AWS_ACCESS_KEY_ID") == "" { t.Skip("this test requires to change the internal parameter") } putGetUserStage(t, 1, 1, true) } func putGetUserStage(t *testing.T, numberOfFiles int, numberOfLines int, isStream bool) { if os.Getenv("AWS_SECRET_ACCESS_KEY") == "" { t.Fatal("no aws secret access key found") } tmpDir, err := generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, t.TempDir()) if err != nil { t.Error(err) } var files string if isStream { list, err := os.ReadDir(tmpDir) if err != nil { t.Error(err) } file := list[0].Name() files = filepath.Join(tmpDir, file) } else { files = filepath.Join(tmpDir, "file*") } runDBTest(t, func(dbt *DBTest) { stageName := fmt.Sprintf("%v_stage_%v_%v", dbname, numberOfFiles, numberOfLines) sqlText := `create or replace table %v (aa int, dt date, ts timestamp, tsltz timestamp_ltz, tsntz timestamp_ntz, tstz timestamp_tz, pct float, ratio number(6,2))` dbt.mustExec(fmt.Sprintf(sqlText, dbname)) userBucket := os.Getenv("SF_AWS_USER_BUCKET") if userBucket == "" { userBucket = fmt.Sprintf("sfc-eng-regression/%v/reg", username) } sqlText = `create or replace stage %v url='s3://%v}/%v-%v-%v' credentials = (AWS_KEY_ID='%v' AWS_SECRET_KEY='%v')` dbt.mustExec(fmt.Sprintf(sqlText, stageName, userBucket, stageName, numberOfFiles, numberOfLines, os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"))) dbt.mustExec("alter session set disable_put_and_get_on_external_stage = false") dbt.mustExec("rm @" + stageName) var fs *os.File if isStream { fs, _ = os.Open(files) dbt.mustExecContext(WithFilePutStream(context.Background(), fs), fmt.Sprintf("put 'file://%v' @%v", strings.ReplaceAll( files, "\\", "\\\\"), stageName)) } else { dbt.mustExec(fmt.Sprintf("put 'file://%v' @%v ", strings.ReplaceAll(files, "\\", "\\\\"), stageName)) } defer func() { if isStream { fs.Close() } dbt.mustExec("rm @" + stageName) dbt.mustExec("drop stage if exists " + stageName) dbt.mustExec("drop table if exists " + dbname) }() dbt.mustExec(fmt.Sprintf("copy into %v from @%v", dbname, stageName)) rows := dbt.mustQuery("select count(*) from " + dbname) defer func() { assertNilF(t, rows.Close()) }() var cnt string if rows.Next() { assertNilF(t, rows.Scan(&cnt)) } count, err := strconv.Atoi(cnt) if err != nil { t.Error(err) } if count != numberOfFiles*numberOfLines { t.Errorf("count did not match expected number. count: %v, expected: %v", count, numberOfFiles*numberOfLines) } }) } func TestPutLoadFromUserStage(t *testing.T) { runDBTest(t, func(dbt *DBTest) { data, err := createTestData(dbt) if err != nil { t.Skip("snowflake admin account not accessible") } defer cleanupPut(dbt, data) dbt.mustExec("alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false") dbt.mustExec("use warehouse " + data.warehouse) dbt.mustExec("use schema " + data.database + ".gotesting_schema") execQuery := fmt.Sprintf( `create or replace stage %v url = 's3://%v/%v' credentials = ( AWS_KEY_ID='%v' AWS_SECRET_KEY='%v')`, data.stage, data.userBucket, data.stage, data.awsAccessKeyID, data.awsSecretAccessKey) dbt.mustExec(execQuery) execQuery = `create or replace table gotest_putget_t2 (c1 STRING, c2 STRING, c3 STRING,c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING, c9 STRING)` dbt.mustExec(execQuery) defer dbt.mustExec("drop table if exists gotest_putget_t2") defer dbt.mustExec("drop stage if exists " + data.stage) execQuery = fmt.Sprintf("put file://%v/test_data/orders_10*.csv @%v", data.dir, data.stage) dbt.mustExec(execQuery) dbt.mustQueryAssertCount("ls @%gotest_putget_t2", 0) rows := dbt.mustQuery(fmt.Sprintf(`copy into gotest_putget_t2 from @%v file_format = (field_delimiter = '|' error_on_column_count_mismatch =false) purge=true`, data.stage)) defer func() { assertNilF(t, rows.Close()) }() var s0, s1, s2, s3, s4, s5 string var s6, s7, s8, s9 any orders100 := fmt.Sprintf("s3://%v/%v/orders_100.csv.gz", data.userBucket, data.stage) orders101 := fmt.Sprintf("s3://%v/%v/orders_101.csv.gz", data.userBucket, data.stage) for rows.Next() { assertNilF(t, rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)) if s0 != orders100 && s0 != orders101 { t.Fatalf("copy did not load orders files. got: %v", s0) } } dbt.mustQueryAssertCount(fmt.Sprintf("ls @%v", data.stage), 0) }) } ================================================ FILE: put_get_with_aws_test.go ================================================ package gosnowflake import ( "bytes" "compress/gzip" "context" "database/sql" "encoding/json" "fmt" "io" "net/url" "os" "path/filepath" "strconv" "strings" "testing" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" ) func TestLoadS3(t *testing.T) { if runningOnGithubAction() && !runningOnAWS() { t.Skip("skipping non aws environment") } runDBTest(t, func(dbt *DBTest) { data, err := createTestData(dbt) if err != nil { t.Skip("snowflake admin account not accessible") } defer cleanupPut(dbt, data) dbt.mustExec("use warehouse " + data.warehouse) dbt.mustExec("use schema " + data.database + ".gotesting_schema") execQuery := `create or replace table tweets(created_at timestamp, id number, id_str string, text string, source string, in_reply_to_status_id number, in_reply_to_status_id_str string, in_reply_to_user_id number, in_reply_to_user_id_str string, in_reply_to_screen_name string, user__id number, user__id_str string, user__name string, user__screen_name string, user__location string, user__description string, user__url string, user__entities__description__urls string, user__protected string, user__followers_count number, user__friends_count number, user__listed_count number, user__created_at timestamp, user__favourites_count number, user__utc_offset number, user__time_zone string, user__geo_enabled string, user__verified string, user__statuses_count number, user__lang string, user__contributors_enabled string, user__is_translator string, user__profile_background_color string, user__profile_background_image_url string, user__profile_background_image_url_https string, user__profile_background_tile string, user__profile_image_url string, user__profile_image_url_https string, user__profile_link_color string, user__profile_sidebar_border_color string, user__profile_sidebar_fill_color string, user__profile_text_color string, user__profile_use_background_image string, user__default_profile string, user__default_profile_image string, user__following string, user__follow_request_sent string, user__notifications string, geo string, coordinates string, place string, contributors string, retweet_count number, favorite_count number, entities__hashtags string, entities__symbols string, entities__urls string, entities__user_mentions string, favorited string, retweeted string, lang string)` dbt.mustExec(execQuery) defer dbt.mustExec("drop table if exists tweets") dbt.mustQueryAssertCount("ls @%tweets", 0) rows := dbt.mustQuery(fmt.Sprintf(`copy into tweets from s3://sfc-eng-data/twitter/O1k/tweets/ credentials=(AWS_KEY_ID='%v' AWS_SECRET_KEY='%v') file_format=(skip_header=1 null_if=('') field_optionally_enclosed_by='\"')`, data.awsAccessKeyID, data.awsSecretAccessKey)) defer func() { assertNilF(t, rows.Close()) }() var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 sql.NullString cnt := 0 for rows.Next() { assertNilF(t, rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)) cnt++ } if cnt != 1 { t.Fatal("copy into tweets did not set row count to 1") } if !s0.Valid || s0.String != "s3://sfc-eng-data/twitter/O1k/tweets/1.csv.gz" { t.Fatalf("got %v as file", s0) } }) } func TestPutWithInvalidToken(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { if !runningOnAWS() { t.Skip("skipping non aws environment") } tmpDir := t.TempDir() fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") originalContents := "123,test1\n456,test2\n" var b bytes.Buffer gzw := gzip.NewWriter(&b) _, err := gzw.Write([]byte(originalContents)) assertNilF(t, err) assertNilF(t, gzw.Close()) if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } tableName := randomString(5) sct.mustExec("create or replace table "+tableName+" (a int, b string)", nil) defer sct.mustExec("drop table "+tableName, nil) jsonBody, err := json.Marshal(execRequest{ SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), }) if err != nil { t.Error(err) } headers := getHeaders() headers[httpHeaderAccept] = headerContentTypeApplicationJSON data, err := sct.sc.rest.FuncPostQuery( sct.sc.ctx, sct.sc.rest, &url.Values{}, headers, jsonBody, sct.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sct.sc.ctx), sct.sc.cfg) if err != nil { t.Fatal(err) } s3Util := new(snowflakeS3Client) s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } client := s3Cli.(*s3.Client) s3Loc, err := s3Util.extractBucketNameAndPath(data.Data.StageInfo.Location) if err != nil { t.Error(err) } s3Path := s3Loc.s3Path + baseName(fname) + ".gz" f, err := os.Open(fname) if err != nil { t.Error(err) } defer func() { assertNilF(t, f.Close()) }() uploader := manager.NewUploader(client) if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ Bucket: &s3Loc.bucketName, Key: &s3Path, Body: f, }); err != nil { t.Fatal(err) } parentPath := filepath.Dir(filepath.Dir(s3Path)) + "/" if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ Bucket: &s3Loc.bucketName, Key: &parentPath, Body: f, }); err == nil { t.Fatal("should have failed attempting to put file in parent path") } info := execResponseStageInfo{ Creds: execResponseCredentials{ AwsID: data.Data.StageInfo.Creds.AwsID, AwsSecretKey: data.Data.StageInfo.Creds.AwsSecretKey, }, } s3Cli, err = s3Util.createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } client = s3Cli.(*s3.Client) uploader = manager.NewUploader(client) if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ Bucket: &s3Loc.bucketName, Key: &s3Path, Body: f, }); err == nil { t.Fatal("should have failed attempting to put with missing aws token") } }) } func TestPretendToPutButList(t *testing.T) { if runningOnGithubAction() && !runningOnAWS() { t.Skip("skipping non aws environment") } tmpDir := t.TempDir() fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") originalContents := "123,test1\n456,test2\n" var b bytes.Buffer gzw := gzip.NewWriter(&b) _, err := gzw.Write([]byte(originalContents)) assertNilF(t, err) assertNilF(t, gzw.Close()) if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } runSnowflakeConnTest(t, func(sct *SCTest) { tableName := randomString(5) sct.mustExec("create or replace table "+tableName+ " (a int, b string)", nil) defer sct.mustExec("drop table "+tableName, nil) jsonBody, err := json.Marshal(execRequest{ SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), }) if err != nil { t.Error(err) } headers := getHeaders() headers[httpHeaderAccept] = headerContentTypeApplicationJSON data, err := sct.sc.rest.FuncPostQuery( sct.sc.ctx, sct.sc.rest, &url.Values{}, headers, jsonBody, sct.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sct.sc.ctx), sct.sc.cfg) if err != nil { t.Fatal(err) } s3Util := new(snowflakeS3Client) s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } client := s3Cli.(*s3.Client) if _, err = client.ListBuckets(context.Background(), &s3.ListBucketsInput{}); err == nil { t.Fatal("list buckets should fail") } }) } func TestPutGetAWSStage(t *testing.T) { if runningOnGithubAction() || !runningOnAWS() { t.Skip("skipping non aws environment") } tmpDir := t.TempDir() name := "test_put_get.txt.gz" fname := filepath.Join(tmpDir, name) originalContents := "123,test1\n456,test2\n" stageName := "test_put_get_stage_" + randomString(5) var b bytes.Buffer gzw := gzip.NewWriter(&b) _, err := gzw.Write([]byte(originalContents)) assertNilF(t, err) assertNilF(t, gzw.Close()) if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } runDBTest(t, func(dbt *DBTest) { var createStageQuery string keyID, secretKey, _, err := getAWSCredentials() if err != nil { t.Skip("snowflake admin account not accessible") } createStageQuery = fmt.Sprintf(createStageStmt, stageName, "s3://"+stageName, fmt.Sprintf("AWS_KEY_ID='%v' AWS_SECRET_KEY='%v'", keyID, secretKey)) dbt.mustExec(createStageQuery) defer dbt.mustExec("DROP STAGE IF EXISTS " + stageName) sql := "put 'file://%v' @~/%v auto_compress=false" sqlText := fmt.Sprintf(sql, strings.ReplaceAll(fname, "\\", "\\\\"), stageName) rows := dbt.mustQuery(sqlText) defer func() { assertNilF(t, rows.Close()) }() var s0, s1, s2, s3, s4, s5, s6, s7 string if rows.Next() { if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { t.Fatal(err) } } if s6 != uploaded.String() { t.Fatalf("expected %v, got: %v", uploaded, s6) } sql = fmt.Sprintf("get @~/%v 'file://%v'", stageName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows = dbt.mustQuery(sqlText) defer func() { assertNilF(t, rows.Close()) }() for rows.Next() { if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { t.Error(err) } if strings.Compare(s0, name) != 0 { t.Error("a file was not downloaded by GET") } if v, err := strconv.Atoi(s1); err != nil || v != 41 { t.Error("did not return the right file size") } if s2 != "DOWNLOADED" { t.Error("did not return DOWNLOADED status") } if s3 != "" { t.Errorf("returned %v", s3) } } files, err := filepath.Glob(filepath.Join(tmpDir, "*")) if err != nil { t.Fatal(err) } fileName := files[0] f, err := os.Open(fileName) if err != nil { t.Error(err) } defer func() { assertNilF(t, f.Close()) }() gz, err := gzip.NewReader(f) if err != nil { t.Error(err) } var contents string for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { if err == io.EOF { contents = contents + string(c[:n]) break } t.Error(err) } else { contents = contents + string(c[:n]) } } if contents != originalContents { t.Error("output is different from the original file") } }) } ================================================ FILE: query.go ================================================ package gosnowflake import ( "encoding/json" "github.com/snowflakedb/gosnowflake/v2/internal/query" "time" ) type resultFormat string const ( jsonFormat resultFormat = "json" arrowFormat resultFormat = "arrow" ) type execBindParameter struct { Type string `json:"type"` Value any `json:"value"` Format string `json:"fmt,omitempty"` Schema *bindingSchema `json:"schema,omitempty"` } type execRequest struct { SQLText string `json:"sqlText"` AsyncExec bool `json:"asyncExec"` SequenceID uint64 `json:"sequenceId"` IsInternal bool `json:"isInternal"` DescribeOnly bool `json:"describeOnly,omitempty"` Parameters map[string]any `json:"parameters,omitempty"` Bindings map[string]execBindParameter `json:"bindings,omitempty"` BindStage string `json:"bindStage,omitempty"` QueryContext requestQueryContext `json:"queryContextDTO"` } type requestQueryContext struct { Entries []requestQueryContextEntry `json:"entries,omitempty"` } type requestQueryContextEntry struct { Context contextData `json:"context"` ID int `json:"id"` Priority int `json:"priority"` Timestamp int64 `json:"timestamp,omitempty"` } type contextData struct { Base64Data string `json:"base64Data,omitempty"` } type execResponseCredentials struct { AwsKeyID string `json:"AWS_KEY_ID,omitempty"` AwsSecretKey string `json:"AWS_SECRET_KEY,omitempty"` AwsToken string `json:"AWS_TOKEN,omitempty"` AwsID string `json:"AWS_ID,omitempty"` AwsKey string `json:"AWS_KEY,omitempty"` AzureSasToken string `json:"AZURE_SAS_TOKEN,omitempty"` GcsAccessToken string `json:"GCS_ACCESS_TOKEN,omitempty"` } type execResponseStageInfo struct { LocationType string `json:"locationType,omitempty"` Location string `json:"location,omitempty"` Path string `json:"path,omitempty"` Region string `json:"region,omitempty"` StorageAccount string `json:"storageAccount,omitempty"` IsClientSideEncrypted bool `json:"isClientSideEncrypted,omitempty"` Creds execResponseCredentials `json:"creds"` PresignedURL string `json:"presignedUrl,omitempty"` EndPoint string `json:"endPoint,omitempty"` UseS3RegionalURL bool `json:"useS3RegionalUrl,omitempty"` UseRegionalURL bool `json:"useRegionalUrl,omitempty"` UseVirtualURL bool `json:"useVirtualUrl,omitempty"` } // make all data field optional type execResponseData struct { // succeed query response data Parameters []nameValueParameter `json:"parameters,omitempty"` RowType []query.ExecResponseRowType `json:"rowtype,omitempty"` RowSet [][]*string `json:"rowset,omitempty"` RowSetBase64 string `json:"rowsetbase64,omitempty"` Total int64 `json:"total,omitempty"` // java:long Returned int64 `json:"returned,omitempty"` // java:long QueryID string `json:"queryId,omitempty"` SQLState string `json:"sqlState,omitempty"` DatabaseProvider string `json:"databaseProvider,omitempty"` FinalDatabaseName string `json:"finalDatabaseName,omitempty"` FinalSchemaName string `json:"finalSchemaName,omitempty"` FinalWarehouseName string `json:"finalWarehouseName,omitempty"` FinalRoleName string `json:"finalRoleName,omitempty"` NumberOfBinds int `json:"numberOfBinds,omitempty"` // java:int StatementTypeID int64 `json:"statementTypeId,omitempty"` // java:long Version int64 `json:"version,omitempty"` // java:long Chunks []query.ExecResponseChunk `json:"chunks,omitempty"` Qrmk string `json:"qrmk,omitempty"` ChunkHeaders map[string]string `json:"chunkHeaders,omitempty"` // ping pong response data GetResultURL string `json:"getResultUrl,omitempty"` ProgressDesc string `json:"progressDesc,omitempty"` QueryAbortTimeout time.Duration `json:"queryAbortsAfterSecs,omitempty"` ResultIDs string `json:"resultIds,omitempty"` ResultTypes string `json:"resultTypes,omitempty"` QueryResultFormat string `json:"queryResultFormat,omitempty"` // async response placeholders AsyncResult *snowflakeResult `json:"asyncResult,omitempty"` AsyncRows *snowflakeRows `json:"asyncRows,omitempty"` // file transfer response data UploadInfo execResponseStageInfo `json:"uploadInfo"` LocalLocation string `json:"localLocation,omitempty"` SrcLocations []string `json:"src_locations,omitempty"` Parallel int64 `json:"parallel,omitempty"` Threshold int64 `json:"threshold,omitempty"` AutoCompress bool `json:"autoCompress,omitempty"` Overwrite bool `json:"overwrite,omitempty"` SourceCompression string `json:"sourceCompression,omitempty"` ShowEncryptionParameter bool `json:"clientShowEncryptionParameter,omitempty"` EncryptionMaterial encryptionWrapper `json:"encryptionMaterial"` PresignedURLs []string `json:"presignedUrls,omitempty"` StageInfo execResponseStageInfo `json:"stageInfo"` Command string `json:"command,omitempty"` Kind string `json:"kind,omitempty"` Operation string `json:"operation,omitempty"` // HTAP QueryContext json.RawMessage `json:"queryContext,omitempty"` } type execResponse struct { Data execResponseData `json:"Data"` Message string `json:"message"` Code string `json:"code"` Success bool `json:"success"` } ================================================ FILE: restful.go ================================================ package gosnowflake import ( "context" "encoding/json" "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "io" "net/http" "net/url" "strconv" "time" ) // HTTP headers const ( headerSnowflakeToken = "Snowflake Token=\"%v\"" headerAuthorizationKey = "Authorization" headerContentTypeApplicationJSON = "application/json" headerAcceptTypeApplicationSnowflake = "application/snowflake" ) // Snowflake Server Endpoints const ( loginRequestPath = "/session/v1/login-request" queryRequestPath = "/queries/v1/query-request" tokenRequestPath = "/session/token-request" abortRequestPath = "/queries/v1/abort-request" authenticatorRequestPath = "/session/authenticator-request" monitoringQueriesPath = "/monitoring/queries" sessionRequestPath = "/session" heartBeatPath = "/session/heartbeat" consoleLoginRequestPath = "/console/login" ) type ( funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error) funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider, *Config) (*http.Response, error) funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, int) (*http.Response, error) bodyCreatorType func() ([]byte, error) ) var emptyBodyCreator = func() ([]byte, error) { return []byte{}, nil } type snowflakeRestful struct { Host string Port int Protocol string LoginTimeout time.Duration // Login timeout RequestTimeout time.Duration // request timeout MaxRetryCount int Client *http.Client JWTClient *http.Client TokenAccessor TokenAccessor HeartBeat *heartbeat Connection *snowflakeConn FuncPostQuery func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, UUID, *Config) (*execResponse, error) FuncPostQueryHelper func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, UUID, *Config) (*execResponse, error) FuncPost funcPostType FuncGet funcGetType FuncAuthPost funcAuthPostType FuncRenewSession func(context.Context, *snowflakeRestful, time.Duration) error FuncCloseSession func(context.Context, *snowflakeRestful, time.Duration) error FuncCancelQuery func(context.Context, *snowflakeRestful, UUID, time.Duration) error FuncPostAuth func(context.Context, *snowflakeRestful, *http.Client, *url.Values, map[string]string, bodyCreatorType, time.Duration) (*authResponse, error) FuncPostAuthSAML func(context.Context, *snowflakeRestful, map[string]string, []byte, time.Duration) (*authResponse, error) FuncPostAuthOKTA func(context.Context, *snowflakeRestful, map[string]string, []byte, string, time.Duration) (*authOKTAResponse, error) FuncGetSSO func(context.Context, *snowflakeRestful, *url.Values, map[string]string, string, time.Duration) ([]byte, error) } func (sr *snowflakeRestful) getURL() *url.URL { return &url.URL{ Scheme: sr.Protocol, Host: sr.Host + ":" + strconv.Itoa(sr.Port), } } func (sr *snowflakeRestful) getFullURL(path string, params *url.Values) *url.URL { ret := &url.URL{ Scheme: sr.Protocol, Host: sr.Host + ":" + strconv.Itoa(sr.Port), Path: path, } if params != nil { ret.RawQuery = params.Encode() } return ret } // We need separate client for JWT, because if token processing takes too long, token may be already expired. func (sr *snowflakeRestful) getClientFor(authType AuthType) *http.Client { switch authType { case AuthTypeJwt: return sr.JWTClient default: return sr.Client } } // Renew the snowflake session if the current token is still the stale token specified func (sr *snowflakeRestful) renewExpiredSessionToken(ctx context.Context, timeout time.Duration, expiredToken string) error { err := sr.TokenAccessor.Lock() if err != nil { return err } defer sr.TokenAccessor.Unlock() currentToken, _, _ := sr.TokenAccessor.GetTokens() if expiredToken == currentToken || currentToken == "" { // Only renew the session if the current token is still the expired token or current token is empty return sr.FuncRenewSession(ctx, sr, timeout) } return nil } type renewSessionResponse struct { Data renewSessionResponseMain `json:"data"` Message string `json:"message"` Code string `json:"code"` Success bool `json:"success"` } type renewSessionResponseMain struct { SessionToken string `json:"sessionToken"` ValidityInSecondsST time.Duration `json:"validityInSecondsST"` MasterToken string `json:"masterToken"` ValidityInSecondsMT time.Duration `json:"validityInSecondsMT"` SessionID int64 `json:"sessionId"` } type cancelQueryResponse struct { Data any `json:"data"` Message string `json:"message"` Code string `json:"code"` Success bool `json:"success"` } type telemetryResponse struct { Data any `json:"data,omitempty"` Message string `json:"message"` Code string `json:"code"` Success bool `json:"success"` Headers map[string]string `json:"headers,omitempty"` } func postRestful( ctx context.Context, sr *snowflakeRestful, fullURL *url.URL, headers map[string]string, body []byte, timeout time.Duration, currentTimeProvider currentTimeProvider, cfg *Config) ( *http.Response, error) { return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, currentTimeProvider, cfg). doPost(). setBody(body). execute() } func getRestful( ctx context.Context, sr *snowflakeRestful, fullURL *url.URL, headers map[string]string, timeout time.Duration) ( *http.Response, error) { return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, defaultTimeProvider, nil).execute() } func postAuthRestful( ctx context.Context, client *http.Client, fullURL *url.URL, headers map[string]string, bodyCreator bodyCreatorType, timeout time.Duration, maxRetryCount int) ( *http.Response, error) { return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, maxRetryCount, defaultTimeProvider, nil). doPost(). setBodyCreator(bodyCreator). execute() } func postRestfulQuery( ctx context.Context, sr *snowflakeRestful, params *url.Values, headers map[string]string, body []byte, timeout time.Duration, requestID UUID, cfg *Config) ( data *execResponse, err error) { data, err = sr.FuncPostQueryHelper(ctx, sr, params, headers, body, timeout, requestID, cfg) if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { // For context cancel/timeout cases, a special cancel request needs to be sent. if cancelErr := sr.FuncCancelQuery(context.Background(), sr, requestID, timeout); cancelErr != nil { // Wrap the original error with the cancel error. err = fmt.Errorf("failed to cancel query. cancelErr: %w, queryErr: %w", cancelErr, err) } } return data, err } func postRestfulQueryHelper( ctx context.Context, sr *snowflakeRestful, params *url.Values, headers map[string]string, body []byte, timeout time.Duration, requestID UUID, cfg *Config) ( data *execResponse, err error) { logger.WithContext(ctx).Infof("params: %v", params) params.Set(requestIDKey, requestID.String()) params.Set(requestGUIDKey, NewUUID().String()) token, _, _ := sr.TokenAccessor.GetTokens() if token != "" { headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) } var resp *http.Response fullURL := sr.getFullURL(queryRequestPath, params) logger.WithContext(ctx).Infof("postQuery: make a request to Host: %v, Path: %v", fullURL.Host, fullURL.Path) resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, cfg) if err != nil { return nil, err } defer func(resp *http.Response, url string) { if closeErr := resp.Body.Close(); closeErr != nil { logger.WithContext(ctx).Warnf("failed to close response body for %v. err: %v", url, closeErr) } }(resp, fullURL.String()) if resp.StatusCode == http.StatusOK { respd := &execResponse{} if err = json.NewDecoder(resp.Body).Decode(respd); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } if respd.Code == sessionExpiredCode { if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { return nil, err } return sr.FuncPostQuery(ctx, sr, params, headers, body, timeout, requestID, cfg) } if queryIDChan := getQueryIDChan(ctx); queryIDChan != nil { queryIDChan <- respd.Data.QueryID close(queryIDChan) ctx = WithQueryIDChan(ctx, nil) } isSessionRenewed := false // if asynchronous query in progress, kick off retrieval but return object if respd.Code == queryInProgressAsyncCode && isAsyncMode(ctx) { return sr.processAsync(ctx, respd, headers, timeout, cfg) } for isSessionRenewed || respd.Code == queryInProgressCode || respd.Code == queryInProgressAsyncCode { if !isSessionRenewed { fullURL = sr.getFullURL(respd.Data.GetResultURL, nil) } logger.WithContext(ctx).Info("ping pong") token, _, _ = sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) respd, err = getExecResponse(ctx, sr, fullURL, headers, timeout) if err != nil { return nil, err } if respd.Code == sessionExpiredCode { if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil { return nil, err } isSessionRenewed = true } else { isSessionRenewed = false } } return respd, nil } b, err := io.ReadAll(resp.Body) if err != nil { logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) return nil, err } logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) logger.WithContext(ctx).Infof("Header: %v", resp.Header) return nil, &SnowflakeError{ Number: ErrFailedToPostQuery, SQLState: SQLStateConnectionFailure, Message: errors2.ErrMsgFailedToPostQuery, MessageArgs: []any{resp.StatusCode, fullURL}, } } func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error { logger.WithContext(ctx).Info("close session") params := &url.Values{} params.Set("delete", "true") params.Set(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) params.Set(requestGUIDKey, NewUUID().String()) fullURL := sr.getFullURL(sessionRequestPath, params) headers := getHeaders() token, _, _ := sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider, nil) if err != nil { return err } defer func() { if err = resp.Body.Close(); err != nil { logger.WithContext(ctx).Warnf("failed to close response body for %v. err: %v", fullURL, err) } }() if resp.StatusCode == http.StatusOK { var respd renewSessionResponse if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return err } if !respd.Success && respd.Code != sessionExpiredCode { c, err := strconv.Atoi(respd.Code) if err != nil { return err } return &SnowflakeError{ Number: c, Message: respd.Message, } } return nil } b, err := io.ReadAll(resp.Body) if err != nil { logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) return err } logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) logger.WithContext(ctx).Infof("Header: %v", resp.Header) return &SnowflakeError{ Number: ErrFailedToCloseSession, SQLState: SQLStateConnectionFailure, Message: errors2.ErrMsgFailedToCloseSession, MessageArgs: []any{resp.StatusCode, fullURL}, } } func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error { params := &url.Values{} params.Set(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) params.Set(requestGUIDKey, NewUUID().String()) fullURL := sr.getFullURL(tokenRequestPath, params) token, masterToken, sessionID := sr.TokenAccessor.GetTokens() headers := getHeaders() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, masterToken) body := make(map[string]string) body["oldSessionToken"] = token body["requestType"] = "RENEW" ctx = context.WithValue(ctx, SFSessionIDKey, sessionID) logger.WithContext(ctx).Info("start renew session") var reqBody []byte reqBody, err := json.Marshal(body) if err != nil { return err } resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider, nil) if err != nil { return err } defer func() { if err = resp.Body.Close(); err != nil { logger.WithContext(ctx).Warnf("failed to close response body for %v. err: %v", fullURL, err) } }() if resp.StatusCode == http.StatusOK { var respd renewSessionResponse err = json.NewDecoder(resp.Body).Decode(&respd) if err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return err } if !respd.Success { c, err := strconv.Atoi(respd.Code) if err != nil { return err } return &SnowflakeError{ Number: c, Message: respd.Message, } } sr.TokenAccessor.SetTokens(respd.Data.SessionToken, respd.Data.MasterToken, respd.Data.SessionID) logger.WithContext(ctx).Info("successfully renewed session") return nil } b, err := io.ReadAll(resp.Body) if err != nil { logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) return err } logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) logger.WithContext(ctx).Infof("Header: %v", resp.Header) return &SnowflakeError{ Number: ErrFailedToRenewSession, SQLState: SQLStateConnectionFailure, Message: errors2.ErrMsgFailedToRenew, MessageArgs: []any{resp.StatusCode, fullURL}, } } func getCancelRetry(ctx context.Context) int { val := ctx.Value(cancelRetry) if val == nil { return 5 } cnt, ok := val.(int) if !ok { return -1 } return cnt } func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, timeout time.Duration) error { logger.WithContext(ctx).Info("cancel query") params := &url.Values{} params.Set(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) params.Set(requestGUIDKey, NewUUID().String()) fullURL := sr.getFullURL(abortRequestPath, params) headers := getHeaders() token, _, _ := sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) req := make(map[string]string) req[requestIDKey] = requestID.String() reqByte, err := json.Marshal(req) if err != nil { return err } resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider, nil) if err != nil { return err } defer func() { if err = resp.Body.Close(); err != nil { logger.WithContext(ctx).Warnf("failed to close response body for %v. err: %v", fullURL, err) } }() if resp.StatusCode == http.StatusOK { var respd cancelQueryResponse if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return err } ctxRetry := getCancelRetry(ctx) if !respd.Success && respd.Code == sessionExpiredCode { if err = sr.FuncRenewSession(ctx, sr, timeout); err != nil { return err } return sr.FuncCancelQuery(ctx, sr, requestID, timeout) } else if !respd.Success && respd.Code == queryNotExecutingCode { if ctxRetry != 0 { return sr.FuncCancelQuery(context.WithValue(ctx, cancelRetry, ctxRetry-1), sr, requestID, timeout) } // After exhausting retries, we can safely treat queryNotExecutingCode as success // since it indicates the query has already completed and there's nothing left to cancel logger.WithContext(ctx).Info("query has already completed, no cancellation needed") return nil } else if respd.Success { return nil } else { c, err := strconv.Atoi(respd.Code) if err != nil { return err } return &SnowflakeError{ Number: c, Message: respd.Message, } } } b, err := io.ReadAll(resp.Body) if err != nil { logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) return err } logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) logger.WithContext(ctx).Infof("Header: %v", resp.Header) return &SnowflakeError{ Number: ErrFailedToCancelQuery, SQLState: SQLStateConnectionFailure, Message: errors2.ErrMsgFailedToCancelQuery, MessageArgs: []any{resp.StatusCode, fullURL}, } } func getQueryIDChan(ctx context.Context) chan<- string { v := ctx.Value(queryIDChannel) if v == nil { return nil } c, ok := v.(chan<- string) if !ok { return nil } return c } // getExecResponse fetches a response using FuncGet and decodes it and returns it. func getExecResponse( ctx context.Context, sr *snowflakeRestful, fullURL *url.URL, headers map[string]string, timeout time.Duration) (*execResponse, error) { resp, err := sr.FuncGet(ctx, sr, fullURL, headers, timeout) if err != nil { logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) return nil, err } defer func() { if closeErr := resp.Body.Close(); closeErr != nil { logger.WithContext(ctx).Errorf("failed to close response body for %v. err: %v", fullURL, closeErr) } }() // decode response and fill into an empty execResponse respd := &execResponse{} err = json.NewDecoder(resp.Body).Decode(respd) if err != nil { logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } return respd, nil } ================================================ FILE: restful_test.go ================================================ package gosnowflake import ( "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "sync" "sync/atomic" "testing" "time" ) func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, errors.New("failed to run post method") } func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, errors.New("failed to run post method") } func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusInsufficientStorage, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, Message: "", Code: queryNotExecutingCode, Success: false, } ba, err := json.Marshal(er) if err != nil { panic(err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, Message: "", Code: sessionExpiredCode, Success: true, } ba, err := json.Marshal(er) logger.Infof("encoded JSON: %v", ba) if err != nil { panic(err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, Message: "", Code: "", Success: true, } ba, err := json.Marshal(er) logger.Infof("encoded JSON: %v", ba) if err != nil { panic(err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, Message: "", Code: "", Success: true, } ba, err := json.Marshal(er) logger.Infof("encoded JSON: %v", ba) if err != nil { panic(err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } func TestUnitPostQueryHelperError(t *testing.T) { sr := &snowflakeRestful{ FuncPost: postTestError, TokenAccessor: getSimpleTokenAccessor(), } var err error requestID := NewUUID() _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{}) if err == nil { t.Fatalf("should have failed to post") } sr.FuncPost = postTestAppBadGatewayError requestID = NewUUID() _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{}) if err == nil { t.Fatalf("should have failed to post") } sr.FuncPost = postTestSuccessButInvalidJSON requestID = NewUUID() _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{}) if err == nil { t.Fatalf("should have failed to post") } } func TestUnitPostQueryHelperOnRenewSessionKeepsRequestIdButGeneratesNewRequestGuid(t *testing.T) { postCount := 0 requestID := NewUUID() sr := &snowflakeRestful{ FuncPost: func(ctx context.Context, restful *snowflakeRestful, url *url.URL, headers map[string]string, bytes []byte, duration time.Duration, provider currentTimeProvider, config *Config) (*http.Response, error) { assertEqualF(t, len((url.Query())[requestIDKey]), 1) assertEqualF(t, len((url.Query())[requestGUIDKey]), 1) return &http.Response{ StatusCode: 200, Body: &fakeResponseBody{body: []byte(`{"data":null,"code":"390112","message":"token expired for testing","success":false,"headers":null}`)}, }, nil }, FuncPostQuery: func(ctx context.Context, restful *snowflakeRestful, values *url.Values, headers map[string]string, bytes []byte, timeout time.Duration, uuid UUID, config *Config) (*execResponse, error) { assertEqualF(t, requestID.String(), uuid.String()) assertEqualF(t, len((*values)[requestIDKey]), 1) assertEqualF(t, len((*values)[requestGUIDKey]), 1) if postCount == 0 { postCount++ return postRestfulQueryHelper(ctx, restful, values, headers, bytes, timeout, uuid, config) } return nil, nil }, FuncRenewSession: renewSessionTest, TokenAccessor: getSimpleTokenAccessor(), } _, err := postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), make([]byte, 0), time.Second, requestID, nil) assertNilE(t, err) } func renewSessionTest(_ context.Context, _ *snowflakeRestful, _ time.Duration) error { return nil } func renewSessionTestError(_ context.Context, _ *snowflakeRestful, _ time.Duration) error { return errors.New("failed to renew session in tests") } func TestUnitTokenAccessorDoesNotRenewStaleToken(t *testing.T) { accessor := getSimpleTokenAccessor() oldToken := "test" accessor.SetTokens(oldToken, "master", 123) renewSessionCalled := false renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error { // should not have gotten to actual renewal renewSessionCalled = true return nil } sr := &snowflakeRestful{ FuncRenewSession: renewSessionDummy, TokenAccessor: accessor, } // try to intentionally renew with stale token assertNilE(t, sr.renewExpiredSessionToken(context.Background(), time.Hour, "stale-token")) if renewSessionCalled { t.Fatal("FuncRenewSession should not have been called") } // set the current token to empty, should still call renew even if stale token is passed in accessor.SetTokens("", "master", 123) assertNilE(t, sr.renewExpiredSessionToken(context.Background(), time.Hour, "stale-token")) if !renewSessionCalled { t.Fatal("FuncRenewSession should have been called because current token is empty") } } type wrappedAccessor struct { ta TokenAccessor lockCallCount int32 unlockCallCount int32 } func (wa *wrappedAccessor) Lock() error { atomic.AddInt32(&wa.lockCallCount, 1) err := wa.ta.Lock() return err } func (wa *wrappedAccessor) Unlock() { atomic.AddInt32(&wa.unlockCallCount, 1) wa.ta.Unlock() } func (wa *wrappedAccessor) GetTokens() (token string, masterToken string, sessionID int64) { return wa.ta.GetTokens() } func (wa *wrappedAccessor) SetTokens(token string, masterToken string, sessionID int64) { wa.ta.SetTokens(token, masterToken, sessionID) } func TestUnitTokenAccessorRenewBlocked(t *testing.T) { accessor := wrappedAccessor{ ta: getSimpleTokenAccessor(), } oldToken := "test" accessor.SetTokens(oldToken, "master", 123) renewSessionCalled := false renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error { renewSessionCalled = true return nil } sr := &snowflakeRestful{ FuncRenewSession: renewSessionDummy, TokenAccessor: &accessor, } // intentionally lock the accessor first assertNilE(t, accessor.Lock()) // try to intentionally renew with stale token var renewalStart sync.WaitGroup var renewalDone sync.WaitGroup renewalStart.Add(1) renewalDone.Add(1) go func() { renewalStart.Done() assertNilE(t, sr.renewExpiredSessionToken(context.Background(), time.Hour, oldToken)) renewalDone.Done() }() // wait for renewal to start and get blocked on lock renewalStart.Wait() // should be blocked and not be able to call renew session if renewSessionCalled { t.Fail() } // rotate the token again so that the session token is considered stale accessor.SetTokens("new-token", "m", 321) // unlock so that renew can happen accessor.Unlock() renewalDone.Wait() // renewal should be done but token should still not // have been renewed since we intentionally swapped token while locked if renewSessionCalled { t.Fail() } // wait for accessor defer unlock assertNilE(t, accessor.Lock()) if accessor.lockCallCount != 3 { t.Fatalf("Expected Lock() to be called thrice, but got %v", accessor.lockCallCount) } if accessor.unlockCallCount != 2 { t.Fatalf("Expected Unlock() to be called twice, but got %v", accessor.unlockCallCount) } } func TestUnitTokenAccessorRenewSessionContention(t *testing.T) { accessor := getSimpleTokenAccessor() oldToken := "test" accessor.SetTokens(oldToken, "master", 123) var counter int32 = 0 expectedToken := "new token" expectedMaster := "new master" expectedSession := int64(321) renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error { accessor.SetTokens(expectedToken, expectedMaster, expectedSession) atomic.AddInt32(&counter, 1) return nil } sr := &snowflakeRestful{ FuncRenewSession: renewSessionDummy, TokenAccessor: accessor, } var renewalsStart sync.WaitGroup var renewalsDone sync.WaitGroup var renewalError error numRoutines := 50 for range numRoutines { renewalsDone.Add(1) renewalsStart.Add(1) go func() { // wait for all goroutines to have been created before proceeding to race against each other renewalsStart.Wait() err := sr.renewExpiredSessionToken(context.Background(), time.Hour, oldToken) if err != nil { renewalError = err } renewalsDone.Done() }() } // unlock all of the waiting goroutines simultaneously renewalsStart.Add(-numRoutines) // wait for all competing goroutines to finish calling renew expired session token renewalsDone.Wait() if renewalError != nil { t.Fatalf("failed to renew session, error %v", renewalError) } newToken, newMaster, newSession := accessor.GetTokens() if newToken != expectedToken { t.Fatalf("token %v does not match expected %v", newToken, expectedToken) } if newMaster != expectedMaster { t.Fatalf("master token %v does not match expected %v", newMaster, expectedMaster) } if newSession != expectedSession { t.Fatalf("session %v does not match expected %v", newSession, expectedSession) } // only the first renewal will go through and FuncRenewSession should be called exactly once if counter != 1 { t.Fatalf("renew expired session was called more than once: %v", counter) } } func TestUnitPostQueryHelperUsesToken(t *testing.T) { accessor := getSimpleTokenAccessor() token := "token123" accessor.SetTokens(token, "", 0) var err error postQueryTest := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, token) { t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, token)) } dd := &execResponseData{} return &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, }, nil } sr := &snowflakeRestful{ FuncPost: postTestRenew, FuncPostQuery: postQueryTest, FuncRenewSession: renewSessionTest, TokenAccessor: accessor, } _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, NewUUID(), &Config{}) if err != nil { t.Fatalf("err: %v", err) } } func TestUnitPostQueryHelperRenewSession(t *testing.T) { var err error origRequestID := NewUUID() postQueryTest := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, requestID UUID, _ *Config) (*execResponse, error) { // ensure the same requestID is used after the session token is renewed. if requestID != origRequestID { t.Fatal("requestID doesn't match") } dd := &execResponseData{} return &execResponse{ Data: *dd, Message: "", Code: "0", Success: true, }, nil } sr := &snowflakeRestful{ FuncPost: postTestRenew, FuncPostQuery: postQueryTest, FuncRenewSession: renewSessionTest, TokenAccessor: getSimpleTokenAccessor(), } _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, origRequestID, &Config{}) if err != nil { t.Fatalf("err: %v", err) } sr.FuncRenewSession = renewSessionTestError _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, origRequestID, &Config{}) if err == nil { t.Fatal("should have failed to renew session") } } func TestUnitRenewRestfulSession(t *testing.T) { accessor := getSimpleTokenAccessor() oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100) newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200) postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) { t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken)) } tr := &renewSessionResponse{ Data: renewSessionResponseMain{ SessionToken: newToken, MasterToken: newMasterToken, SessionID: newSessionID, }, Message: "", Success: true, } ba, err := json.Marshal(tr) if err != nil { t.Fatalf("failed to serialize token response %v", err) } return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: ba}, }, nil } sr := &snowflakeRestful{ FuncPost: postTestAfterRenew, TokenAccessor: accessor, } err := renewRestfulSession(context.Background(), sr, time.Second) if err != nil { t.Fatalf("err: %v", err) } sr.FuncPost = postTestError err = renewRestfulSession(context.Background(), sr, time.Second) if err == nil { t.Fatal("should have failed to run post request after the renewal") } sr.FuncPost = postTestAppBadGatewayError err = renewRestfulSession(context.Background(), sr, time.Second) if err == nil { t.Fatal("should have failed to run post request after the renewal") } sr.FuncPost = postTestSuccessButInvalidJSON err = renewRestfulSession(context.Background(), sr, time.Second) if err == nil { t.Fatal("should have failed to run post request after the renewal") } accessor.SetTokens(oldToken, oldMasterToken, oldSessionID) sr.FuncPost = postTestSuccessWithNewTokens err = renewRestfulSession(context.Background(), sr, time.Second) if err != nil { t.Fatal("should not have failed to run post request after the renewal") } token, masterToken, sessionID := accessor.GetTokens() if token != newToken { t.Fatalf("unexpected new token %v", token) } if masterToken != newMasterToken { t.Fatalf("unexpected new master token %v", masterToken) } if sessionID != newSessionID { t.Fatalf("unexpected new session id %v", sessionID) } } func TestUnitCloseSession(t *testing.T) { sr := &snowflakeRestful{ FuncPost: postTestAfterRenew, TokenAccessor: getSimpleTokenAccessor(), } err := closeSession(context.Background(), sr, time.Second) if err != nil { t.Fatalf("err: %v", err) } sr.FuncPost = postTestError err = closeSession(context.Background(), sr, time.Second) if err == nil { t.Fatal("should have failed to close session") } sr.FuncPost = postTestAppBadGatewayError err = closeSession(context.Background(), sr, time.Second) if err == nil { t.Fatal("should have failed to close session") } sr.FuncPost = postTestSuccessButInvalidJSON err = closeSession(context.Background(), sr, time.Second) if err == nil { t.Fatal("should have failed to close session") } } func TestUnitCancelQuery(t *testing.T) { sr := &snowflakeRestful{ FuncPost: postTestAfterRenew, TokenAccessor: getSimpleTokenAccessor(), } ctx := context.Background() err := cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second) if err != nil { t.Fatalf("err: %v", err) } sr.FuncPost = postTestError err = cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second) if err == nil { t.Fatal("should have failed to close session") } sr.FuncPost = postTestAppBadGatewayError err = cancelQuery(context.Background(), sr, getOrGenerateRequestIDFromContext(ctx), time.Second) if err == nil { t.Fatal("should have failed to close session") } sr.FuncPost = postTestSuccessButInvalidJSON err = cancelQuery(context.Background(), sr, getOrGenerateRequestIDFromContext(ctx), time.Second) if err == nil { t.Fatal("should have failed to close session") } } func TestCancelRetry(t *testing.T) { sr := &snowflakeRestful{ TokenAccessor: getSimpleTokenAccessor(), FuncPost: postTestQueryNotExecuting, FuncCancelQuery: cancelQuery, } ctx := context.Background() err := cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second) if err != nil { t.Fatal(err) } } func TestPostRestfulQueryContextErrors(t *testing.T) { var cancelCalled bool newRestfulWithError := func(queryErr error) *snowflakeRestful { cancelCalled = false return &snowflakeRestful{ FuncPostQueryHelper: func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, UUID, *Config) (*execResponse, error) { return nil, queryErr }, FuncCancelQuery: func(context.Context, *snowflakeRestful, UUID, time.Duration) error { cancelCalled = true return nil }, TokenAccessor: getSimpleTokenAccessor(), } } runPostRestfulQuery := func(sr *snowflakeRestful) (data *execResponse, err error) { return postRestfulQuery(context.Background(), sr, &url.Values{}, nil, nil, 0, NewUUID(), nil) } t.Run("postRestfulQuery error does not trigger cancel", func(t *testing.T) { expectedErr := fmt.Errorf("query error") sr := newRestfulWithError(expectedErr) _, err := runPostRestfulQuery(sr) assertFalseE(t, cancelCalled) assertErrIsE(t, expectedErr, err) }) t.Run("context.Canceled triggers cancel", func(t *testing.T) { sr := newRestfulWithError(context.Canceled) _, err := runPostRestfulQuery(sr) assertTrueE(t, cancelCalled) assertErrIsE(t, context.Canceled, err) }) t.Run("context.DeadlineExceeded triggers cancel", func(t *testing.T) { sr := newRestfulWithError(context.DeadlineExceeded) _, err := runPostRestfulQuery(sr) assertTrueE(t, cancelCalled) assertErrIsE(t, context.DeadlineExceeded, err) }) t.Run("cancel failure returns wrapped error", func(t *testing.T) { fatalCancelErr := fmt.Errorf("fatal failure") sr := newRestfulWithError(context.Canceled) sr.FuncCancelQuery = func(context.Context, *snowflakeRestful, UUID, time.Duration) error { cancelCalled = true return fatalCancelErr } _, err := runPostRestfulQuery(sr) assertTrueE(t, cancelCalled) assertErrIsE(t, err, context.Canceled) assertErrIsE(t, err, fatalCancelErr) assertEqualE(t, "failed to cancel query. cancelErr: fatal failure, queryErr: context canceled", err.Error()) }) } func TestErrorReturnedFromLongRunningQuery(t *testing.T) { t.Run("e2e test", func(t *testing.T) { t.Skip("long running test, uncomment to run manually, otherwise the test on mocks should be sufficient") db := openDB(t) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Second) defer cancel() _, err := db.ExecContext(ctx, "CALL SYSTEM$WAIT(55, 'SECONDS')") assertNotNilF(t, err) assertErrIsE(t, err, context.DeadlineExceeded) }) t.Run("mock test", func(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("auth/password/successful_flow.json"), newWiremockMapping("query/long_running_query.json"), newWiremockMapping("query/query_by_id_timeout.json"), ) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() db := wiremock.openDb(t) _, err := db.QueryContext(ctx, "SELECT 1") assertNotNilF(t, err) assertErrIsE(t, err, context.DeadlineExceeded) }) } ================================================ FILE: result.go ================================================ package gosnowflake import "errors" // QueryStatus denotes the status of a query. type QueryStatus string const ( // QueryStatusInProgress denotes a query execution in progress QueryStatusInProgress QueryStatus = "queryStatusInProgress" // QueryStatusComplete denotes a completed query execution QueryStatusComplete QueryStatus = "queryStatusComplete" // QueryFailed denotes a failed query QueryFailed QueryStatus = "queryFailed" ) // SnowflakeResult provides an API for methods exposed to the clients type SnowflakeResult interface { GetQueryID() string GetStatus() QueryStatus } type snowflakeResult struct { affectedRows int64 insertID int64 // Snowflake doesn't support last insert id queryID string status QueryStatus err error errChannel chan error } func (res *snowflakeResult) LastInsertId() (int64, error) { if err := res.waitForAsyncExecStatus(); err != nil { return -1, err } return res.insertID, nil } func (res *snowflakeResult) RowsAffected() (int64, error) { if err := res.waitForAsyncExecStatus(); err != nil { return -1, err } return res.affectedRows, nil } func (res *snowflakeResult) GetQueryID() string { return res.queryID } func (res *snowflakeResult) GetStatus() QueryStatus { return res.status } func (res *snowflakeResult) waitForAsyncExecStatus() error { // if async exec, block until execution is finished switch res.status { case QueryStatusInProgress: err := <-res.errChannel res.status = QueryStatusComplete if err != nil { res.status = QueryFailed res.err = err return err } return nil case QueryFailed: return res.err default: return nil } } type snowflakeResultNoRows struct { queryID string } func (*snowflakeResultNoRows) LastInsertId() (int64, error) { return 0, errors.New("no LastInsertId available") } func (*snowflakeResultNoRows) RowsAffected() (int64, error) { return 0, errors.New("no RowsAffected available") } func (rnr *snowflakeResultNoRows) GetQueryID() string { return rnr.queryID } ================================================ FILE: retry.go ================================================ package gosnowflake import ( "bytes" "context" "fmt" "io" "math" "math/rand" "net/http" "net/url" "slices" "strconv" "strings" "sync" "time" ) type waitAlgo struct { mutex *sync.Mutex // required for *rand.Rand usage random *rand.Rand base time.Duration // base wait time cap time.Duration // maximum wait time } var random *rand.Rand var defaultWaitAlgo *waitAlgo var authEndpoints = []string{ loginRequestPath, tokenRequestPath, authenticatorRequestPath, } var clientErrorsStatusCodesEligibleForRetry = []int{ http.StatusTooManyRequests, http.StatusRequestTimeout, } func init() { random = rand.New(rand.NewSource(time.Now().UnixNano())) // sleep time before retrying starts from 1s and the max sleep time is 16s defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random, base: 1 * time.Second, cap: 16 * time.Second} } const ( // requestGUIDKey is attached to every request against Snowflake requestGUIDKey string = "request_guid" // retryCountKey is attached to query-request from the second time retryCountKey string = "retryCount" // retryReasonKey contains last HTTP status or 0 if timeout retryReasonKey string = "retryReason" // clientStartTime contains a time when client started request (first request, not retries) clientStartTimeKey string = "clientStartTime" // requestIDKey is attached to all requests to Snowflake requestIDKey string = "requestId" ) // This class takes in an url during construction and replaces the value of // request_guid every time replace() is called. If the url does not contain // request_guid, just return the original url type requestGUIDReplacer interface { // replace the url with new ID replace() *url.URL } // Make requestGUIDReplacer given a url string func newRequestGUIDReplace(urlPtr *url.URL) requestGUIDReplacer { values, err := url.ParseQuery(urlPtr.RawQuery) if err != nil { // nop if invalid query parameters return &transientReplace{urlPtr} } if len(values.Get(requestGUIDKey)) == 0 { // nop if no request_guid is included. return &transientReplace{urlPtr} } return &requestGUIDReplace{urlPtr, values} } // this replacer does nothing but replace the url type transientReplace struct { urlPtr *url.URL } func (replacer *transientReplace) replace() *url.URL { return replacer.urlPtr } /* requestGUIDReplacer is a one-shot object that is created out of the retry loop and called with replace to change the retry_guid's value upon every retry */ type requestGUIDReplace struct { urlPtr *url.URL urlValues url.Values } /* * This function would replace they value of the requestGUIDKey in a url with a newly generated UUID */ func (replacer *requestGUIDReplace) replace() *url.URL { replacer.urlValues.Del(requestGUIDKey) replacer.urlValues.Add(requestGUIDKey, NewUUID().String()) replacer.urlPtr.RawQuery = replacer.urlValues.Encode() return replacer.urlPtr } type retryCountUpdater interface { replaceOrAdd(retry int) *url.URL } type retryCountUpdate struct { urlPtr *url.URL urlValues url.Values } // this replacer does nothing but replace the url type transientRetryCountUpdater struct { urlPtr *url.URL } func (replaceOrAdder *transientRetryCountUpdater) replaceOrAdd(retry int) *url.URL { return replaceOrAdder.urlPtr } func (replacer *retryCountUpdate) replaceOrAdd(retry int) *url.URL { replacer.urlValues.Del(retryCountKey) replacer.urlValues.Add(retryCountKey, strconv.Itoa(retry)) replacer.urlPtr.RawQuery = replacer.urlValues.Encode() return replacer.urlPtr } func newRetryCountUpdater(urlPtr *url.URL) retryCountUpdater { if !isQueryRequest(urlPtr) { // nop if not query-request return &transientRetryCountUpdater{urlPtr} } values, err := url.ParseQuery(urlPtr.RawQuery) if err != nil { // nop if the URL is not valid return &transientRetryCountUpdater{urlPtr} } return &retryCountUpdate{urlPtr, values} } type retryReasonUpdater interface { replaceOrAdd(reason int) *url.URL } type retryReasonUpdate struct { url *url.URL } func (retryReasonUpdater *retryReasonUpdate) replaceOrAdd(reason int) *url.URL { query := retryReasonUpdater.url.Query() query.Del(retryReasonKey) query.Add(retryReasonKey, strconv.Itoa(reason)) retryReasonUpdater.url.RawQuery = query.Encode() return retryReasonUpdater.url } type transientRetryReasonUpdater struct { url *url.URL } func (retryReasonUpdater *transientRetryReasonUpdater) replaceOrAdd(_ int) *url.URL { return retryReasonUpdater.url } func newRetryReasonUpdater(url *url.URL, cfg *Config) retryReasonUpdater { // not a query request if !isQueryRequest(url) { return &transientRetryReasonUpdater{url} } // implicitly disabled retry reason if cfg != nil && cfg.IncludeRetryReason == ConfigBoolFalse { return &transientRetryReasonUpdater{url} } return &retryReasonUpdate{url} } func ensureClientStartTimeIsSet(url *url.URL, clientStartTime string) *url.URL { if !isQueryRequest(url) { // nop if not query-request return url } query := url.Query() if query.Has(clientStartTimeKey) { return url } query.Add(clientStartTimeKey, clientStartTime) url.RawQuery = query.Encode() return url } func isQueryRequest(url *url.URL) bool { return strings.HasPrefix(url.Path, queryRequestPath) } // jitter backoff in seconds func (w *waitAlgo) calculateWaitBeforeRetryForAuthRequest(attempt int, currWaitTimeDuration time.Duration) time.Duration { w.mutex.Lock() defer w.mutex.Unlock() currWaitTimeInSeconds := currWaitTimeDuration.Seconds() jitterAmount := w.getJitter(currWaitTimeInSeconds) jitteredSleepTime := chooseRandomFromRange(currWaitTimeInSeconds+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount) return time.Duration(jitteredSleepTime * float64(time.Second)) } func (w *waitAlgo) calculateWaitBeforeRetry(sleep time.Duration) time.Duration { w.mutex.Lock() defer w.mutex.Unlock() // use decorrelated jitter in retry time randDuration := randMilliSecondDuration(w.base, sleep*3) return durationMin(w.cap, randDuration) } func randMilliSecondDuration(base time.Duration, bound time.Duration) time.Duration { baseNumber := int64(base / time.Millisecond) boundNumber := int64(bound / time.Millisecond) randomDuration := random.Int63n(boundNumber-baseNumber) + baseNumber return time.Duration(randomDuration) * time.Millisecond } func (w *waitAlgo) getJitter(currWaitTime float64) float64 { multiplicationFactor := chooseRandomFromRange(-1, 1) jitterAmount := 0.5 * currWaitTime * multiplicationFactor return jitterAmount } type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) type clientInterface interface { Do(req *http.Request) (*http.Response, error) } type retryHTTP struct { ctx context.Context client clientInterface req requestFunc method string fullURL *url.URL headers map[string]string bodyCreator bodyCreatorType timeout time.Duration maxRetryCount int currentTimeProvider currentTimeProvider cfg *Config } func newRetryHTTP(ctx context.Context, client clientInterface, req requestFunc, fullURL *url.URL, headers map[string]string, timeout time.Duration, maxRetryCount int, currentTimeProvider currentTimeProvider, cfg *Config) *retryHTTP { instance := retryHTTP{} instance.ctx = ctx instance.client = client instance.req = req instance.method = "GET" instance.fullURL = fullURL instance.headers = headers instance.timeout = timeout instance.maxRetryCount = maxRetryCount instance.bodyCreator = emptyBodyCreator instance.currentTimeProvider = currentTimeProvider instance.cfg = cfg return &instance } func (r *retryHTTP) doPost() *retryHTTP { r.method = "POST" return r } func (r *retryHTTP) setBody(body []byte) *retryHTTP { r.bodyCreator = func() ([]byte, error) { return body, nil } return r } func (r *retryHTTP) setBodyCreator(bodyCreator bodyCreatorType) *retryHTTP { r.bodyCreator = bodyCreator return r } func (r *retryHTTP) execute() (res *http.Response, err error) { totalTimeout := r.timeout logger.WithContext(r.ctx).Debugf("retryHTTP.totalTimeout: %v", totalTimeout) retryCounter := 0 sleepTime := time.Duration(time.Second) clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10) var requestGUIDReplacer requestGUIDReplacer var retryCountUpdater retryCountUpdater var retryReasonUpdater retryReasonUpdater for { timer := time.Now() logger.WithContext(r.ctx).Debugf("retry count: %v", retryCounter) body, err := r.bodyCreator() if err != nil { return nil, err } req, err := r.req(r.method, r.fullURL.String(), bytes.NewReader(body)) if err != nil { return nil, err } if req != nil { // req can be nil in tests req = req.WithContext(r.ctx) } for k, v := range r.headers { req.Header.Set(k, v) } res, err = r.client.Do(req) // check if it can retry. retryable, err := isRetryableError(r.ctx, req, res, err) if !retryable { return res, err } logger.WithContext(r.ctx).Debugf("Request to %v - response received after milliseconds %v with status .", r.fullURL.Host, time.Since(timer).String()) if err != nil { logger.WithContext(r.ctx).Warnf( "failed http connection. err: %v. retrying...\n", err) } else { logger.WithContext(r.ctx).Tracef( "failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode) if closeErr := res.Body.Close(); closeErr != nil { logger.Warnf("failed to close response body. err: %v", closeErr) } } // uses exponential jitter backoff retryCounter++ if isLoginRequest(req) { sleepTime = defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(retryCounter, sleepTime) } else { sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(sleepTime) } if totalTimeout > 0 { // if any timeout is set totalTimeout -= sleepTime } if (r.timeout > 0 && totalTimeout <= 0) || retryCounter > r.maxRetryCount { if err != nil { return nil, err } if res != nil { return nil, fmt.Errorf("timeout after %s and %v attempts. HTTP Status: %v. Hanging?", r.timeout, retryCounter, res.StatusCode) } return nil, fmt.Errorf("timeout after %s and %v attempts. Hanging?", r.timeout, retryCounter) } if requestGUIDReplacer == nil { requestGUIDReplacer = newRequestGUIDReplace(r.fullURL) } r.fullURL = requestGUIDReplacer.replace() if retryCountUpdater == nil { retryCountUpdater = newRetryCountUpdater(r.fullURL) } r.fullURL = retryCountUpdater.replaceOrAdd(retryCounter) if retryReasonUpdater == nil { retryReasonUpdater = newRetryReasonUpdater(r.fullURL, r.cfg) } retryReason := 0 if res != nil { retryReason = res.StatusCode } r.fullURL = retryReasonUpdater.replaceOrAdd(retryReason) r.fullURL = ensureClientStartTimeIsSet(r.fullURL, clientStartTime) logger.WithContext(r.ctx).Debugf("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout) logger.WithContext(r.ctx).Debugf("retry count: %v, retry reason: %v", retryCounter, retryReason) await := time.NewTimer(sleepTime) select { case <-await.C: // retry the request case <-r.ctx.Done(): await.Stop() return res, r.ctx.Err() } } } func isRetryableError(ctx context.Context, req *http.Request, res *http.Response, err error) (bool, error) { if ctx.Err() != nil { return false, ctx.Err() } if err != nil && res == nil { // Failed http connection. Most probably client timeout. return true, err } if res == nil || req == nil { return false, err } return isRetryableStatus(res.StatusCode), err } func isRetryableStatus(statusCode int) bool { return (statusCode >= 500 && statusCode < 600) || slices.Contains(clientErrorsStatusCodesEligibleForRetry, statusCode) } func isLoginRequest(req *http.Request) bool { return slices.Contains(authEndpoints, req.URL.Path) } ================================================ FILE: retry_test.go ================================================ package gosnowflake import ( "bytes" "context" "database/sql" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "io" "net/http" "net/url" "strconv" "strings" "testing" "time" ) func fakeRequestFunc(_, _ string, _ io.Reader) (*http.Request, error) { return nil, nil } func emptyRequest(method string, urlStr string, body io.Reader) (*http.Request, error) { return http.NewRequest(method, urlStr, body) } type fakeHTTPError struct { err string timeout bool } func (e *fakeHTTPError) Error() string { return e.err } func (e *fakeHTTPError) Timeout() bool { return e.timeout } func (e *fakeHTTPError) Temporary() bool { return true } type fakeResponseBody struct { body []byte cnt int } func (b *fakeResponseBody) Read(p []byte) (n int, err error) { if b.cnt == 0 { copy(p, b.body) b.cnt = 1 return len(b.body), nil } b.cnt = 0 return 0, io.EOF } func (b *fakeResponseBody) Close() error { return nil } type fakeHTTPClient struct { t *testing.T // for assertions cnt int // number of retry success bool // return success after retry in cnt times timeout bool // timeout body []byte // return body reqBody []byte // last request body statusCode int // status code retryNumber int // consecutive number of retries expectedQueryParams map[int]map[string]string // expected query params per each retry (0-based) } func (c *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { defer func() { c.retryNumber++ }() if req != nil { buf := new(bytes.Buffer) _, err := buf.ReadFrom(req.Body) assertNilF(c.t, err) c.reqBody = buf.Bytes() } if len(c.expectedQueryParams) > 0 { expectedQueryParams, ok := c.expectedQueryParams[c.retryNumber] if ok { for queryParamName, expectedValue := range expectedQueryParams { actualValue := req.URL.Query().Get(queryParamName) if actualValue != expectedValue { c.t.Fatalf("expected query param %v to be %v, got %v", queryParamName, expectedValue, actualValue) } } } } c.cnt-- if c.cnt < 0 { c.cnt = 0 } logger.Infof("fakeHTTPClient.cnt: %v", c.cnt) var retcode int if c.success && c.cnt == 0 { retcode = 200 } else { if c.timeout { // simulate timeout time.Sleep(time.Second * 1) return nil, &fakeHTTPError{ err: "Whatever reason (Client.Timeout exceeded while awaiting headers)", timeout: true, } } if c.statusCode != 0 { retcode = c.statusCode } else { retcode = 0 } } ret := &http.Response{ StatusCode: retcode, Body: &fakeResponseBody{body: c.body}, } return ret, nil } func TestRequestGUID(t *testing.T) { var ridReplacer requestGUIDReplacer var testURL *url.URL var actualURL *url.URL retryTime := 4 // empty url testURL = &url.URL{} ridReplacer = newRequestGUIDReplace(testURL) for range retryTime { actualURL = ridReplacer.replace() if actualURL.String() != "" { t.Fatalf("empty url not replaced by an empty one, got %s", actualURL) } } // url with on retry id testURL = &url.URL{ Path: "/" + requestIDKey + "=123-1923-9?param2=value", } ridReplacer = newRequestGUIDReplace(testURL) for range retryTime { actualURL = ridReplacer.replace() if actualURL != testURL { t.Fatalf("url without retry id not replaced by origin one, got %s", actualURL) } } // url with retry id // With both prefix and suffix prefix := "/" + requestIDKey + "=123-1923-9?" + requestGUIDKey + "=" suffix := "?param2=value" testURL = &url.URL{ Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, } ridReplacer = newRequestGUIDReplace(testURL) for range retryTime { actualURL = ridReplacer.replace() if (!strings.HasPrefix(actualURL.Path, prefix)) || (!strings.HasSuffix(actualURL.Path, suffix)) || len(testURL.Path) != len(actualURL.Path) { t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) } } // With no suffix prefix = "/" + requestIDKey + "=123-1923-9?" + requestGUIDKey + "=" suffix = "" testURL = &url.URL{ Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, } ridReplacer = newRequestGUIDReplace(testURL) for range retryTime { actualURL = ridReplacer.replace() if (!strings.HasPrefix(actualURL.Path, prefix)) || (!strings.HasSuffix(actualURL.Path, suffix)) || len(testURL.Path) != len(actualURL.Path) { t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) } } // With no prefix prefix = requestGUIDKey + "=" suffix = "?param2=value" testURL = &url.URL{ Path: prefix + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + suffix, } ridReplacer = newRequestGUIDReplace(testURL) for range retryTime { actualURL = ridReplacer.replace() if (!strings.HasPrefix(actualURL.Path, prefix)) || (!strings.HasSuffix(actualURL.Path, suffix)) || len(testURL.Path) != len(actualURL.Path) { t.Fatalf("Retry url not replaced correctedly: \n origin: %s \n result: %s", testURL, actualURL) } } } func TestRetryQuerySuccess(t *testing.T) { logger.Info("Retry N times and Success") client := &fakeHTTPClient{ cnt: 3, success: true, statusCode: 429, t: t, expectedQueryParams: map[int]map[string]string{ 0: { "retryCount": "", "retryReason": "", "clientStartTime": "", }, 1: { "retryCount": "1", "retryReason": "429", "clientStartTime": "123456", }, 2: { "retryCount": "2", "retryReason": "429", "clientStartTime": "123456", }, }, } urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid") assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolTrue}).doPost().setBody([]byte{0}).execute() assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) assertNilF(t, err, "failed to parse the test URL") retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { t.Fatalf("failed to get retry counter: %v", err) } if retry < 2 { t.Fatalf("not enough retry counter: %v", retry) } } func TestRetryQuerySuccessWithRetryReasonDisabled(t *testing.T) { logger.Info("Retry N times and Success") client := &fakeHTTPClient{ cnt: 3, success: true, statusCode: 429, t: t, expectedQueryParams: map[int]map[string]string{ 0: { "retryCount": "", "retryReason": "", "clientStartTime": "", }, 1: { "retryCount": "1", "retryReason": "", "clientStartTime": "123456", }, 2: { "retryCount": "2", "retryReason": "", "clientStartTime": "123456", }, }, } urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid") assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolFalse}).doPost().setBody([]byte{0}).execute() assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) assertNilF(t, err, "failed to parse the test URL") retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { t.Fatalf("failed to get retry counter: %v", err) } if retry < 2 { t.Fatalf("not enough retry counter: %v", retry) } } func TestRetryQuerySuccessWithTimeout(t *testing.T) { logger.Info("Retry N times and Success") client := &fakeHTTPClient{ cnt: 3, success: true, timeout: true, t: t, expectedQueryParams: map[int]map[string]string{ 0: { "retryCount": "", "retryReason": "", }, 1: { "retryCount": "1", "retryReason": "0", }, 2: { "retryCount": "2", "retryReason": "0", }, }, } urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid") assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, constTimeProvider(123456), nil).doPost().setBody([]byte{0}).execute() assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) assertNilF(t, err, "failed to parse the test URL") retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { t.Fatalf("failed to get retry counter: %v", err) } if retry < 2 { t.Fatalf("not enough retry counter: %v", retry) } } func TestRetryQueryFailWithTimeout(t *testing.T) { logger.Info("Retry N times until there is a timeout and Fail") client := &fakeHTTPClient{ statusCode: http.StatusTooManyRequests, success: false, t: t, } urlPtr, err := url.Parse("https://fakeaccountretryfail.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey) assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 20*time.Second, 100, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() assertNotNilF(t, err, "should fail to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) assertNilF(t, err, fmt.Sprintf("failed to parse the URL: %v", err)) retry, err := strconv.Atoi(values.Get(retryCountKey)) assertNilF(t, err, fmt.Sprintf("failed to get retry counter: %v", err)) if retry < 2 { t.Fatalf("not enough retries: %v", retry) } } func TestRetryQueryFailWithMaxRetryCount(t *testing.T) { tcs := []struct { name string timeout time.Duration }{ { name: "with timeout", timeout: 15 * time.Hour, }, { name: "without timeout", timeout: 0, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { maxRetryCount := 3 logger.Info("Retry 3 times until retry reaches MaxRetryCount and Fail") client := &fakeHTTPClient{ statusCode: http.StatusTooManyRequests, success: false, t: t, } urlPtr, err := url.Parse("https://fakeaccountretryfail.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey) assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), tc.timeout, maxRetryCount, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() assertNotNilF(t, err, "should fail to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { t.Fatalf("failed to parse the URL: %v", err) } retryCount, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { t.Fatalf("failed to get retry counter: %v", err) } if retryCount < 3 { t.Fatalf("not enough retries: %v; expected %v", retryCount, maxRetryCount) } }) } } func TestRetryLoginRequest(t *testing.T) { logger.Info("Retry N times for timeouts and Success") client := &fakeHTTPClient{ cnt: 3, success: true, timeout: true, t: t, expectedQueryParams: map[int]map[string]string{ 0: { "retryCount": "", "retryReason": "", }, 1: { "retryCount": "", "retryReason": "", }, 2: { "retryCount": "", "retryReason": "", }, }, } urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) assertNilF(t, err, "failed to parse the test URL") if values.Get(retryCountKey) != "" { t.Fatalf("no retry counter should be attached: %v", retryCountKey) } logger.Info("Retry N times for timeouts and Fail") client = &fakeHTTPClient{ success: false, timeout: true, t: t, } _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 5*time.Second, 3, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() assertNotNilF(t, err, "should fail to run retry") values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { t.Fatalf("failed to parse the URL: %v", err) } if values.Get(retryCountKey) != "" { t.Fatalf("no retry counter should be attached: %v", retryCountKey) } } func TestRetryAuthLoginRequest(t *testing.T) { logger.Info("Retry N times always with newer body") client := &fakeHTTPClient{ cnt: 3, success: true, timeout: true, t: t, } urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") assertNilF(t, err, "failed to parse the test URL") execID := 0 bodyCreator := func() ([]byte, error) { execID++ return fmt.Appendf(nil, "execID: %d", execID), nil } _, err = newRetryHTTP(context.Background(), client, http.NewRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, nil).doPost().setBodyCreator(bodyCreator).execute() assertNilF(t, err, "failed to run retry") if lastReqBody := string(client.reqBody); lastReqBody != "execID: 3" { t.Fatalf("body should be updated on each request, expected: execID: 3, last body: %v", lastReqBody) } } func TestLoginRetry429(t *testing.T) { client := &fakeHTTPClient{ cnt: 3, success: true, statusCode: 429, t: t, } urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) assertNilF(t, err, fmt.Sprintf("failed to parse the URL: %v", err)) if values.Get(retryCountKey) != "" { t.Fatalf("no retry counter should be attached: %v", retryCountKey) } } func TestIsRetryable(t *testing.T) { deadLineCtx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() time.Sleep(2 * time.Nanosecond) tcs := []struct { ctx context.Context req *http.Request res *http.Response err error expected bool }{ { ctx: context.Background(), req: nil, res: nil, err: nil, expected: false, }, { ctx: context.Background(), req: nil, res: &http.Response{StatusCode: http.StatusBadRequest}, err: nil, expected: false, }, { ctx: context.Background(), req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: nil, err: nil, expected: false, }, { ctx: context.Background(), req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: &http.Response{StatusCode: http.StatusNotFound}, expected: false, }, { ctx: context.Background(), req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: nil, err: &url.Error{Err: context.DeadlineExceeded}, expected: true, }, { ctx: context.Background(), req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: nil, err: errors.ErrUnknownError(), expected: true, }, { ctx: context.Background(), req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: &http.Response{StatusCode: http.StatusTooManyRequests}, err: nil, expected: true, }, { ctx: deadLineCtx, req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: nil, err: &url.Error{Err: context.DeadlineExceeded}, expected: false, }, { ctx: deadLineCtx, req: &http.Request{URL: &url.URL{Path: queryRequestPath}}, res: nil, err: &url.Error{Err: context.DeadlineExceeded}, expected: false, }, } for _, tc := range tcs { t.Run(fmt.Sprintf("req %v, resp %v", tc.req, tc.res), func(t *testing.T) { result, _ := isRetryableError(tc.ctx, tc.req, tc.res, tc.err) if result != tc.expected { t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res) } }) } } func TestCalculateRetryWait(t *testing.T) { // test for randomly selected attempt and currWaitTime values // minSleepTime, maxSleepTime are limit values tcs := []struct { attempt int currWaitTime float64 minSleepTime float64 maxSleepTime float64 }{ { attempt: 1, currWaitTime: 3.346609, minSleepTime: 0.326695, maxSleepTime: 5.019914, }, { attempt: 2, currWaitTime: 4.260357, minSleepTime: 1.869821, maxSleepTime: 6.390536, }, { attempt: 3, currWaitTime: 7.857728, minSleepTime: 3.928864, maxSleepTime: 11.928864, }, { attempt: 4, currWaitTime: 7.249255, minSleepTime: 3.624628, maxSleepTime: 19.624628, }, { attempt: 5, currWaitTime: 23.598257, minSleepTime: 11.799129, maxSleepTime: 43.799129, }, { attempt: 8, currWaitTime: 27.088613, minSleepTime: 13.544306, maxSleepTime: 269.544306, }, { attempt: 10, currWaitTime: 30.879329, minSleepTime: 15.439664, maxSleepTime: 1039.439664, }, { attempt: 12, currWaitTime: 39.919798, minSleepTime: 19.959899, maxSleepTime: 4115.959899, }, { attempt: 15, currWaitTime: 33.750758, minSleepTime: 16.875379, maxSleepTime: 32784.875379, }, { attempt: 20, currWaitTime: 32.357793, minSleepTime: 16.178897, maxSleepTime: 1048592.178897, }, } for _, tc := range tcs { t.Run(fmt.Sprintf("attmept: %v", tc.attempt), func(t *testing.T) { result := defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(tc.attempt, time.Duration(tc.currWaitTime*float64(time.Second))) assertBetweenE(t, result.Seconds(), tc.minSleepTime, tc.maxSleepTime) }) } } func TestCalculateRetryWaitForNonAuthRequests(t *testing.T) { // test for randomly selected currWaitTime values // maxSleepTime is the limit value tcs := []struct { currWaitTime float64 maxSleepTime float64 }{ { currWaitTime: 3.346609, maxSleepTime: 10.039827, }, { currWaitTime: 4.260357, maxSleepTime: 12.781071, }, { currWaitTime: 5.154231, maxSleepTime: 15.462693, }, { currWaitTime: 7.249255, maxSleepTime: 16, }, { currWaitTime: 23.598257, maxSleepTime: 16, }, } for _, tc := range tcs { defaultMinSleepTime := 1 t.Run(fmt.Sprintf("currWaitTime: %v", tc.currWaitTime), func(t *testing.T) { result := defaultWaitAlgo.calculateWaitBeforeRetry(time.Duration(tc.currWaitTime) * time.Second) assertBetweenInclusiveE(t, result.Seconds(), float64(defaultMinSleepTime), tc.maxSleepTime) }) } } func TestRedirectRetry(t *testing.T) { wiremock.registerMappings(t, newWiremockMapping("retry/redirection_retry_workflow.json")) cfg := wiremock.connectionConfig() cfg.ClientTimeout = 3 * time.Second connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) runSmokeQuery(t, db) } ================================================ FILE: rows.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "github.com/snowflakedb/gosnowflake/v2/internal/errors" "io" "reflect" "strings" "time" "github.com/apache/arrow-go/v18/arrow" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" ) const ( headerSseCAlgorithm = "x-amz-server-side-encryption-customer-algorithm" headerSseCKey = "x-amz-server-side-encryption-customer-key" headerSseCAes = "AES256" ) var ( // customJSONDecoderEnabled has the chunk downloader use the custom JSON decoder to reduce memory footprint. customJSONDecoderEnabled = false maxChunkDownloaderErrorCounter = 5 ) const defaultMaxChunkDownloadWorkers = 10 const clientPrefetchThreadsKey = "client_prefetch_threads" // SnowflakeRows provides an API for methods exposed to the clients type SnowflakeRows interface { GetQueryID() string GetStatus() QueryStatus // NextResultSet switches Arrow Batches to the next result set. // Returns io.EOF if there are no more result sets. NextResultSet() error } type snowflakeRows struct { sc *snowflakeConn ChunkDownloader chunkDownloader tailChunkDownloader chunkDownloader queryID string status QueryStatus err error errChannel chan error location *time.Location ctx context.Context } func (rows *snowflakeRows) getLocation() *time.Location { if rows.location == nil && rows.sc != nil && rows.sc.cfg != nil { rows.location = getCurrentLocation(&rows.sc.syncParams) } return rows.location } type snowflakeValue any type chunkRowType struct { RowSet []*string ArrowRow []snowflakeValue } type rowSetType struct { RowType []query.ExecResponseRowType JSON [][]*string RowSetBase64 string } type chunkError struct { Index int Error error } func (rows *snowflakeRows) Close() (err error) { if err := rows.waitForAsyncQueryStatus(); err != nil { return err } logger.WithContext(rows.sc.ctx).Debug("Rows.Close") if scd, ok := rows.ChunkDownloader.(*snowflakeChunkDownloader); ok { scd.releaseRawArrowBatches() } return nil } // ColumnTypeDatabaseTypeName returns the database column name. func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string { if err := rows.waitForAsyncQueryStatus(); err != nil { return err.Error() } return strings.ToUpper(rows.ChunkDownloader.getRowType()[index].Type) } // ColumnTypeLength returns the length of the column func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) { if err := rows.waitForAsyncQueryStatus(); err != nil { return 0, false } if index < 0 || index > len(rows.ChunkDownloader.getRowType()) { return 0, false } switch rows.ChunkDownloader.getRowType()[index].Type { case "text", "variant", "object", "array", "binary": return rows.ChunkDownloader.getRowType()[index].Length, true } return 0, false } func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) { if err := rows.waitForAsyncQueryStatus(); err != nil { return false, false } if index < 0 || index > len(rows.ChunkDownloader.getRowType()) { return false, false } return rows.ChunkDownloader.getRowType()[index].Nullable, true } func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { if err := rows.waitForAsyncQueryStatus(); err != nil { return 0, 0, false } rowType := rows.ChunkDownloader.getRowType() if index < 0 || index > len(rowType) { return 0, 0, false } switch rowType[index].Type { case "fixed": return rowType[index].Precision, rowType[index].Scale, true case "time": return rowType[index].Scale, 0, true case "timestamp": return rowType[index].Scale, 0, true } return 0, 0, false } func (rows *snowflakeRows) Columns() []string { if err := rows.waitForAsyncQueryStatus(); err != nil { return make([]string, 0) } logger.WithContext(rows.ctx).Debug("Rows.Columns") ret := make([]string, len(rows.ChunkDownloader.getRowType())) for i, n := 0, len(rows.ChunkDownloader.getRowType()); i < n; i++ { ret[i] = rows.ChunkDownloader.getRowType()[i].Name } return ret } func (rows *snowflakeRows) ColumnTypeScanType(index int) reflect.Type { if err := rows.waitForAsyncQueryStatus(); err != nil { return nil } return snowflakeTypeToGo(rows.ctx, types.GetSnowflakeType(rows.ChunkDownloader.getRowType()[index].Type), rows.ChunkDownloader.getRowType()[index].Precision, rows.ChunkDownloader.getRowType()[index].Scale, rows.ChunkDownloader.getRowType()[index].Fields) } func (rows *snowflakeRows) GetQueryID() string { return rows.queryID } func (rows *snowflakeRows) GetStatus() QueryStatus { return rows.status } // GetArrowBatches returns raw arrow batch data for use by the arrowbatches sub-package. // Implements ia.BatchDataProvider. func (rows *snowflakeRows) GetArrowBatches() (*ia.BatchDataInfo, error) { if err := rows.waitForAsyncQueryStatus(); err != nil { return nil, err } if rows.ChunkDownloader.getQueryResultFormat() != arrowFormat { return nil, exceptionTelemetry(errors.ErrNonArrowResponseForArrowBatches(rows.queryID), rows.sc) } scd, ok := rows.ChunkDownloader.(*snowflakeChunkDownloader) if !ok { return nil, &SnowflakeError{ Number: ErrNotImplemented, Message: "chunk downloader does not support arrow batch data", } } rawBatches := scd.getRawArrowBatches() batches := make([]ia.BatchRaw, len(rawBatches)) for i, raw := range rawBatches { batch := ia.BatchRaw{ Records: raw.records, Index: i, RowCount: raw.rowCount, Location: raw.loc, } raw.records = nil if batch.Records == nil { capturedIdx := i if scd.firstBatchRaw != nil { capturedIdx = i - 1 } batch.Download = func(ctx context.Context) (*[]arrow.Record, int, error) { if err := scd.FuncDownloadHelper(ctx, scd, capturedIdx); err != nil { return nil, 0, err } actualRaw := scd.rawBatches[capturedIdx] return actualRaw.records, actualRaw.rowCount, nil } } batches[i] = batch } return &ia.BatchDataInfo{ Batches: batches, RowTypes: scd.RowSet.RowType, Allocator: scd.pool, Ctx: scd.ctx, QueryID: rows.queryID, }, nil } func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { if err = rows.waitForAsyncQueryStatus(); err != nil { return err } row, err := rows.ChunkDownloader.next() if err != nil { // includes io.EOF if err == io.EOF { rows.ChunkDownloader.reset() } return err } if rows.ChunkDownloader.getQueryResultFormat() == arrowFormat { for i, n := 0, len(row.ArrowRow); i < n; i++ { dest[i] = row.ArrowRow[i] } } else { for i, n := 0, len(row.RowSet); i < n; i++ { // could move to chunk downloader so that each go routine // can convert data err = stringToValue(rows.ctx, &dest[i], rows.ChunkDownloader.getRowType()[i], row.RowSet[i], rows.getLocation(), &rows.sc.syncParams) if err != nil { return err } } } return err } func (rows *snowflakeRows) HasNextResultSet() bool { if err := rows.waitForAsyncQueryStatus(); err != nil { return false } hasNextResultSet := rows.ChunkDownloader.getNextChunkDownloader() != nil logger.WithContext(rows.ctx).Debugf("[queryId: %v] Rows.HasNextResultSet: %v", rows.queryID, hasNextResultSet) return hasNextResultSet } func (rows *snowflakeRows) NextResultSet() error { logger.WithContext(rows.ctx).Debugf("[queryId: %v] Rows.NextResultSet", rows.queryID) if err := rows.waitForAsyncQueryStatus(); err != nil { return err } if rows.ChunkDownloader.getNextChunkDownloader() == nil { return io.EOF } rows.ChunkDownloader = rows.ChunkDownloader.getNextChunkDownloader() if err := rows.ChunkDownloader.start(); err != nil { return err } return nil } func (rows *snowflakeRows) waitForAsyncQueryStatus() error { // if async query, block until query is finished switch rows.status { case QueryStatusInProgress: err := <-rows.errChannel rows.status = QueryStatusComplete if err != nil { rows.status = QueryFailed rows.err = err return rows.err } case QueryFailed: return rows.err default: return nil } return nil } func (rows *snowflakeRows) addDownloader(newDL chunkDownloader) { if rows.ChunkDownloader == nil { rows.ChunkDownloader = newDL rows.tailChunkDownloader = newDL return } rows.tailChunkDownloader.setNextChunkDownloader(newDL) rows.tailChunkDownloader = newDL } ================================================ FILE: rows_test.go ================================================ package gosnowflake import ( "context" "database/sql" "database/sql/driver" "fmt" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "github.com/snowflakedb/gosnowflake/v2/internal/query" "io" "net/http" "sync" "testing" "time" ) type RowsExtended struct { rows *sql.Rows closeChan *chan bool t *testing.T } func (rs *RowsExtended) Close() error { *rs.closeChan <- true close(*rs.closeChan) return rs.rows.Close() } func (rs *RowsExtended) ColumnTypes() ([]*sql.ColumnType, error) { return rs.rows.ColumnTypes() } func (rs *RowsExtended) Columns() ([]string, error) { return rs.rows.Columns() } func (rs *RowsExtended) Err() error { return rs.rows.Err() } func (rs *RowsExtended) Next() bool { return rs.rows.Next() } func (rs *RowsExtended) mustNext() { assertTrueF(rs.t, rs.rows.Next()) } func (rs *RowsExtended) NextResultSet() bool { return rs.rows.NextResultSet() } func (rs *RowsExtended) Scan(dest ...any) error { return rs.rows.Scan(dest...) } func (rs *RowsExtended) mustScan(dest ...any) { err := rs.rows.Scan(dest...) assertNilF(rs.t, err) } // test variables var ( rowsInChunk = 123 ) // Special cases where rows are already closed func TestRowsClose(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows, err := dbt.query("SELECT 1") if err != nil { dbt.Fatal(err) } if err = rows.Close(); err != nil { dbt.Fatal(err) } if rows.Next() { dbt.Fatal("unexpected row after rows.Close()") } if err = rows.Err(); err != nil { dbt.Fatal(err) } }) } func TestResultNoRows(t *testing.T) { // DDL runDBTest(t, func(dbt *DBTest) { row, err := dbt.exec("CREATE OR REPLACE TABLE test(c1 int)") if err != nil { t.Fatalf("failed to execute DDL. err: %v", err) } if _, err = row.RowsAffected(); err == nil { t.Fatal("should have failed to get RowsAffected") } if _, err = row.LastInsertId(); err == nil { t.Fatal("should have failed to get LastInsertID") } }) } func TestRowsWithoutChunkDownloader(t *testing.T) { sts1 := "1" sts2 := "Test1" var i int cc := make([][]*string, 0) for i = 0; i < 10; i++ { cc = append(cc, []*string{&sts1, &sts2}) } rt := []query.ExecResponseRowType{ {Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true}, {Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false}, } cm := []query.ExecResponseChunk{} rows := new(snowflakeRows) sc := &snowflakeConn{ cfg: &Config{}, } rows.sc = sc rows.ctx = context.Background() rows.ChunkDownloader = &snowflakeChunkDownloader{ sc: sc, ctx: context.Background(), Total: int64(len(cc)), ChunkMetas: cm, TotalRowIndex: int64(-1), Qrmk: "", FuncDownload: nil, FuncDownloadHelper: nil, RowSet: rowSetType{RowType: rt, JSON: cc}, QueryResultFormat: "json", } err := rows.ChunkDownloader.start() assertNilF(t, err) dest := make([]driver.Value, 2) for i = 0; i < len(cc); i++ { if err := rows.Next(dest); err != nil { t.Fatalf("failed to get value. err: %v", err) } if dest[0] != sts1 { t.Fatalf("failed to get value. expected: %v, got: %v", sts1, dest[0]) } if dest[1] != sts2 { t.Fatalf("failed to get value. expected: %v, got: %v", sts2, dest[1]) } } if err := rows.Next(dest); err != io.EOF { t.Fatalf("failed to finish getting data. err: %v", err) } logger.Infof("dest: %v", dest) } func downloadChunkTest(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { d := make([][]*string, 0) for i := range rowsInChunk { v1 := fmt.Sprintf("%v", idx*1000+i) v2 := fmt.Sprintf("testchunk%v", idx*1000+i) d = append(d, []*string{&v1, &v2}) } scd.ChunksMutex.Lock() scd.Chunks[idx] = make([]chunkRowType, len(d)) populateJSONRowSet(scd.Chunks[idx], d) scd.DoneDownloadCond.Broadcast() scd.ChunksMutex.Unlock() } func TestRowsWithChunkDownloader(t *testing.T) { numChunks := 12 var i int cc := make([][]*string, 0) for i = 0; i < 100; i++ { v1 := fmt.Sprintf("%v", i) v2 := fmt.Sprintf("Test%v", i) cc = append(cc, []*string{&v1, &v2}) } rt := []query.ExecResponseRowType{ {Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true}, {Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false}, } cm := make([]query.ExecResponseChunk, 0) for i = range numChunks { cm = append(cm, query.ExecResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk}) } rows := new(snowflakeRows) two := "2" params := map[string]*string{ clientPrefetchThreadsKey: &two, } sc := &snowflakeConn{ cfg: &Config{}, syncParams: syncParams{params: params}, } rows.sc = sc rows.ctx = context.Background() rows.ChunkDownloader = &snowflakeChunkDownloader{ sc: sc, ctx: context.Background(), Total: int64(len(cc) + numChunks*rowsInChunk), ChunkMetas: cm, TotalRowIndex: int64(-1), Qrmk: "HAHAHA", FuncDownload: downloadChunkTest, RowSet: rowSetType{RowType: rt, JSON: cc}, } assertNilF(t, rows.ChunkDownloader.start()) cnt := 0 dest := make([]driver.Value, 2) var err error for err != io.EOF { err := rows.Next(dest) if err == io.EOF { break } if err != nil { t.Fatalf("failed to get value. err: %v", err) } cnt++ } if cnt != len(cc)+numChunks*rowsInChunk { t.Fatalf("failed to get all results. expected:%v, got:%v", len(cc)+numChunks*rowsInChunk, cnt) } logger.Infof("dest: %v", dest) } func downloadChunkTestError(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { // fail to download 6th and 10th chunk, and retry up to N times and success // NOTE: zero based index scd.ChunksMutex.Lock() defer scd.ChunksMutex.Unlock() if (idx == 6 || idx == 10) && scd.ChunksErrorCounter < maxChunkDownloaderErrorCounter { scd.ChunksError <- &chunkError{ Index: idx, Error: fmt.Errorf( "dummy error. idx: %v, errCnt: %v", idx+1, scd.ChunksErrorCounter)} scd.DoneDownloadCond.Broadcast() return } d := make([][]*string, 0) for i := range rowsInChunk { v1 := fmt.Sprintf("%v", idx*1000+i) v2 := fmt.Sprintf("testchunk%v", idx*1000+i) d = append(d, []*string{&v1, &v2}) } scd.Chunks[idx] = make([]chunkRowType, len(d)) populateJSONRowSet(scd.Chunks[idx], d) scd.DoneDownloadCond.Broadcast() } func TestRowsWithChunkDownloaderError(t *testing.T) { numChunks := 12 var i int cc := make([][]*string, 0) for i = 0; i < 100; i++ { v1 := fmt.Sprintf("%v", i) v2 := fmt.Sprintf("Test%v", i) cc = append(cc, []*string{&v1, &v2}) } rt := []query.ExecResponseRowType{ {Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true}, {Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false}, } cm := make([]query.ExecResponseChunk, 0) for i = range numChunks { cm = append(cm, query.ExecResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk}) } rows := new(snowflakeRows) three := "3" params := map[string]*string{ clientPrefetchThreadsKey: &three, } sc := &snowflakeConn{ cfg: &Config{}, syncParams: syncParams{params: params}, } rows.sc = sc rows.ctx = context.Background() rows.ChunkDownloader = &snowflakeChunkDownloader{ sc: sc, ctx: context.Background(), Total: int64(len(cc) + numChunks*rowsInChunk), ChunkMetas: cm, TotalRowIndex: int64(-1), Qrmk: "HOHOHO", FuncDownload: downloadChunkTestError, RowSet: rowSetType{RowType: rt, JSON: cc}, } assertNilF(t, rows.ChunkDownloader.start()) cnt := 0 dest := make([]driver.Value, 2) var err error for err != io.EOF { err := rows.Next(dest) if err == io.EOF { break } if err != nil { t.Fatalf("failed to get value. err: %v", err) } // fmt.Printf("data: %v\n", dest) cnt++ } if cnt != len(cc)+numChunks*rowsInChunk { t.Fatalf("failed to get all results. expected:%v, got:%v", len(cc)+numChunks*rowsInChunk, cnt) } logger.Infof("dest: %v", dest) } func downloadChunkTestErrorFail(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { // fail to download 6th and 10th chunk, and retry up to N times and fail // NOTE: zero based index scd.ChunksMutex.Lock() defer scd.ChunksMutex.Unlock() if idx == 6 && scd.ChunksErrorCounter <= maxChunkDownloaderErrorCounter { scd.ChunksError <- &chunkError{ Index: idx, Error: fmt.Errorf( "dummy error. idx: %v, errCnt: %v", idx+1, scd.ChunksErrorCounter)} scd.DoneDownloadCond.Broadcast() return } d := make([][]*string, 0) for i := range rowsInChunk { v1 := fmt.Sprintf("%v", idx*1000+i) v2 := fmt.Sprintf("testchunk%v", idx*1000+i) d = append(d, []*string{&v1, &v2}) } scd.Chunks[idx] = make([]chunkRowType, len(d)) populateJSONRowSet(scd.Chunks[idx], d) scd.DoneDownloadCond.Broadcast() } func TestRowsWithChunkDownloaderErrorFail(t *testing.T) { numChunks := 12 // changed the workers logger.Info("START TESTS") var i int cc := make([][]*string, 0) for i = 0; i < 100; i++ { v1 := fmt.Sprintf("%v", i) v2 := fmt.Sprintf("Test%v", i) cc = append(cc, []*string{&v1, &v2}) } rt := []query.ExecResponseRowType{ {Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true}, {Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false}, } cm := make([]query.ExecResponseChunk, 0) for i = range numChunks { cm = append(cm, query.ExecResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk}) } rows := new(snowflakeRows) sc := &snowflakeConn{ cfg: &Config{}, } rows.sc = sc rows.ctx = context.Background() rows.ChunkDownloader = &snowflakeChunkDownloader{ sc: sc, ctx: context.Background(), Total: int64(len(cc) + numChunks*rowsInChunk), ChunkMetas: cm, TotalRowIndex: int64(-1), Qrmk: "HOHOHO", FuncDownload: downloadChunkTestErrorFail, RowSet: rowSetType{RowType: rt, JSON: cc}, } assertNilF(t, rows.ChunkDownloader.start()) cnt := 0 dest := make([]driver.Value, 2) var err error for err != io.EOF { err := rows.Next(dest) if err == io.EOF { break } if err != nil { logger.Infof( "failure was expected by the number of rows is wrong. expected: %v, got: %v", 715, cnt) break } cnt++ } } func getChunkTestInvalidResponseBody(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) ( *http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func TestDownloadChunkInvalidResponseBody(t *testing.T) { numChunks := 2 cm := make([]query.ExecResponseChunk, 0) for i := range numChunks { cm = append(cm, query.ExecResponseChunk{URL: fmt.Sprintf( "dummyURL%v", i+1), RowCount: rowsInChunk}) } scd := &snowflakeChunkDownloader{ sc: &snowflakeConn{ rest: &snowflakeRestful{RequestTimeout: sfconfig.DefaultRequestTimeout}, }, ctx: context.Background(), ChunkMetas: cm, TotalRowIndex: int64(-1), Qrmk: "HOHOHO", FuncDownload: downloadChunk, FuncDownloadHelper: downloadChunkHelper, FuncGet: getChunkTestInvalidResponseBody, } scd.ChunksMutex = &sync.Mutex{} scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex) scd.Chunks = make(map[int][]chunkRowType) scd.ChunksError = make(chan *chunkError, 1) scd.FuncDownload(scd.ctx, scd, 1) select { case errc := <-scd.ChunksError: if errc.Index != 1 { t.Fatalf("the error should have caused with chunk idx: %v", errc.Index) } default: t.Fatal("should have caused an error and queued in scd.ChunksError") } } func getChunkTestErrorStatus(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) ( *http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func TestDownloadChunkErrorStatus(t *testing.T) { numChunks := 2 cm := make([]query.ExecResponseChunk, 0) for i := range numChunks { cm = append(cm, query.ExecResponseChunk{URL: fmt.Sprintf( "dummyURL%v", i+1), RowCount: rowsInChunk}) } scd := &snowflakeChunkDownloader{ sc: &snowflakeConn{ rest: &snowflakeRestful{RequestTimeout: sfconfig.DefaultRequestTimeout}, }, ctx: context.Background(), ChunkMetas: cm, TotalRowIndex: int64(-1), Qrmk: "HOHOHO", FuncDownload: downloadChunk, FuncDownloadHelper: downloadChunkHelper, FuncGet: getChunkTestErrorStatus, } scd.ChunksMutex = &sync.Mutex{} scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex) scd.Chunks = make(map[int][]chunkRowType) scd.ChunksError = make(chan *chunkError, 1) scd.FuncDownload(scd.ctx, scd, 1) select { case errc := <-scd.ChunksError: if errc.Index != 1 { t.Fatalf("the error should have caused with chunk idx: %v", errc.Index) } serr, ok := errc.Error.(*SnowflakeError) if !ok { t.Fatalf("should have been snowflake error. err: %v", errc.Error) } if serr.Number != ErrFailedToGetChunk { t.Fatalf("message error code is not correct. msg: %v", serr.Number) } default: t.Fatal("should have caused an error and queued in scd.ChunksError") } } func TestLocationChangesAfterAlterSession(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE OR REPLACE TABLE location_timestamp_ltz (val timestamp_ltz)") defer dbt.mustExec("DROP TABLE location_timestamp_ltz") dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.mustExec("INSERT INTO location_timestamp_ltz VALUES('2023-08-09 10:00:00')") rows1 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz") defer func() { assertNilF(t, rows1.Close()) }() if !rows1.Next() { t.Fatalf("cannot read a record") } var t1 time.Time assertNilF(t, rows1.Scan(&t1)) if t1.Location().String() != "Europe/Warsaw" { t.Fatalf("should return time in Warsaw timezone") } dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Pacific/Honolulu'") rows2 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz") defer func() { assertNilF(t, rows2.Close()) }() if !rows2.Next() { t.Fatalf("cannot read a record") } var t2 time.Time assertNilF(t, rows2.Scan(&t2)) if t2.Location().String() != "Pacific/Honolulu" { t.Fatalf("should return time in Honolulu timezone") } }) } ================================================ FILE: s3_storage_client.go ================================================ package gosnowflake import ( "bytes" "cmp" "context" "errors" "fmt" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/smithy-go" "github.com/aws/smithy-go/logging" "io" "net/http" "os" "strings" ) const ( sfcDigest = "sfc-digest" amzMatdesc = "x-amz-matdesc" amzKey = "x-amz-key" amzIv = "x-amz-iv" notFound = "NotFound" expiredToken = "ExpiredToken" errNoWsaeconnaborted = "10053" ) type snowflakeS3Client struct { cfg *Config telemetry *snowflakeTelemetry } type s3Location struct { bucketName string s3Path string } // S3LoggingMode allows to configure which logs should be included. // By default no logs are included. // See https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#ClientLogMode for allowed values. // Deprecated: will be moved to DSN/Config in a future release. var S3LoggingMode aws.ClientLogMode func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, telemetry *snowflakeTelemetry) (cloudClient, error) { stageCredentials := info.Creds s3Logger := logging.LoggerFunc(s3LoggingFunc) endPoint := getS3CustomEndpoint(info) transport, err := newTransportFactory(util.cfg, telemetry).createTransport(transportConfigFor(transportTypeCloudProvider)) if err != nil { return nil, err } return s3.New(s3.Options{ Region: info.Region, Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider( stageCredentials.AwsKeyID, stageCredentials.AwsSecretKey, stageCredentials.AwsToken)), BaseEndpoint: endPoint, UseAccelerate: useAccelerateEndpoint, HTTPClient: &http.Client{ Transport: transport, }, ClientLogMode: S3LoggingMode, Logger: s3Logger, }), nil } // to be used with S3 transferAccelerateConfigWithUtil func (util *snowflakeS3Client) createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config, telemetry *snowflakeTelemetry) (cloudClient, error) { // copy snowflakeFileTransferAgent's config onto the cloud client so we could decide which Transport to use util.cfg = cfg util.telemetry = telemetry return util.createClient(info, useAccelerateEndpoint, telemetry) } func getS3CustomEndpoint(info *execResponseStageInfo) *string { var endPoint *string isRegionalURLEnabled := info.UseRegionalURL || info.UseS3RegionalURL if info.EndPoint != "" { tmp := fmt.Sprintf("https://%s", info.EndPoint) endPoint = &tmp } else if info.Region != "" && isRegionalURLEnabled { domainSuffixForRegionalURL := "amazonaws.com" if strings.HasPrefix(strings.ToLower(info.Region), "cn-") { domainSuffixForRegionalURL = "amazonaws.com.cn" } tmp := fmt.Sprintf("https://s3.%s.%s", info.Region, domainSuffixForRegionalURL) endPoint = &tmp } return endPoint } func s3LoggingFunc(classification logging.Classification, format string, v ...any) { switch classification { case logging.Debug: logger.WithField("logger", "S3").Debugf(format, v...) case logging.Warn: logger.WithField("logger", "S3").Warnf(format, v...) } } type s3HeaderAPI interface { HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) } // cloudUtil implementation func (util *snowflakeS3Client) getFileHeader(ctx context.Context, meta *fileMetadata, filename string) (*fileHeader, error) { headObjInput, err := util.getS3Object(meta, filename) if err != nil { return nil, err } var s3Cli s3HeaderAPI s3Cli, ok := meta.client.(*s3.Client) if !ok { return nil, errors.New("could not parse client to s3.Client") } // for testing only if meta.mockHeader != nil { s3Cli = meta.mockHeader } out, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (*s3.HeadObjectOutput, error) { return s3Cli.HeadObject(ctx, headObjInput) }) if err != nil { var ae smithy.APIError if errors.As(err, &ae) { if ae.ErrorCode() == notFound { meta.resStatus = notFoundFile return nil, errors.New("could not find file") } else if ae.ErrorCode() == expiredToken { meta.resStatus = renewToken return nil, errors.New("received expired token. renewing") } meta.resStatus = errStatus meta.lastError = err return nil, fmt.Errorf("error while retrieving header, errorCode=%v. %w", ae.ErrorCode(), err) } meta.resStatus = errStatus meta.lastError = err return nil, fmt.Errorf("unexpected error while retrieving header: %w", err) } meta.resStatus = uploaded var encMeta encryptMetadata if out.Metadata[amzKey] != "" { encMeta = encryptMetadata{ out.Metadata[amzKey], out.Metadata[amzIv], out.Metadata[amzMatdesc], } } contentLength := convertContentLength(out.ContentLength) return &fileHeader{ out.Metadata[sfcDigest], contentLength, &encMeta, }, nil } // SNOW-974548 remove this function after upgrading AWS SDK func convertContentLength(contentLength any) int64 { switch t := contentLength.(type) { case int64: return t case *int64: if t != nil { return *t } } return 0 } type s3UploadAPI interface { Upload(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) } // cloudUtil implementation func (util *snowflakeS3Client) uploadFile( ctx context.Context, dataFile string, meta *fileMetadata, maxConcurrency int, multiPartThreshold int64) error { s3Meta := map[string]string{ httpHeaderContentType: httpHeaderValueOctetStream, sfcDigest: meta.sha256Digest, } if meta.encryptMeta != nil { s3Meta[amzIv] = meta.encryptMeta.iv s3Meta[amzKey] = meta.encryptMeta.key s3Meta[amzMatdesc] = meta.encryptMeta.matdesc } s3loc, err := util.extractBucketNameAndPath(meta.stageInfo.Location) if err != nil { return err } s3path := s3loc.s3Path + strings.TrimLeft(meta.dstFileName, "/") client, ok := meta.client.(*s3.Client) if !ok { return &SnowflakeError{ Message: "failed to cast to s3 client", } } var uploader s3UploadAPI uploader = manager.NewUploader(client, func(u *manager.Uploader) { u.Concurrency = maxConcurrency u.PartSize = int64Max(multiPartThreshold, manager.DefaultUploadPartSize) }) // for testing only if meta.mockUploader != nil { uploader = meta.mockUploader } _, err = withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (any, error) { if meta.srcStream != nil { uploadStream := cmp.Or(meta.realSrcStream, meta.srcStream) return uploader.Upload(ctx, &s3.PutObjectInput{ Bucket: &s3loc.bucketName, Key: &s3path, Body: bytes.NewBuffer(uploadStream.Bytes()), Metadata: s3Meta, }) } var file *os.File file, err = os.Open(dataFile) if err != nil { return nil, err } defer func() { if err = file.Close(); err != nil { logger.Warnf("failed to close %v file: %v", dataFile, err) } }() return uploader.Upload(ctx, &s3.PutObjectInput{ Bucket: &s3loc.bucketName, Key: &s3path, Body: file, Metadata: s3Meta, }) }) if err != nil { var ae smithy.APIError if errors.As(err, &ae) { if ae.ErrorCode() == expiredToken { meta.resStatus = renewToken return err } else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) { meta.lastError = err meta.resStatus = needRetryWithLowerConcurrency return err } } meta.lastError = err meta.resStatus = needRetry return fmt.Errorf("error while uploading file. %w", err) } meta.dstFileSize = meta.uploadSize meta.resStatus = uploaded return nil } type s3DownloadAPI interface { Download(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) } // cloudUtil implementation func (util *snowflakeS3Client) nativeDownloadFile( ctx context.Context, meta *fileMetadata, fullDstFileName string, maxConcurrency int64, partSize int64) error { s3Obj, _ := util.getS3Object(meta, meta.srcFileName) client, ok := meta.client.(*s3.Client) if !ok { return &SnowflakeError{ Message: "failed to cast to s3 client", } } logger.Debugf("S3 Client: Send Get Request to the Bucket: %v", meta.stageInfo.Location) var downloader s3DownloadAPI downloader = manager.NewDownloader(client, func(u *manager.Downloader) { u.Concurrency = int(maxConcurrency) u.PartSize = int64Max(partSize, manager.DefaultDownloadPartSize) }) // for testing only if meta.mockDownloader != nil { downloader = meta.mockDownloader } _, err := withCloudStorageTimeout(ctx, util.cfg, func(ctx context.Context) (any, error) { if isFileGetStream(ctx) { buf := manager.NewWriteAtBuffer([]byte{}) if _, err := downloader.Download(ctx, buf, &s3.GetObjectInput{ Bucket: s3Obj.Bucket, Key: s3Obj.Key, }); err != nil { return nil, err } meta.dstStream = bytes.NewBuffer(buf.Bytes()) } else { f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode) if err != nil { return nil, err } defer func() { if err = f.Close(); err != nil { logger.Warnf("failed to close %v file: %v", fullDstFileName, err) } }() if _, err = downloader.Download(ctx, f, &s3.GetObjectInput{ Bucket: s3Obj.Bucket, Key: s3Obj.Key, }); err != nil { return nil, err } } return nil, nil }) if err != nil { var ae smithy.APIError if errors.As(err, &ae) { if ae.ErrorCode() == expiredToken { meta.resStatus = renewToken return err } else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) { meta.lastError = err meta.resStatus = needRetryWithLowerConcurrency return err } meta.lastError = err meta.resStatus = errStatus return fmt.Errorf("error while downloading file, errorCode=%v. %w", ae.ErrorCode(), err) } meta.lastError = err meta.resStatus = needRetry return fmt.Errorf("error while downloading file. %w", err) } meta.resStatus = downloaded return nil } func (util *snowflakeS3Client) extractBucketNameAndPath(location string) (*s3Location, error) { stageLocation, err := expandUser(location) if err != nil { return nil, err } bucketName := stageLocation s3Path := "" if before, after, ok := strings.Cut(stageLocation, "/"); ok { bucketName = before s3Path = after if s3Path != "" && !strings.HasSuffix(s3Path, "/") { s3Path += "/" } } return &s3Location{bucketName, s3Path}, nil } func (util *snowflakeS3Client) getS3Object(meta *fileMetadata, filename string) (*s3.HeadObjectInput, error) { s3loc, err := util.extractBucketNameAndPath(meta.stageInfo.Location) if err != nil { return nil, err } s3path := s3loc.s3Path + strings.TrimLeft(filename, "/") return &s3.HeadObjectInput{ Bucket: &s3loc.bucketName, Key: &s3path, }, nil } ================================================ FILE: s3_storage_client_test.go ================================================ package gosnowflake import ( "bytes" "context" "errors" "fmt" "io" "os" "path" "strconv" "testing" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/smithy-go" ) type tcBucketPath struct { in string bucket string path string } func TestExtractBucketNameAndPath(t *testing.T) { s3util := new(snowflakeS3Client) testcases := []tcBucketPath{ {"sfc-eng-regression/test_sub_dir/", "sfc-eng-regression", "test_sub_dir/"}, {"sfc-eng-regression/dir/test_stg/test_sub_dir/", "sfc-eng-regression", "dir/test_stg/test_sub_dir/"}, {"sfc-eng-regression/", "sfc-eng-regression", ""}, {"sfc-eng-regression//", "sfc-eng-regression", "/"}, {"sfc-eng-regression///", "sfc-eng-regression", "//"}, } for _, test := range testcases { t.Run(test.in, func(t *testing.T) { s3Loc, err := s3util.extractBucketNameAndPath(test.in) if err != nil { t.Error(err) } if s3Loc.bucketName != test.bucket { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.bucket, s3Loc.bucketName) } if s3Loc.s3Path != test.path { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.path, s3Loc.s3Path) } }) } } type mockUploadObjectAPI func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) func (m mockUploadObjectAPI) Upload( ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) { return m(ctx, params, optFns...) } func TestUploadOneFileToS3WSAEConnAborted(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-customer-stage/rwyi-testacco/users/9220/", LocationType: "S3", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: false, parallel: initialParallel, client: s3Cli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptMeta: testEncryptionMeta(), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) { return nil, &smithy.GenericAPIError{ Code: errNoWsaeconnaborted, Message: "mock err, connection aborted", } }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }} uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Error("should have raised an error") } if uploadMeta.lastMaxConcurrency == 0 { t.Fatalf("expected concurrency. got: 0") } if uploadMeta.lastMaxConcurrency != int(initialParallel/defaultMaxRetry) { t.Fatalf("expected last max concurrency to be: %v, got: %v", int(initialParallel/defaultMaxRetry), uploadMeta.lastMaxConcurrency) } initialParallel = 4 uploadMeta.parallel = initialParallel err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Error("should have raised an error") } if uploadMeta.lastMaxConcurrency == 0 { t.Fatalf("expected no last max concurrency. got: %v", uploadMeta.lastMaxConcurrency) } if uploadMeta.lastMaxConcurrency != 1 { t.Fatalf("expected last max concurrency to be: 1, got: %v", uploadMeta.lastMaxConcurrency) } } func TestUploadOneFileToS3ConnReset(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-teststage/rwyitestacco/users/1234/", LocationType: "S3", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: true, parallel: initialParallel, client: s3Cli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptMeta: testEncryptionMeta(), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) { return nil, &smithy.GenericAPIError{ Code: strconv.Itoa(-1), Message: "mock err, connection aborted", } }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Error("should have raised an error") } if uploadMeta.lastMaxConcurrency != 0 { t.Fatalf("expected no concurrency. got: %v", uploadMeta.lastMaxConcurrency) } } func TestUploadFileWithS3UploadFailedError(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-teststage/rwyitestacco/users/1234/", LocationType: "S3", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: true, parallel: initialParallel, client: s3Cli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptMeta: testEncryptionMeta(), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) { return nil, &smithy.GenericAPIError{ Code: expiredToken, Message: "An error occurred (ExpiredToken) when calling the " + "operation: The provided token has expired.", } }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != renewToken { t.Fatalf("expected %v result status, got: %v", renewToken, uploadMeta.resStatus) } } type mockHeaderAPI func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) func (m mockHeaderAPI) HeadObject( ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return m(ctx, params, optFns...) } func TestGetHeadExpiryError(t *testing.T) { meta := fileMetadata{ client: s3.New(s3.Options{}), stageInfo: &execResponseStageInfo{Location: ""}, mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return nil, &smithy.GenericAPIError{ Code: expiredToken, } }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeS3Client{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != renewToken { t.Fatalf("expected %v result status, got: %v", renewToken, meta.resStatus) } } func TestGetHeaderUnexpectedError(t *testing.T) { meta := fileMetadata{ client: s3.New(s3.Options{}), stageInfo: &execResponseStageInfo{Location: ""}, mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return nil, &smithy.GenericAPIError{ Code: "-1", } }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } if header, err := (&snowflakeS3Client{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != errStatus { t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus) } } func TestGetHeaderNonApiError(t *testing.T) { meta := fileMetadata{ client: s3.New(s3.Options{}), stageInfo: &execResponseStageInfo{Location: ""}, mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return nil, errors.New("something went wrong here") }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } header, err := (&snowflakeS3Client{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt") assertNilE(t, header, fmt.Sprintf("expected header to be nil, actual: %v", header)) assertNotNilE(t, err, "expected err to not be nil") assertEqualE(t, meta.resStatus, errStatus, fmt.Sprintf("expected %v result status for non-APIerror, got: %v", errStatus, meta.resStatus)) } func TestGetHeaderNotFoundError(t *testing.T) { meta := fileMetadata{ client: s3.New(s3.Options{}), stageInfo: &execResponseStageInfo{Location: ""}, mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return nil, &smithy.GenericAPIError{ Code: notFound, } }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } _, err := (&snowflakeS3Client{cfg: &Config{}}).getFileHeader(context.Background(), &meta, "file.txt") if err != nil && err.Error() != "could not find file" { t.Error(err) } if meta.resStatus != notFoundFile { t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus) } } type mockDownloadObjectAPI func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) func (m mockDownloadObjectAPI) Download( ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) { return m(ctx, w, params, optFns...) } func TestDownloadFileWithS3TokenExpired(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-teststage/rwyitestacco/users/1234/", LocationType: "S3", } dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: true, client: s3Cli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockDownloader: mockDownloadObjectAPI(func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) { return 0, &smithy.GenericAPIError{ Code: expiredToken, Message: "An error occurred (ExpiredToken) when calling the " + "operation: The provided token has expired.", } }), mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return &s3.HeadObjectOutput{}, nil }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } if downloadMeta.resStatus != renewToken { t.Fatalf("expected %v result status, got: %v", renewToken, downloadMeta.resStatus) } } func TestDownloadFileWithS3ConnReset(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-teststage/rwyitestacco/users/1234/", LocationType: "S3", } dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: true, client: s3Cli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockDownloader: mockDownloadObjectAPI(func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) { return 0, &smithy.GenericAPIError{ Code: strconv.Itoa(-1), Message: "mock err, connection aborted", } }), mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return &s3.HeadObjectOutput{}, nil }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } if downloadMeta.lastMaxConcurrency != 0 { t.Fatalf("expected no concurrency. got: %v", downloadMeta.lastMaxConcurrency) } } func TestDownloadOneFileToS3WSAEConnAborted(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-teststage/rwyitestacco/users/1234/", LocationType: "S3", } dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: true, client: s3Cli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockDownloader: mockDownloadObjectAPI(func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) { return 0, &smithy.GenericAPIError{ Code: errNoWsaeconnaborted, Message: "mock err, connection aborted", } }), mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return &s3.HeadObjectOutput{}, nil }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } if downloadMeta.resStatus != needRetryWithLowerConcurrency { t.Fatalf("expected %v result status, got: %v", needRetryWithLowerConcurrency, downloadMeta.resStatus) } } func TestDownloadOneFileToS3Failed(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-teststage/rwyitestacco/users/1234/", LocationType: "S3", } dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } downloadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: true, client: s3Cli, stageInfo: &info, dstFileName: "data1.txt.gz", overwrite: true, srcFileName: "data1.txt.gz", localLocation: dir, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockDownloader: mockDownloadObjectAPI(func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) { return 0, errors.New("Failed to upload file") }), mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return &s3.HeadObjectOutput{}, nil }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } err = new(remoteStorageUtil).downloadOneFile(context.Background(), &downloadMeta) if err == nil { t.Error("should have raised an error") } if downloadMeta.resStatus != needRetry { t.Fatalf("expected %v result status, got: %v", needRetry, downloadMeta.resStatus) } } func TestUploadFileToS3ClientCastFail(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-customer-stage/rwyi-testacco/users/9220/", LocationType: "S3", } dir, err := os.Getwd() if err != nil { t.Error(err) } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: false, client: azureCli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptMeta: testEncryptionMeta(), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Fatal("should have failed") } } func TestGetHeaderClientCastFail(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-customer-stage/rwyi-testacco/users/9220/", LocationType: "S3", } azureCli, err := new(snowflakeAzureClient).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } meta := fileMetadata{ client: azureCli, stageInfo: &execResponseStageInfo{Location: ""}, mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return nil, &smithy.GenericAPIError{ Code: notFound, } }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } _, err = new(snowflakeS3Client).getFileHeader(context.Background(), &meta, "file.txt") if err == nil { t.Fatal("should have failed") } } func TestS3UploadRetryWithHeaderNotFound(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-customer-stage/rwyi-testacco/users/9220/", LocationType: "S3", } initialParallel := int64(100) dir, err := os.Getwd() if err != nil { t.Error(err) } s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: true, parallel: initialParallel, client: s3Cli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), encryptMeta: testEncryptionMeta(), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) { return &manager.UploadOutput{ Location: "https://sfc-customer-stage/rwyi-testacco/users/9220/data1.txt.gz", }, nil }), mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return nil, &smithy.GenericAPIError{ Code: notFound, } }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) if err != nil { t.Error(err) } uploadMeta.uploadSize = fi.Size() err = (&remoteStorageUtil{cfg: &Config{}}).uploadOneFileWithRetry(context.Background(), &uploadMeta) if err != nil { t.Error(err) } if uploadMeta.resStatus != errStatus { t.Fatalf("expected %v result status, got: %v", errStatus, uploadMeta.resStatus) } } func TestS3UploadStreamFailed(t *testing.T) { info := execResponseStageInfo{ Location: "sfc-customer-stage/rwyi-testacco/users/9220/", LocationType: "S3", } initialParallel := int64(100) src := []byte{65, 66, 67} s3Cli, err := new(snowflakeS3Client).createClient(&info, false, &snowflakeTelemetry{}) if err != nil { t.Error(err) } uploadMeta := fileMetadata{ name: "data1.txt.gz", stageLocationType: "S3", noSleepingTime: true, parallel: initialParallel, client: s3Cli, sha256Digest: "123456789abcdef", stageInfo: &info, dstFileName: "data1.txt.gz", srcStream: bytes.NewBuffer(src), encryptMeta: testEncryptionMeta(), overwrite: true, options: &SnowflakeFileTransferOptions{ MultiPartThreshold: multiPartThreshold, }, mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) { return nil, errors.New("unexpected error uploading file") }), sfa: &snowflakeFileTransferAgent{ sc: &snowflakeConn{ cfg: &Config{}, }, }, } uploadMeta.realSrcStream = uploadMeta.srcStream err = new(remoteStorageUtil).uploadOneFile(context.Background(), &uploadMeta) if err == nil { t.Fatal("should have failed") } } func TestConvertContentLength(t *testing.T) { someInt := int64(1) tcs := []struct { contentLength any desc string expected int64 }{ { contentLength: someInt, desc: "int", expected: 1, }, { contentLength: &someInt, desc: "pointer", expected: 1, }, { contentLength: float64(1), desc: "another type", expected: 0, }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { actual := convertContentLength(tc.contentLength) assertEqualF(t, actual, tc.expected, fmt.Sprintf("expected %v (%T) but got %v (%T)", actual, actual, tc.expected, tc.expected)) }) } } func TestGetS3Endpoint(t *testing.T) { testcases := []struct { desc string in execResponseStageInfo out string }{ { desc: "when UseRegionalURL is valid and the region does not start with cn-", in: execResponseStageInfo{ UseS3RegionalURL: false, UseRegionalURL: true, EndPoint: "", Region: "WEST-1", }, out: "https://s3.WEST-1.amazonaws.com", }, { desc: "when UseS3RegionalURL is valid and the region does not start with cn-", in: execResponseStageInfo{ UseS3RegionalURL: true, UseRegionalURL: false, EndPoint: "", Region: "WEST-1", }, out: "https://s3.WEST-1.amazonaws.com", }, { desc: "when endPoint is enabled and the region does not start with cn-", in: execResponseStageInfo{ UseS3RegionalURL: false, UseRegionalURL: false, EndPoint: "s3.endpoint", Region: "mockLocation", }, out: "https://s3.endpoint", }, { desc: "when endPoint is enabled and the region starts with cn-", in: execResponseStageInfo{ UseS3RegionalURL: false, UseRegionalURL: false, EndPoint: "s3.endpoint", Region: "cn-mockLocation", }, out: "https://s3.endpoint", }, { desc: "when useS3RegionalURL is valid and domain starts with cn", in: execResponseStageInfo{ UseS3RegionalURL: true, UseRegionalURL: false, EndPoint: "", Region: "cn-mockLocation", }, out: "https://s3.cn-mockLocation.amazonaws.com.cn", }, { desc: "when useRegionalURL is valid and domain starts with cn", in: execResponseStageInfo{ UseS3RegionalURL: true, UseRegionalURL: false, EndPoint: "", Region: "cn-mockLocation", }, out: "https://s3.cn-mockLocation.amazonaws.com.cn", }, { desc: "when useRegionalURL is valid and domain starts with cn", in: execResponseStageInfo{ UseS3RegionalURL: true, UseRegionalURL: false, EndPoint: "", Region: "cn-mockLocation", }, out: "https://s3.cn-mockLocation.amazonaws.com.cn", }, { desc: "when endPoint is specified, both UseRegionalURL and useS3PRegionalUrl are valid, and the region starts with cn", in: execResponseStageInfo{ UseS3RegionalURL: true, UseRegionalURL: true, EndPoint: "s3.endpoint", Region: "cn-mockLocation", }, out: "https://s3.endpoint", }, } for _, test := range testcases { t.Run(test.desc, func(t *testing.T) { endpoint := getS3CustomEndpoint(&test.in) if *endpoint != test.out { t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out, *endpoint) } }) } } ================================================ FILE: secret_detector.go ================================================ package gosnowflake import loggerinternal "github.com/snowflakedb/gosnowflake/v2/internal/logger" // maskSecrets masks secrets in text (unexported for internal use within main package) func maskSecrets(text string) string { return loggerinternal.MaskSecrets(text) } ================================================ FILE: secret_detector_test.go ================================================ package gosnowflake import ( "fmt" "testing" "time" "github.com/golang-jwt/jwt/v5" ) const ( longToken = "_Y1ZNETTn5/qfUWj3Jedby7gipDzQs=UKyJH9DS=nFzzWnfZKGV+C7GopWC" + // pragma: allowlist secret "GD4LjOLLFZKOE26LXHDt3pTi4iI1qwKuSpf/FmClCMBSissVsU3Ei590FP0lPQQhcSG" + // pragma: allowlist secret "cDu69ZL_1X6e9h5z62t/iY7ZkII28n2qU=nrBJUgPRCIbtJQkVJXIuOHjX4G5yUEKjZ" + // pragma: allowlist secret "BAx4w6=_lqtt67bIA=o7D=oUSjfywsRFoloNIkBPXCwFTv+1RVUHgVA2g8A9Lw5XdJY" + // pragma: allowlist secret "uI8vhg=f0bKSq7AhQ2Bh" randomPassword = `Fh[+2J~AcqeqW%?` falsePositiveToken = "2020-04-30 23:06:04,069 - MainThread auth.py:397" + " - write_temporary_credential() - DEBUG - no ID token is given when " + "try to store temporary credential" ) // generateTestJWT creates a test JWT token for masking tests using the JWT library func generateTestJWT(t *testing.T) string { // Create claims for the test JWT claims := jwt.MapClaims{ "sub": "test123", "name": "Test User", "exp": time.Now().Add(time.Hour).Unix(), "iat": time.Now().Unix(), } // Create the token with HS256 signing method token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) // Sign the token with a test secret testSecret := []byte("test-secret-for-masking-validation") tokenString, err := token.SignedString(testSecret) if err != nil { // Fallback to a simple test JWT if signing fails t.Fatalf("Failed to generate test JWT: %s", err) } return tokenString } func TestSecretsDetector(t *testing.T) { testCases := []struct { name string input string expected string }{ // Token masking tests {"Token with equals", fmt.Sprintf("Token =%s", longToken), "Token =****"}, {"idToken with colon space", fmt.Sprintf("idToken : %s", longToken), "idToken : ****"}, {"sessionToken with colon space", fmt.Sprintf("sessionToken : %s", longToken), "sessionToken : ****"}, {"masterToken with colon space", fmt.Sprintf("masterToken : %s", longToken), "masterToken : ****"}, {"accessToken with colon space", fmt.Sprintf("accessToken : %s", longToken), "accessToken : ****"}, {"refreshToken with colon space", fmt.Sprintf("refreshToken : %s", longToken), "refreshToken : ****"}, {"programmaticAccessToken with colon space", fmt.Sprintf("programmaticAccessToken : %s", longToken), "programmaticAccessToken : ****"}, {"programmatic_access_token with colon space", fmt.Sprintf("programmatic_access_token : %s", longToken), "programmatic_access_token : ****"}, {"JWT - with Bearer prefix", fmt.Sprintf("Bearer %s", generateTestJWT(t)), "Bearer ****"}, {"JWT - with JWT prefix", fmt.Sprintf("JWT %s", generateTestJWT(t)), "JWT ****"}, // Password masking tests {"password with colon", fmt.Sprintf("password:%s", randomPassword), "password:****"}, {"PASSWORD uppercase with colon", fmt.Sprintf("PASSWORD:%s", randomPassword), "PASSWORD:****"}, {"PaSsWoRd mixed case with colon", fmt.Sprintf("PaSsWoRd:%s", randomPassword), "PaSsWoRd:****"}, {"password with equals and spaces", fmt.Sprintf("password = %s", randomPassword), "password = ****"}, {"pwd with colon", fmt.Sprintf("pwd:%s", randomPassword), "pwd:****"}, // Mixed token and password tests { "token and password mixed", fmt.Sprintf("token=%s foo bar baz password:%s", longToken, randomPassword), "token=**** foo bar baz password:****", }, { "PWD and TOKEN mixed", fmt.Sprintf("PWD = %s blah blah blah TOKEN:%s", randomPassword, longToken), "PWD = **** blah blah blah TOKEN:****", }, // Client secret tests {"clientSecret with values", "clientSecret abc oauthClientSECRET=def", "clientSecret **** oauthClientSECRET=****"}, // False positive test {"false positive should not be masked", falsePositiveToken, falsePositiveToken}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { result := maskSecrets(tc.input) assertEqualE(t, result, tc.expected) }) } } ================================================ FILE: secure_storage_manager.go ================================================ package gosnowflake import ( "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "os" "os/user" "path/filepath" "strconv" "strings" "sync" "time" ) type tokenType string const ( idToken tokenType = "ID_TOKEN" mfaToken tokenType = "MFA_TOKEN" oauthAccessToken tokenType = "OAUTH_ACCESS_TOKEN" oauthRefreshToken tokenType = "OAUTH_REFRESH_TOKEN" ) const ( credCacheDirEnv = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR" credCacheFileName = "credential_cache_v1.json" ) type cacheDirConf struct { envVar string pathSegments []string } var defaultLinuxCacheDirConf = []cacheDirConf{ {envVar: credCacheDirEnv, pathSegments: []string{}}, {envVar: "XDG_CACHE_DIR", pathSegments: []string{"snowflake"}}, {envVar: "HOME", pathSegments: []string{".cache", "snowflake"}}, } type secureTokenSpec struct { host, user string tokenType tokenType } func (t *secureTokenSpec) buildKey() (string, error) { return buildCredentialsKey(t.host, t.user, t.tokenType) } func newMfaTokenSpec(host, user string) *secureTokenSpec { return &secureTokenSpec{ host, user, mfaToken, } } func newIDTokenSpec(host, user string) *secureTokenSpec { return &secureTokenSpec{ host, user, idToken, } } func newOAuthAccessTokenSpec(host, user string) *secureTokenSpec { return &secureTokenSpec{ host, user, oauthAccessToken, } } func newOAuthRefreshTokenSpec(host, user string) *secureTokenSpec { return &secureTokenSpec{ host, user, oauthRefreshToken, } } type secureStorageManager interface { setCredential(tokenSpec *secureTokenSpec, value string) getCredential(tokenSpec *secureTokenSpec) string deleteCredential(tokenSpec *secureTokenSpec) } var credentialsStorage = newSecureStorageManager() func newSecureStorageManager() secureStorageManager { return defaultOsSpecificSecureStorageManager() } type fileBasedSecureStorageManager struct { credDirPath string } func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) { credDirPath, err := buildCredCacheDirPath(defaultLinuxCacheDirConf) if err != nil { return nil, err } ssm := &fileBasedSecureStorageManager{ credDirPath: credDirPath, } return ssm, nil } func lookupCacheDir(envVar string, pathSegments ...string) (string, error) { envVal := os.Getenv(envVar) if envVal == "" { return "", fmt.Errorf("environment variable %s not set", envVar) } fileInfo, err := os.Stat(envVal) if err != nil { return "", fmt.Errorf("failed to stat %s=%s, due to %v", envVar, envVal, err) } if !fileInfo.IsDir() { return "", fmt.Errorf("environment variable %s=%s is not a directory", envVar, envVal) } cacheDir := filepath.Join(envVal, filepath.Join(pathSegments...)) parentOfCacheDir := cacheDir[:strings.LastIndex(cacheDir, "/")] if err = os.MkdirAll(parentOfCacheDir, os.FileMode(0755)); err != nil { return "", err } // We don't check if permissions are incorrect here if a directory exists, because we check it later. if err = os.Mkdir(cacheDir, os.FileMode(0700)); err != nil && !errors.Is(err, os.ErrExist) { return "", err } return cacheDir, nil } func buildCredCacheDirPath(confs []cacheDirConf) (string, error) { for _, conf := range confs { path, err := lookupCacheDir(conf.envVar, conf.pathSegments...) if err != nil { logger.Debugf("Skipping %s in cache directory lookup due to %v", conf.envVar, err) } else { logger.Debugf("Using %s as cache directory", path) return path, nil } } return "", errors.New("no credentials cache directory found") } func (ssm *fileBasedSecureStorageManager) getTokens(data map[string]any) map[string]any { val, ok := data["tokens"] if !ok { return map[string]any{} } tokens, ok := val.(map[string]any) if !ok { return map[string]any{} } return tokens } func (ssm *fileBasedSecureStorageManager) withLock(action func(cacheFile *os.File)) { err := ssm.lockFile() if err != nil { logger.Warnf("Unable to lock cache. %v", err) return } defer ssm.unlockFile() ssm.withCacheFile(action) } func (ssm *fileBasedSecureStorageManager) withCacheFile(action func(*os.File)) { cacheFile, err := os.OpenFile(ssm.credFilePath(), os.O_CREATE|os.O_RDWR, 0600) if err != nil { logger.Warnf("cannot access %v. %v", ssm.credFilePath(), err) return } defer func(file *os.File) { if err := file.Close(); err != nil { logger.Warnf("cannot release file descriptor for %v. %v", ssm.credFilePath(), err) } }(cacheFile) cacheDir, err := os.Open(ssm.credDirPath) if err != nil { logger.Warnf("cannot access %v. %v", ssm.credDirPath, err) } defer func(file *os.File) { if err := file.Close(); err != nil { logger.Warnf("cannot release file descriptor for %v. %v", cacheDir, err) } }(cacheDir) if err := ensureFileOwner(cacheFile); err != nil { logger.Warnf("failed to ensure owner for temporary cache file. %v", err) return } if err := ensureFilePermissions(cacheFile, 0600); err != nil { logger.Warnf("failed to ensure permission for temporary cache file. %v", err) return } if err := ensureFileOwner(cacheDir); err != nil { logger.Warnf("failed to ensure owner for temporary cache dir. %v", err) return } if err := ensureFilePermissions(cacheDir, 0700|os.ModeDir); err != nil { logger.Warnf("failed to ensure permission for temporary cache dir. %v", err) return } action(cacheFile) } func (ssm *fileBasedSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) { if value == "" { logger.Debug("no token provided") return } credentialsKey, err := tokenSpec.buildKey() if err != nil { logger.Warnf("cannot build token spec: %v", err) return } ssm.withLock(func(cacheFile *os.File) { credCache, err := ssm.readTemporaryCacheFile(cacheFile) if err != nil { logger.Warnf("Error while reading cache file. %v", err) return } tokens := ssm.getTokens(credCache) tokens[credentialsKey] = value credCache["tokens"] = tokens err = ssm.writeTemporaryCacheFile(credCache, cacheFile) if err != nil { logger.Warnf("Set credential failed. Unable to write cache. %v", err) } else { logger.Debugf("Set credential succeeded. Authentication type: %v, User: %v, file location: %v", tokenSpec.tokenType, tokenSpec.user, ssm.credFilePath()) } }) } func (ssm *fileBasedSecureStorageManager) lockPath() string { return filepath.Join(ssm.credDirPath, credCacheFileName+".lck") } func (ssm *fileBasedSecureStorageManager) lockFile() error { const numRetries = 10 const retryInterval = 100 * time.Millisecond lockPath := ssm.lockPath() lockFile, err := os.Open(lockPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("failed to open %v. err: %v", lockPath, err) } defer func() { if lockFile != nil { err = lockFile.Close() if err != nil { logger.Debugf("error while closing lock file. %v", err) } } }() if err == nil { // file exists fileInfo, err := lockFile.Stat() if err != nil { return fmt.Errorf("failed to stat %v and determine if lock is stale. err: %v", lockPath, err) } ownerUID, err := provideFileOwner(lockFile) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } currentUser, err := user.Current() if err != nil { return err } if strconv.Itoa(int(ownerUID)) != currentUser.Uid { return errors.New("incorrect owner of " + lockFile.Name()) } // removing stale lock now := time.Now() if fileInfo.ModTime().Add(time.Second).UnixNano() < now.UnixNano() { logger.Debugf("removing credentials cache lock file, stale for %vms", (now.UnixNano()-fileInfo.ModTime().UnixNano())/1000/1000) err = os.Remove(lockPath) if err != nil { return fmt.Errorf("failed to remove %v while trying to remove stale lock. err: %v", lockPath, err) } } } locked := false for range numRetries { err := os.Mkdir(lockPath, 0700) if err != nil { if errors.Is(err, os.ErrExist) { time.Sleep(retryInterval) continue } return fmt.Errorf("failed to create cache lock: %v, err: %v", lockPath, err) } locked = true break } if !locked { return fmt.Errorf("failed to lock cache. lockPath: %v", lockPath) } return nil } func (ssm *fileBasedSecureStorageManager) unlockFile() { lockPath := ssm.lockPath() err := os.Remove(lockPath) if err != nil { logger.Warnf("Failed to unlock cache lock: %v. %v", lockPath, err) } } func (ssm *fileBasedSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string { credentialsKey, err := tokenSpec.buildKey() if err != nil { logger.Warnf("cannot build token spec: %v", err) return "" } ret := "" ssm.withLock(func(cacheFile *os.File) { credCache, err := ssm.readTemporaryCacheFile(cacheFile) if err != nil { logger.Warnf("Error while reading cache file. %v", err) return } cred, ok := ssm.getTokens(credCache)[credentialsKey] if !ok { return } credStr, ok := cred.(string) if !ok { return } ret = credStr }) return ret } func (ssm *fileBasedSecureStorageManager) credFilePath() string { return filepath.Join(ssm.credDirPath, credCacheFileName) } func ensureFileOwner(f *os.File) error { ownerUID, err := provideFileOwner(f) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } currentUser, err := user.Current() if err != nil { return err } if errors.Is(err, os.ErrNotExist) { return nil } if strconv.Itoa(int(ownerUID)) != currentUser.Uid { return errors.New("incorrect owner of " + f.Name()) } return nil } func ensureFilePermissions(f *os.File, expectedMode os.FileMode) error { fileInfo, err := f.Stat() if err != nil { return err } if fileInfo.Mode().Perm() != expectedMode&os.ModePerm { return fmt.Errorf("incorrect permissions(%v, expected %v) for credential file", fileInfo.Mode(), expectedMode) } return nil } func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile(cacheFile *os.File) (map[string]any, error) { jsonData, err := io.ReadAll(cacheFile) if err != nil { logger.Warnf("Failed to read credential cache file. %v.\n", err) return map[string]any{}, nil } if _, err = cacheFile.Seek(0, 0); err != nil { return map[string]any{}, fmt.Errorf("cannot seek to the beginning of a cache file. %v", err) } if len(jsonData) == 0 { // Happens when the file didn't exist before. return map[string]any{}, nil } credentialsMap := map[string]any{} err = json.Unmarshal(jsonData, &credentialsMap) if err != nil { return map[string]any{}, fmt.Errorf("failed to unmarshal credential cache file. %v", err) } return credentialsMap, nil } func (ssm *fileBasedSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) { credentialsKey, err := tokenSpec.buildKey() if err != nil { logger.Warnf("cannot build token spec: %v", err) return } ssm.withLock(func(cacheFile *os.File) { credCache, err := ssm.readTemporaryCacheFile(cacheFile) if err != nil { logger.Warnf("Error while reading cache file. %v", err) return } delete(ssm.getTokens(credCache), credentialsKey) err = ssm.writeTemporaryCacheFile(credCache, cacheFile) if err != nil { logger.Warnf("Set credential failed. Unable to write cache. %v", err) } else { logger.Debugf("Deleted credential succeeded. Authentication type: %v, User: %v, file location: %v", tokenSpec.tokenType, tokenSpec.user, ssm.credFilePath()) } }) } func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(cache map[string]any, cacheFile *os.File) error { bytes, err := json.Marshal(cache) if err != nil { return fmt.Errorf("failed to marshal credential cache map. %w", err) } if err = cacheFile.Truncate(0); err != nil { return fmt.Errorf("error while truncating credentials cache. %v", err) } _, err = cacheFile.Write(bytes) if err != nil { return fmt.Errorf("failed to write the credential cache file: %w", err) } return nil } func buildCredentialsKey(host, user string, credType tokenType) (string, error) { if host == "" { return "", errors.New("host is not provided to store in token cache, skipping") } if user == "" { return "", errors.New("user is not provided to store in token cache, skipping") } plainCredKey := host + ":" + user + ":" + string(credType) checksum := sha256.New() checksum.Write([]byte(plainCredKey)) return hex.EncodeToString(checksum.Sum(nil)), nil } type noopSecureStorageManager struct { } func newNoopSecureStorageManager() *noopSecureStorageManager { return &noopSecureStorageManager{} } func (ssm *noopSecureStorageManager) setCredential(_ *secureTokenSpec, _ string) { } func (ssm *noopSecureStorageManager) getCredential(_ *secureTokenSpec) string { return "" } func (ssm *noopSecureStorageManager) deleteCredential(_ *secureTokenSpec) { } type threadSafeSecureStorageManager struct { mu *sync.Mutex delegate secureStorageManager } func (ssm *threadSafeSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) { ssm.mu.Lock() defer ssm.mu.Unlock() ssm.delegate.setCredential(tokenSpec, value) } func (ssm *threadSafeSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string { ssm.mu.Lock() defer ssm.mu.Unlock() return ssm.delegate.getCredential(tokenSpec) } func (ssm *threadSafeSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) { ssm.mu.Lock() defer ssm.mu.Unlock() ssm.delegate.deleteCredential(tokenSpec) } ================================================ FILE: secure_storage_manager_linux.go ================================================ //go:build linux package gosnowflake import ( "runtime" "sync" ) func defaultOsSpecificSecureStorageManager() secureStorageManager { logger.Debugf("OS is %v, using file based secure storage manager.", runtime.GOOS) ssm, err := newFileBasedSecureStorageManager() if err != nil { logger.Debugf("failed to create credentials cache dir: %v. Not storing credentials locally.", err) return newNoopSecureStorageManager() } return &threadSafeSecureStorageManager{&sync.Mutex{}, ssm} } ================================================ FILE: secure_storage_manager_notlinux.go ================================================ //go:build !linux package gosnowflake import ( "github.com/99designs/keyring" "runtime" "strings" "sync" ) func defaultOsSpecificSecureStorageManager() secureStorageManager { switch runtime.GOOS { case "darwin", "windows": logger.Debugf("OS is %v, using keyring based secure storage manager.", runtime.GOOS) return &threadSafeSecureStorageManager{&sync.Mutex{}, newKeyringBasedSecureStorageManager()} default: logger.Debugf("OS %v does not support credentials cache", runtime.GOOS) return newNoopSecureStorageManager() } } type keyringSecureStorageManager struct { } func newKeyringBasedSecureStorageManager() *keyringSecureStorageManager { return &keyringSecureStorageManager{} } func (ssm *keyringSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) { if value == "" { logger.Debug("no token provided") } else { credentialsKey, err := tokenSpec.buildKey() if err != nil { logger.Warnf("cannot build token spec: %v", err) return } switch runtime.GOOS { case "windows": ring, _ := keyring.Open(keyring.Config{ WinCredPrefix: strings.ToUpper(tokenSpec.host), ServiceName: strings.ToUpper(tokenSpec.user), }) item := keyring.Item{ Key: credentialsKey, Data: []byte(value), } if err := ring.Set(item); err != nil { logger.Debugf("Failed to write to Windows credential manager. Err: %v", err) } case "darwin": ring, _ := keyring.Open(keyring.Config{ ServiceName: credentialsKey, }) account := strings.ToUpper(tokenSpec.user) item := keyring.Item{ Key: account, Data: []byte(value), } if err := ring.Set(item); err != nil { logger.Debugf("Failed to write to keychain. Err: %v", err) } } } } func (ssm *keyringSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string { cred := "" credentialsKey, err := tokenSpec.buildKey() if err != nil { logger.Warnf("cannot build token spec: %v", err) return "" } switch runtime.GOOS { case "windows": ring, _ := keyring.Open(keyring.Config{ WinCredPrefix: strings.ToUpper(tokenSpec.host), ServiceName: strings.ToUpper(tokenSpec.user), }) i, err := ring.Get(credentialsKey) if err != nil { logger.Debugf("Failed to read credentialsKey or could not find it in Windows Credential Manager. Error: %v", err) } cred = string(i.Data) case "darwin": ring, _ := keyring.Open(keyring.Config{ ServiceName: credentialsKey, }) account := strings.ToUpper(tokenSpec.user) i, err := ring.Get(account) if err != nil { logger.Debugf("Failed to find the item in keychain or item does not exist. Error: %v", err) } cred = string(i.Data) if cred == "" { logger.Debug("Returned credential is empty") } else { logger.Debug("Successfully read token. Returning as string") } } return cred } func (ssm *keyringSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) { credentialsKey, err := tokenSpec.buildKey() if err != nil { logger.Warnf("cannot build token spec: %v", err) return } switch runtime.GOOS { case "windows": ring, _ := keyring.Open(keyring.Config{ WinCredPrefix: strings.ToUpper(tokenSpec.host), ServiceName: strings.ToUpper(tokenSpec.user), }) err := ring.Remove(string(credentialsKey)) if err != nil { logger.Debugf("Failed to delete credentialsKey in Windows Credential Manager. Error: %v", err) } case "darwin": ring, _ := keyring.Open(keyring.Config{ ServiceName: credentialsKey, }) account := strings.ToUpper(tokenSpec.user) err := ring.Remove(account) if err != nil { logger.Debugf("Failed to delete credentialsKey in keychain. Error: %v", err) } } } ================================================ FILE: secure_storage_manager_test.go ================================================ package gosnowflake import ( "encoding/json" "os" "path/filepath" "testing" "time" ) func TestBuildCredCacheDirPath(t *testing.T) { skipOnWindows(t, "permission model is different") testRoot1, err := os.MkdirTemp("", "") assertNilF(t, err) defer os.RemoveAll(testRoot1) testRoot2, err := os.MkdirTemp("", "") assertNilF(t, err) defer os.RemoveAll(testRoot2) env1 := overrideEnv("CACHE_DIR_TEST_NOT_EXISTING", "/tmp/not_existing_dir") defer env1.rollback() env2 := overrideEnv("CACHE_DIR_TEST_1", testRoot1) defer env2.rollback() env3 := overrideEnv("CACHE_DIR_TEST_2", testRoot2) defer env3.rollback() t.Run("cannot find any dir", func(t *testing.T) { _, err := buildCredCacheDirPath([]cacheDirConf{ {envVar: "CACHE_DIR_TEST_NOT_EXISTING"}, }) assertEqualE(t, err.Error(), "no credentials cache directory found") _, err = os.Stat("/tmp/not_existing_dir") assertStringContainsE(t, err.Error(), "no such file or directory") }) t.Run("should use first dir that exists", func(t *testing.T) { path, err := buildCredCacheDirPath([]cacheDirConf{ {envVar: "CACHE_DIR_TEST_NOT_EXISTING"}, {envVar: "CACHE_DIR_TEST_1"}, }) assertNilF(t, err) assertEqualE(t, path, testRoot1) stat, err := os.Stat(testRoot1) assertNilF(t, err) assertEqualE(t, stat.Mode(), 0700|os.ModeDir) }) t.Run("should use first dir that exists and append segments", func(t *testing.T) { path, err := buildCredCacheDirPath([]cacheDirConf{ {envVar: "CACHE_DIR_TEST_NOT_EXISTING"}, {envVar: "CACHE_DIR_TEST_2", pathSegments: []string{"sub1", "sub2"}}, }) assertNilF(t, err) assertEqualE(t, path, filepath.Join(testRoot2, "sub1", "sub2")) stat, err := os.Stat(testRoot2) assertNilF(t, err) assertEqualE(t, stat.Mode(), 0700|os.ModeDir) }) } func TestSnowflakeFileBasedSecureStorageManager(t *testing.T) { skipOnWindows(t, "file system permission is different") credCacheDir, err := os.MkdirTemp("", "") assertNilF(t, err) assertNilF(t, os.MkdirAll(credCacheDir, os.ModePerm)) credCacheDirEnvOverride := overrideEnv(credCacheDirEnv, credCacheDir) defer credCacheDirEnvOverride.rollback() ssm, err := newFileBasedSecureStorageManager() assertNilF(t, err) t.Run("store single token", func(t *testing.T) { tokenSpec := newMfaTokenSpec("host.com", "johndoe") cred := "token123" ssm.setCredential(tokenSpec, cred) assertEqualE(t, ssm.getCredential(tokenSpec), cred) ssm.deleteCredential(tokenSpec) assertEqualE(t, ssm.getCredential(tokenSpec), "") }) t.Run("store tokens of different types, hosts and users", func(t *testing.T) { mfaTokenSpec := newMfaTokenSpec("host.com", "johndoe") mfaCred := "token12" idTokenSpec := newIDTokenSpec("host.com", "johndoe") idCred := "token34" idTokenSpec2 := newIDTokenSpec("host.org", "johndoe") idCred2 := "token56" idTokenSpec3 := newIDTokenSpec("host.com", "someoneelse") idCred3 := "token78" ssm.setCredential(mfaTokenSpec, mfaCred) ssm.setCredential(idTokenSpec, idCred) ssm.setCredential(idTokenSpec2, idCred2) ssm.setCredential(idTokenSpec3, idCred3) assertEqualE(t, ssm.getCredential(mfaTokenSpec), mfaCred) assertEqualE(t, ssm.getCredential(idTokenSpec), idCred) assertEqualE(t, ssm.getCredential(idTokenSpec2), idCred2) assertEqualE(t, ssm.getCredential(idTokenSpec3), idCred3) ssm.deleteCredential(mfaTokenSpec) assertEqualE(t, ssm.getCredential(mfaTokenSpec), "") assertEqualE(t, ssm.getCredential(idTokenSpec), idCred) assertEqualE(t, ssm.getCredential(idTokenSpec2), idCred2) assertEqualE(t, ssm.getCredential(idTokenSpec3), idCred3) }) t.Run("override single token", func(t *testing.T) { mfaTokenSpec := newMfaTokenSpec("host.com", "johndoe") mfaCred := "token123" idTokenSpec := newIDTokenSpec("host.com", "johndoe") idCred := "token456" ssm.setCredential(mfaTokenSpec, mfaCred) ssm.setCredential(idTokenSpec, idCred) assertEqualE(t, ssm.getCredential(mfaTokenSpec), mfaCred) mfaCredOverride := "token789" ssm.setCredential(mfaTokenSpec, mfaCredOverride) assertEqualE(t, ssm.getCredential(mfaTokenSpec), mfaCredOverride) ssm.setCredential(idTokenSpec, idCred) }) t.Run("unlock stale cache", func(t *testing.T) { tokenSpec := newMfaTokenSpec("stale", "cache") assertNilF(t, os.Mkdir(ssm.lockPath(), 0700)) time.Sleep(1000 * time.Millisecond) ssm.setCredential(tokenSpec, "unlocked") assertEqualE(t, ssm.getCredential(tokenSpec), "unlocked") }) t.Run("wait for other process to unlock cache", func(t *testing.T) { tokenSpec := newMfaTokenSpec("stale", "cache") startTime := time.Now() assertNilF(t, os.Mkdir(ssm.lockPath(), 0700)) time.Sleep(500 * time.Millisecond) go func() { time.Sleep(500 * time.Millisecond) assertNilF(t, os.Remove(ssm.lockPath())) }() ssm.setCredential(tokenSpec, "unlocked") totalDurationMillis := time.Since(startTime).Milliseconds() assertEqualE(t, ssm.getCredential(tokenSpec), "unlocked") assertTrueE(t, totalDurationMillis > 1000 && totalDurationMillis < 1200) }) t.Run("should not modify keys other than tokens", func(t *testing.T) { content := []byte(`{ "otherKey": "otherValue" }`) err = os.WriteFile(ssm.credFilePath(), content, 0600) assertNilF(t, err) ssm.setCredential(newMfaTokenSpec("somehost.com", "someUser"), "someToken") result, err := os.ReadFile(ssm.credFilePath()) assertNilF(t, err) assertStringContainsE(t, string(result), `"otherKey":"otherValue"`) }) t.Run("should not modify file if it has wrong permission", func(t *testing.T) { tokenSpec := newMfaTokenSpec("somehost.com", "someUser") ssm.setCredential(tokenSpec, "initialValue") assertEqualE(t, ssm.getCredential(tokenSpec), "initialValue") err = os.Chmod(ssm.credFilePath(), 0644) assertNilF(t, err) defer func() { assertNilE(t, os.Chmod(ssm.credFilePath(), 0600)) }() ssm.setCredential(tokenSpec, "newValue") assertEqualE(t, ssm.getCredential(tokenSpec), "") fileContent, err := os.ReadFile(ssm.credFilePath()) assertNilF(t, err) var m map[string]any err = json.Unmarshal(fileContent, &m) assertNilF(t, err) cacheKey, err := tokenSpec.buildKey() assertNilF(t, err) tokens := m["tokens"].(map[string]any) assertEqualE(t, tokens[cacheKey], "initialValue") }) t.Run("should not modify file if its dir has wrong permission", func(t *testing.T) { tokenSpec := newMfaTokenSpec("somehost.com", "someUser") ssm.setCredential(tokenSpec, "initialValue") assertEqualE(t, ssm.getCredential(tokenSpec), "initialValue") err = os.Chmod(ssm.credDirPath, 0777) assertNilF(t, err) defer func() { assertNilE(t, os.Chmod(ssm.credDirPath, 0700)) }() ssm.setCredential(tokenSpec, "newValue") assertEqualE(t, ssm.getCredential(tokenSpec), "") fileContent, err := os.ReadFile(ssm.credFilePath()) assertNilF(t, err) var m map[string]any err = json.Unmarshal(fileContent, &m) assertNilF(t, err) cacheKey, err := tokenSpec.buildKey() assertNilF(t, err) tokens := m["tokens"].(map[string]any) assertEqualE(t, tokens[cacheKey], "initialValue") }) } func TestSetAndGetCredential(t *testing.T) { skipOnMissingHome(t) for _, tokenSpec := range []*secureTokenSpec{ newMfaTokenSpec("testhost", "testuser"), newIDTokenSpec("testhost", "testuser"), } { t.Run(string(tokenSpec.tokenType), func(t *testing.T) { skipOnMac(t, "keyring asks for password") fakeMfaToken := "test token" tokenSpec := newMfaTokenSpec("testHost", "testUser") credentialsStorage.setCredential(tokenSpec, fakeMfaToken) assertEqualE(t, credentialsStorage.getCredential(tokenSpec), fakeMfaToken) // delete credential and check it no longer exists credentialsStorage.deleteCredential(tokenSpec) assertEqualE(t, credentialsStorage.getCredential(tokenSpec), "") }) } } func TestSkipStoringCredentialIfUserIsEmpty(t *testing.T) { tokenSpecs := []*secureTokenSpec{ newMfaTokenSpec("mfaHost.com", ""), newIDTokenSpec("idHost.com", ""), } for _, tokenSpec := range tokenSpecs { t.Run(tokenSpec.host, func(t *testing.T) { credentialsStorage.setCredential(tokenSpec, "non-empty-value") assertEqualE(t, credentialsStorage.getCredential(tokenSpec), "") }) } } func TestSkipStoringCredentialIfHostIsEmpty(t *testing.T) { tokenSpecs := []*secureTokenSpec{ newMfaTokenSpec("", "mfaUser"), newIDTokenSpec("", "idUser"), } for _, tokenSpec := range tokenSpecs { t.Run(tokenSpec.user, func(t *testing.T) { credentialsStorage.setCredential(tokenSpec, "non-empty-value") assertEqualE(t, credentialsStorage.getCredential(tokenSpec), "") }) } } func TestStoreTemporaryCredential(t *testing.T) { if runningOnGithubAction() { t.Skip("cannot write to github file system") } testcases := []struct { tokenSpec *secureTokenSpec value string }{ {newMfaTokenSpec("testhost", "testuser"), "mfa token"}, {newIDTokenSpec("testhost", "testuser"), "id token"}, {newOAuthAccessTokenSpec("testhost", "testuser"), "access token"}, {newOAuthRefreshTokenSpec("testhost", "testuser"), "refresh token"}, } ssm, err := newFileBasedSecureStorageManager() assertNilF(t, err) for _, test := range testcases { t.Run(test.value, func(t *testing.T) { ssm.setCredential(test.tokenSpec, test.value) assertEqualE(t, ssm.getCredential(test.tokenSpec), test.value) ssm.deleteCredential(test.tokenSpec) assertEqualE(t, ssm.getCredential(test.tokenSpec), "") }) } } func TestBuildCredentialsKey(t *testing.T) { testcases := []struct { host string user string credType tokenType out string }{ {"testaccount.snowflakecomputing.com", "testuser", "mfaToken", "c4e781475e7a5e74aca87cd462afafa8cc48ebff6f6ccb5054b894dae5eb6345"}, // pragma: allowlist secret {"testaccount.snowflakecomputing.com", "testuser", "IdToken", "5014e26489992b6ea56b50e936ba85764dc51338f60441bdd4a69eac7e15bada"}, // pragma: allowlist secret } for _, test := range testcases { target, err := buildCredentialsKey(test.host, test.user, test.credType) assertNilF(t, err) if target != test.out { t.Fatalf("failed to convert target. expected: %v, but got: %v", test.out, target) } } } ================================================ FILE: sflog/interface.go ================================================ // Package sflog package defines the logging interface for Snowflake's Go driver. // If you want to implement a custom logger, you should implement the SFLogger interface defined in this package. package sflog import ( "context" "io" ) // ClientLogContextHook is a client-defined hook that can be used to insert log // fields based on the Context. type ClientLogContextHook func(context.Context) string // LogEntry allows for logging using a snapshot of field values. // No implementation-specific logging details should be placed into this interface. type LogEntry interface { Tracef(format string, args ...any) Debugf(format string, args ...any) Infof(format string, args ...any) Warnf(format string, args ...any) Errorf(format string, args ...any) Fatalf(format string, args ...any) Trace(msg string) Debug(msg string) Info(msg string) Warn(msg string) Error(msg string) Fatal(msg string) } // SFLogger Snowflake logger interface which abstracts away the underlying logging mechanism. // No implementation-specific logging details should be placed into this interface. type SFLogger interface { LogEntry WithField(key string, value any) LogEntry WithFields(fields map[string]any) LogEntry SetLogLevel(level string) error SetLogLevelInt(level Level) error GetLogLevel() string GetLogLevelInt() Level WithContext(ctx context.Context) LogEntry SetOutput(output io.Writer) } ================================================ FILE: sflog/levels.go ================================================ package sflog import ( "fmt" "math" "strings" ) // Level represents the log level for a log message. It extends slog's standard levels with custom levels. type Level int // Custom level constants that extend slog's standard levels const ( LevelTrace = Level(-8) LevelDebug = Level(-4) LevelInfo = Level(0) LevelWarn = Level(4) LevelError = Level(8) LevelFatal = Level(12) LevelOff = Level(math.MaxInt) ) // ParseLevel converts a string level to Level func ParseLevel(level string) (Level, error) { switch strings.ToUpper(level) { case "TRACE": return LevelTrace, nil case "DEBUG": return LevelDebug, nil case "INFO": return LevelInfo, nil case "WARN": return LevelWarn, nil case "ERROR": return LevelError, nil case "FATAL": return LevelFatal, nil case "OFF": return LevelOff, nil default: return LevelInfo, fmt.Errorf("unknown log level: %s", level) } } // LevelToString converts Level to string func LevelToString(level Level) (string, error) { switch level { case LevelTrace: return "TRACE", nil case LevelDebug: return "DEBUG", nil case LevelInfo: return "INFO", nil case LevelWarn: return "WARN", nil case LevelError: return "ERROR", nil case LevelFatal: return "FATAL", nil case LevelOff: return "OFF", nil default: return "", fmt.Errorf("unknown log level: %d", level) } } ================================================ FILE: sflog/slog.go ================================================ package sflog import "log/slog" // SFSlogLogger is an optional interface for advanced slog handler configuration. // This interface is separate from SFLogger to maintain framework-agnostic design. // Users can type-assert the logger to check if slog handler configuration is supported. // // Example usage: // // logger := gosnowflake.GetLogger() // if slogLogger, ok := logger.(gosnowflake.SFSlogLogger); ok { // customHandler := slog.NewJSONHandler(os.Stdout, nil) // slogLogger.SetHandler(customHandler) // } type SFSlogLogger interface { SetHandler(handler slog.Handler) error } ================================================ FILE: sqlstate.go ================================================ package gosnowflake const ( // SQLStateNumericValueOutOfRange is a SQL State code indicating Numeric value is out of range. SQLStateNumericValueOutOfRange = "22003" // SQLStateInvalidDataTimeFormat is a SQL State code indicating DataTime format is invalid. SQLStateInvalidDataTimeFormat = "22007" // SQLStateConnectionWasNotEstablished is a SQL State code indicating connection was not established. SQLStateConnectionWasNotEstablished = "08001" // SQLStateConnectionRejected is a SQL State code indicating connection was rejected. SQLStateConnectionRejected = "08004" // SQLStateConnectionFailure is a SQL State code indicating connection failed. SQLStateConnectionFailure = "08006" // SQLStateFeatureNotSupported is a SQL State code indicating the feature is not enabled. SQLStateFeatureNotSupported = "0A000" ) ================================================ FILE: statement.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "errors" "fmt" "time" ) // SnowflakeStmt represents the prepared statement in driver. type SnowflakeStmt interface { GetQueryID() string } type snowflakeStmt struct { sc *snowflakeConn query string lastQueryID string } func (stmt *snowflakeStmt) Close() error { logger.WithContext(stmt.sc.ctx).Info("Stmt.Close") // noop return nil } func (stmt *snowflakeStmt) NumInput() int { logger.WithContext(stmt.sc.ctx).Info("Stmt.NumInput") // Go Snowflake doesn't know the number of binding parameters. return -1 } func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Info("Stmt.ExecContext") return stmt.execInternal(ctx, args) } func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Info("Stmt.QueryContext") rows, err := stmt.sc.QueryContext(ctx, stmt.query, args) if err != nil { stmt.setQueryIDFromError(err) return nil, err } r, ok := rows.(SnowflakeRows) if !ok { return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows) } stmt.lastQueryID = r.GetQueryID() return rows, nil } func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Info("Stmt.Exec") return stmt.execInternal(context.Background(), toNamedValues(args)) } func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Debug("Stmt.execInternal") if ctx == nil { ctx = context.Background() } stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement) timer := time.Now() result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args) if err != nil { stmt.setQueryIDFromError(err) logger.WithContext(ctx).Errorf("QueryID: %v failed to execute because of the error %v. It took %v ms.", stmt.lastQueryID, err, time.Since(timer).String()) return nil, err } rnr, ok := result.(*snowflakeResultNoRows) if ok { stmt.lastQueryID = rnr.GetQueryID() logger.WithContext(ctx).Debugf("Query ID: %v has no result. It took %v ms.,", stmt.lastQueryID, time.Since(timer).String()) return driver.ResultNoRows, nil } r, ok := result.(SnowflakeResult) if !ok { return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) } stmt.lastQueryID = r.GetQueryID() logger.WithContext(ctx).Debugf("Query ID: %v has no result. It took %v ms.,", stmt.lastQueryID, time.Since(timer).String()) return result, err } func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Info("Stmt.Query") timer := time.Now() rows, err := stmt.sc.Query(stmt.query, args) if err != nil { logger.WithContext(stmt.sc.ctx).Errorf("QueryID: %v failed to execute because of the error %v. It took %v ms.", stmt.lastQueryID, err, time.Since(timer).String()) stmt.setQueryIDFromError(err) return nil, err } r, ok := rows.(SnowflakeRows) if !ok { logger.WithContext(stmt.sc.ctx).Errorf("Query ID: %v failed to convert the rows to SnowflakeRows. It took %v ms.,", stmt.lastQueryID, time.Since(timer).String()) return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows) } stmt.lastQueryID = r.GetQueryID() logger.WithContext(stmt.sc.ctx).Debugf("Query ID: %v has no result. It took %v ms.,", stmt.lastQueryID, time.Since(timer).String()) return rows, err } func (stmt *snowflakeStmt) GetQueryID() string { return stmt.lastQueryID } func (stmt *snowflakeStmt) setQueryIDFromError(err error) { var snowflakeError *SnowflakeError if errors.As(err, &snowflakeError) { stmt.lastQueryID = snowflakeError.QueryID } } ================================================ FILE: statement_test.go ================================================ //lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility. package gosnowflake import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "net/http" "net/url" "testing" "time" ) func openDB(t *testing.T) *sql.DB { var db *sql.DB var err error if db, err = sql.Open("snowflake", dsn); err != nil { t.Fatalf("failed to open db. %v", err) } return db } func openConn(t *testing.T, config *testConfig) (*sql.DB, *sql.Conn) { var db *sql.DB var conn *sql.Conn var err error if db, err = sql.Open("snowflake", config.dsn); err != nil { t.Fatalf("failed to open db. %v, err: %v", dsn, err) } if conn, err = db.Conn(context.Background()); err != nil { t.Fatalf("failed to open connection: %v", err) } return db, conn } func TestExecStmt(t *testing.T) { dqlQuery := "SELECT 1" dmlQuery := "INSERT INTO TestDDLExec VALUES (1)" ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" multiStmtQuery := "DELETE FROM TestDDLExec;\n" + "SELECT 1;\n" + "SELECT 2;" ctx := context.Background() multiStmtCtx := WithMultiStatement(ctx, 3) runDBTest(t, func(dbt *DBTest) { dbt.mustExec(ddlQuery) defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") testcases := []struct { name string query string f func(stmt driver.Stmt) (any, error) }{ { name: "dql Exec", query: dqlQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.Exec(nil) }, }, { name: "dql ExecContext", query: dqlQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, { name: "ddl Exec", query: ddlQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.Exec(nil) }, }, { name: "ddl ExecContext", query: ddlQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, { name: "dml Exec", query: dmlQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.Exec(nil) }, }, { name: "dml ExecContext", query: dmlQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, { name: "multistmt ExecContext", query: multiStmtQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil) }, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) if err != nil { t.Error(err) } if stmt.(SnowflakeStmt).GetQueryID() != "" { t.Error("queryId should be empty before executing any query") } if _, err := tc.f(stmt); err != nil { t.Errorf("should have not failed to execute the query, err: %s\n", err) } if stmt.(SnowflakeStmt).GetQueryID() == "" { t.Error("should have set the query id") } return nil }) if err != nil { t.Fatal(err) } }) } }) } func TestFailedQueryIdInSnowflakeError(t *testing.T) { failingQuery := "SELECTT 1" failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE" runDBTest(t, func(dbt *DBTest) { testcases := []struct { name string query string f func(dbt *DBTest) (any, error) }{ { name: "query", f: func(dbt *DBTest) (any, error) { return dbt.query(failingQuery) }, }, { name: "exec", f: func(dbt *DBTest) (any, error) { return dbt.exec(failingExec) }, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { _, err := tc.f(dbt) if err == nil { t.Error("should have failed") } var snowflakeError *SnowflakeError if !errors.As(err, &snowflakeError) { t.Error("should be a SnowflakeError") } if snowflakeError.QueryID == "" { t.Error("QueryID should be set") } }) } }) } func TestSetFailedQueryId(t *testing.T) { ctx := context.Background() failingQuery := "SELECTT 1" failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE" runDBTest(t, func(dbt *DBTest) { testcases := []struct { name string query string f func(stmt driver.Stmt) (any, error) }{ { name: "query", query: failingQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.Query(nil) }, }, { name: "exec", query: failingExec, f: func(stmt driver.Stmt) (any, error) { return stmt.Exec(nil) }, }, { name: "queryContext", query: failingQuery, f: func(stmt driver.Stmt) (any, error) { return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) }, }, { name: "execContext", query: failingExec, f: func(stmt driver.Stmt) (any, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) if err != nil { t.Error(err) } if stmt.(SnowflakeStmt).GetQueryID() != "" { t.Error("queryId should be empty before executing any query") } if _, err := tc.f(stmt); err == nil { t.Error("should have failed to execute the query") } if stmt.(SnowflakeStmt).GetQueryID() == "" { t.Error("should have set the query id") } return nil }) if err != nil { t.Fatal(err) } }) } }) } func TestAsyncFailQueryId(t *testing.T) { ctx := WithAsyncMode(context.Background()) runDBTest(t, func(dbt *DBTest) { err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECTT 1") if err != nil { t.Error(err) } if stmt.(SnowflakeStmt).GetQueryID() != "" { t.Error("queryId should be empty before executing any query") } rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) if err != nil { t.Error("should not fail the initial request") } if rows.(SnowflakeRows).GetStatus() != QueryStatusInProgress { t.Error("should be in progress") } // Wait for the query to complete assertNotNilE(t, rows.Next(nil)) if rows.(SnowflakeRows).GetStatus() != QueryFailed { t.Error("should have failed") } if rows.(SnowflakeRows).GetQueryID() != stmt.(SnowflakeStmt).GetQueryID() { t.Error("last query id should be the same as rows query id") } return nil }) if err != nil { t.Fatal(err) } }) } func TestGetQueryID(t *testing.T) { ctx := context.Background() runDBTest(t, func(dbt *DBTest) { if err := dbt.conn.Raw(func(x any) error { rows, err := x.(driver.QueryerContext).QueryContext(ctx, "select 1", nil) if err != nil { return err } defer rows.Close() if _, err = x.(driver.QueryerContext).QueryContext(ctx, "selectt 1", nil); err == nil { t.Fatal("should have failed to execute query") } if driverErr, ok := err.(*SnowflakeError); ok { if driverErr.Number != 1003 { t.Fatalf("incorrect error code. expected: 1003, got: %v", driverErr.Number) } if driverErr.QueryID == "" { t.Fatal("should have an associated query ID") } } else { t.Fatal("should have been able to cast to Snowflake Error") } return nil }); err != nil { t.Fatalf("failed to prepare statement. err: %v", err) } }) } func TestEmitQueryID(t *testing.T) { queryIDChan := make(chan string, 1) numrows := 100000 ctx := WithAsyncMode(context.Background()) ctx = WithQueryIDChan(ctx, queryIDChan) goRoutineChan := make(chan string) go func(grCh chan string, qIDch chan string) { queryID := <-queryIDChan grCh <- queryID }(goRoutineChan, queryIDChan) cnt := 0 var idx int var v string runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) defer rows.Close() for rows.Next() { if err := rows.Scan(&idx, &v); err != nil { t.Fatal(err) } cnt++ } logger.Infof("NextResultSet: %v", rows.NextResultSet()) }) queryID := <-goRoutineChan if queryID == "" { t.Fatal("expected a nonempty query ID") } if cnt != numrows { t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) } } // End-to-end test to fetch result with queryID func TestE2EFetchResultByID(t *testing.T) { db := openDB(t) defer db.Close() if _, err := db.Exec(`create or replace table test_fetch_result(c1 number, c2 string) as select 10, 'z'`); err != nil { t.Fatalf("failed to create table: %v", err) } ctx := context.Background() conn, err := db.Conn(ctx) if err != nil { t.Error(err) } if err = conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "select * from test_fetch_result") if err != nil { return err } rows1, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) if err != nil { return err } qid := rows1.(SnowflakeResult).GetQueryID() newCtx := context.WithValue(context.Background(), fetchResultByID, qid) rows2, err := db.QueryContext(newCtx, "") if err != nil { t.Fatalf("Fetch Query Result by ID failed: %v", err) } var c1 sql.NullInt64 var c2 sql.NullString for rows2.Next() { err = rows2.Scan(&c1, &c2) } if c1.Int64 != 10 || c2.String != "z" { t.Fatalf("Query result is not expected: %v", err) } return nil }); err != nil { t.Fatalf("failed to drop table: %v", err) } if _, err := db.Exec("drop table if exists test_fetch_result"); err != nil { t.Fatalf("failed to drop table: %v", err) } } func TestWithDescribeOnly(t *testing.T) { runDBTest(t, func(dbt *DBTest) { ctx := WithDescribeOnly(context.Background()) rows := dbt.mustQueryContext(ctx, selectVariousTypes) defer rows.Close() cols, err := rows.Columns() if err != nil { t.Error(err) } types, err := rows.ColumnTypes() if err != nil { t.Error(err) } for i, col := range cols { if types[i].Name() != col { t.Fatalf("column name mismatch. expected: %v, got: %v", col, types[i].Name()) } } if rows.Next() { t.Fatal("there should not be any rows in describe only mode") } }) } func TestCallStatement(t *testing.T) { runDBTest(t, func(dbt *DBTest) { in1 := float64(1) in2 := string("[2,3]") expected := "1 \"[2,3]\" [2,3]" var out string dbt.mustExec("ALTER SESSION SET USE_STATEMENT_TYPE_CALL_FOR_STORED_PROC_CALLS = true") dbt.mustExec("create or replace procedure " + "TEST_SP_CALL_STMT_ENABLED(in1 float, in2 variant) " + "returns string language javascript as $$ " + "let res = snowflake.execute({sqlText: 'select ? c1, ? c2', binds:[IN1, JSON.stringify(IN2)]}); " + "res.next(); " + "return res.getColumnValueAsString(1) + ' ' + res.getColumnValueAsString(2) + ' ' + IN2; " + "$$;") stmt, err := dbt.conn.PrepareContext(context.Background(), "call TEST_SP_CALL_STMT_ENABLED(?, to_variant(?))") if err != nil { dbt.Errorf("failed to prepare query: %v", err) } defer stmt.Close() err = stmt.QueryRow(in1, in2).Scan(&out) if err != nil { dbt.Errorf("failed to scan: %v", err) } if expected != out { dbt.Errorf("expected: %s, got: %s", expected, out) } dbt.mustExec("drop procedure if exists TEST_SP_CALL_STMT_ENABLED(float, variant)") }) } func TestStmtExec(t *testing.T) { ctx := context.Background() runDBTest(t, func(dbt *DBTest) { dbt.mustExecT(t, `create or replace table test_table(col1 int, col2 int)`) if err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (1, 2)") if err != nil { t.Error(err) } _, err = stmt.(*snowflakeStmt).Exec(nil) if err != nil { t.Error(err) } _, err = stmt.(*snowflakeStmt).Query(nil) if err != nil { t.Error(err) } return nil }); err != nil { t.Fatalf("failed to drop table: %v", err) } dbt.mustExecT(t, "drop table if exists test_table") }) } func TestStmtExec_Error(t *testing.T) { ctx := context.Background() runDBTest(t, func(dbt *DBTest) { // Create a test table dbt.mustExecT(t, `create or replace table test_table(col1 int, col2 int)`) defer dbt.mustExecT(t, "drop table if exists test_table") // Attempt to execute an invalid statement if err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (?, ?)") if err != nil { t.Fatalf("failed to prepare statement: %v", err) } // Intentionally passing a string instead of an integer to cause an error _, err = stmt.(*snowflakeStmt).Exec([]driver.Value{"invalid_data", 2}) if err == nil { t.Errorf("expected an error, but got none") } return nil }); err != nil { t.Fatalf("unexpected error: %v", err) } }) } func getStatusSuccessButInvalidJSONfunc(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } func TestUnitCheckQueryStatus(t *testing.T) { sc := getDefaultSnowflakeConn() ctx := context.Background() qid := NewUUID() sr := &snowflakeRestful{ FuncGet: getStatusSuccessButInvalidJSONfunc, TokenAccessor: getSimpleTokenAccessor(), } sc.rest = sr _, err := sc.checkQueryStatus(ctx, qid.String()) if err == nil { t.Fatal("invalid json. should have failed") } sc.rest.FuncGet = funcGetQueryRespFail _, err = sc.checkQueryStatus(ctx, qid.String()) if err == nil { t.Fatal("should have failed") } sc.rest.FuncGet = funcGetQueryRespError _, err = sc.checkQueryStatus(ctx, qid.String()) if err == nil { t.Fatal("should have failed") } driverErr, ok := err.(*SnowflakeError) if !ok { t.Fatalf("should be snowflake error. err: %v", err) } if driverErr.Number != ErrQueryStatus { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrQueryStatus, driverErr.Number) } } func TestStatementQueryIdForQueries(t *testing.T) { ctx := context.Background() testcases := []struct { name string f func(stmt driver.Stmt) (driver.Rows, error) }{ { "query", func(stmt driver.Stmt) (driver.Rows, error) { return stmt.Query(nil) }, }, { "queryContext", func(stmt driver.Stmt) (driver.Rows, error) { return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) }, }, } runDBTest(t, func(dbt *DBTest) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") if err != nil { t.Fatal(err) } if stmt.(SnowflakeStmt).GetQueryID() != "" { t.Error("queryId should be empty before executing any query") } firstQuery, err := tc.f(stmt) if err != nil { t.Fatal(err) } if stmt.(SnowflakeStmt).GetQueryID() == "" { t.Error("queryId should not be empty after executing query") } if stmt.(SnowflakeStmt).GetQueryID() != firstQuery.(SnowflakeRows).GetQueryID() { t.Error("queryId should be equal among query result and prepared statement") } secondQuery, err := tc.f(stmt) if err != nil { t.Fatal(err) } if stmt.(SnowflakeStmt).GetQueryID() == "" { t.Error("queryId should not be empty after executing query") } if stmt.(SnowflakeStmt).GetQueryID() != secondQuery.(SnowflakeRows).GetQueryID() { t.Error("queryId should be equal among query result and prepared statement") } return nil }) if err != nil { t.Fatal(err) } }) } }) } func TestStatementQuery(t *testing.T) { ctx := context.Background() testcases := []struct { name string query string f func(stmt driver.Stmt) (driver.Rows, error) wantErr bool }{ { "validQuery", "SELECT 1", func(stmt driver.Stmt) (driver.Rows, error) { return stmt.Query(nil) }, false, }, { "validQueryContext", "SELECT 1", func(stmt driver.Stmt) (driver.Rows, error) { return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) }, false, }, { "invalidQuery", "SELECT * FROM non_existing_table", func(stmt driver.Stmt) (driver.Rows, error) { return stmt.Query(nil) }, true, }, { "invalidQueryContext", "SELECT * FROM non_existing_table", func(stmt driver.Stmt) (driver.Rows, error) { return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) }, true, }, } runDBTest(t, func(dbt *DBTest) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) if err != nil { if tc.wantErr { return nil // expected error } t.Fatal(err) } _, err = tc.f(stmt) if (err != nil) != tc.wantErr { t.Fatalf("error = %v, wantErr %v", err, tc.wantErr) } return nil }) if err != nil { t.Fatal(err) } }) } }) } func TestStatementQueryIdForExecs(t *testing.T) { ctx := context.Background() runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)") defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs") testcases := []struct { name string f func(stmt driver.Stmt) (driver.Result, error) }{ { "exec", func(stmt driver.Stmt) (driver.Result, error) { return stmt.Exec(nil) }, }, { "execContext", func(stmt driver.Stmt) (driver.Result, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") if err != nil { t.Fatal(err) } if stmt.(SnowflakeStmt).GetQueryID() != "" { t.Error("queryId should be empty before executing any query") } firstExec, err := tc.f(stmt) if err != nil { t.Fatal(err) } if stmt.(SnowflakeStmt).GetQueryID() == "" { t.Error("queryId should not be empty after executing query") } if stmt.(SnowflakeStmt).GetQueryID() != firstExec.(SnowflakeResult).GetQueryID() { t.Error("queryId should be equal among query result and prepared statement") } secondExec, err := tc.f(stmt) if err != nil { t.Fatal(err) } if stmt.(SnowflakeStmt).GetQueryID() == "" { t.Error("queryId should not be empty after executing query") } if stmt.(SnowflakeStmt).GetQueryID() != secondExec.(SnowflakeResult).GetQueryID() { t.Error("queryId should be equal among query result and prepared statement") } return nil }) if err != nil { t.Fatal(err) } }) } }) } func TestStatementQueryExecs(t *testing.T) { ctx := context.Background() runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE TestStatementQueryExecs (v INTEGER)") defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementForExecs") testcases := []struct { name string query string f func(stmt driver.Stmt) (driver.Result, error) wantErr bool }{ { "validExec", "INSERT INTO TestStatementQueryExecs VALUES (1)", func(stmt driver.Stmt) (driver.Result, error) { return stmt.Exec(nil) }, false, }, { "validExecContext", "INSERT INTO TestStatementQueryExecs VALUES (1)", func(stmt driver.Stmt) (driver.Result, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, false, }, { "invalidExec", "INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')", func(stmt driver.Stmt) (driver.Result, error) { return stmt.Exec(nil) }, true, }, { "invalidExecContext", "INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')", func(stmt driver.Stmt) (driver.Result, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, true, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { err := dbt.conn.Raw(func(x any) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) if err != nil { if tc.wantErr { return nil // expected error } t.Fatal(err) } _, err = tc.f(stmt) if (err != nil) != tc.wantErr { t.Fatalf("error = %v, wantErr %v", err, tc.wantErr) } return nil }) if err != nil { t.Fatal(err) } }) } }) } func TestWithQueryTag(t *testing.T) { runDBTest(t, func(dbt *DBTest) { testQueryTag := "TEST QUERY TAG" ctx := WithQueryTag(context.Background(), testQueryTag) // This query itself will be part of the history and will have the query tag rows := dbt.mustQueryContext( ctx, "SELECT QUERY_TAG FROM table(information_schema.query_history_by_session())") defer rows.Close() assertTrueF(t, rows.Next()) var tag sql.NullString err := rows.Scan(&tag) assertNilF(t, err) assertTrueF(t, tag.Valid, "no QUERY_TAG set") assertEqualF(t, tag.String, testQueryTag) }) } ================================================ FILE: storage_client.go ================================================ package gosnowflake import ( "context" "fmt" "math" "os" "path" "path/filepath" "time" ) const ( defaultConcurrency = 1 defaultMaxRetry = 5 ) // implemented by localUtil and remoteStorageUtil type storageUtil interface { createClient(*execResponseStageInfo, bool, *Config, *snowflakeTelemetry) (cloudClient, error) uploadOneFileWithRetry(context.Context, *fileMetadata) error downloadOneFile(context.Context, *fileMetadata) error } // implemented by snowflakeS3Util, snowflakeAzureUtil and snowflakeGcsUtil type cloudUtil interface { createClient(*execResponseStageInfo, bool, *snowflakeTelemetry) (cloudClient, error) getFileHeader(context.Context, *fileMetadata, string) (*fileHeader, error) uploadFile(context.Context, string, *fileMetadata, int, int64) error nativeDownloadFile(context.Context, *fileMetadata, string, int64, int64) error } type cloudClient any type remoteStorageUtil struct { cfg *Config telemetry *snowflakeTelemetry } func (rsu *remoteStorageUtil) getNativeCloudType(cli string, cfg *Config) cloudUtil { if cloudType(cli) == s3Client { logger.Info("Using S3 client for remote storage") return &snowflakeS3Client{ cfg, rsu.telemetry, } } else if cloudType(cli) == azureClient { logger.Info("Using Azure client for remote storage") return &snowflakeAzureClient{ cfg, rsu.telemetry, } } else if cloudType(cli) == gcsClient { logger.Info("Using GCS client for remote storage") return &snowflakeGcsClient{ cfg, rsu.telemetry, } } return nil } // call cloud utils' native create client methods func (rsu *remoteStorageUtil) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config, telemetry *snowflakeTelemetry) (cloudClient, error) { utilClass := rsu.getNativeCloudType(info.LocationType, cfg) return utilClass.createClient(info, useAccelerateEndpoint, telemetry) } func (rsu *remoteStorageUtil) uploadOneFile(ctx context.Context, meta *fileMetadata) error { utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType, meta.sfa.sc.cfg) maxConcurrency := int(meta.parallel) var lastErr error var timer time.Time var elapsedTime string maxRetry := defaultMaxRetry logger.Debugf( "Started Uploading. File: %v, location: %v", meta.realSrcFileName, meta.stageInfo.Location) for retry := range maxRetry { timer = time.Now() if !meta.overwrite { header, err := utilClass.getFileHeader(ctx, meta, meta.dstFileName) if meta.resStatus == notFoundFile { err := utilClass.uploadFile(ctx, meta.realSrcFileName, meta, maxConcurrency, meta.options.MultiPartThreshold) if err != nil { logger.Warnf("Error uploading %v. err: %v", meta.realSrcFileName, err) } } else if err != nil { return err } if header != nil && meta.resStatus == uploaded { meta.dstFileSize = 0 meta.resStatus = skipped return nil } } if meta.overwrite || meta.resStatus == notFoundFile { err := utilClass.uploadFile(ctx, meta.realSrcFileName, meta, maxConcurrency, meta.options.MultiPartThreshold) if err != nil { logger.Warnf("Error uploading %v. err: %v", meta.realSrcFileName, err) } } elapsedTime = time.Since(timer).String() switch meta.resStatus { case uploaded, renewToken, renewPresignedURL: logger.Debugf("Uploading file: %v finished in %v ms with the status: %v.", meta.realSrcFileName, elapsedTime, meta.resStatus) return nil case needRetry: if !meta.noSleepingTime { sleepingTime := intMin(int(math.Exp2(float64(retry))), 16) logger.Debugf("Need to retry for uploading file: %v. Current retry: %v, Sleeping time: %v.", meta.realSrcFileName, retry, sleepingTime) time.Sleep(time.Second * time.Duration(sleepingTime)) } else { logger.Debugf("Need to retry for uploading file: %v. Current retry: %v without the sleeping time.", meta.realSrcFileName, retry) } case needRetryWithLowerConcurrency: maxConcurrency = int(meta.parallel) - (retry * int(meta.parallel) / maxRetry) maxConcurrency = intMax(defaultConcurrency, maxConcurrency) meta.lastMaxConcurrency = maxConcurrency if !meta.noSleepingTime { sleepingTime := intMin(int(math.Exp2(float64(retry))), 16) logger.Debugf("Need to retry with lower concurrency for uploading file: %v. Current retry: %v, Sleeping time: %v.", meta.realSrcFileName, retry, sleepingTime) time.Sleep(time.Second * time.Duration(sleepingTime)) } else { logger.Debugf("Need to retry with lower concurrency for uploading file: %v. Current retry: %v without Sleeping time.", meta.realSrcFileName, retry) } } lastErr = meta.lastError } if lastErr != nil { logger.Errorf(`Failed to uploading file: %v, with error: %v`, meta.realSrcFileName, lastErr) return lastErr } return fmt.Errorf("unkown error uploading %v", meta.realSrcFileName) } func (rsu *remoteStorageUtil) uploadOneFileWithRetry(ctx context.Context, meta *fileMetadata) error { utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType, rsu.cfg) retryOuter := true for range 10 { // retry if err := rsu.uploadOneFile(ctx, meta); err != nil { return err } retryInner := true if meta.resStatus == uploaded || meta.resStatus == skipped { for range 10 { status := meta.resStatus if _, err := utilClass.getFileHeader(ctx, meta, meta.dstFileName); err != nil { logger.Warnf("error while getting file %v header. %v", meta.dstFileSize, err) } // check file header status and verify upload/skip if meta.resStatus == notFoundFile { if !meta.noSleepingTime { time.Sleep(time.Second) // wait 1 second for S3 eventual consistency } continue } else { retryInner = false meta.resStatus = status break } } } if !retryInner { retryOuter = false break } else { continue } } if retryOuter { // wanted to continue retrying but could not upload/find file meta.resStatus = errStatus } return nil } func (rsu *remoteStorageUtil) downloadOneFile(ctx context.Context, meta *fileMetadata) error { fullDstFileName := path.Join(meta.localLocation, baseName(meta.dstFileName)) fullDstFileName, err := expandUser(fullDstFileName) if err != nil { return err } if !filepath.IsAbs(fullDstFileName) { cwd, err := os.Getwd() if err != nil { return err } fullDstFileName = filepath.Join(cwd, fullDstFileName) } baseDir, err := getDirectory() if err != nil { return err } if _, err = os.Stat(baseDir); os.IsNotExist(err) { if err = os.MkdirAll(baseDir, os.ModePerm); err != nil { return err } } utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType, meta.sfa.sc.cfg) header, err := utilClass.getFileHeader(ctx, meta, meta.srcFileName) if err != nil { return err } if header != nil { meta.srcFileSize = header.contentLength } maxConcurrency := meta.parallel partSize := meta.options.MultiPartThreshold var lastErr error maxRetry := defaultMaxRetry timer := time.Now() for range maxRetry { tempDownloadFile := fullDstFileName + ".tmp" defer func() { // Clean up temp file if it still exists if _, statErr := os.Stat(tempDownloadFile); statErr == nil { logger.Debugf("Cleaning up temporary download file: %s", tempDownloadFile) if removeErr := os.Remove(tempDownloadFile); removeErr != nil { logger.Warnf("Failed to clean up temporary file %s: %v", tempDownloadFile, removeErr) } } }() if err = utilClass.nativeDownloadFile(ctx, meta, tempDownloadFile, maxConcurrency, partSize); err != nil { logger.Errorf("Failed to download file to temporary location %s: %v", tempDownloadFile, err) return err } if meta.resStatus == downloaded { logger.Debugf("Downloading file: %v finished in %v ms. File size: %v", meta.srcFileName, time.Since(timer).String(), meta.srcFileSize) if meta.encryptionMaterial != nil { if meta.presignedURL != nil { header, err = utilClass.getFileHeader(ctx, meta, meta.srcFileName) if err != nil { logger.Errorf("Failed to get file header for %s: %v", meta.srcFileName, err) return err } } timer = time.Now() if isFileGetStream(ctx) { totalFileSize, err := decryptStreamCBC(header.encryptionMetadata, meta.encryptionMaterial, 0, meta.dstStream, meta.sfa.streamBuffer) if err != nil { logger.Errorf("Stream decryption failed for %s - temp file will be cleaned up to prevent corrupted data: %v", meta.srcFileName, err) return err } logger.Debugf("Total file size: %d", totalFileSize) if totalFileSize < 0 || totalFileSize > meta.sfa.streamBuffer.Len() { return fmt.Errorf("invalid total file size: %d", totalFileSize) } meta.sfa.streamBuffer.Truncate(totalFileSize) meta.dstFileSize = int64(totalFileSize) } else { if err = rsu.processEncryptedFileToDestination(meta, header, tempDownloadFile, fullDstFileName); err != nil { return err } } logger.Debugf("Decrypting file: %v finished in %v ms.", meta.srcFileName, time.Since(timer).String()) } else { // file is not encrypted if !isFileGetStream(ctx) { // if we have a real file, and not a stream, move the file if err = os.Rename(tempDownloadFile, fullDstFileName); err != nil { return fmt.Errorf("failed to move downloaded file to destination: %w", err) } } else { // if we have a stream and no encyrption, just reuse the stream meta.sfa.streamBuffer = meta.dstStream } } if !isFileGetStream(ctx) { if fi, err := os.Stat(fullDstFileName); err == nil { meta.dstFileSize = fi.Size() } else { logger.Warnf("Failed to get file size for %s: %v", fullDstFileName, err) } } logger.Debugf("File download completed successfully for %s (size: %d bytes)", meta.srcFileName, meta.dstFileSize) return nil } lastErr = meta.lastError } if lastErr != nil { logger.Errorf(`Failed to downloading file: %v, with error: %v`, meta.srcFileName, lastErr) return lastErr } return fmt.Errorf("unkown error downloading %v", fullDstFileName) } func (rsu *remoteStorageUtil) processEncryptedFileToDestination(meta *fileMetadata, header *fileHeader, tempDownloadFile, fullDstFileName string) error { // Clean up the temp download file on any exit path defer func() { if _, statErr := os.Stat(tempDownloadFile); statErr == nil { logger.Debugf("Cleaning up temporary download file: %s", tempDownloadFile) err := os.Remove(tempDownloadFile) if err != nil { logger.Warnf("Failed to clean up temporary download file %s: %v", tempDownloadFile, err) } } }() tmpDstFileName, err := decryptFileCBC(header.encryptionMetadata, meta.encryptionMaterial, tempDownloadFile, 0, meta.tmpDir) // Ensure cleanup of the decrypted temp file if decryption or rename fails defer func() { if _, statErr := os.Stat(tmpDstFileName); statErr == nil { err := os.Remove(tmpDstFileName) if err != nil { logger.Warnf("Failed to clean up temporary decrypted file %s: %v", tmpDstFileName, err) } } }() if err != nil { logger.Errorf("File decryption failed for %s: %v", meta.srcFileName, err) return err } if err = os.Rename(tmpDstFileName, fullDstFileName); err != nil { logger.Errorf("Failed to move decrypted file from %s to final destination %s: %v", tmpDstFileName, fullDstFileName, err) return err } logger.Debugf("Successfully decrypted and moved file to %s", fullDstFileName) return nil } ================================================ FILE: storage_client_test.go ================================================ package gosnowflake import ( "os" "path/filepath" "strings" "testing" ) // TestProcessEncryptedFileToDestination_DecryptionFailure tests that temporary files // are cleaned up when decryption fails due to invalid encryption data func TestProcessEncryptedFileToDestination_DecryptionFailure(t *testing.T) { tmpDir := t.TempDir() fullDstFileName := filepath.Join(tmpDir, "final_destination.txt") tempDownloadFile := fullDstFileName + ".tmp" assertNilF(t, os.WriteFile(tempDownloadFile, []byte("invalid encrypted content"), 0644), "Failed to create temp download file") // Create metadata with invalid encryption material to trigger decryption failure meta := &fileMetadata{ encryptionMaterial: &snowflakeFileEncryption{ QueryStageMasterKey: "invalid_key", // Invalid key to cause decryption failure QueryID: "test-query-id", SMKID: 12345, }, tmpDir: tmpDir, srcFileName: "test_file.txt", } // Create header with invalid encryption metadata header := &fileHeader{ encryptionMetadata: &encryptMetadata{ key: "invalid_key_data", // Invalid encryption data iv: "invalid_iv_data", matdesc: `{"smkId":"12345","queryId":"test-query-id","keySize":"256"}`, }, } // Test: decryption should fail due to invalid encryption data rsu := &remoteStorageUtil{} err := rsu.processEncryptedFileToDestination(meta, header, tempDownloadFile, fullDstFileName) assertNotNilF(t, err, "Expected decryption to fail with invalid encryption data") // Verify that the final destination file was not created _, err = os.Stat(fullDstFileName) assertTrueF(t, os.IsNotExist(err), "Final destination file should not exist after decryption failure") // Verify the temp download file was cleaned up even though decryption failed _, err = os.Stat(tempDownloadFile) assertTrueF(t, os.IsNotExist(err), "Temp download file should be cleaned up even after decryption failure") verifyNoTmpFilesLeftBehind(t, fullDstFileName) } // TestProcessEncryptedFileToDestination_Success tests successful decryption and file handling func TestProcessEncryptedFileToDestination_Success(t *testing.T) { tmpDir := t.TempDir() // Create test data and encrypt it properly inputData := "test data for successful encryption/decryption" inputFile := filepath.Join(tmpDir, "input.txt") assertNilF(t, os.WriteFile(inputFile, []byte(inputData), 0644), "Failed to create input file") // Create valid encryption material encMat := &snowflakeFileEncryption{ QueryStageMasterKey: "ztke8tIdVt1zmlQIZm0BMA==", QueryID: "test-query-id", SMKID: 12345, } // Encrypt the file to create valid encrypted content metadata, encryptedFile, err := encryptFileCBC(encMat, inputFile, 0, tmpDir) assertNilF(t, err, "Failed to encrypt test file") defer os.Remove(encryptedFile) // Create final destination path fullDstFileName := filepath.Join(tmpDir, "final_destination.txt") // Create metadata for decryption meta := &fileMetadata{ encryptionMaterial: encMat, tmpDir: tmpDir, srcFileName: "test_file.txt", } header := &fileHeader{ encryptionMetadata: metadata, } // Test: successful decryption and file move rsu := &remoteStorageUtil{} err = rsu.processEncryptedFileToDestination(meta, header, encryptedFile, fullDstFileName) assertNilF(t, err, "Expected successful decryption and file move") // Verify that the final destination file was created with correct content finalContent, err := os.ReadFile(fullDstFileName) assertNilF(t, err, "Failed to read final destination file") assertEqualF(t, string(finalContent), inputData, "Final file content should match original input") // Verify the final destination file exists and has correct content _, err = os.Stat(fullDstFileName) assertNilF(t, err, "Final destination file should exist") verifyNoTmpFilesLeftBehind(t, fullDstFileName) } func verifyNoTmpFilesLeftBehind(t *testing.T, fullDstFileName string) { destDir := filepath.Dir(fullDstFileName) files, err := os.ReadDir(destDir) assertNilF(t, err, "Failed to read destination directory") tmpFileCount := 0 for _, file := range files { if strings.HasSuffix(file.Name(), ".tmp") { tmpFileCount++ } } assertEqualF(t, tmpFileCount, 0, "No .tmp files should remain in destination directory after successful operation") } ================================================ FILE: storage_file_util_test.go ================================================ package gosnowflake func testEncryptionMeta() *encryptMetadata { const mockMatDesc = "{\"queryid\":\"01abc874-0406-1bf0-0000-53b10668e056\",\"smkid\":\"92019681909886\",\"key\":\"128\"}" return &encryptMetadata{ key: "testencryptedkey12345678910==", iv: "testIVkey12345678910==", matdesc: mockMatDesc, } } ================================================ FILE: structured_type.go ================================================ package gosnowflake import ( "context" "database/sql" "database/sql/driver" "encoding/hex" "encoding/json" "errors" "fmt" "github.com/snowflakedb/gosnowflake/v2/internal/query" "github.com/snowflakedb/gosnowflake/v2/internal/types" "math/big" "reflect" "slices" "strconv" "strings" "time" "unicode" ) // ObjectType Empty marker of an object used in column type ScanType function type ObjectType struct { } var structuredObjectWriterType = reflect.TypeFor[StructuredObjectWriter]() // StructuredObject is a representation of structured object for reading. type StructuredObject interface { GetString(fieldName string) (string, error) GetNullString(fieldName string) (sql.NullString, error) GetByte(fieldName string) (byte, error) GetNullByte(fieldName string) (sql.NullByte, error) GetInt16(fieldName string) (int16, error) GetNullInt16(fieldName string) (sql.NullInt16, error) GetInt32(fieldName string) (int32, error) GetNullInt32(fieldName string) (sql.NullInt32, error) GetInt64(fieldName string) (int64, error) GetNullInt64(fieldName string) (sql.NullInt64, error) GetBigInt(fieldName string) (*big.Int, error) GetFloat32(fieldName string) (float32, error) GetFloat64(fieldName string) (float64, error) GetNullFloat64(fieldName string) (sql.NullFloat64, error) GetBigFloat(fieldName string) (*big.Float, error) GetBool(fieldName string) (bool, error) GetNullBool(fieldName string) (sql.NullBool, error) GetBytes(fieldName string) ([]byte, error) GetTime(fieldName string) (time.Time, error) GetNullTime(fieldName string) (sql.NullTime, error) GetStruct(fieldName string, scanner sql.Scanner) (sql.Scanner, error) GetRaw(fieldName string) (any, error) ScanTo(sc sql.Scanner) error } // StructuredObjectWriter is an interface to implement, when binding structured objects. type StructuredObjectWriter interface { Write(sowc StructuredObjectWriterContext) error } // StructuredObjectWriterContext is a helper interface to write particular fields of structured object. type StructuredObjectWriterContext interface { WriteString(fieldName string, value string) error WriteNullString(fieldName string, value sql.NullString) error WriteByt(fieldName string, value byte) error // WriteByte name is prohibited by go vet WriteNullByte(fieldName string, value sql.NullByte) error WriteInt16(fieldName string, value int16) error WriteNullInt16(fieldName string, value sql.NullInt16) error WriteInt32(fieldName string, value int32) error WriteNullInt32(fieldName string, value sql.NullInt32) error WriteInt64(fieldName string, value int64) error WriteNullInt64(fieldName string, value sql.NullInt64) error WriteFloat32(fieldName string, value float32) error WriteFloat64(fieldName string, value float64) error WriteNullFloat64(fieldName string, value sql.NullFloat64) error WriteBytes(fieldName string, value []byte) error WriteBool(fieldName string, value bool) error WriteNullBool(fieldName string, value sql.NullBool) error WriteTime(fieldName string, value time.Time, tsmode []byte) error WriteNullTime(fieldName string, value sql.NullTime, tsmode []byte) error WriteStruct(fieldName string, value StructuredObjectWriter) error WriteNullableStruct(fieldName string, value StructuredObjectWriter, typ reflect.Type) error // WriteRaw is used for inserting slices and maps only. WriteRaw(fieldName string, value any, tsmode ...[]byte) error // WriteNullRaw is used for inserting nil slices and maps only. WriteNullRaw(fieldName string, typ reflect.Type, tsmode ...[]byte) error WriteAll(sow StructuredObjectWriter) error } // NilMapTypes is used to define types when binding nil maps. type NilMapTypes struct { Key reflect.Type Value reflect.Type } type structuredObjectWriterEntry struct { name string typ string nullable bool length int scale int precision int fields []query.FieldMetadata } func (e *structuredObjectWriterEntry) toFieldMetadata() query.FieldMetadata { return query.FieldMetadata{ Name: e.name, Type: e.typ, Nullable: e.nullable, Length: e.length, Scale: e.scale, Precision: e.precision, Fields: e.fields, } } type structuredObjectWriterContext struct { values map[string]any entries []structuredObjectWriterEntry params *syncParams } func (sowc *structuredObjectWriterContext) init(params *syncParams) { sowc.values = make(map[string]any) sowc.params = params } func (sowc *structuredObjectWriterContext) WriteString(fieldName string, value string) error { return sowc.writeString(fieldName, value) } func (sowc *structuredObjectWriterContext) WriteNullString(fieldName string, value sql.NullString) error { if value.Valid { return sowc.WriteString(fieldName, value.String) } return sowc.writeString(fieldName, nil) } func (sowc *structuredObjectWriterContext) writeString(fieldName string, value any) error { return sowc.write(value, structuredObjectWriterEntry{ name: fieldName, typ: "text", nullable: true, length: 134217728, }) } func (sowc *structuredObjectWriterContext) WriteByt(fieldName string, value byte) error { return sowc.writeFixed(fieldName, value) } func (sowc *structuredObjectWriterContext) WriteNullByte(fieldName string, value sql.NullByte) error { if value.Valid { return sowc.writeFixed(fieldName, value.Byte) } return sowc.writeFixed(fieldName, nil) } func (sowc *structuredObjectWriterContext) WriteInt16(fieldName string, value int16) error { return sowc.writeFixed(fieldName, value) } func (sowc *structuredObjectWriterContext) WriteNullInt16(fieldName string, value sql.NullInt16) error { if value.Valid { return sowc.writeFixed(fieldName, value.Int16) } return sowc.writeFixed(fieldName, nil) } func (sowc *structuredObjectWriterContext) WriteInt32(fieldName string, value int32) error { return sowc.writeFixed(fieldName, value) } func (sowc *structuredObjectWriterContext) WriteNullInt32(fieldName string, value sql.NullInt32) error { if value.Valid { return sowc.writeFixed(fieldName, value.Int32) } return sowc.writeFixed(fieldName, nil) } func (sowc *structuredObjectWriterContext) WriteInt64(fieldName string, value int64) error { return sowc.writeFixed(fieldName, value) } func (sowc *structuredObjectWriterContext) WriteNullInt64(fieldName string, value sql.NullInt64) error { if value.Valid { return sowc.writeFixed(fieldName, value.Int64) } return sowc.writeFixed(fieldName, nil) } func (sowc *structuredObjectWriterContext) WriteFloat32(fieldName string, value float32) error { return sowc.writeFloat(fieldName, value) } func (sowc *structuredObjectWriterContext) WriteFloat64(fieldName string, value float64) error { return sowc.writeFloat(fieldName, value) } func (sowc *structuredObjectWriterContext) WriteNullFloat64(fieldName string, value sql.NullFloat64) error { if value.Valid { return sowc.writeFloat(fieldName, value.Float64) } return sowc.writeFloat(fieldName, nil) } func (sowc *structuredObjectWriterContext) WriteBool(fieldName string, value bool) error { return sowc.writeBool(fieldName, value) } func (sowc *structuredObjectWriterContext) WriteNullBool(fieldName string, value sql.NullBool) error { if value.Valid { return sowc.writeBool(fieldName, value.Bool) } return sowc.writeBool(fieldName, nil) } func (sowc *structuredObjectWriterContext) writeBool(fieldName string, value any) error { return sowc.write(value, structuredObjectWriterEntry{ name: fieldName, typ: "boolean", nullable: true, }) } func (sowc *structuredObjectWriterContext) WriteBytes(fieldName string, value []byte) error { var res *string if value != nil { r := hex.EncodeToString(value) res = &r } return sowc.write(res, structuredObjectWriterEntry{ name: fieldName, typ: "binary", nullable: true, }) } func (sowc *structuredObjectWriterContext) WriteTime(fieldName string, value time.Time, tsmode []byte) error { snowflakeType, err := dataTypeMode(tsmode) if err != nil { return err } typ := types.DriverTypeToSnowflake[snowflakeType] sfFormat, err := dateTimeInputFormatByType(typ, sowc.params) if err != nil { return err } goFormat, err := snowflakeFormatToGoFormat(sfFormat) if err != nil { return err } return sowc.writeTime(fieldName, value.Format(goFormat), typ) } func (sowc *structuredObjectWriterContext) WriteNullTime(fieldName string, value sql.NullTime, tsmode []byte) error { if value.Valid { return sowc.WriteTime(fieldName, value.Time, tsmode) } snowflakeType, err := dataTypeMode(tsmode) if err != nil { return err } typ := types.DriverTypeToSnowflake[snowflakeType] return sowc.writeTime(fieldName, nil, typ) } func (sowc *structuredObjectWriterContext) writeTime(fieldName string, value any, typ string) error { return sowc.write(value, structuredObjectWriterEntry{ name: fieldName, typ: strings.ToLower(typ), nullable: true, scale: 9, }) } func (sowc *structuredObjectWriterContext) WriteStruct(fieldName string, value StructuredObjectWriter) error { if reflect.ValueOf(value).IsNil() { return fmt.Errorf("%s is nil, use WriteNullableStruct instead", fieldName) } childSowc := structuredObjectWriterContext{} childSowc.init(sowc.params) err := value.Write(&childSowc) if err != nil { return err } return sowc.write(childSowc.values, structuredObjectWriterEntry{ name: fieldName, typ: "object", nullable: true, fields: childSowc.toFields(), }) } func (sowc *structuredObjectWriterContext) WriteNullableStruct(structFieldName string, value StructuredObjectWriter, typ reflect.Type) error { if value == nil || reflect.ValueOf(value).IsNil() { childSowc, err := buildSowcFromType(sowc.params, typ) if err != nil { return err } return sowc.write(nil, structuredObjectWriterEntry{ name: structFieldName, typ: "OBJECT", nullable: true, fields: childSowc.toFields(), }) } return sowc.WriteStruct(structFieldName, value) } func (sowc *structuredObjectWriterContext) WriteRaw(fieldName string, value any, dataTypeModes ...[]byte) error { dataTypeModeSingle := DataTypeArray if len(dataTypeModes) == 1 && dataTypeModes[0] != nil { dataTypeModeSingle = dataTypeModes[0] } tsmode, err := dataTypeMode(dataTypeModeSingle) if err != nil { return err } switch reflect.ValueOf(value).Kind() { case reflect.Slice: metadata, err := goTypeToFieldMetadata(reflect.TypeOf(value).Elem(), tsmode, sowc.params) if err != nil { return err } return sowc.write(value, structuredObjectWriterEntry{ name: fieldName, typ: "ARRAY", nullable: true, fields: []query.FieldMetadata{metadata}, }) case reflect.Map: keyMetadata, err := goTypeToFieldMetadata(reflect.TypeOf(value).Key(), tsmode, sowc.params) if err != nil { return err } valueMetadata, err := goTypeToFieldMetadata(reflect.TypeOf(value).Elem(), tsmode, sowc.params) if err != nil { return err } return sowc.write(value, structuredObjectWriterEntry{ name: fieldName, typ: "MAP", nullable: true, fields: []query.FieldMetadata{keyMetadata, valueMetadata}, }) } return fmt.Errorf("unsupported raw type: %T", value) } func (sowc *structuredObjectWriterContext) WriteNullRaw(fieldName string, typ reflect.Type, dataTypeModes ...[]byte) error { dataTypeModeSingle := DataTypeArray if len(dataTypeModes) == 1 && dataTypeModes[0] != nil { dataTypeModeSingle = dataTypeModes[0] } tsmode, err := dataTypeMode(dataTypeModeSingle) if err != nil { return err } if typ.Kind() == reflect.Slice || typ.Kind() == reflect.Map { metadata, err := goTypeToFieldMetadata(typ, tsmode, sowc.params) if err != nil { return err } if err := sowc.write(nil, structuredObjectWriterEntry{ name: fieldName, typ: metadata.Type, nullable: true, fields: metadata.Fields, }); err != nil { return err } return nil } return fmt.Errorf("cannot use %v as nillable field", typ.Kind().String()) } func buildSowcFromType(params *syncParams, typ reflect.Type) (*structuredObjectWriterContext, error) { childSowc := &structuredObjectWriterContext{} childSowc.init(params) if typ.Kind() == reflect.Pointer { typ = typ.Elem() } for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) fieldName := getSfFieldName(field) if field.Type.Kind() == reflect.String { if err := childSowc.writeString(fieldName, nil); err != nil { return nil, err } } else if field.Type.Kind() == reflect.Uint8 || field.Type.Kind() == reflect.Int16 || field.Type.Kind() == reflect.Int32 || field.Type.Kind() == reflect.Int64 { if err := childSowc.writeFixed(fieldName, nil); err != nil { return nil, err } } else if field.Type.Kind() == reflect.Float32 || field.Type.Kind() == reflect.Float64 { if err := childSowc.writeFloat(fieldName, nil); err != nil { return nil, err } } else if field.Type.Kind() == reflect.Bool { if err := childSowc.writeBool(fieldName, nil); err != nil { return nil, err } } else if (field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) && field.Type.Elem().Kind() == reflect.Uint8 { if err := childSowc.WriteBytes(fieldName, nil); err != nil { return nil, err } } else if field.Type.Kind() == reflect.Struct || field.Type.Kind() == reflect.Pointer { t := field.Type if field.Type.Kind() == reflect.Pointer { t = field.Type.Elem() } if t.AssignableTo(reflect.TypeFor[sql.NullString]()) { if err := childSowc.WriteNullString(fieldName, sql.NullString{}); err != nil { return nil, err } } else if t.AssignableTo(reflect.TypeFor[sql.NullByte]()) { if err := childSowc.WriteNullByte(fieldName, sql.NullByte{}); err != nil { return nil, err } } else if t.AssignableTo(reflect.TypeFor[sql.NullInt16]()) { if err := childSowc.WriteNullInt16(fieldName, sql.NullInt16{}); err != nil { return nil, err } } else if t.AssignableTo(reflect.TypeFor[sql.NullInt32]()) { if err := childSowc.WriteNullInt32(fieldName, sql.NullInt32{}); err != nil { return nil, err } } else if t.AssignableTo(reflect.TypeFor[sql.NullInt64]()) { if err := childSowc.WriteNullInt64(fieldName, sql.NullInt64{}); err != nil { return nil, err } } else if t.AssignableTo(reflect.TypeFor[sql.NullFloat64]()) { if err := childSowc.WriteNullFloat64(fieldName, sql.NullFloat64{}); err != nil { return nil, err } } else if t.AssignableTo(reflect.TypeFor[sql.NullBool]()) { if err := childSowc.WriteNullBool(fieldName, sql.NullBool{}); err != nil { return nil, err } } else if t.AssignableTo(reflect.TypeFor[sql.NullTime]()) || t.AssignableTo(reflect.TypeFor[time.Time]()) { timeSnowflakeType, err := getTimeSnowflakeType(field) if err != nil { return nil, err } if timeSnowflakeType == nil { return nil, fmt.Errorf("field %v does not have proper sf tag", fieldName) } if err := childSowc.WriteNullTime(fieldName, sql.NullTime{}, timeSnowflakeType); err != nil { return nil, err } } else if field.Type.AssignableTo(structuredObjectWriterType) { if err := childSowc.WriteNullableStruct(fieldName, nil, field.Type); err != nil { return nil, err } } else if t.Implements(reflect.TypeFor[driver.Valuer]()) { if err := childSowc.WriteNullString(fieldName, sql.NullString{}); err != nil { return nil, err } } else { return nil, fmt.Errorf("field %s has unsupported type", field.Name) } } else if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Map { timeSnowflakeType, err := getTimeSnowflakeType(field) if err != nil { return nil, err } if err := childSowc.WriteNullRaw(fieldName, field.Type, timeSnowflakeType); err != nil { return nil, err } } } return childSowc, nil } func (sowc *structuredObjectWriterContext) writeFixed(fieldName string, value any) error { return sowc.write(value, structuredObjectWriterEntry{ name: fieldName, typ: "fixed", nullable: true, precision: 38, scale: 0, }) } func (sowc *structuredObjectWriterContext) writeFloat(fieldName string, value any) error { return sowc.write(value, structuredObjectWriterEntry{ name: fieldName, typ: "real", nullable: true, precision: 38, scale: 0, }) } func (sowc *structuredObjectWriterContext) write(value any, entry structuredObjectWriterEntry) error { sowc.values[entry.name] = value sowc.entries = append(sowc.entries, entry) return nil } func (sowc *structuredObjectWriterContext) WriteAll(sow StructuredObjectWriter) error { typ := reflect.TypeOf(sow) if typ.Kind() == reflect.Pointer { typ = typ.Elem() } val := reflect.Indirect(reflect.ValueOf(sow)) for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) if shouldIgnoreField(field) { continue } fieldName := getSfFieldName(field) if field.Type.Kind() == reflect.String { if err := sowc.WriteString(fieldName, val.Field(i).String()); err != nil { return err } } else if field.Type.Kind() == reflect.Uint8 { if err := sowc.WriteByt(fieldName, byte(val.Field(i).Uint())); err != nil { return err } } else if field.Type.Kind() == reflect.Int16 { if err := sowc.WriteInt16(fieldName, int16(val.Field(i).Int())); err != nil { return err } } else if field.Type.Kind() == reflect.Int32 { if err := sowc.WriteInt32(fieldName, int32(val.Field(i).Int())); err != nil { return err } } else if field.Type.Kind() == reflect.Int64 { if err := sowc.WriteInt64(fieldName, val.Field(i).Int()); err != nil { return err } } else if field.Type.Kind() == reflect.Float32 { if err := sowc.WriteFloat32(fieldName, float32(val.Field(i).Float())); err != nil { return err } } else if field.Type.Kind() == reflect.Float64 { if err := sowc.WriteFloat64(fieldName, val.Field(i).Float()); err != nil { return err } } else if field.Type.Kind() == reflect.Bool { if err := sowc.WriteBool(fieldName, val.Field(i).Bool()); err != nil { return err } } else if (field.Type.Kind() == reflect.Array || field.Type.Kind() == reflect.Slice) && field.Type.Elem().Kind() == reflect.Uint8 { if err := sowc.WriteBytes(fieldName, val.Field(i).Bytes()); err != nil { return err } } else if field.Type.Kind() == reflect.Struct || field.Type.Kind() == reflect.Pointer { if v, ok := val.Field(i).Interface().(time.Time); ok { timeSnowflakeType, err := getTimeSnowflakeType(typ.Field(i)) if err != nil { return err } if timeSnowflakeType == nil { return fmt.Errorf("field %v does not have a proper sf tag", fieldName) } if err := sowc.WriteTime(fieldName, v, timeSnowflakeType); err != nil { return err } } else if v, ok := val.Field(i).Interface().(sql.NullString); ok { if err := sowc.WriteNullString(fieldName, v); err != nil { return err } } else if v, ok := val.Field(i).Interface().(sql.NullByte); ok { if err := sowc.WriteNullByte(fieldName, v); err != nil { return err } } else if v, ok := val.Field(i).Interface().(sql.NullInt16); ok { if err := sowc.WriteNullInt16(fieldName, v); err != nil { return err } } else if v, ok := val.Field(i).Interface().(sql.NullInt32); ok { if err := sowc.WriteNullInt32(fieldName, v); err != nil { return err } } else if v, ok := val.Field(i).Interface().(sql.NullInt64); ok { if err := sowc.WriteNullInt64(fieldName, v); err != nil { return err } } else if v, ok := val.Field(i).Interface().(sql.NullFloat64); ok { if err := sowc.WriteNullFloat64(fieldName, v); err != nil { return err } } else if v, ok := val.Field(i).Interface().(sql.NullBool); ok { if err := sowc.WriteNullBool(fieldName, v); err != nil { return err } } else if v, ok := val.Field(i).Interface().(sql.NullTime); ok { timeSnowflakeType, err := getTimeSnowflakeType(typ.Field(i)) if err != nil { return err } if timeSnowflakeType == nil { return fmt.Errorf("field %v does not have a proper sf tag", fieldName) } if err := sowc.WriteNullTime(fieldName, v, timeSnowflakeType); err != nil { return err } } else if v, ok := val.Field(i).Interface().(StructuredObjectWriter); ok { if reflect.ValueOf(v).IsNil() { if err := sowc.WriteNullableStruct(fieldName, nil, reflect.TypeOf(v)); err != nil { return err } } else { childSowc := &structuredObjectWriterContext{} childSowc.init(sowc.params) if err := v.Write(childSowc); err != nil { return err } if err := sowc.write(childSowc.values, structuredObjectWriterEntry{ name: fieldName, typ: "OBJECT", nullable: true, fields: childSowc.toFields(), }); err != nil { return err } } } } else if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Map { var timeSfType []byte var err error if field.Type.Elem().AssignableTo(reflect.TypeFor[time.Time]()) || field.Type.Elem().AssignableTo(reflect.TypeFor[sql.NullTime]()) { timeSfType, err = getTimeSnowflakeType(typ.Field(i)) if err != nil { return err } } if err := sowc.WriteRaw(fieldName, val.Field(i).Interface(), timeSfType); err != nil { return err } } else { return fmt.Errorf("field %s has unsupported type", field.Name) } } return nil } func (sowc *structuredObjectWriterContext) toFields() []query.FieldMetadata { fieldMetadatas := make([]query.FieldMetadata, len(sowc.entries)) for i, entry := range sowc.entries { fieldMetadatas[i] = entry.toFieldMetadata() } return fieldMetadatas } // ArrayOfScanners Helper type for scanning array of sql.Scanner values. type ArrayOfScanners[T sql.Scanner] []T func (st *ArrayOfScanners[T]) Scan(val any) error { if val == nil { return nil } sts := val.([]*structuredType) *st = make([]T, len(sts)) var t T for i, s := range sts { (*st)[i] = reflect.New(reflect.TypeOf(t).Elem()).Interface().(T) if err := (*st)[i].Scan(s); err != nil { return err } } return nil } // ScanArrayOfScanners is a helper function for scanning arrays of sql.Scanner values. // Example: // // var res []*simpleObject // err := rows.Scan(ScanArrayOfScanners(&res)) func ScanArrayOfScanners[T sql.Scanner](value *[]T) *ArrayOfScanners[T] { return (*ArrayOfScanners[T])(value) } // MapOfScanners Helper type for scanning map of sql.Scanner values. type MapOfScanners[K comparable, V sql.Scanner] map[K]V func (st *MapOfScanners[K, V]) Scan(val any) error { if val == nil { return nil } sts := val.(map[K]*structuredType) *st = make(map[K]V) var someV V for k, v := range sts { if v != nil && !reflect.ValueOf(v).IsNil() { (*st)[k] = reflect.New(reflect.TypeOf(someV).Elem()).Interface().(V) if err := (*st)[k].Scan(sts[k]); err != nil { return err } } else { (*st)[k] = reflect.Zero(reflect.TypeOf(someV)).Interface().(V) } } return nil } // ScanMapOfScanners is a helper function for scanning maps of sql.Scanner values. // Example: // // var res map[string]*simpleObject // err := rows.Scan(ScanMapOfScanners(&res)) func ScanMapOfScanners[K comparable, V sql.Scanner](m *map[K]V) *MapOfScanners[K, V] { return (*MapOfScanners[K, V])(m) } type structuredType struct { values map[string]any fieldMetadata []query.FieldMetadata params *syncParams } func getType[T any](st *structuredType, fieldName string, emptyValue T) (T, bool, error) { v, ok := st.values[fieldName] if !ok { return emptyValue, false, errors.New("field " + fieldName + " does not exist") } if v == nil { return emptyValue, true, nil } v, ok = v.(T) if !ok { return emptyValue, false, fmt.Errorf("cannot convert field %v to %T", fieldName, emptyValue) } return v.(T), false, nil } func (st *structuredType) GetString(fieldName string) (string, error) { nullString, err := st.GetNullString(fieldName) if err != nil { return "", err } if !nullString.Valid { return "", fmt.Errorf("nil value for %v, use GetNullString instead", fieldName) } return nullString.String, nil } func (st *structuredType) GetNullString(fieldName string) (sql.NullString, error) { s, wasNil, err := getType[string](st, fieldName, "") if err != nil { return sql.NullString{}, err } if wasNil { return sql.NullString{Valid: false}, err } return sql.NullString{Valid: true, String: s}, nil } func (st *structuredType) GetByte(fieldName string) (byte, error) { nullByte, err := st.GetNullByte(fieldName) if err != nil { return 0, err } if !nullByte.Valid { return 0, fmt.Errorf("nil value for %v, use GetNullByte instead", fieldName) } return nullByte.Byte, nil } func (st *structuredType) GetNullByte(fieldName string) (sql.NullByte, error) { b, err := st.GetNullInt64(fieldName) if err != nil { return sql.NullByte{}, err } if !b.Valid { return sql.NullByte{Valid: false}, nil } return sql.NullByte{Valid: true, Byte: byte(b.Int64)}, nil } func (st *structuredType) GetInt16(fieldName string) (int16, error) { nullInt16, err := st.GetNullInt16(fieldName) if err != nil { return 0, err } if !nullInt16.Valid { return 0, fmt.Errorf("nil value for %v, use GetNullInt16 instead", fieldName) } return nullInt16.Int16, nil } func (st *structuredType) GetNullInt16(fieldName string) (sql.NullInt16, error) { b, err := st.GetNullInt64(fieldName) if err != nil { return sql.NullInt16{}, err } if !b.Valid { return sql.NullInt16{Valid: false}, nil } return sql.NullInt16{Valid: true, Int16: int16(b.Int64)}, nil } func (st *structuredType) GetInt32(fieldName string) (int32, error) { nullInt32, err := st.GetNullInt32(fieldName) if err != nil { return 0, err } if !nullInt32.Valid { return 0, fmt.Errorf("nil value for %v, use GetNullInt32 instead", fieldName) } return nullInt32.Int32, nil } func (st *structuredType) GetNullInt32(fieldName string) (sql.NullInt32, error) { b, err := st.GetNullInt64(fieldName) if err != nil { return sql.NullInt32{}, err } if !b.Valid { return sql.NullInt32{Valid: false}, nil } return sql.NullInt32{Valid: true, Int32: int32(b.Int64)}, nil } func (st *structuredType) GetInt64(fieldName string) (int64, error) { nullInt64, err := st.GetNullInt64(fieldName) if err != nil { return 0, err } if !nullInt64.Valid { return 0, fmt.Errorf("nil value for %v, use GetNullInt64 instead", fieldName) } return nullInt64.Int64, nil } func (st *structuredType) GetNullInt64(fieldName string) (sql.NullInt64, error) { i64, wasNil, err := getType[int64](st, fieldName, 0) if wasNil { return sql.NullInt64{Valid: false}, err } if err == nil { return sql.NullInt64{Valid: true, Int64: i64}, nil } if s, _, err := getType[string](st, fieldName, ""); err == nil { i, err := strconv.ParseInt(s, 10, 64) if err != nil { return sql.NullInt64{Valid: false}, err } return sql.NullInt64{Valid: true, Int64: i}, nil } else if b, _, err := getType[float64](st, fieldName, 0); err == nil { return sql.NullInt64{Valid: true, Int64: int64(b)}, nil } else if b, _, err := getType[json.Number](st, fieldName, ""); err == nil { i, err := strconv.ParseInt(string(b), 10, 64) if err != nil { return sql.NullInt64{Valid: false}, err } return sql.NullInt64{Valid: true, Int64: i}, err } else { return sql.NullInt64{Valid: false}, fmt.Errorf("cannot cast column %v to byte", fieldName) } } func (st *structuredType) GetBigInt(fieldName string) (*big.Int, error) { b, wasNull, err := getType[*big.Int](st, fieldName, new(big.Int)) if wasNull { return nil, nil } return b, err } func (st *structuredType) GetFloat32(fieldName string) (float32, error) { f32, err := st.GetFloat64(fieldName) if err != nil { return 0, err } return float32(f32), err } func (st *structuredType) GetFloat64(fieldName string) (float64, error) { nullFloat64, err := st.GetNullFloat64(fieldName) if err != nil { return 0, err } if !nullFloat64.Valid { return 0, fmt.Errorf("nil value for %v, use GetNullFloat64 instead", fieldName) } return nullFloat64.Float64, nil } func (st *structuredType) GetNullFloat64(fieldName string) (sql.NullFloat64, error) { float64, wasNull, err := getType[float64](st, fieldName, 0) if wasNull { return sql.NullFloat64{Valid: false}, nil } if err == nil { return sql.NullFloat64{Valid: true, Float64: float64}, nil } s, _, err := getType[string](st, fieldName, "") if err == nil { f64, err := strconv.ParseFloat(s, 64) if err != nil { return sql.NullFloat64{}, err } return sql.NullFloat64{Valid: true, Float64: f64}, err } jsonNumber, _, err := getType[json.Number](st, fieldName, "") if err != nil { return sql.NullFloat64{}, err } f64, err := strconv.ParseFloat(string(jsonNumber), 64) if err != nil { return sql.NullFloat64{}, err } return sql.NullFloat64{Valid: true, Float64: f64}, nil } func (st *structuredType) GetBigFloat(fieldName string) (*big.Float, error) { float, wasNull, err := getType[*big.Float](st, fieldName, new(big.Float)) if wasNull { return nil, nil } return float, err } func (st *structuredType) GetBool(fieldName string) (bool, error) { nullBool, err := st.GetNullBool(fieldName) if err != nil { return false, err } if !nullBool.Valid { return false, fmt.Errorf("nil value for %v, use GetNullBool instead", fieldName) } return nullBool.Bool, err } func (st *structuredType) GetNullBool(fieldName string) (sql.NullBool, error) { b, wasNull, err := getType[bool](st, fieldName, false) if wasNull { return sql.NullBool{Valid: false}, nil } if err != nil { return sql.NullBool{}, err } return sql.NullBool{Valid: true, Bool: b}, nil } func (st *structuredType) GetBytes(fieldName string) ([]byte, error) { if bi, _, err := getType[[]byte](st, fieldName, nil); err == nil { return bi, nil } else if bi, _, err := getType[string](st, fieldName, ""); err == nil { return hex.DecodeString(bi) } bytes, _, err := getType[[]byte](st, fieldName, []byte{}) return bytes, err } func (st *structuredType) GetTime(fieldName string) (time.Time, error) { nullTime, err := st.GetNullTime(fieldName) if err != nil { return time.Time{}, err } if !nullTime.Valid { return time.Time{}, fmt.Errorf("nil value for %v, use GetNullBool instead", fieldName) } return nullTime.Time, nil } func (st *structuredType) GetNullTime(fieldName string) (sql.NullTime, error) { s, wasNull, err := getType[string](st, fieldName, "") if wasNull { return sql.NullTime{Valid: false}, nil } if err == nil { fieldMetadata, err := st.fieldMetadataByFieldName(fieldName) if err != nil { return sql.NullTime{}, err } format, err := dateTimeOutputFormatByType(fieldMetadata.Type, st.params) if err != nil { return sql.NullTime{}, err } goFormat, err := snowflakeFormatToGoFormat(format) if err != nil { return sql.NullTime{}, err } time, err := time.Parse(goFormat, s) return sql.NullTime{Valid: true, Time: time}, err } time, _, err := getType[time.Time](st, fieldName, time.Time{}) if err != nil { return sql.NullTime{}, err } return sql.NullTime{Valid: true, Time: time}, nil } func (st *structuredType) GetStruct(fieldName string, scanner sql.Scanner) (sql.Scanner, error) { childSt, wasNull, err := getType[*structuredType](st, fieldName, &structuredType{}) if wasNull { return nil, nil } if err != nil { return nil, err } err = scanner.Scan(childSt) return scanner, err } func (st *structuredType) GetRaw(fieldName string) (any, error) { return st.values[fieldName], nil } func (st *structuredType) ScanTo(sc sql.Scanner) error { v := reflect.Indirect(reflect.ValueOf(sc)) t := v.Type() for i := 0; i < t.NumField(); i++ { field := t.Field(i) if shouldIgnoreField(field) { continue } switch field.Type.Kind() { case reflect.String: s, err := st.GetString(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).SetString(s) case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, err := st.GetInt64(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).SetInt(i) case reflect.Uint8: b, err := st.GetByte(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).SetUint(uint64(int64(b))) case reflect.Float32, reflect.Float64: f, err := st.GetFloat64(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).SetFloat(f) case reflect.Bool: b, err := st.GetBool(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).SetBool(b) case reflect.Slice, reflect.Array: switch field.Type.Elem().Kind() { case reflect.Uint8: b, err := st.GetBytes(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).SetBytes(b) default: raw, err := st.GetRaw(getSfFieldName(field)) if err != nil { return err } if raw != nil { v.FieldByName(field.Name).Set(reflect.ValueOf(raw)) } } case reflect.Map: raw, err := st.GetRaw(getSfFieldName(field)) if err != nil { return err } if raw != nil { v.FieldByName(field.Name).Set(reflect.ValueOf(raw)) } case reflect.Struct: a := v.FieldByName(field.Name).Interface() if _, ok := a.(time.Time); ok { time, err := st.GetTime(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(time)) } else if _, ok := a.(sql.Scanner); ok { scanner := reflect.New(reflect.TypeOf(a)).Interface().(sql.Scanner) s, err := st.GetStruct(getSfFieldName(field), scanner) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.Indirect(reflect.ValueOf(s))) } else if _, ok := a.(sql.NullString); ok { ns, err := st.GetNullString(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(ns)) } else if _, ok := a.(sql.NullByte); ok { nb, err := st.GetNullByte(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(nb)) } else if _, ok := a.(sql.NullBool); ok { nb, err := st.GetNullBool(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(nb)) } else if _, ok := a.(sql.NullInt16); ok { ni, err := st.GetNullInt16(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(ni)) } else if _, ok := a.(sql.NullInt32); ok { ni, err := st.GetNullInt32(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(ni)) } else if _, ok := a.(sql.NullInt64); ok { ni, err := st.GetNullInt64(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(ni)) } else if _, ok := a.(sql.NullFloat64); ok { nf, err := st.GetNullFloat64(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(nf)) } else if _, ok := a.(sql.NullTime); ok { nt, err := st.GetNullTime(getSfFieldName(field)) if err != nil { return err } v.FieldByName(field.Name).Set(reflect.ValueOf(nt)) } case reflect.Pointer: switch field.Type.Elem().Kind() { case reflect.Struct: a := reflect.New(field.Type.Elem()).Interface() s, err := st.GetStruct(getSfFieldName(field), a.(sql.Scanner)) if err != nil { return err } if s != nil { v.FieldByName(field.Name).Set(reflect.ValueOf(s)) } default: return errors.New("only struct pointers are supported") } } } return nil } func (st *structuredType) fieldMetadataByFieldName(fieldName string) (query.FieldMetadata, error) { for _, fm := range st.fieldMetadata { if fm.Name == fieldName { return fm, nil } } return query.FieldMetadata{}, errors.New("no metadata for field " + fieldName) } func structuredTypesEnabled(ctx context.Context) bool { v := ctx.Value(enableStructuredTypes) if v == nil { return false } d, ok := v.(bool) return ok && d } func embeddedValuesNullableEnabled(ctx context.Context) bool { v := ctx.Value(embeddedValuesNullable) if v == nil { return false } d, ok := v.(bool) return ok && d } func getSfFieldName(field reflect.StructField) string { sfTag := field.Tag.Get("sf") if sfTag != "" { return strings.Split(sfTag, ",")[0] } r := []rune(field.Name) r[0] = unicode.ToLower(r[0]) return string(r) } func shouldIgnoreField(field reflect.StructField) bool { sfTag := strings.ToLower(field.Tag.Get("sf")) if sfTag == "" { return false } return slices.Contains(strings.Split(sfTag, ",")[1:], "ignore") } func getTimeSnowflakeType(field reflect.StructField) ([]byte, error) { sfTag := strings.ToLower(field.Tag.Get("sf")) if sfTag == "" { return nil, nil } values := strings.Split(sfTag, ",")[1:] if slices.Contains(values, "time") { return DataTypeTime, nil } else if slices.Contains(values, "date") { return DataTypeDate, nil } else if slices.Contains(values, "ltz") { return DataTypeTimestampLtz, nil } else if slices.Contains(values, "ntz") { return DataTypeTimestampNtz, nil } else if slices.Contains(values, "tz") { return DataTypeTimestampTz, nil } return nil, nil } ================================================ FILE: structured_type_arrow_batches_test.go ================================================ package gosnowflake_test import ( "context" "crypto/rsa" "crypto/x509" "database/sql" "database/sql/driver" "encoding/pem" "fmt" "os" "path/filepath" "reflect" "strings" "testing" "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" sf "github.com/snowflakedb/gosnowflake/v2" "github.com/snowflakedb/gosnowflake/v2/arrowbatches" ) func arrowTestRepoRoot(t *testing.T) string { t.Helper() dir, err := os.Getwd() if err != nil { t.Fatalf("failed to get working directory: %v", err) } for { if _, err = os.Stat(filepath.Join(dir, "go.mod")); err == nil { return dir } if !os.IsNotExist(err) { t.Fatalf("failed to stat go.mod in %q: %v", dir, err) } parent := filepath.Dir(dir) if parent == dir { t.Fatal("could not find repository root (no go.mod found)") } dir = parent } } func arrowTestReadPrivateKey(t *testing.T, path string) *rsa.PrivateKey { t.Helper() if !filepath.IsAbs(path) { path = filepath.Join(arrowTestRepoRoot(t), path) } data, err := os.ReadFile(path) if err != nil { t.Fatalf("failed to read private key file %q: %v", path, err) } block, _ := pem.Decode(data) if block == nil { t.Fatalf("failed to decode PEM block from %q", path) } key, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { t.Fatalf("failed to parse private key from %q: %v", path, err) } rsaKey, ok := key.(*rsa.PrivateKey) if !ok { t.Fatalf("private key in %q is not RSA (got %T)", path, key) } return rsaKey } // arrowTestConn manages a Snowflake connection for arrow batch tests. type arrowTestConn struct { db *sql.DB conn *sql.Conn } func openArrowTestConn(t *testing.T) *arrowTestConn { t.Helper() configParams := []*sf.ConfigParam{ {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, {Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true}, {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, {Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false}, {Name: "Warehouse", EnvName: "SNOWFLAKE_TEST_WAREHOUSE", FailOnMissing: false}, } isJWT := os.Getenv("SNOWFLAKE_TEST_AUTHENTICATOR") == "SNOWFLAKE_JWT" if !isJWT { configParams = append(configParams, &sf.ConfigParam{Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true}, ) } cfg, err := sf.GetConfigFromEnv(configParams) if err != nil { t.Fatalf("failed to get config from environment: %v", err) } if isJWT { privKeyPath := os.Getenv("SNOWFLAKE_TEST_PRIVATE_KEY") if privKeyPath == "" { t.Fatal("SNOWFLAKE_TEST_PRIVATE_KEY must be set for JWT authentication") } cfg.PrivateKey = arrowTestReadPrivateKey(t, privKeyPath) cfg.Authenticator = sf.AuthTypeJwt } tz := "UTC" if cfg.Params == nil { cfg.Params = make(map[string]*string) } cfg.Params["timezone"] = &tz dsn, err := sf.DSN(cfg) if err != nil { t.Fatalf("failed to create DSN: %v", err) } db, err := sql.Open("snowflake", dsn) if err != nil { t.Fatalf("failed to open db: %v", err) } conn, err := db.Conn(context.Background()) if err != nil { db.Close() t.Fatalf("failed to get connection: %v", err) } return &arrowTestConn{db: db, conn: conn} } func (tc *arrowTestConn) close() { tc.conn.Close() tc.db.Close() } func (tc *arrowTestConn) exec(t *testing.T, query string) { t.Helper() _, err := tc.conn.ExecContext(context.Background(), query) if err != nil { t.Fatalf("exec %q failed: %v", query, err) } } func (tc *arrowTestConn) enableStructuredTypes(t *testing.T) { t.Helper() tc.exec(t, "ALTER SESSION SET ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE = true") tc.exec(t, "ALTER SESSION SET IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE = true") } func (tc *arrowTestConn) forceNativeArrow(t *testing.T) { t.Helper() tc.exec(t, "ALTER SESSION SET GO_QUERY_RESULT_FORMAT = ARROW") tc.exec(t, "ALTER SESSION SET ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true") tc.exec(t, "ALTER SESSION SET FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true") } func (tc *arrowTestConn) queryArrowBatches(t *testing.T, ctx context.Context, query string) ([]*arrowbatches.ArrowBatch, func()) { t.Helper() var rows driver.Rows var err error err = tc.conn.Raw(func(x any) error { queryer, ok := x.(driver.QueryerContext) if !ok { return fmt.Errorf("connection does not implement QueryerContext") } rows, err = queryer.QueryContext(ctx, query, nil) return err }) if err != nil { t.Fatalf("failed to execute query: %v", err) } sfRows, ok := rows.(sf.SnowflakeRows) if !ok { rows.Close() t.Fatalf("rows do not implement SnowflakeRows") } batches, err := arrowbatches.GetArrowBatches(sfRows) if err != nil { rows.Close() t.Fatalf("GetArrowBatches failed: %v", err) } if len(batches) == 0 { rows.Close() t.Fatal("expected at least one batch") } return batches, func() { rows.Close() } } func (tc *arrowTestConn) fetchFirst(t *testing.T, ctx context.Context, query string) ([]arrow.Record, func()) { t.Helper() batches, closeRows := tc.queryArrowBatches(t, ctx, query) records, err := batches[0].Fetch() if err != nil { closeRows() t.Fatalf("Fetch failed: %v", err) } if records == nil || len(*records) == 0 { closeRows() t.Fatal("expected at least one record") } return *records, closeRows } func equalIgnoringWhitespace(a, b string) bool { return strings.ReplaceAll(strings.ReplaceAll(a, " ", ""), "\n", "") == strings.ReplaceAll(strings.ReplaceAll(b, " ", ""), "\n", "") } func TestStructuredTypeInArrowBatchesSimple(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool) tc := openArrowTestConn(t) defer tc.close() tc.enableStructuredTypes(t) tc.forceNativeArrow(t) records, closeRows := tc.fetchFirst(t, ctx, "SELECT 1, {'s': 'some string'}::OBJECT(s VARCHAR)") defer closeRows() for _, record := range records { defer record.Release() if v := record.Column(0).(*array.Int8).Value(0); v != int8(1) { t.Errorf("expected column 0 = 1, got %v", v) } if v := record.Column(1).(*array.Struct).Field(0).(*array.String).Value(0); v != "some string" { t.Errorf("expected 'some string', got %q", v) } } } func TestStructuredTypeInArrowBatchesAllTypes(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool) tc := openArrowTestConn(t) defer tc.close() tc.enableStructuredTypes(t) tc.forceNativeArrow(t) records, closeRows := tc.fetchFirst(t, ctx, "SELECT 1, {'s': 'some string', 'i': 1, 'time': '11:22:33'::TIME, 'date': '2024-04-16'::DATE, "+ "'ltz': '2024-04-16 11:22:33'::TIMESTAMPLTZ, 'tz': '2025-04-16 22:33:11 +0100'::TIMESTAMPTZ, "+ "'ntz': '2026-04-16 15:22:31'::TIMESTAMPNTZ}::OBJECT(s VARCHAR, i INTEGER, time TIME, date DATE, "+ "ltz TIMESTAMPLTZ, tz TIMESTAMPTZ, ntz TIMESTAMPNTZ)") defer closeRows() for _, record := range records { defer record.Release() if v := record.Column(0).(*array.Int8).Value(0); v != int8(1) { t.Errorf("expected column 0 = 1, got %v", v) } st := record.Column(1).(*array.Struct) if v := st.Field(0).(*array.String).Value(0); v != "some string" { t.Errorf("expected 'some string', got %q", v) } if v := st.Field(1).(*array.Int64).Value(0); v != 1 { t.Errorf("expected i=1, got %v", v) } if v := st.Field(2).(*array.Time64).Value(0).ToTime(arrow.Nanosecond); !v.Equal(time.Date(1970, 1, 1, 11, 22, 33, 0, time.UTC)) { t.Errorf("expected time 11:22:33, got %v", v) } if v := st.Field(3).(*array.Date32).Value(0).ToTime(); !v.Equal(time.Date(2024, 4, 16, 0, 0, 0, 0, time.UTC)) { t.Errorf("expected date 2024-04-16, got %v", v) } if v := st.Field(4).(*array.Timestamp).Value(0).ToTime(arrow.Nanosecond); !v.Equal(time.Date(2024, 4, 16, 11, 22, 33, 0, time.UTC)) { t.Errorf("expected ltz 2024-04-16 11:22:33 UTC, got %v", v) } if v := st.Field(5).(*array.Timestamp).Value(0).ToTime(arrow.Nanosecond); !v.Equal(time.Date(2025, 4, 16, 21, 33, 11, 0, time.UTC)) { t.Errorf("expected tz 2025-04-16 21:33:11 UTC, got %v", v) } if v := st.Field(6).(*array.Timestamp).Value(0).ToTime(arrow.Nanosecond); !v.Equal(time.Date(2026, 4, 16, 15, 22, 31, 0, time.UTC)) { t.Errorf("expected ntz 2026-04-16 15:22:31, got %v", v) } } } func TestStructuredTypeInArrowBatchesWithTimestampOptionAndHigherPrecisionAndUtf8Validation(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := arrowbatches.WithUtf8Validation( sf.WithHigherPrecision( arrowbatches.WithTimestampOption( sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool), arrowbatches.UseOriginalTimestamp, ), ), ) tc := openArrowTestConn(t) defer tc.close() tc.enableStructuredTypes(t) tc.forceNativeArrow(t) records, closeRows := tc.fetchFirst(t, ctx, "SELECT 1, {'i': 123, 'f': 12.34, 'n0': 321, 'n19': 1.5, 's': 'some string', "+ "'bi': TO_BINARY('616263', 'HEX'), 'bool': true, 'time': '11:22:33', "+ "'date': '2024-04-18', 'ntz': '2024-04-01 11:22:33', "+ "'tz': '2024-04-02 11:22:33 +0100', 'ltz': '2024-04-03 11:22:33'}::"+ "OBJECT(i INTEGER, f DOUBLE, n0 NUMBER(38, 0), n19 NUMBER(38, 19), "+ "s VARCHAR, bi BINARY, bool BOOLEAN, time TIME, date DATE, "+ "ntz TIMESTAMP_NTZ, tz TIMESTAMP_TZ, ltz TIMESTAMP_LTZ)") defer closeRows() for _, record := range records { defer record.Release() if v := record.Column(0).(*array.Int8).Value(0); v != int8(1) { t.Errorf("expected column 0 = 1, got %v", v) } st := record.Column(1).(*array.Struct) if v := st.Field(0).(*array.Decimal128).Value(0).LowBits(); v != uint64(123) { t.Errorf("expected i=123, got %v", v) } if v := st.Field(1).(*array.Float64).Value(0); v != 12.34 { t.Errorf("expected f=12.34, got %v", v) } if v := st.Field(2).(*array.Decimal128).Value(0).LowBits(); v != uint64(321) { t.Errorf("expected n0=321, got %v", v) } if v := st.Field(3).(*array.Decimal128).Value(0).LowBits(); v != uint64(15000000000000000000) { t.Errorf("expected n19=15000000000000000000, got %v", v) } if v := st.Field(4).(*array.String).Value(0); v != "some string" { t.Errorf("expected 'some string', got %q", v) } if v := st.Field(5).(*array.Binary).Value(0); !reflect.DeepEqual(v, []byte{'a', 'b', 'c'}) { t.Errorf("expected 'abc' binary, got %v", v) } if v := st.Field(6).(*array.Boolean).Value(0); v != true { t.Errorf("expected true, got %v", v) } if v := st.Field(7).(*array.Time64).Value(0).ToTime(arrow.Nanosecond); !v.Equal(time.Date(1970, 1, 1, 11, 22, 33, 0, time.UTC)) { t.Errorf("expected time 11:22:33, got %v", v) } if v := st.Field(8).(*array.Date32).Value(0).ToTime(); !v.Equal(time.Date(2024, 4, 18, 0, 0, 0, 0, time.UTC)) { t.Errorf("expected date 2024-04-18, got %v", v) } // With UseOriginalTimestamp, timestamps remain as raw structs (epoch + fraction) if v := st.Field(9).(*array.Struct).Field(0).(*array.Int64).Value(0); v != int64(1711970553) { t.Errorf("expected ntz epoch=1711970553, got %v", v) } if v := st.Field(9).(*array.Struct).Field(1).(*array.Int32).Value(0); v != int32(0) { t.Errorf("expected ntz fraction=0, got %v", v) } if v := st.Field(10).(*array.Struct).Field(0).(*array.Int64).Value(0); v != int64(1712053353) { t.Errorf("expected tz epoch=1712053353, got %v", v) } if v := st.Field(10).(*array.Struct).Field(1).(*array.Int32).Value(0); v != int32(0) { t.Errorf("expected tz fraction=0, got %v", v) } if v := st.Field(11).(*array.Struct).Field(0).(*array.Int64).Value(0); v != int64(1712143353) { t.Errorf("expected ltz epoch=1712143353, got %v", v) } if v := st.Field(11).(*array.Struct).Field(1).(*array.Int32).Value(0); v != int32(0) { t.Errorf("expected ltz fraction=0, got %v", v) } } } func TestStructuredTypeInArrowBatchesWithEmbeddedObject(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool) tc := openArrowTestConn(t) defer tc.close() tc.enableStructuredTypes(t) tc.forceNativeArrow(t) records, closeRows := tc.fetchFirst(t, ctx, "SELECT {'o': {'s': 'some string'}}::OBJECT(o OBJECT(s VARCHAR))") defer closeRows() for _, record := range records { defer record.Release() if v := record.Column(0).(*array.Struct).Field(0).(*array.Struct).Field(0).(*array.String).Value(0); v != "some string" { t.Errorf("expected 'some string', got %q", v) } } } func TestStructuredTypeInArrowBatchesAsNull(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool) tc := openArrowTestConn(t) defer tc.close() tc.enableStructuredTypes(t) tc.forceNativeArrow(t) records, closeRows := tc.fetchFirst(t, ctx, "SELECT {'s': 'some string'}::OBJECT(s VARCHAR) UNION SELECT null ORDER BY 1") defer closeRows() for _, record := range records { defer record.Release() if record.Column(0).IsNull(0) { t.Error("expected first row to be non-null") } if !record.Column(0).IsNull(1) { t.Error("expected second row to be null") } } } func TestStructuredArrayInArrowBatches(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool) tc := openArrowTestConn(t) defer tc.close() tc.enableStructuredTypes(t) tc.forceNativeArrow(t) records, closeRows := tc.fetchFirst(t, ctx, "SELECT [1, 2, 3]::ARRAY(INTEGER) UNION SELECT [4, 5, 6]::ARRAY(INTEGER) ORDER BY 1") defer closeRows() record := records[0] defer record.Release() listCol := record.Column(0).(*array.List) vals := listCol.ListValues().(*array.Int64) expectedVals := []int64{1, 2, 3, 4, 5, 6} for i, exp := range expectedVals { if v := vals.Value(i); v != exp { t.Errorf("list value[%d]: expected %d, got %d", i, exp, v) } } expectedOffsets := []int32{0, 3, 6} for i, exp := range expectedOffsets { if v := listCol.Offsets()[i]; v != exp { t.Errorf("offset[%d]: expected %d, got %d", i, exp, v) } } } func TestStructuredMapInArrowBatches(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool) tc := openArrowTestConn(t) defer tc.close() tc.enableStructuredTypes(t) tc.forceNativeArrow(t) records, closeRows := tc.fetchFirst(t, ctx, "SELECT {'a': 'b', 'c': 'd'}::MAP(VARCHAR, VARCHAR)") defer closeRows() for _, record := range records { defer record.Release() m := record.Column(0).(*array.Map) keys := m.Keys().(*array.String) items := m.Items().(*array.String) if v := keys.Value(0); v != "a" { t.Errorf("expected key 'a', got %q", v) } if v := keys.Value(1); v != "c" { t.Errorf("expected key 'c', got %q", v) } if v := items.Value(0); v != "b" { t.Errorf("expected item 'b', got %q", v) } if v := items.Value(1); v != "d" { t.Errorf("expected item 'd', got %q", v) } } } func TestSelectingNullObjectsInArrowBatches(t *testing.T) { testcases := []string{ "select null::object(v VARCHAR)", "select null::object", } tc := openArrowTestConn(t) defer tc.close() tc.enableStructuredTypes(t) for _, query := range testcases { t.Run(query, func(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool) records, closeRows := tc.fetchFirst(t, ctx, query) defer closeRows() for _, record := range records { defer record.Release() if record.NumRows() != 1 { t.Fatalf("wrong number of rows: expected 1, got %d", record.NumRows()) } if record.NumCols() != 1 { t.Fatalf("wrong number of cols: expected 1, got %d", record.NumCols()) } if !record.Column(0).IsNull(0) { t.Error("expected null value") } } }) } } func TestSelectingSemistructuredTypesInArrowBatches(t *testing.T) { testcases := []struct { name string query string expected string withUtf8Validation bool }{ { name: "semistructured object with utf8 validation", withUtf8Validation: true, expected: `{"s":"someString"}`, query: "SELECT {'s':'someString'}::OBJECT", }, { name: "semistructured object without utf8 validation", withUtf8Validation: false, expected: `{"s":"someString"}`, query: "SELECT {'s':'someString'}::OBJECT", }, { name: "semistructured array without utf8 validation", withUtf8Validation: false, expected: `[1,2,3]`, query: "SELECT [1, 2, 3]::ARRAY", }, { name: "semistructured array with utf8 validation", withUtf8Validation: true, expected: `[1,2,3]`, query: "SELECT [1, 2, 3]::ARRAY", }, } tc := openArrowTestConn(t) defer tc.close() for _, tc2 := range testcases { t.Run(tc2.name, func(t *testing.T) { pool := memory.NewCheckedAllocator(memory.DefaultAllocator) defer pool.AssertSize(t, 0) ctx := sf.WithArrowAllocator(arrowbatches.WithArrowBatches(context.Background()), pool) if tc2.withUtf8Validation { ctx = arrowbatches.WithUtf8Validation(ctx) } records, closeRows := tc.fetchFirst(t, ctx, tc2.query) defer closeRows() for _, record := range records { defer record.Release() if record.NumCols() != 1 { t.Fatalf("unexpected number of columns: %d", record.NumCols()) } if record.NumRows() != 1 { t.Fatalf("unexpected number of rows: %d", record.NumRows()) } stringCol, ok := record.Column(0).(*array.String) if !ok { t.Fatalf("wrong type for column, expected string, got %T", record.Column(0)) } if !equalIgnoringWhitespace(stringCol.Value(0), tc2.expected) { t.Errorf("expected %q, got %q", tc2.expected, stringCol.Value(0)) } } }) } } ================================================ FILE: structured_type_read_test.go ================================================ package gosnowflake import ( "context" "database/sql" "fmt" "math/big" "reflect" "strings" "testing" "time" ) type objectWithAllTypes struct { s string b byte i16 int16 i32 int32 i64 int64 f32 float32 f64 float64 nfraction float64 bo bool bi []byte date time.Time `sf:"date,date"` time time.Time `sf:"time,time"` ltz time.Time `sf:"ltz,ltz"` tz time.Time `sf:"tz,tz"` ntz time.Time `sf:"ntz,ntz"` so *simpleObject sArr []string f64Arr []float64 someMap map[string]bool uuid testUUID } func (o *objectWithAllTypes) Scan(val any) error { st, ok := val.(StructuredObject) if !ok { return fmt.Errorf("expected StructuredObject, got %T", val) } var err error if o.s, err = st.GetString("s"); err != nil { return err } if o.b, err = st.GetByte("b"); err != nil { return err } if o.i16, err = st.GetInt16("i16"); err != nil { return err } if o.i32, err = st.GetInt32("i32"); err != nil { return err } if o.i64, err = st.GetInt64("i64"); err != nil { return err } if o.f32, err = st.GetFloat32("f32"); err != nil { return err } if o.f64, err = st.GetFloat64("f64"); err != nil { return err } if o.nfraction, err = st.GetFloat64("nfraction"); err != nil { return err } if o.bo, err = st.GetBool("bo"); err != nil { return err } if o.bi, err = st.GetBytes("bi"); err != nil { return err } if o.date, err = st.GetTime("date"); err != nil { return err } if o.time, err = st.GetTime("time"); err != nil { return err } if o.ltz, err = st.GetTime("ltz"); err != nil { return err } if o.tz, err = st.GetTime("tz"); err != nil { return err } if o.ntz, err = st.GetTime("ntz"); err != nil { return err } so, err := st.GetStruct("so", &simpleObject{}) if err != nil { return err } o.so = so.(*simpleObject) sArr, err := st.GetRaw("sArr") if err != nil { return err } if sArr != nil { o.sArr = sArr.([]string) } f64Arr, err := st.GetRaw("f64Arr") if err != nil { return err } if f64Arr != nil { o.f64Arr = f64Arr.([]float64) } someMap, err := st.GetRaw("someMap") if err != nil { return err } if someMap != nil { o.someMap = someMap.(map[string]bool) } uuidStr, err := st.GetString("uuid") if err != nil { return err } o.uuid = parseTestUUID(uuidStr) return nil } func (o objectWithAllTypes) Write(sowc StructuredObjectWriterContext) error { if err := sowc.WriteString("s", o.s); err != nil { return err } if err := sowc.WriteByt("b", o.b); err != nil { return err } if err := sowc.WriteInt16("i16", o.i16); err != nil { return err } if err := sowc.WriteInt32("i32", o.i32); err != nil { return err } if err := sowc.WriteInt64("i64", o.i64); err != nil { return err } if err := sowc.WriteFloat32("f32", o.f32); err != nil { return err } if err := sowc.WriteFloat64("f64", o.f64); err != nil { return err } if err := sowc.WriteFloat64("nfraction", o.nfraction); err != nil { return err } if err := sowc.WriteBool("bo", o.bo); err != nil { return err } if err := sowc.WriteBytes("bi", o.bi); err != nil { return err } if err := sowc.WriteTime("date", o.date, DataTypeDate); err != nil { return err } if err := sowc.WriteTime("time", o.time, DataTypeTime); err != nil { return err } if err := sowc.WriteTime("ltz", o.ltz, DataTypeTimestampLtz); err != nil { return err } if err := sowc.WriteTime("ntz", o.ntz, DataTypeTimestampNtz); err != nil { return err } if err := sowc.WriteTime("tz", o.tz, DataTypeTimestampTz); err != nil { return err } if err := sowc.WriteStruct("so", o.so); err != nil { return err } if err := sowc.WriteRaw("sArr", o.sArr); err != nil { return err } if err := sowc.WriteRaw("f64Arr", o.f64Arr); err != nil { return err } if err := sowc.WriteRaw("someMap", o.someMap); err != nil { return err } if err := sowc.WriteString("uuid", o.uuid.String()); err != nil { return err } return nil } type simpleObject struct { s string i int32 } func (so *simpleObject) Scan(val any) error { st, ok := val.(StructuredObject) if !ok { return fmt.Errorf("expected StructuredObject, got %T", val) } var err error if so.s, err = st.GetString("s"); err != nil { return err } if so.i, err = st.GetInt32("i"); err != nil { return err } return nil } func (so *simpleObject) Write(sowc StructuredObjectWriterContext) error { if err := sowc.WriteString("s", so.s); err != nil { return err } if err := sowc.WriteInt32("i", so.i); err != nil { return err } return nil } func TestObjectWithAllTypesAsString(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { skipForStringingNativeArrow(t, format) rows := dbt.mustQuery("SELECT {'s': 'some string', 'i32': 3}::OBJECT(s VARCHAR, i32 INTEGER)") defer rows.Close() assertTrueF(t, rows.Next()) var res string err := rows.Scan(&res) assertNilF(t, err) assertEqualIgnoringWhitespaceE(t, res, `{"s": "some string", "i32": 3}`) }) }) } func TestObjectWithAllTypesAsObject(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { uid := newTestUUID() rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT 1, {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}, 'uuid': '%s'}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uid)) defer rows.Close() rows.Next() var ignore int var res objectWithAllTypes err := rows.Scan(&ignore, &res) assertNilF(t, err) assertEqualE(t, res.s, "some string") assertEqualE(t, res.b, byte(1)) assertEqualE(t, res.i16, int16(2)) assertEqualE(t, res.i32, int32(3)) assertEqualE(t, res.i64, int64(9223372036854775807)) assertEqualE(t, res.f32, float32(1.1)) assertEqualE(t, res.f64, 2.2) assertEqualE(t, res.nfraction, 3.3) assertEqualE(t, res.bo, true) assertBytesEqualE(t, res.bi, []byte{'a', 'b', 'c'}) assertEqualE(t, res.date, time.Date(2024, time.March, 21, 0, 0, 0, 0, time.UTC)) assertEqualE(t, res.time.Hour(), 13) assertEqualE(t, res.time.Minute(), 3) assertEqualE(t, res.time.Second(), 2) assertTrueE(t, res.ltz.Equal(time.Date(2021, time.July, 21, 11, 22, 33, 0, warsawTz))) assertTrueE(t, res.tz.Equal(time.Date(2022, time.August, 31, 13, 43, 22, 0, warsawTz))) assertTrueE(t, res.ntz.Equal(time.Date(2023, time.May, 22, 1, 17, 19, 0, time.UTC))) assertDeepEqualE(t, res.so, &simpleObject{s: "child", i: 9}) assertDeepEqualE(t, res.sArr, []string{"x", "y", "z"}) assertDeepEqualE(t, res.f64Arr, []float64{1.1, 2.2, 3.3}) assertDeepEqualE(t, res.someMap, map[string]bool{"x": true, "y": false}) assertEqualE(t, res.uuid.String(), uid.String()) }) }) } func TestNullObject(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { t.Run("null", func(t *testing.T) { rows := dbt.mustQueryContextT(ctx, t, "SELECT null::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)") defer rows.Close() assertTrueF(t, rows.Next()) var res *objectWithAllTypes err := rows.Scan(&res) assertNilF(t, err) assertNilE(t, res) }) t.Run("not null", func(t *testing.T) { uid := newTestUUID() rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}, 'uuid': '%s'}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uid)) defer rows.Close() assertTrueF(t, rows.Next()) var res *objectWithAllTypes err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.s, "some string") }) }) }) } type objectWithAllTypesNullable struct { s sql.NullString b sql.NullByte i16 sql.NullInt16 i32 sql.NullInt32 i64 sql.NullInt64 f64 sql.NullFloat64 bo sql.NullBool bi []byte date sql.NullTime time sql.NullTime ltz sql.NullTime tz sql.NullTime ntz sql.NullTime so *simpleObject sArr []string f64Arr []float64 someMap map[string]bool uuid testUUID } func (o *objectWithAllTypesNullable) Scan(val any) error { st, ok := val.(StructuredObject) if !ok { return fmt.Errorf("expected StructuredObject, got %T", val) } var err error if o.s, err = st.GetNullString("s"); err != nil { return err } if o.b, err = st.GetNullByte("b"); err != nil { return err } if o.i16, err = st.GetNullInt16("i16"); err != nil { return err } if o.i32, err = st.GetNullInt32("i32"); err != nil { return err } if o.i64, err = st.GetNullInt64("i64"); err != nil { return err } if o.f64, err = st.GetNullFloat64("f64"); err != nil { return err } if o.bo, err = st.GetNullBool("bo"); err != nil { return err } if o.bi, err = st.GetBytes("bi"); err != nil { return err } if o.date, err = st.GetNullTime("date"); err != nil { return err } if o.time, err = st.GetNullTime("time"); err != nil { return err } if o.ltz, err = st.GetNullTime("ltz"); err != nil { return err } if o.tz, err = st.GetNullTime("tz"); err != nil { return err } if o.ntz, err = st.GetNullTime("ntz"); err != nil { return err } so, err := st.GetStruct("so", &simpleObject{}) if err != nil { return err } if so != nil { o.so = so.(*simpleObject) } else { o.so = nil } sArr, err := st.GetRaw("sArr") if err != nil { return err } if sArr != nil { o.sArr = sArr.([]string) } f64Arr, err := st.GetRaw("f64Arr") if err != nil { return err } if f64Arr != nil { o.f64Arr = f64Arr.([]float64) } someMap, err := st.GetRaw("someMap") if err != nil { return err } if someMap != nil { o.someMap = someMap.(map[string]bool) } uuidStr, err := st.GetNullString("uuid") if err != nil { return err } o.uuid = parseTestUUID(uuidStr.String) return nil } func (o *objectWithAllTypesNullable) Write(sowc StructuredObjectWriterContext) error { if err := sowc.WriteNullString("s", o.s); err != nil { return err } if err := sowc.WriteNullByte("b", o.b); err != nil { return err } if err := sowc.WriteNullInt16("i16", o.i16); err != nil { return err } if err := sowc.WriteNullInt32("i32", o.i32); err != nil { return err } if err := sowc.WriteNullInt64("i64", o.i64); err != nil { return err } if err := sowc.WriteNullFloat64("f64", o.f64); err != nil { return err } if err := sowc.WriteNullBool("bo", o.bo); err != nil { return err } if err := sowc.WriteBytes("bi", o.bi); err != nil { return err } if err := sowc.WriteNullTime("date", o.date, DataTypeDate); err != nil { return err } if err := sowc.WriteNullTime("time", o.time, DataTypeTime); err != nil { return err } if err := sowc.WriteNullTime("ltz", o.ltz, DataTypeTimestampLtz); err != nil { return err } if err := sowc.WriteNullTime("ntz", o.ntz, DataTypeTimestampNtz); err != nil { return err } if err := sowc.WriteNullTime("tz", o.tz, DataTypeTimestampTz); err != nil { return err } if err := sowc.WriteNullableStruct("so", o.so, reflect.TypeFor[simpleObject]()); err != nil { return err } if err := sowc.WriteRaw("sArr", o.sArr); err != nil { return err } if err := sowc.WriteRaw("f64Arr", o.f64Arr); err != nil { return err } if err := sowc.WriteRaw("someMap", o.someMap); err != nil { return err } if err := sowc.WriteNullString("uuid", sql.NullString{String: o.uuid.String(), Valid: true}); err != nil { return err } return nil } func TestObjectWithAllTypesNullable(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { t.Run("null", func(t *testing.T) { rows := dbt.mustQueryContextT(ctx, t, "select null, object_construct_keep_null('s', null, 'b', null, 'i16', null, 'i32', null, 'i64', null, 'f64', null, 'bo', null, 'bi', null, 'date', null, 'time', null, 'ltz', null, 'tz', null, 'ntz', null, 'so', null, 'sArr', null, 'f64Arr', null, 'someMap', null, 'uuid', null)::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)") defer rows.Close() assertTrueF(t, rows.Next()) var ignore sql.NullInt32 var res objectWithAllTypesNullable err := rows.Scan(&ignore, &res) assertNilF(t, err) assertEqualE(t, ignore, sql.NullInt32{Valid: false}) assertEqualE(t, res.s, sql.NullString{Valid: false}) assertEqualE(t, res.b, sql.NullByte{Valid: false}) assertEqualE(t, res.i16, sql.NullInt16{Valid: false}) assertEqualE(t, res.i32, sql.NullInt32{Valid: false}) assertEqualE(t, res.i64, sql.NullInt64{Valid: false}) assertEqualE(t, res.f64, sql.NullFloat64{Valid: false}) assertEqualE(t, res.bo, sql.NullBool{Valid: false}) assertBytesEqualE(t, res.bi, nil) assertEqualE(t, res.date, sql.NullTime{Valid: false}) assertEqualE(t, res.time, sql.NullTime{Valid: false}) assertEqualE(t, res.ltz, sql.NullTime{Valid: false}) assertEqualE(t, res.tz, sql.NullTime{Valid: false}) assertEqualE(t, res.ntz, sql.NullTime{Valid: false}) var so *simpleObject assertDeepEqualE(t, res.so, so) assertEqualE(t, res.uuid, testUUID{}) }) t.Run("not null", func(t *testing.T) { uuid := newTestUUID() rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("select 1, object_construct_keep_null('s', 'abc', 'b', 1, 'i16', 2, 'i32', 3, 'i64', 9223372036854775807, 'f64', 2.2, 'bo', true, 'bi', TO_BINARY('616263', 'HEX'), 'date', '2024-03-21'::DATE, 'time', '13:03:02'::TIME, 'ltz', '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz', '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz', '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so', {'s': 'child', 'i': 9}::OBJECT, 'sArr', ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr', ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap', {'x': true, 'y': false}, 'uuid', '%s')::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uuid)) defer rows.Close() rows.Next() var ignore sql.NullInt32 var res objectWithAllTypesNullable err := rows.Scan(&ignore, &res) assertNilF(t, err) assertEqualE(t, ignore, sql.NullInt32{Valid: true, Int32: 1}) assertEqualE(t, res.s, sql.NullString{Valid: true, String: "abc"}) assertEqualE(t, res.b, sql.NullByte{Valid: true, Byte: byte(1)}) assertEqualE(t, res.i16, sql.NullInt16{Valid: true, Int16: int16(2)}) assertEqualE(t, res.i32, sql.NullInt32{Valid: true, Int32: 3}) assertEqualE(t, res.i64, sql.NullInt64{Valid: true, Int64: 9223372036854775807}) assertEqualE(t, res.f64, sql.NullFloat64{Valid: true, Float64: 2.2}) assertEqualE(t, res.bo, sql.NullBool{Valid: true, Bool: true}) assertBytesEqualE(t, res.bi, []byte{'a', 'b', 'c'}) assertEqualE(t, res.date, sql.NullTime{Valid: true, Time: time.Date(2024, time.March, 21, 0, 0, 0, 0, time.UTC)}) assertTrueE(t, res.time.Valid) assertEqualE(t, res.time.Time.Hour(), 13) assertEqualE(t, res.time.Time.Minute(), 3) assertEqualE(t, res.time.Time.Second(), 2) assertTrueE(t, res.ltz.Valid) assertTrueE(t, res.ltz.Time.Equal(time.Date(2021, time.July, 21, 11, 22, 33, 0, warsawTz))) assertTrueE(t, res.tz.Valid) assertTrueE(t, res.tz.Time.Equal(time.Date(2022, time.August, 31, 13, 43, 22, 0, warsawTz))) assertTrueE(t, res.ntz.Valid) assertTrueE(t, res.ntz.Time.Equal(time.Date(2023, time.May, 22, 1, 17, 19, 0, time.UTC))) assertDeepEqualE(t, res.so, &simpleObject{s: "child", i: 9}) assertDeepEqualE(t, res.sArr, []string{"x", "y", "z"}) assertDeepEqualE(t, res.f64Arr, []float64{1.1, 2.2, 3.3}) assertDeepEqualE(t, res.someMap, map[string]bool{"x": true, "y": false}) assertEqualE(t, res.uuid.String(), uuid.String()) }) }) }) } type objectWithAllTypesSimpleScan struct { S string B byte I16 int16 I32 int32 I64 int64 F32 float32 F64 float64 Nfraction float64 Bo bool Bi []byte Date time.Time `sf:"date,date"` Time time.Time `sf:"time,time"` Ltz time.Time `sf:"ltz,ltz"` Tz time.Time `sf:"tz,tz"` Ntz time.Time `sf:"ntz,ntz"` So *simpleObject SArr []string F64Arr []float64 SomeMap map[string]bool } func (so *objectWithAllTypesSimpleScan) Scan(val any) error { st, ok := val.(StructuredObject) if !ok { return fmt.Errorf("expected StructuredObject, got %T", val) } return st.ScanTo(so) } func (so *objectWithAllTypesSimpleScan) Write(sowc StructuredObjectWriterContext) error { return sowc.WriteAll(so) } func TestObjectWithAllTypesSimpleScan(t *testing.T) { uid := newTestUUID() warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT 1, {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}, 'uuid': '%s'}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uid)) defer rows.Close() rows.Next() var ignore int var res objectWithAllTypesSimpleScan err := rows.Scan(&ignore, &res) assertNilF(t, err) assertEqualE(t, res.S, "some string") assertEqualE(t, res.B, byte(1)) assertEqualE(t, res.I16, int16(2)) assertEqualE(t, res.I32, int32(3)) assertEqualE(t, res.I64, int64(9223372036854775807)) assertEqualE(t, res.F32, float32(1.1)) assertEqualE(t, res.F64, 2.2) assertEqualE(t, res.Nfraction, 3.3) assertEqualE(t, res.Bo, true) assertBytesEqualE(t, res.Bi, []byte{'a', 'b', 'c'}) assertEqualE(t, res.Date, time.Date(2024, time.March, 21, 0, 0, 0, 0, time.UTC)) assertEqualE(t, res.Time.Hour(), 13) assertEqualE(t, res.Time.Minute(), 3) assertEqualE(t, res.Time.Second(), 2) assertTrueE(t, res.Ltz.Equal(time.Date(2021, time.July, 21, 11, 22, 33, 0, warsawTz))) assertTrueE(t, res.Tz.Equal(time.Date(2022, time.August, 31, 13, 43, 22, 0, warsawTz))) assertTrueE(t, res.Ntz.Equal(time.Date(2023, time.May, 22, 1, 17, 19, 0, time.UTC))) assertDeepEqualE(t, res.So, &simpleObject{s: "child", i: 9}) assertDeepEqualE(t, res.SArr, []string{"x", "y", "z"}) assertDeepEqualE(t, res.F64Arr, []float64{1.1, 2.2, 3.3}) assertDeepEqualE(t, res.SomeMap, map[string]bool{"x": true, "y": false}) }) }) } func TestNullObjectSimpleScan(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { t.Run("null", func(t *testing.T) { rows := dbt.mustQueryContextT(ctx, t, "SELECT null::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)") defer rows.Close() assertTrueF(t, rows.Next()) var res *objectWithAllTypesSimpleScan err := rows.Scan(&res) assertNilF(t, err) assertNilE(t, res) }) t.Run("not null", func(t *testing.T) { uid := newTestUUID() rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}, 'uuid': '%s'}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uid)) defer rows.Close() assertTrueF(t, rows.Next()) var res *objectWithAllTypesSimpleScan err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.S, "some string") }) }) }) } type objectWithAllTypesNullableSimpleScan struct { S sql.NullString B sql.NullByte I16 sql.NullInt16 I32 sql.NullInt32 I64 sql.NullInt64 F64 sql.NullFloat64 Bo sql.NullBool Bi []byte Date sql.NullTime `sf:"date,date"` Time sql.NullTime `sf:"time,time"` Ltz sql.NullTime `sf:"ltz,ltz"` Tz sql.NullTime `sf:"tz,tz"` Ntz sql.NullTime `sf:"ntz,ntz"` So *simpleObject SArr []string F64Arr []float64 SomeMap map[string]bool } func (o *objectWithAllTypesNullableSimpleScan) Scan(val any) error { st, ok := val.(StructuredObject) if !ok { return fmt.Errorf("expected StructuredObject, got %T", val) } return st.ScanTo(o) } func (o *objectWithAllTypesNullableSimpleScan) Write(sowc StructuredObjectWriterContext) error { return sowc.WriteAll(o) } func TestObjectWithAllTypesSimpleScanNullable(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { t.Run("null", func(t *testing.T) { rows := dbt.mustQueryContextT(ctx, t, "select null, object_construct_keep_null('s', null, 'b', null, 'i16', null, 'i32', null, 'i64', null, 'f64', null, 'bo', null, 'bi', null, 'date', null, 'time', null, 'ltz', null, 'tz', null, 'ntz', null, 'so', null, 'sArr', null, 'f64Arr', null, 'someMap', null)::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") defer rows.Close() rows.Next() var ignore sql.NullInt32 var res objectWithAllTypesNullableSimpleScan err := rows.Scan(&ignore, &res) assertNilF(t, err) assertEqualE(t, ignore, sql.NullInt32{Valid: false}) assertEqualE(t, res.S, sql.NullString{Valid: false}) assertEqualE(t, res.B, sql.NullByte{Valid: false}) assertEqualE(t, res.I16, sql.NullInt16{Valid: false}) assertEqualE(t, res.I32, sql.NullInt32{Valid: false}) assertEqualE(t, res.I64, sql.NullInt64{Valid: false}) assertEqualE(t, res.F64, sql.NullFloat64{Valid: false}) assertEqualE(t, res.Bo, sql.NullBool{Valid: false}) assertBytesEqualE(t, res.Bi, nil) assertEqualE(t, res.Date, sql.NullTime{Valid: false}) assertEqualE(t, res.Time, sql.NullTime{Valid: false}) assertEqualE(t, res.Ltz, sql.NullTime{Valid: false}) assertEqualE(t, res.Tz, sql.NullTime{Valid: false}) assertEqualE(t, res.Ntz, sql.NullTime{Valid: false}) var so *simpleObject assertDeepEqualE(t, res.So, so) }) t.Run("not null", func(t *testing.T) { uuid := newTestUUID() rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("select 1, object_construct_keep_null('s', 'abc', 'b', 1, 'i16', 2, 'i32', 3, 'i64', 9223372036854775807, 'f64', 2.2, 'bo', true, 'bi', TO_BINARY('616263', 'HEX'), 'date', '2024-03-21'::DATE, 'time', '13:03:02'::TIME, 'ltz', '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz', '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz', '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so', {'s': 'child', 'i': 9}::OBJECT, 'sArr', ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr', ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap', {'x': true, 'y': false}, 'uuid', '%s')::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uuid)) defer rows.Close() rows.Next() var ignore sql.NullInt32 var res objectWithAllTypesNullableSimpleScan err := rows.Scan(&ignore, &res) assertNilF(t, err) assertEqualE(t, ignore, sql.NullInt32{Valid: true, Int32: 1}) assertEqualE(t, res.S, sql.NullString{Valid: true, String: "abc"}) assertEqualE(t, res.B, sql.NullByte{Valid: true, Byte: byte(1)}) assertEqualE(t, res.I16, sql.NullInt16{Valid: true, Int16: int16(2)}) assertEqualE(t, res.I32, sql.NullInt32{Valid: true, Int32: 3}) assertEqualE(t, res.I64, sql.NullInt64{Valid: true, Int64: 9223372036854775807}) assertEqualE(t, res.F64, sql.NullFloat64{Valid: true, Float64: 2.2}) assertEqualE(t, res.Bo, sql.NullBool{Valid: true, Bool: true}) assertBytesEqualE(t, res.Bi, []byte{'a', 'b', 'c'}) assertEqualE(t, res.Date, sql.NullTime{Valid: true, Time: time.Date(2024, time.March, 21, 0, 0, 0, 0, time.UTC)}) assertTrueE(t, res.Time.Valid) assertEqualE(t, res.Time.Time.Hour(), 13) assertEqualE(t, res.Time.Time.Minute(), 3) assertEqualE(t, res.Time.Time.Second(), 2) assertTrueE(t, res.Ltz.Valid) assertTrueE(t, res.Ltz.Time.Equal(time.Date(2021, time.July, 21, 11, 22, 33, 0, warsawTz))) assertTrueE(t, res.Tz.Valid) assertTrueE(t, res.Tz.Time.Equal(time.Date(2022, time.August, 31, 13, 43, 22, 0, warsawTz))) assertTrueE(t, res.Ntz.Valid) assertTrueE(t, res.Ntz.Time.Equal(time.Date(2023, time.May, 22, 1, 17, 19, 0, time.UTC))) assertDeepEqualE(t, res.So, &simpleObject{s: "child", i: 9}) assertDeepEqualE(t, res.SArr, []string{"x", "y", "z"}) assertDeepEqualE(t, res.F64Arr, []float64{1.1, 2.2, 3.3}) assertDeepEqualE(t, res.SomeMap, map[string]bool{"x": true, "y": false}) }) }) }) } type objectWithCustomNameAndIgnoredField struct { SomeString string `sf:"anotherName"` IgnoreMe string `sf:"ignoreMe,ignore"` } func (o *objectWithCustomNameAndIgnoredField) Scan(val any) error { st, ok := val.(StructuredObject) if !ok { return fmt.Errorf("expected StructuredObject, got %T", val) } return st.ScanTo(o) } func (o *objectWithCustomNameAndIgnoredField) Write(sowc StructuredObjectWriterContext) error { return sowc.WriteAll(o) } func TestObjectWithCustomName(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { rows := dbt.mustQueryContextT(ctx, t, "SELECT {'anotherName': 'some string'}::OBJECT(anotherName VARCHAR)") defer rows.Close() rows.Next() var res objectWithCustomNameAndIgnoredField err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.SomeString, "some string") assertEqualE(t, res.IgnoreMe, "") }) }) } func TestObjectMetadataAsObject(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { rows := dbt.mustQueryContextT(ctx, t, "SELECT {'a': 'b'}::OBJECT(a VARCHAR) as structured_type") defer rows.Close() columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[ObjectType]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "OBJECT") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) }) } func TestObjectMetadataAsString(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { skipForStringingNativeArrow(t, format) rows := dbt.mustQueryT(t, "SELECT {'a': 'b'}::OBJECT(a VARCHAR) as structured_type") defer rows.Close() columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "OBJECT") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) }) } func TestObjectWithoutSchema(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { if format == "NATIVE_ARROW" { t.Skip("Native arrow is not supported in objects without schema") } rows := dbt.mustQuery("SELECT {'a': 'b'}::OBJECT AS STRUCTURED_TYPE") defer rows.Close() rows.Next() var v string err := rows.Scan(&v) assertNilF(t, err) assertStringContainsE(t, v, `"a": "b"`) }) }) } func TestObjectWithoutSchemaMetadata(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { if format == "NATIVE_ARROW" { t.Skip("Native arrow is not supported in objects without schema") } rows := dbt.mustQuery("SELECT {'a': 'b'}::OBJECT AS structured_type") defer rows.Close() columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "OBJECT") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) }) } func TestArrayAndMetadataAsString(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { skipForStringingNativeArrow(t, format) rows := dbt.mustQueryT(t, "SELECT ARRAY_CONSTRUCT(1, 2)::ARRAY(INTEGER) AS STRUCTURED_TYPE") defer rows.Close() assertTrueF(t, rows.Next()) var res string err := rows.Scan(&res) assertNilF(t, err) assertEqualIgnoringWhitespaceE(t, "[1, 2]", res) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "ARRAY") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) }) } func TestArrayAndMetadataAsArray(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { testcases := []struct { name string query string expected1 any expected2 any actual any }{ { name: "integer", query: "SELECT ARRAY_CONSTRUCT(1, 2)::ARRAY(INTEGER) as structured_type UNION SELECT ARRAY_CONSTRUCT(4, 5, 6)::ARRAY(INTEGER) ORDER BY 1", expected1: []int64{1, 2}, expected2: []int64{4, 5, 6}, actual: []int64{}, }, { name: "double", query: "SELECT ARRAY_CONSTRUCT(1.1, 2.2)::ARRAY(DOUBLE) as structured_type UNION SELECT ARRAY_CONSTRUCT(3.3)::ARRAY(DOUBLE) ORDER BY 1", expected1: []float64{1.1, 2.2}, expected2: []float64{3.3}, actual: []float64{}, }, { name: "number - fixed integer", query: "SELECT ARRAY_CONSTRUCT(1, 2)::ARRAY(NUMBER(38, 0)) as structured_type UNION SELECT ARRAY_CONSTRUCT(3)::ARRAY(NUMBER(38, 0)) ORDER BY 1", expected1: []int64{1, 2}, expected2: []int64{3}, actual: []int64{}, }, { name: "number - fixed fraction", query: "SELECT ARRAY_CONSTRUCT(1.1, 2.2)::ARRAY(NUMBER(38, 19)) as structured_type UNION SELECT ARRAY_CONSTRUCT()::ARRAY(NUMBER(38, 19)) ORDER BY 1", expected1: []float64{}, expected2: []float64{1.1, 2.2}, actual: []float64{}, }, { name: "string", query: "SELECT ARRAY_CONSTRUCT('a', 'b')::ARRAY(VARCHAR) as structured_type", expected1: []string{"a", "b"}, actual: []string{}, }, { name: "time", query: "SELECT ARRAY_CONSTRUCT('13:03:02'::TIME, '05:13:22'::TIME)::ARRAY(TIME) as structured_type", expected1: []time.Time{time.Date(0, 1, 1, 13, 3, 2, 0, time.UTC), time.Date(0, 1, 1, 5, 13, 22, 0, time.UTC)}, actual: []time.Time{}, }, { name: "date", query: "SELECT ARRAY_CONSTRUCT('2024-01-05'::DATE, '2001-11-12'::DATE)::ARRAY(DATE) as structured_type", expected1: []time.Time{time.Date(2024, time.January, 5, 0, 0, 0, 0, time.UTC), time.Date(2001, time.November, 12, 0, 0, 0, 0, time.UTC)}, actual: []time.Time{}, }, { name: "timestamp_ntz", query: "SELECT ARRAY_CONSTRUCT('2024-01-05 11:22:33'::TIMESTAMP_NTZ, '2001-11-12 11:22:33'::TIMESTAMP_NTZ)::ARRAY(TIMESTAMP_NTZ) as structured_type", expected1: []time.Time{time.Date(2024, time.January, 5, 11, 22, 33, 0, time.UTC), time.Date(2001, time.November, 12, 11, 22, 33, 0, time.UTC)}, actual: []time.Time{}, }, { name: "timestamp_ltz", query: "SELECT ARRAY_CONSTRUCT('2024-01-05 11:22:33'::TIMESTAMP_LTZ, '2001-11-12 11:22:33'::TIMESTAMP_LTZ)::ARRAY(TIMESTAMP_LTZ) as structured_type", expected1: []time.Time{time.Date(2024, time.January, 5, 11, 22, 33, 0, warsawTz), time.Date(2001, time.November, 12, 11, 22, 33, 0, warsawTz)}, actual: []time.Time{}, }, { name: "timestamp_tz", query: "SELECT ARRAY_CONSTRUCT('2024-01-05 11:22:33 +0100'::TIMESTAMP_TZ, '2001-11-12 11:22:33 +0100'::TIMESTAMP_TZ)::ARRAY(TIMESTAMP_TZ) as structured_type", expected1: []time.Time{time.Date(2024, time.January, 5, 11, 22, 33, 0, warsawTz), time.Date(2001, time.November, 12, 11, 22, 33, 0, warsawTz)}, actual: []time.Time{}, }, { name: "bool", query: "SELECT ARRAY_CONSTRUCT(true, false)::ARRAY(boolean) as structured_type", expected1: []bool{true, false}, actual: []bool{}, }, { name: "binary", query: "SELECT ARRAY_CONSTRUCT(TO_BINARY('616263', 'HEX'), TO_BINARY('646566', 'HEX'))::ARRAY(BINARY) as structured_type", expected1: [][]byte{{'a', 'b', 'c'}, {'d', 'e', 'f'}}, actual: [][]byte{}, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { rows := dbt.mustQueryContextT(ctx, t, tc.query) defer rows.Close() rows.Next() err := rows.Scan(&tc.actual) assertNilF(t, err) if _, ok := tc.actual.([]time.Time); ok { assertEqualE(t, len(tc.actual.([]time.Time)), len(tc.expected1.([]time.Time))) for i := range tc.actual.([]time.Time) { if tc.name == "time" { assertEqualE(t, tc.actual.([]time.Time)[i].Hour(), tc.expected1.([]time.Time)[i].Hour()) assertEqualE(t, tc.actual.([]time.Time)[i].Minute(), tc.expected1.([]time.Time)[i].Minute()) assertEqualE(t, tc.actual.([]time.Time)[i].Second(), tc.expected1.([]time.Time)[i].Second()) } else { assertTrueE(t, tc.actual.([]time.Time)[i].UTC().Equal(tc.expected1.([]time.Time)[i].UTC())) } } } else { assertDeepEqualE(t, tc.actual, tc.expected1) } if tc.expected2 != nil { rows.Next() err := rows.Scan(&tc.actual) assertNilF(t, err) if _, ok := tc.actual.([]time.Time); ok { assertEqualE(t, len(tc.actual.([]time.Time)), len(tc.expected2.([]time.Time))) for i := range tc.actual.([]time.Time) { assertTrueE(t, tc.actual.([]time.Time)[i].UTC().Equal(tc.expected2.([]time.Time)[i].UTC())) } } else { assertDeepEqualE(t, tc.actual, tc.expected2) } } columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeOf(tc.expected1)) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "ARRAY") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) } }) }) } func TestArrayWithoutSchema(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { if format == "NATIVE_ARROW" { t.Skip("Native arrow is not supported in arrays without schema") } rows := dbt.mustQuery("SELECT ARRAY_CONSTRUCT(1, 2)") defer rows.Close() rows.Next() var v string err := rows.Scan(&v) assertNilF(t, err) assertEqualIgnoringWhitespaceE(t, v, "[1, 2]") }) }) } func TestEmptyArraysAndNullArrays(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { rows := dbt.mustQueryContextT(ctx, t, "SELECT ARRAY_CONSTRUCT(1, 2)::ARRAY(INTEGER) as structured_type UNION SELECT ARRAY_CONSTRUCT()::ARRAY(INTEGER) UNION SELECT NULL UNION SELECT ARRAY_CONSTRUCT(4, 5, 6)::ARRAY(INTEGER) ORDER BY 1") defer rows.Close() checkRow := func(rows *RowsExtended, expected *[]int64) { var res *[]int64 rows.Next() err := rows.Scan(&res) assertNilF(t, err) assertDeepEqualE(t, res, expected) } checkRow(rows, &[]int64{}) checkRow(rows, &[]int64{1, 2}) checkRow(rows, &[]int64{4, 5, 6}) checkRow(rows, nil) }) }) } func TestArrayWithoutSchemaMetadata(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { if format == "NATIVE_ARROW" { t.Skip("Native arrow is not supported in arrays without schema") } rows := dbt.mustQuery("SELECT ARRAY_CONSTRUCT(1, 2) AS structured_type") defer rows.Close() columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "ARRAY") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) }) } func TestArrayOfObjects(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { rows := dbt.mustQueryContextT(ctx, t, "SELECT ARRAY_CONSTRUCT({'s': 's1', 'i': 9}, {'s': 's2', 'i': 8})::ARRAY(OBJECT(s VARCHAR, i INTEGER)) as structured_type UNION SELECT ARRAY_CONSTRUCT({'s': 's3', 'i': 7})::ARRAY(OBJECT(s VARCHAR, i INTEGER)) ORDER BY 1") defer rows.Close() rows.Next() var res []*simpleObject err := rows.Scan(ScanArrayOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, []*simpleObject{{s: "s3", i: 7}}) rows.Next() err = rows.Scan(ScanArrayOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, []*simpleObject{{s: "s1", i: 9}, {s: "s2", i: 8}}) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[[]ObjectType]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "ARRAY") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) }) } func TestArrayOfArrays(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) testcases := []struct { name string query string actual any expected any }{ { name: "string", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT('a', 'b', 'c'), ARRAY_CONSTRUCT('d', 'e'))::ARRAY(ARRAY(VARCHAR))", actual: make([][]string, 2), expected: [][]string{{"a", "b", "c"}, {"d", "e"}}, }, { name: "int64", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT(1, 2), ARRAY_CONSTRUCT(3, 4))::ARRAY(ARRAY(INTEGER))", actual: make([][]int64, 2), expected: [][]int64{{1, 2}, {3, 4}}, }, { name: "float64 - fixed", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT(1.1, 2.2), ARRAY_CONSTRUCT(3.3, 4.4))::ARRAY(ARRAY(NUMBER(38, 19)))", actual: make([][]float64, 2), expected: [][]float64{{1.1, 2.2}, {3.3, 4.4}}, }, { name: "float64 - real", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT(1.1, 2.2), ARRAY_CONSTRUCT(3.3, 4.4))::ARRAY(ARRAY(DOUBLE))", actual: make([][]float64, 2), expected: [][]float64{{1.1, 2.2}, {3.3, 4.4}}, }, { name: "bool", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT(true, false), ARRAY_CONSTRUCT(false, true, false))::ARRAY(ARRAY(BOOLEAN))", actual: make([][]bool, 2), expected: [][]bool{{true, false}, {false, true, false}}, }, { name: "binary", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT(TO_BINARY('6162'), TO_BINARY('6364')), ARRAY_CONSTRUCT(TO_BINARY('6566'), TO_BINARY('6768')))::ARRAY(ARRAY(BINARY))", actual: make([][][]byte, 2), expected: [][][]byte{{{'a', 'b'}, {'c', 'd'}}, {{'e', 'f'}, {'g', 'h'}}}, }, { name: "date", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT('2024-01-01'::DATE, '2024-02-02'::DATE), ARRAY_CONSTRUCT('2024-03-03'::DATE, '2024-04-04'::DATE))::ARRAY(ARRAY(DATE))", actual: make([][]time.Time, 2), expected: [][]time.Time{{time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2024, 2, 2, 0, 0, 0, 0, time.UTC)}, {time.Date(2024, 3, 3, 0, 0, 0, 0, time.UTC), time.Date(2024, 4, 4, 0, 0, 0, 0, time.UTC)}}, }, { name: "time", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT('01:01:01'::TIME, '02:02:02'::TIME), ARRAY_CONSTRUCT('03:03:03'::TIME, '04:04:04'::TIME))::ARRAY(ARRAY(TIME))", actual: make([][]time.Time, 2), expected: [][]time.Time{{time.Date(0, 1, 1, 1, 1, 1, 0, time.UTC), time.Date(0, 1, 1, 2, 2, 2, 0, time.UTC)}, {time.Date(0, 1, 1, 3, 3, 3, 0, time.UTC), time.Date(0, 1, 1, 4, 4, 4, 0, time.UTC)}}, }, { name: "timestamp_ltz", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT('2024-01-05 11:22:33'::TIMESTAMP_LTZ), ARRAY_CONSTRUCT('2001-11-12 11:22:33'::TIMESTAMP_LTZ))::ARRAY(ARRAY(TIMESTAMP_LTZ))", actual: make([][]time.Time, 2), expected: [][]time.Time{{time.Date(2024, time.January, 5, 11, 22, 33, 0, warsawTz)}, {time.Date(2001, time.November, 12, 11, 22, 33, 0, warsawTz)}}, }, { name: "timestamp_ntz", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT('2024-01-05 11:22:33'::TIMESTAMP_NTZ), ARRAY_CONSTRUCT('2001-11-12 11:22:33'::TIMESTAMP_NTZ))::ARRAY(ARRAY(TIMESTAMP_NTZ))", actual: make([][]time.Time, 2), expected: [][]time.Time{{time.Date(2024, time.January, 5, 11, 22, 33, 0, time.UTC)}, {time.Date(2001, time.November, 12, 11, 22, 33, 0, time.UTC)}}, }, { name: "timestamp_tz", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT('2024-01-05 11:22:33 +0100'::TIMESTAMP_TZ), ARRAY_CONSTRUCT('2001-11-12 11:22:33 +0100'::TIMESTAMP_TZ))::ARRAY(ARRAY(TIMESTAMP_TZ))", actual: make([][]time.Time, 2), expected: [][]time.Time{{time.Date(2024, time.January, 5, 11, 22, 33, 0, warsawTz)}, {time.Date(2001, time.November, 12, 11, 22, 33, 0, warsawTz)}}, }, } runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { rows := dbt.mustQueryContextT(ctx, t, tc.query) defer rows.Close() rows.Next() err := rows.Scan(&tc.actual) assertNilF(t, err) if timesOfTimes, ok := tc.expected.([][]time.Time); ok { for i, timeOfTimes := range timesOfTimes { for j, tm := range timeOfTimes { if tc.name == "time" { assertEqualE(t, tm.Hour(), tc.actual.([][]time.Time)[i][j].Hour()) assertEqualE(t, tm.Minute(), tc.actual.([][]time.Time)[i][j].Minute()) assertEqualE(t, tm.Second(), tc.actual.([][]time.Time)[i][j].Second()) } else { assertTrueE(t, tm.Equal(tc.actual.([][]time.Time)[i][j])) } } } } else { assertDeepEqualE(t, tc.actual, tc.expected) } }) } }) }) } func TestMapAndMetadataAsString(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { if format == "NATIVE_ARROW" { t.Skip("Native arrow is not supported in maps without schema") } rows := dbt.mustQuery("SELECT {'a': 'b', 'c': 'd'}::MAP(VARCHAR, VARCHAR) AS STRUCTURED_TYPE") defer rows.Close() assertTrueF(t, rows.Next()) var v string err := rows.Scan(&v) assertNilF(t, err) assertEqualIgnoringWhitespaceE(t, v, `{"a": "b", "c": "d"}`) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[string]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "MAP") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) }) } func TestMapAndMetadataAsMap(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") testcases := []struct { name string query string expected1 any expected2 any actual any }{ { name: "string string", query: "SELECT {'a': 'x', 'b': 'y'}::MAP(VARCHAR, VARCHAR) as structured_type UNION SELECT {'c': 'z'}::MAP(VARCHAR, VARCHAR) ORDER BY 1 DESC", expected1: map[string]string{"a": "x", "b": "y"}, expected2: map[string]string{"c": "z"}, actual: make(map[string]string), }, { name: "integer string", query: "SELECT {'1': 'x', '2': 'y'}::MAP(INTEGER, VARCHAR) as structured_type UNION SELECT {'3': 'z'}::MAP(INTEGER, VARCHAR) ORDER BY 1 DESC", expected1: map[int64]string{int64(1): "x", int64(2): "y"}, expected2: map[int64]string{int64(3): "z"}, actual: make(map[int64]string), }, { name: "string bool", query: "SELECT {'a': true, 'b': false}::MAP(VARCHAR, BOOLEAN) as structured_type UNION SELECT {'c': true}::MAP(VARCHAR, BOOLEAN) ORDER BY 1 DESC", expected1: map[string]bool{"a": true, "b": false}, expected2: map[string]bool{"c": true}, actual: make(map[string]bool), }, { name: "integer bool", query: "SELECT {'1': true, '2': false}::MAP(INTEGER, BOOLEAN) as structured_type UNION SELECT {'3': true}::MAP(INTEGER, BOOLEAN) ORDER BY 1 DESC", expected1: map[int64]bool{int64(1): true, int64(2): false}, expected2: map[int64]bool{int64(3): true}, actual: make(map[int64]bool), }, { name: "string integer", query: "SELECT {'a': 11, 'b': 22}::MAP(VARCHAR, INTEGER) as structured_type UNION SELECT {'c': 33}::MAP(VARCHAR, INTEGER) ORDER BY 1 DESC", expected1: map[string]int64{"a": 11, "b": 22}, expected2: map[string]int64{"c": 33}, actual: make(map[string]int64), }, { name: "integer integer", query: "SELECT {'1': 11, '2': 22}::MAP(INTEGER, INTEGER) as structured_type UNION SELECT {'3': 33}::MAP(INTEGER, INTEGER) ORDER BY 1 DESC", expected1: map[int64]int64{int64(1): int64(11), int64(2): int64(22)}, expected2: map[int64]int64{int64(3): int64(33)}, actual: make(map[int64]int64), }, { name: "string double", query: "SELECT {'a': 11.1, 'b': 22.2}::MAP(VARCHAR, DOUBLE) as structured_type UNION SELECT {'c': 33.3}::MAP(VARCHAR, DOUBLE) ORDER BY 1 DESC", expected1: map[string]float64{"a": 11.1, "b": 22.2}, expected2: map[string]float64{"c": 33.3}, actual: make(map[string]float64), }, { name: "integer double", query: "SELECT {'1': 11.1, '2': 22.2}::MAP(INTEGER, DOUBLE) as structured_type UNION SELECT {'3': 33.3}::MAP(INTEGER, DOUBLE) ORDER BY 1 DESC", expected1: map[int64]float64{int64(1): 11.1, int64(2): 22.2}, expected2: map[int64]float64{int64(3): 33.3}, actual: make(map[int64]float64), }, { name: "string number integer", query: "SELECT {'a': 11, 'b': 22}::MAP(VARCHAR, NUMBER(38, 0)) as structured_type UNION SELECT {'c': 33}::MAP(VARCHAR, NUMBER(38, 0)) ORDER BY 1 DESC", expected1: map[string]int64{"a": 11, "b": 22}, expected2: map[string]int64{"c": 33}, actual: make(map[string]int64), }, { name: "integer number integer", query: "SELECT {'1': 11, '2': 22}::MAP(INTEGER, NUMBER(38, 0)) as structured_type UNION SELECT {'3': 33}::MAP(INTEGER, NUMBER(38, 0)) ORDER BY 1 DESC", expected1: map[int64]int64{int64(1): int64(11), int64(2): int64(22)}, expected2: map[int64]int64{int64(3): int64(33)}, actual: make(map[int64]int64), }, { name: "string number fraction", query: "SELECT {'a': 11.1, 'b': 22.2}::MAP(VARCHAR, NUMBER(38, 19)) as structured_type UNION SELECT {'c': 33.3}::MAP(VARCHAR, NUMBER(38, 19)) ORDER BY 1 DESC", expected1: map[string]float64{"a": 11.1, "b": 22.2}, expected2: map[string]float64{"c": 33.3}, actual: make(map[string]float64), }, { name: "integer number fraction", query: "SELECT {'1': 11.1, '2': 22.2}::MAP(INTEGER, NUMBER(38, 19)) as structured_type UNION SELECT {'3': 33.3}::MAP(INTEGER, NUMBER(38, 19)) ORDER BY 1 DESC", expected1: map[int64]float64{int64(1): 11.1, int64(2): 22.2}, expected2: map[int64]float64{int64(3): 33.3}, actual: make(map[int64]float64), }, { name: "string binary", query: "SELECT {'a': TO_BINARY('616263', 'HEX'), 'b': TO_BINARY('646566', 'HEX')}::MAP(VARCHAR, BINARY) as structured_type UNION SELECT {'c': TO_BINARY('676869', 'HEX')}::MAP(VARCHAR, BINARY) ORDER BY 1 DESC", expected1: map[string][]byte{"a": {'a', 'b', 'c'}, "b": {'d', 'e', 'f'}}, expected2: map[string][]byte{"c": {'g', 'h', 'i'}}, actual: make(map[string][]byte), }, { name: "integer binary", query: "SELECT {'1': TO_BINARY('616263', 'HEX'), '2': TO_BINARY('646566', 'HEX')}::MAP(INTEGER, BINARY) as structured_type UNION SELECT {'3': TO_BINARY('676869', 'HEX')}::MAP(INTEGER, BINARY) ORDER BY 1 DESC", expected1: map[int64][]byte{1: {'a', 'b', 'c'}, 2: {'d', 'e', 'f'}}, expected2: map[int64][]byte{3: {'g', 'h', 'i'}}, actual: make(map[int64][]byte), }, { name: "string date", query: "SELECT {'a': '2024-04-02'::DATE, 'b': '2024-04-03'::DATE}::MAP(VARCHAR, DATE) as structured_type UNION SELECT {'c': '2024-04-04'::DATE}::MAP(VARCHAR, DATE) ORDER BY 1 DESC", expected1: map[string]time.Time{"a": time.Date(2024, time.April, 2, 0, 0, 0, 0, time.UTC), "b": time.Date(2024, time.April, 3, 0, 0, 0, 0, time.UTC)}, expected2: map[string]time.Time{"c": time.Date(2024, time.April, 4, 0, 0, 0, 0, time.UTC)}, actual: make(map[string]time.Time), }, { name: "integer date", query: "SELECT {'1': '2024-04-02'::DATE, '2': '2024-04-03'::DATE}::MAP(INTEGER, DATE) as structured_type UNION SELECT {'3': '2024-04-04'::DATE}::MAP(INTEGER, DATE) ORDER BY 1 DESC", expected1: map[int64]time.Time{1: time.Date(2024, time.April, 2, 0, 0, 0, 0, time.UTC), 2: time.Date(2024, time.April, 3, 0, 0, 0, 0, time.UTC)}, expected2: map[int64]time.Time{3: time.Date(2024, time.April, 4, 0, 0, 0, 0, time.UTC)}, actual: make(map[int64]time.Time), }, { name: "string time", query: "SELECT {'a': '13:03:02'::TIME, 'b': '13:03:03'::TIME}::MAP(VARCHAR, TIME) as structured_type UNION SELECT {'c': '13:03:04'::TIME}::MAP(VARCHAR, TIME) ORDER BY 1 DESC", expected1: map[string]time.Time{"a": time.Date(0, 0, 0, 13, 3, 2, 0, time.UTC), "b": time.Date(0, 0, 0, 13, 3, 3, 0, time.UTC)}, expected2: map[string]time.Time{"c": time.Date(0, 0, 0, 13, 3, 4, 0, time.UTC)}, actual: make(map[string]time.Time), }, { name: "integer time", query: "SELECT {'1': '13:03:02'::TIME, '2': '13:03:03'::TIME}::MAP(VARCHAR, TIME) as structured_type UNION SELECT {'3': '13:03:04'::TIME}::MAP(VARCHAR, TIME) ORDER BY 1 DESC", expected1: map[string]time.Time{"1": time.Date(0, 0, 0, 13, 3, 2, 0, time.UTC), "2": time.Date(0, 0, 0, 13, 3, 3, 0, time.UTC)}, expected2: map[string]time.Time{"3": time.Date(0, 0, 0, 13, 3, 4, 0, time.UTC)}, actual: make(map[int64]time.Time), }, { name: "string timestamp_ntz", query: "SELECT {'a': '2024-01-05 11:22:33'::TIMESTAMP_NTZ, 'b': '2024-01-06 11:22:33'::TIMESTAMP_NTZ}::MAP(VARCHAR, TIMESTAMP_NTZ) as structured_type UNION SELECT {'c': '2024-01-07 11:22:33'::TIMESTAMP_NTZ}::MAP(VARCHAR, TIMESTAMP_NTZ) ORDER BY 1 DESC", expected1: map[string]time.Time{"a": time.Date(2024, time.January, 5, 11, 22, 33, 0, time.UTC), "b": time.Date(2024, time.January, 6, 11, 22, 33, 0, time.UTC)}, expected2: map[string]time.Time{"c": time.Date(2024, time.January, 7, 11, 22, 33, 0, time.UTC)}, actual: make(map[string]time.Time), }, { name: "integer timestamp_ntz", query: "SELECT {'1': '2024-01-05 11:22:33'::TIMESTAMP_NTZ, '2': '2024-01-06 11:22:33'::TIMESTAMP_NTZ}::MAP(INTEGER, TIMESTAMP_NTZ) as structured_type UNION SELECT {'3': '2024-01-07 11:22:33'::TIMESTAMP_NTZ}::MAP(INTEGER, TIMESTAMP_NTZ) ORDER BY 1 DESC", expected1: map[int64]time.Time{1: time.Date(2024, time.January, 5, 11, 22, 33, 0, time.UTC), 2: time.Date(2024, time.January, 6, 11, 22, 33, 0, time.UTC)}, expected2: map[int64]time.Time{3: time.Date(2024, time.January, 7, 11, 22, 33, 0, time.UTC)}, actual: make(map[int64]time.Time), }, { name: "string timestamp_tz", query: "SELECT {'a': '2024-01-05 11:22:33 +0100'::TIMESTAMP_TZ, 'b': '2024-01-06 11:22:33 +0100'::TIMESTAMP_TZ}::MAP(VARCHAR, TIMESTAMP_TZ) as structured_type UNION SELECT {'c': '2024-01-07 11:22:33 +0100'::TIMESTAMP_TZ}::MAP(VARCHAR, TIMESTAMP_TZ) ORDER BY 1 DESC", expected1: map[string]time.Time{"a": time.Date(2024, time.January, 5, 11, 22, 33, 0, warsawTz), "b": time.Date(2024, time.January, 6, 11, 22, 33, 0, warsawTz)}, expected2: map[string]time.Time{"c": time.Date(2024, time.January, 7, 11, 22, 33, 0, warsawTz)}, actual: make(map[string]time.Time), }, { name: "integer timestamp_tz", query: "SELECT {'1': '2024-01-05 11:22:33 +0100'::TIMESTAMP_TZ, '2': '2024-01-06 11:22:33 +0100'::TIMESTAMP_TZ}::MAP(INTEGER, TIMESTAMP_TZ) as structured_type UNION SELECT {'3': '2024-01-07 11:22:33 +0100'::TIMESTAMP_TZ}::MAP(INTEGER, TIMESTAMP_TZ) ORDER BY 1 DESC", expected1: map[int64]time.Time{1: time.Date(2024, time.January, 5, 11, 22, 33, 0, time.UTC), 2: time.Date(2024, time.January, 6, 11, 22, 33, 0, time.UTC)}, expected2: map[int64]time.Time{3: time.Date(2024, time.January, 7, 11, 22, 33, 0, time.UTC)}, actual: make(map[int64]time.Time), }, { name: "string timestamp_ltz", query: "SELECT {'a': '2024-01-05 11:22:33'::TIMESTAMP_LTZ, 'b': '2024-01-06 11:22:33'::TIMESTAMP_LTZ}::MAP(VARCHAR, TIMESTAMP_LTZ) as structured_type UNION SELECT {'c': '2024-01-07 11:22:33'::TIMESTAMP_LTZ}::MAP(VARCHAR, TIMESTAMP_LTZ) ORDER BY 1 DESC", expected1: map[string]time.Time{"a": time.Date(2024, time.January, 5, 11, 22, 33, 0, warsawTz), "b": time.Date(2024, time.January, 6, 11, 22, 33, 0, warsawTz)}, expected2: map[string]time.Time{"c": time.Date(2024, time.January, 7, 11, 22, 33, 0, warsawTz)}, actual: make(map[string]time.Time), }, { name: "integer timestamp_ltz", query: "SELECT {'1': '2024-01-05 11:22:33'::TIMESTAMP_LTZ, '2': '2024-01-06 11:22:33'::TIMESTAMP_LTZ}::MAP(INTEGER, TIMESTAMP_LTZ) as structured_type UNION SELECT {'3': '2024-01-07 11:22:33'::TIMESTAMP_LTZ}::MAP(INTEGER, TIMESTAMP_LTZ) ORDER BY 1 DESC", expected1: map[int64]time.Time{1: time.Date(2024, time.January, 5, 11, 22, 33, 0, time.UTC), 2: time.Date(2024, time.January, 6, 11, 22, 33, 0, time.UTC)}, expected2: map[int64]time.Time{3: time.Date(2024, time.January, 7, 11, 22, 33, 0, time.UTC)}, actual: make(map[int64]time.Time), }, } forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { rows := dbt.mustQueryContextT(ctx, t, tc.query) defer rows.Close() checkRow := func(expected any) { rows.Next() err := rows.Scan(&tc.actual) assertNilF(t, err) if _, ok := expected.(map[string]time.Time); ok { assertEqualE(t, len(tc.actual.(map[string]time.Time)), len(expected.(map[string]time.Time))) for k, v := range expected.(map[string]time.Time) { if strings.Contains(tc.name, "time") { assertEqualE(t, v.Hour(), tc.actual.(map[string]time.Time)[k].Hour()) assertEqualE(t, v.Minute(), tc.actual.(map[string]time.Time)[k].Minute()) assertEqualE(t, v.Second(), tc.actual.(map[string]time.Time)[k].Second()) } else { assertTrueE(t, v.UTC().Equal(tc.actual.(map[string]time.Time)[k].UTC())) } } } else if _, ok := expected.(map[int64]time.Time); ok { assertEqualE(t, len(tc.actual.(map[int64]time.Time)), len(expected.(map[int64]time.Time))) for k, v := range expected.(map[int64]time.Time) { if strings.Contains(tc.name, "time") { } else { assertTrueE(t, v.UTC().Equal(tc.actual.(map[int64]time.Time)[k].UTC())) } } } else { assertDeepEqualE(t, tc.actual, expected) } } checkRow(tc.expected1) checkRow(tc.expected2) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeOf(tc.expected1)) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "MAP") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) } }) }) } func TestMapOfObjects(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { rows := dbt.mustQueryContextT(ctx, t, "SELECT {'x': {'s': 'abc', 'i': 1}, 'y': {'s': 'def', 'i': 2}}::MAP(VARCHAR, OBJECT(s VARCHAR, i INTEGER))") defer rows.Close() var res map[string]*simpleObject rows.Next() err := rows.Scan(ScanMapOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, map[string]*simpleObject{"x": {s: "abc", i: 1}, "y": {s: "def", i: 2}}) }) }) } func TestMapOfArrays(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) testcases := []struct { name string query string actual any expected any }{ { name: "string", query: "SELECT {'x': ARRAY_CONSTRUCT('ab', 'cd'), 'y': ARRAY_CONSTRUCT('ef')}::MAP(VARCHAR, ARRAY(VARCHAR))", actual: make(map[string][]string), expected: map[string][]string{"x": {"ab", "cd"}, "y": {"ef"}}, }, { name: "fixed - scale == 0", query: "SELECT {'x': ARRAY_CONSTRUCT(1, 2), 'y': ARRAY_CONSTRUCT(3, 4)}::MAP(VARCHAR, ARRAY(INTEGER))", actual: make(map[string][]int64), expected: map[string][]int64{"x": {1, 2}, "y": {3, 4}}, }, { name: "fixed - scale != 0", query: "SELECT {'x': ARRAY_CONSTRUCT(1.1, 2.2), 'y': ARRAY_CONSTRUCT(3.3, 4.4)}::MAP(VARCHAR, ARRAY(NUMBER(38, 19)))", actual: make(map[string][]float64), expected: map[string][]float64{"x": {1.1, 2.2}, "y": {3.3, 4.4}}, }, { name: "real", query: "SELECT {'x': ARRAY_CONSTRUCT(1.1, 2.2), 'y': ARRAY_CONSTRUCT(3.3, 4.4)}::MAP(VARCHAR, ARRAY(DOUBLE))", actual: make(map[string][]float64), expected: map[string][]float64{"x": {1.1, 2.2}, "y": {3.3, 4.4}}, }, { name: "binary", query: "SELECT {'x': ARRAY_CONSTRUCT(TO_BINARY('6162')), 'y': ARRAY_CONSTRUCT(TO_BINARY('6364'), TO_BINARY('6566'))}::MAP(VARCHAR, ARRAY(BINARY))", actual: make(map[string][][]byte), expected: map[string][][]byte{"x": {[]byte{'a', 'b'}}, "y": {[]byte{'c', 'd'}, []byte{'e', 'f'}}}, }, { name: "boolean", query: "SELECT {'x': ARRAY_CONSTRUCT(true, false), 'y': ARRAY_CONSTRUCT(false, true)}::MAP(VARCHAR, ARRAY(BOOLEAN))", actual: make(map[string][]bool), expected: map[string][]bool{"x": {true, false}, "y": {false, true}}, }, { name: "date", query: "SELECT {'a': ARRAY_CONSTRUCT('2024-04-02'::DATE, '2024-04-03'::DATE)}::MAP(VARCHAR, ARRAY(DATE))", expected: map[string][]time.Time{"a": {time.Date(2024, time.April, 2, 0, 0, 0, 0, time.UTC), time.Date(2024, time.April, 3, 0, 0, 0, 0, time.UTC)}}, actual: make(map[string]time.Time), }, { name: "time", query: "SELECT {'a': ARRAY_CONSTRUCT('13:03:02'::TIME, '13:03:03'::TIME)}::MAP(VARCHAR, ARRAY(TIME))", expected: map[string][]time.Time{"a": {time.Date(0, 0, 0, 13, 3, 2, 0, time.UTC), time.Date(0, 0, 0, 13, 3, 3, 0, time.UTC)}}, actual: make(map[string]time.Time), }, { name: "timestamp_ntz", query: "SELECT {'a': ARRAY_CONSTRUCT('2024-01-05 11:22:33'::TIMESTAMP_NTZ, '2024-01-06 11:22:33'::TIMESTAMP_NTZ)}::MAP(VARCHAR, ARRAY(TIMESTAMP_NTZ))", expected: map[string][]time.Time{"a": {time.Date(2024, time.January, 5, 11, 22, 33, 0, time.UTC), time.Date(2024, time.January, 6, 11, 22, 33, 0, time.UTC)}}, actual: make(map[string]time.Time), }, { name: "string timestamp_tz", query: "SELECT {'a': ARRAY_CONSTRUCT('2024-01-05 11:22:33 +0100'::TIMESTAMP_TZ, '2024-01-06 11:22:33 +0100'::TIMESTAMP_TZ)}::MAP(VARCHAR, ARRAY(TIMESTAMP_TZ))", expected: map[string][]time.Time{"a": {time.Date(2024, time.January, 5, 11, 22, 33, 0, warsawTz), time.Date(2024, time.January, 6, 11, 22, 33, 0, warsawTz)}}, actual: make(map[string]time.Time), }, { name: "string timestamp_ltz", query: "SELECT {'a': ARRAY_CONSTRUCT('2024-01-05 11:22:33'::TIMESTAMP_LTZ, '2024-01-06 11:22:33'::TIMESTAMP_LTZ)}::MAP(VARCHAR, ARRAY(TIMESTAMP_LTZ))", expected: map[string][]time.Time{"a": {time.Date(2024, time.January, 5, 11, 22, 33, 0, warsawTz), time.Date(2024, time.January, 6, 11, 22, 33, 0, warsawTz)}}, actual: make(map[string]time.Time), }, } runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { rows := dbt.mustQueryContextT(ctx, t, tc.query) defer rows.Close() rows.Next() err := rows.Scan(&tc.actual) assertNilF(t, err) if expected, ok := tc.expected.(map[string][]time.Time); ok { for k, v := range expected { for i, expectedTime := range v { if tc.name == "time" { assertEqualE(t, expectedTime.Hour(), tc.actual.(map[string][]time.Time)[k][i].Hour()) assertEqualE(t, expectedTime.Minute(), tc.actual.(map[string][]time.Time)[k][i].Minute()) assertEqualE(t, expectedTime.Second(), tc.actual.(map[string][]time.Time)[k][i].Second()) } else { assertTrueE(t, expectedTime.Equal(tc.actual.(map[string][]time.Time)[k][i])) } } } } else { assertDeepEqualE(t, tc.actual, tc.expected) } }) } }) }) } func TestNullAndEmptyMaps(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { rows := dbt.mustQueryContextT(ctx, t, "SELECT {'a': 1}::MAP(VARCHAR, INTEGER) UNION SELECT NULL UNION SELECT {}::MAP(VARCHAR, INTEGER) UNION SELECT {'d': 4}::MAP(VARCHAR, INTEGER) ORDER BY 1") defer rows.Close() checkRow := func(rows *RowsExtended, expected *map[string]int64) { rows.Next() var res *map[string]int64 err := rows.Scan(&res) assertNilF(t, err) assertDeepEqualE(t, res, expected) } checkRow(rows, &map[string]int64{}) checkRow(rows, &map[string]int64{"d": 4}) checkRow(rows, &map[string]int64{"a": 1}) checkRow(rows, nil) }) }) } func TestMapWithNullValues(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) testcases := []struct { name string query string actual any expected any }{ { name: "string", query: "SELECT object_construct_keep_null('x', 'abc', 'y', null)::MAP(VARCHAR, VARCHAR)", actual: make(map[string]sql.NullString), expected: map[string]sql.NullString{"x": {Valid: true, String: "abc"}, "y": {Valid: false}}, }, { name: "bool", query: "SELECT object_construct_keep_null('x', true, 'y', null)::MAP(VARCHAR, BOOLEAN)", actual: make(map[string]sql.NullBool), expected: map[string]sql.NullBool{"x": {Valid: true, Bool: true}, "y": {Valid: false}}, }, { name: "fixed - scale == 0", query: "SELECT object_construct_keep_null('x', 1, 'y', null)::MAP(VARCHAR, BIGINT)", actual: make(map[string]sql.NullInt64), expected: map[string]sql.NullInt64{"x": {Valid: true, Int64: 1}, "y": {Valid: false}}, }, { name: "fixed - scale != 0", query: "SELECT object_construct_keep_null('x', 1.1, 'y', null)::MAP(VARCHAR, NUMBER(38, 19))", actual: make(map[string]sql.NullFloat64), expected: map[string]sql.NullFloat64{"x": {Valid: true, Float64: 1.1}, "y": {Valid: false}}, }, { name: "real", query: "SELECT object_construct_keep_null('x', 1.1, 'y', null)::MAP(VARCHAR, DOUBLE)", actual: make(map[string]sql.NullFloat64), expected: map[string]sql.NullFloat64{"x": {Valid: true, Float64: 1.1}, "y": {Valid: false}}, }, { name: "binary", query: "SELECT object_construct_keep_null('x', TO_BINARY('616263'), 'y', null)::MAP(VARCHAR, BINARY)", actual: make(map[string][]byte), expected: map[string][]byte{"x": {'a', 'b', 'c'}, "y": nil}, }, { name: "date", query: "SELECT object_construct_keep_null('x', '2024-04-05'::DATE, 'y', null)::MAP(VARCHAR, DATE)", actual: make(map[string]sql.NullTime), expected: map[string]sql.NullTime{"x": {Valid: true, Time: time.Date(2024, time.April, 5, 0, 0, 0, 0, time.UTC)}, "y": {Valid: false}}, }, { name: "time", query: "SELECT object_construct_keep_null('x', '13:14:15'::TIME, 'y', null)::MAP(VARCHAR, TIME)", actual: make(map[string]sql.NullTime), expected: map[string]sql.NullTime{"x": {Valid: true, Time: time.Date(1, 0, 0, 13, 14, 15, 0, time.UTC)}, "y": {Valid: false}}, }, { name: "timestamp_tz", query: "SELECT object_construct_keep_null('x', '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'y', null)::MAP(VARCHAR, TIMESTAMP_TZ)", actual: make(map[string]sql.NullTime), expected: map[string]sql.NullTime{"x": {Valid: true, Time: time.Date(2022, 8, 31, 13, 43, 22, 0, warsawTz)}, "y": {Valid: false}}, }, { name: "timestamp_ntz", query: "SELECT object_construct_keep_null('x', '2022-08-31 13:43:22'::TIMESTAMP_NTZ, 'y', null)::MAP(VARCHAR, TIMESTAMP_NTZ)", actual: make(map[string]sql.NullTime), expected: map[string]sql.NullTime{"x": {Valid: true, Time: time.Date(2022, 8, 31, 13, 43, 22, 0, time.UTC)}, "y": {Valid: false}}, }, { name: "timestamp_ltz", query: "SELECT object_construct_keep_null('x', '2022-08-31 13:43:22'::TIMESTAMP_LTZ, 'y', null)::MAP(VARCHAR, TIMESTAMP_LTZ)", actual: make(map[string]sql.NullTime), expected: map[string]sql.NullTime{"x": {Valid: true, Time: time.Date(2022, 8, 31, 13, 43, 22, 0, warsawTz)}, "y": {Valid: false}}, }, } runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { rows := dbt.mustQueryContextT(WithEmbeddedValuesNullable(ctx), t, tc.query) defer rows.Close() rows.Next() err = rows.Scan(&tc.actual) assertNilF(t, err) switch tc.name { case "time": for i, nt := range tc.actual.(map[string]sql.NullTime) { assertEqualE(t, nt.Valid, tc.expected.(map[string]sql.NullTime)[i].Valid) assertEqualE(t, nt.Time.Hour(), tc.expected.(map[string]sql.NullTime)[i].Time.Hour()) assertEqualE(t, nt.Time.Minute(), tc.expected.(map[string]sql.NullTime)[i].Time.Minute()) assertEqualE(t, nt.Time.Second(), tc.expected.(map[string]sql.NullTime)[i].Time.Second()) } case "timestamp_tz", "timestamp_ltz", "timestamp_ntz": for i, nt := range tc.actual.(map[string]sql.NullTime) { assertEqualE(t, nt.Valid, tc.expected.(map[string]sql.NullTime)[i].Valid) assertTrueE(t, nt.Time.Equal(tc.expected.(map[string]sql.NullTime)[i].Time)) } default: assertDeepEqualE(t, tc.actual, tc.expected) } }) } }) }) } func TestArraysWithNullValues(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) testcases := []struct { name string query string actual any expected any }{ { name: "string", query: "SELECT ARRAY_CONSTRUCT('x', null, 'yz', null)::ARRAY(STRING)", actual: []sql.NullString{}, expected: []sql.NullString{{Valid: true, String: "x"}, {Valid: false}, {Valid: true, String: "yz"}, {Valid: false}}, }, { name: "bool", query: "SELECT ARRAY_CONSTRUCT(true, null, false)::ARRAY(BOOLEAN)", actual: []sql.NullBool{}, expected: []sql.NullBool{{Valid: true, Bool: true}, {Valid: false}, {Valid: true, Bool: false}}, }, { name: "fixed - scale == 0", query: "SELECT ARRAY_CONSTRUCT(null, 2, 3)::ARRAY(BIGINT)", actual: []sql.NullInt64{}, expected: []sql.NullInt64{{Valid: false}, {Valid: true, Int64: 2}, {Valid: true, Int64: 3}}, }, { name: "fixed - scale == 0", query: "SELECT ARRAY_CONSTRUCT(1.3, 2.0, null, null)::ARRAY(NUMBER(38, 19))", actual: []sql.NullFloat64{}, expected: []sql.NullFloat64{{Valid: true, Float64: 1.3}, {Valid: true, Float64: 2.0}, {Valid: false}, {Valid: false}}, }, { name: "real", query: "SELECT ARRAY_CONSTRUCT(1.9, 0.2, null)::ARRAY(DOUBLE)", actual: []sql.NullFloat64{}, expected: []sql.NullFloat64{{Valid: true, Float64: 1.9}, {Valid: true, Float64: 0.2}, {Valid: false}}, }, { name: "binary", query: "SELECT ARRAY_CONSTRUCT(null, TO_BINARY('616263'))::ARRAY(BINARY)", actual: [][]byte{}, expected: [][]byte{nil, {'a', 'b', 'c'}}, }, { name: "date", query: "SELECT ARRAY_CONSTRUCT('2024-04-05'::DATE, null)::ARRAY(DATE)", actual: []sql.NullTime{}, expected: []sql.NullTime{{Valid: true, Time: time.Date(2024, time.April, 5, 0, 0, 0, 0, time.UTC)}, {Valid: false}}, }, { name: "time", query: "SELECT ARRAY_CONSTRUCT('13:14:15'::TIME, null)::ARRAY(TIME)", actual: []sql.NullTime{}, expected: []sql.NullTime{{Valid: true, Time: time.Date(1, 0, 0, 13, 14, 15, 0, time.UTC)}, {Valid: false}}, }, { name: "timestamp_tz", query: "SELECT ARRAY_CONSTRUCT('2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, null)::ARRAY(TIMESTAMP_TZ)", actual: []sql.NullTime{}, expected: []sql.NullTime{{Valid: true, Time: time.Date(2022, 8, 31, 13, 43, 22, 0, warsawTz)}, {Valid: false}}, }, { name: "timestamp_ntz", query: "SELECT ARRAY_CONSTRUCT('2022-08-31 13:43:22'::TIMESTAMP_NTZ, null)::ARRAY(TIMESTAMP_NTZ)", actual: []sql.NullTime{}, expected: []sql.NullTime{{Valid: true, Time: time.Date(2022, 8, 31, 13, 43, 22, 0, time.UTC)}, {Valid: false}}, }, { name: "timestamp_ltz", query: "SELECT ARRAY_CONSTRUCT('2022-08-31 13:43:22'::TIMESTAMP_LTZ, null)::ARRAY(TIMESTAMP_LTZ)", actual: []sql.NullTime{}, expected: []sql.NullTime{{Valid: true, Time: time.Date(2022, 8, 31, 13, 43, 22, 0, warsawTz)}, {Valid: false}}, }, { name: "array", query: "SELECT ARRAY_CONSTRUCT(ARRAY_CONSTRUCT(true, null), null, ARRAY_CONSTRUCT(null, false, true))::ARRAY(ARRAY(BOOLEAN))", actual: [][]sql.NullBool{}, expected: [][]sql.NullBool{{{Valid: true, Bool: true}, {Valid: false}}, nil, {{Valid: false}, {Valid: true, Bool: false}, {Valid: true, Bool: true}}}, }, } runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.forceNativeArrow() dbt.enableStructuredTypes() for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { rows := dbt.mustQueryContext(WithStructuredTypesEnabled(WithEmbeddedValuesNullable(context.Background())), tc.query) defer rows.Close() rows.Next() err := rows.Scan(&tc.actual) assertNilF(t, err) switch tc.name { case "time": for i, nt := range tc.actual.([]sql.NullTime) { assertEqualE(t, nt.Valid, tc.expected.([]sql.NullTime)[i].Valid) assertEqualE(t, nt.Time.Hour(), tc.expected.([]sql.NullTime)[i].Time.Hour()) assertEqualE(t, nt.Time.Minute(), tc.expected.([]sql.NullTime)[i].Time.Minute()) assertEqualE(t, nt.Time.Second(), tc.expected.([]sql.NullTime)[i].Time.Second()) } case "timestamp_tz", "timestamp_ltz", "timestamp_ntz": for i, nt := range tc.actual.([]sql.NullTime) { assertEqualE(t, nt.Valid, tc.expected.([]sql.NullTime)[i].Valid) assertTrueE(t, nt.Time.Equal(tc.expected.([]sql.NullTime)[i].Time)) } default: assertDeepEqualE(t, tc.actual, tc.expected) } }) } }) } func TestArraysWithNullValuesHigherPrecision(t *testing.T) { testcases := []struct { name string query string actual any expected any }{ { name: "fixed - scale == 0", query: "SELECT ARRAY_CONSTRUCT(null, 2)::ARRAY(BIGINT)", actual: []*big.Int{}, }, } runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.forceNativeArrow() dbt.enableStructuredTypes() for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { ctx := WithHigherPrecision(WithStructuredTypesEnabled(WithEmbeddedValuesNullable(context.Background()))) rows := dbt.mustQueryContext(ctx, tc.query) defer rows.Close() rows.Next() err := rows.Scan(&tc.actual) assertNilF(t, err) assertNilF(t, tc.actual.([]*big.Int)[0]) bigInt, _ := new(big.Int).SetString("2", 10) assertEqualE(t, tc.actual.([]*big.Int)[1].Cmp(bigInt), 0) }) } }) } type HigherPrecisionStruct struct { i *big.Int f *big.Float } func (hps *HigherPrecisionStruct) Scan(val any) error { st, ok := val.(StructuredObject) if !ok { return fmt.Errorf("expected StructuredObject, got %T", val) } var err error if hps.i, err = st.GetBigInt("i"); err != nil { return err } if hps.f, err = st.GetBigFloat("f"); err != nil { return err } return nil } func TestWithHigherPrecision(t *testing.T) { ctx := WithHigherPrecision(WithStructuredTypesEnabled(context.Background())) runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { if format != "NATIVE_ARROW" { t.Skip("JSON structured type does not support higher precision") } t.Run("object", func(t *testing.T) { rows := dbt.mustQueryContext(ctx, "SELECT {'i': 10000000000000000000000000000000000000::DECIMAL(38, 0), 'f': 1.2345678901234567890123456789012345678::DECIMAL(38, 37)}::OBJECT(i DECIMAL(38, 0), f DECIMAL(38, 37)) as structured_type") defer rows.Close() rows.Next() var v HigherPrecisionStruct err := rows.Scan(&v) assertNilF(t, err) bigInt, b := new(big.Int).SetString("10000000000000000000000000000000000000", 10) assertTrueF(t, b) assertEqualE(t, bigInt.Cmp(v.i), 0) bigFloat, b := new(big.Float).SetPrec(v.f.Prec()).SetString("1.2345678901234567890123456789012345678") assertTrueE(t, b) assertEqualE(t, bigFloat.Cmp(v.f), 0) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[ObjectType]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "OBJECT") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) t.Run("array of big ints", func(t *testing.T) { rows := dbt.mustQueryContext(ctx, "SELECT ARRAY_CONSTRUCT(10000000000000000000000000000000000000)::ARRAY(DECIMAL(38, 0)) as structured_type") defer rows.Close() rows.Next() var v *[]*big.Int err := rows.Scan(&v) assertNilF(t, err) bigInt, b := new(big.Int).SetString("10000000000000000000000000000000000000", 10) assertTrueF(t, b) assertEqualE(t, bigInt.Cmp((*v)[0]), 0) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[[]*big.Int]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "ARRAY") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) t.Run("array of big floats", func(t *testing.T) { rows := dbt.mustQueryContext(ctx, "SELECT ARRAY_CONSTRUCT(1.2345678901234567890123456789012345678)::ARRAY(DECIMAL(38, 37)) as structured_type") defer rows.Close() rows.Next() var v *[]*big.Float err := rows.Scan(&v) assertNilF(t, err) bigFloat, b := new(big.Float).SetPrec((*v)[0].Prec()).SetString("1.2345678901234567890123456789012345678") assertTrueE(t, b) assertEqualE(t, bigFloat.Cmp((*v)[0]), 0) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[[]*big.Float]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "ARRAY") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) t.Run("map of string to big ints", func(t *testing.T) { rows := dbt.mustQueryContext(ctx, "SELECT object_construct_keep_null('x', 10000000000000000000000000000000000000, 'y', null)::MAP(VARCHAR, DECIMAL(38, 0)) as structured_type") defer rows.Close() rows.Next() var v *map[string]*big.Int err := rows.Scan(&v) assertNilF(t, err) bigInt, b := new(big.Int).SetString("10000000000000000000000000000000000000", 10) assertTrueF(t, b) assertEqualE(t, bigInt.Cmp((*v)["x"]), 0) assertEqualE(t, (*v)["y"], (*big.Int)(nil)) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[map[string]*big.Int]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "MAP") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) t.Run("map of string to big floats", func(t *testing.T) { rows := dbt.mustQueryContext(ctx, "SELECT {'x': 1.2345678901234567890123456789012345678, 'y': null}::MAP(VARCHAR, DECIMAL(38, 37)) as structured_type") defer rows.Close() rows.Next() var v *map[string]*big.Float err := rows.Scan(&v) assertNilF(t, err) bigFloat, b := new(big.Float).SetPrec((*v)["x"].Prec()).SetString("1.2345678901234567890123456789012345678") assertTrueE(t, b) assertEqualE(t, bigFloat.Cmp((*v)["x"]), 0) assertEqualE(t, (*v)["y"], (*big.Float)(nil)) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[map[string]*big.Float]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "MAP") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) t.Run("map of int64 to big ints", func(t *testing.T) { rows := dbt.mustQueryContext(ctx, "SELECT {'1': 10000000000000000000000000000000000000}::MAP(INTEGER, DECIMAL(38, 0)) as structured_type") defer rows.Close() rows.Next() var v *map[int64]*big.Int err := rows.Scan(&v) assertNilF(t, err) bigInt, b := new(big.Int).SetString("10000000000000000000000000000000000000", 10) assertTrueF(t, b) assertEqualE(t, bigInt.Cmp((*v)[1]), 0) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[map[int64]*big.Int]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "MAP") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) t.Run("map of int64 to big floats", func(t *testing.T) { rows := dbt.mustQueryContext(ctx, "SELECT {'1': 1.2345678901234567890123456789012345678}::MAP(INTEGER, DECIMAL(38, 37)) as structured_type") defer rows.Close() rows.Next() var v *map[int64]*big.Float err := rows.Scan(&v) assertNilF(t, err) bigFloat, b := new(big.Float).SetPrec((*v)[1].Prec()).SetString("1.2345678901234567890123456789012345678") assertTrueE(t, b) assertEqualE(t, bigFloat.Cmp((*v)[1]), 0) columnTypes, err := rows.ColumnTypes() assertNilF(t, err) assertEqualE(t, len(columnTypes), 1) assertEqualE(t, columnTypes[0].ScanType(), reflect.TypeFor[map[int64]*big.Float]()) assertEqualE(t, columnTypes[0].DatabaseTypeName(), "MAP") assertEqualE(t, columnTypes[0].Name(), "STRUCTURED_TYPE") }) }) }) } func forAllStructureTypeFormats(dbt *DBTest, f func(t *testing.T, format string)) { for _, tc := range []struct { name string forceFormat func(test *DBTest) }{ { name: "JSON", forceFormat: func(test *DBTest) { dbt.forceJSON() }, }, { name: "ARROW", forceFormat: func(test *DBTest) { dbt.forceArrow() }, }, { name: "NATIVE_ARROW", forceFormat: func(test *DBTest) { dbt.forceNativeArrow() }, }, } { dbt.Run(tc.name, func(t *testing.T) { tc.forceFormat(dbt) dbt.enableStructuredTypes() f(t, tc.name) }) } } func skipForStringingNativeArrow(t *testing.T, format string) { if format == "NATIVE_ARROW" { t.Skip("returning native arrow structured types as string is currently not supported") } } ================================================ FILE: structured_type_write_test.go ================================================ package gosnowflake import ( "context" "database/sql" "fmt" "reflect" "testing" "time" ) func TestBindingVariant(t *testing.T) { t.Skip("binding variant is currently not supported") runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE TABLE test_variant_binding (var VARIANT)") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_variant_binding") }() dbt.mustExec("INSERT INTO test_variant_binding SELECT (?)", DataTypeVariant, nil) dbt.mustExec("INSERT INTO test_variant_binding SELECT (?)", DataTypeVariant, sql.NullString{Valid: false}) dbt.mustExec("INSERT INTO test_variant_binding SELECT (?)", DataTypeVariant, "{'s': 'some string'}") dbt.mustExec("INSERT INTO test_variant_binding SELECT (?)", DataTypeVariant, sql.NullString{Valid: true, String: "{'s': 'some string2'}"}) rows := dbt.mustQuery("SELECT * FROM test_variant_binding") defer rows.Close() var res sql.NullString assertTrueF(t, rows.Next()) err := rows.Scan(&res) assertNilF(t, err) assertFalseF(t, res.Valid) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertFalseF(t, res.Valid) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertTrueE(t, res.Valid) assertEqualIgnoringWhitespaceE(t, res.String, `{"s": "some string"}`) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertTrueE(t, res.Valid) assertEqualIgnoringWhitespaceE(t, res.String, `{"s": "some string2"}`) }) } func TestBindingObjectWithoutSchema(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE TABLE test_object_binding (obj OBJECT)") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() dbt.mustExec("INSERT INTO test_object_binding SELECT (?)", DataTypeObject, nil) dbt.mustExec("INSERT INTO test_object_binding SELECT (?)", DataTypeObject, sql.NullString{Valid: false}) dbt.mustExec("INSERT INTO test_object_binding SELECT (?)", DataTypeObject, "{'s': 'some string'}") dbt.mustExec("INSERT INTO test_object_binding SELECT (?)", DataTypeObject, sql.NullString{Valid: true, String: "{'s': 'some string2'}"}) rows := dbt.mustQuery("SELECT * FROM test_object_binding") defer rows.Close() var res sql.NullString assertTrueF(t, rows.Next()) err := rows.Scan(&res) assertNilF(t, err) assertFalseF(t, res.Valid) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertFalseF(t, res.Valid) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertTrueE(t, res.Valid) assertEqualIgnoringWhitespaceE(t, res.String, `{"s": "some string"}`) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertTrueE(t, res.Valid) assertEqualIgnoringWhitespaceE(t, res.String, `{"s": "some string2"}`) }) } func TestBindingArrayWithoutSchema(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE TABLE test_array_binding (arr ARRAY)") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_array_binding") }() dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", DataTypeArray, nil) dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", DataTypeArray, sql.NullString{Valid: false}) dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", DataTypeArray, "[1, 2, 3]") dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", DataTypeArray, sql.NullString{Valid: true, String: "[1, 2, 3]"}) dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", DataTypeArray, []int{1, 2, 3}) rows := dbt.mustQuery("SELECT * FROM test_array_binding") defer rows.Close() var res sql.NullString assertTrueF(t, rows.Next()) err := rows.Scan(&res) assertNilF(t, err) assertFalseF(t, res.Valid) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertFalseF(t, res.Valid) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertTrueE(t, res.Valid) assertEqualIgnoringWhitespaceE(t, res.String, `[1, 2, 3]`) assertTrueF(t, rows.Next()) err = rows.Scan(&res) assertNilF(t, err) assertTrueE(t, res.Valid) assertEqualIgnoringWhitespaceE(t, res.String, `[1, 2, 3]`) }) } func TestBindingObjectWithSchema(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") ctx := WithStructuredTypesEnabled(context.Background()) assertNilF(t, err) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMPNTZ, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.mustExec("ALTER SESSION SET TIMESTAMP_OUTPUT_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM'") o := objectWithAllTypes{ s: "some string", b: 1, i16: 2, i32: 3, i64: 4, f32: 1.1, f64: 2.2, nfraction: 3.3, bo: true, bi: []byte{'a', 'b', 'c'}, date: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC), time: time.Date(1, 1, 1, 11, 22, 33, 0, time.UTC), ltz: time.Date(2025, time.May, 24, 11, 22, 33, 44, warsawTz), ntz: time.Date(2026, time.May, 24, 11, 22, 33, 0, time.UTC), tz: time.Date(2027, time.May, 24, 11, 22, 33, 44, warsawTz), so: &simpleObject{s: "another string", i: 123}, sArr: []string{"a", "b"}, f64Arr: []float64{1.1, 2.2}, someMap: map[string]bool{"a": true, "b": false}, uuid: newTestUUID(), } dbt.mustExecT(t, "INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE obj = ?", o) defer rows.Close() assertTrueE(t, rows.Next()) var res objectWithAllTypes err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.s, o.s) assertEqualE(t, res.b, o.b) assertEqualE(t, res.i16, o.i16) assertEqualE(t, res.i32, o.i32) assertEqualE(t, res.i64, o.i64) assertEqualE(t, res.f32, o.f32) assertEqualE(t, res.f64, o.f64) assertEqualE(t, res.nfraction, o.nfraction) assertEqualE(t, res.bo, o.bo) assertDeepEqualE(t, res.bi, o.bi) assertTrueE(t, res.date.Equal(o.date)) assertEqualE(t, res.time.Hour(), o.time.Hour()) assertEqualE(t, res.time.Minute(), o.time.Minute()) assertEqualE(t, res.time.Second(), o.time.Second()) assertTrueE(t, res.ltz.Equal(o.ltz)) assertTrueE(t, res.tz.Equal(o.tz)) assertTrueE(t, res.ntz.Equal(o.ntz)) assertDeepEqualE(t, res.so, o.so) assertDeepEqualE(t, res.sArr, o.sArr) assertDeepEqualE(t, res.f64Arr, o.f64Arr) assertDeepEqualE(t, res.someMap, o.someMap) assertEqualE(t, res.uuid.String(), o.uuid.String()) }) } func TestBindingObjectWithNullableFieldsWithSchema(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMPNTZ, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.mustExec("ALTER SESSION SET TIMESTAMP_OUTPUT_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM'") t.Run("not null", func(t *testing.T) { o := &objectWithAllTypesNullable{ s: sql.NullString{String: "some string", Valid: true}, b: sql.NullByte{Byte: 1, Valid: true}, i16: sql.NullInt16{Int16: 2, Valid: true}, i32: sql.NullInt32{Int32: 3, Valid: true}, i64: sql.NullInt64{Int64: 4, Valid: true}, f64: sql.NullFloat64{Float64: 2.2, Valid: true}, bo: sql.NullBool{Bool: true, Valid: true}, bi: []byte{'a', 'b', 'c'}, date: sql.NullTime{Time: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC), Valid: true}, time: sql.NullTime{Time: time.Date(1, 1, 1, 11, 22, 33, 0, time.UTC), Valid: true}, ltz: sql.NullTime{Time: time.Date(2025, time.May, 24, 11, 22, 33, 44, warsawTz), Valid: true}, ntz: sql.NullTime{Time: time.Date(2026, time.May, 24, 11, 22, 33, 0, time.UTC), Valid: true}, tz: sql.NullTime{Time: time.Date(2027, time.May, 24, 11, 22, 33, 44, warsawTz), Valid: true}, so: &simpleObject{s: "another string", i: 123}, sArr: []string{"a", "b"}, f64Arr: []float64{1.1, 2.2}, someMap: map[string]bool{"a": true, "b": false}, uuid: newTestUUID(), } dbt.mustExecT(t, "INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE obj = ?", o) defer rows.Close() assertTrueE(t, rows.Next()) var res objectWithAllTypesNullable err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.s, o.s) assertEqualE(t, res.b, o.b) assertEqualE(t, res.i16, o.i16) assertEqualE(t, res.i32, o.i32) assertEqualE(t, res.i64, o.i64) assertEqualE(t, res.f64, o.f64) assertEqualE(t, res.bo, o.bo) assertDeepEqualE(t, res.bi, o.bi) assertTrueE(t, res.date.Time.Equal(o.date.Time)) assertEqualE(t, res.time.Time.Hour(), o.time.Time.Hour()) assertEqualE(t, res.time.Time.Minute(), o.time.Time.Minute()) assertEqualE(t, res.time.Time.Second(), o.time.Time.Second()) assertTrueE(t, res.ltz.Time.Equal(o.ltz.Time)) assertTrueE(t, res.tz.Time.Equal(o.tz.Time)) assertTrueE(t, res.ntz.Time.Equal(o.ntz.Time)) assertDeepEqualE(t, res.so, o.so) assertDeepEqualE(t, res.sArr, o.sArr) assertDeepEqualE(t, res.f64Arr, o.f64Arr) assertDeepEqualE(t, res.someMap, o.someMap) assertEqualE(t, res.uuid.String(), o.uuid.String()) }) t.Run("null", func(t *testing.T) { o := &objectWithAllTypesNullable{ s: sql.NullString{}, b: sql.NullByte{}, i16: sql.NullInt16{}, i32: sql.NullInt32{}, i64: sql.NullInt64{}, f64: sql.NullFloat64{}, bo: sql.NullBool{}, bi: nil, date: sql.NullTime{}, time: sql.NullTime{}, ltz: sql.NullTime{}, ntz: sql.NullTime{}, tz: sql.NullTime{}, so: nil, sArr: nil, f64Arr: nil, someMap: nil, } dbt.mustExecT(t, "INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE obj = ?", o) defer rows.Close() assertTrueE(t, rows.Next()) var res objectWithAllTypesNullable err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.s, o.s) assertEqualE(t, res.b, o.b) assertEqualE(t, res.i16, o.i16) assertEqualE(t, res.i32, o.i32) assertEqualE(t, res.i64, o.i64) assertEqualE(t, res.f64, o.f64) assertEqualE(t, res.bo, o.bo) assertDeepEqualE(t, res.bi, o.bi) assertTrueE(t, res.date.Time.Equal(o.date.Time)) assertEqualE(t, res.time.Time.Hour(), o.time.Time.Hour()) assertEqualE(t, res.time.Time.Minute(), o.time.Time.Minute()) assertEqualE(t, res.time.Time.Second(), o.time.Time.Second()) assertTrueE(t, res.ltz.Time.Equal(o.ltz.Time)) assertTrueE(t, res.tz.Time.Equal(o.tz.Time)) assertTrueE(t, res.ntz.Time.Equal(o.ntz.Time)) assertDeepEqualE(t, res.so, o.so) assertDeepEqualE(t, res.sArr, o.sArr) assertDeepEqualE(t, res.f64Arr, o.f64Arr) assertDeepEqualE(t, res.someMap, o.someMap) }) }) } func TestBindingObjectWithSchemaSimpleWrite(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN)))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.mustExec("ALTER SESSION SET TIMESTAMP_OUTPUT_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM'") o := &objectWithAllTypesSimpleScan{ S: "some string", B: 1, I16: 2, I32: 3, I64: 4, F32: 1.1, F64: 2.2, Nfraction: 3.3, Bo: true, Bi: []byte{'a', 'b', 'c'}, Date: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC), Time: time.Date(1, 1, 1, 11, 22, 33, 0, time.UTC), Ltz: time.Date(2025, time.May, 24, 11, 22, 33, 44, warsawTz), Ntz: time.Date(2026, time.May, 24, 11, 22, 33, 0, time.UTC), Tz: time.Date(2027, time.May, 24, 11, 22, 33, 44, warsawTz), So: &simpleObject{s: "another string", i: 123}, SArr: []string{"a", "b"}, F64Arr: []float64{1.1, 2.2}, SomeMap: map[string]bool{"a": true, "b": false}, } dbt.mustExecT(t, "INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE obj = ?", o) defer rows.Close() assertTrueE(t, rows.Next()) var res objectWithAllTypesSimpleScan err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.S, o.S) assertEqualE(t, res.B, o.B) assertEqualE(t, res.I16, o.I16) assertEqualE(t, res.I32, o.I32) assertEqualE(t, res.I64, o.I64) assertEqualE(t, res.F32, o.F32) assertEqualE(t, res.F64, o.F64) assertEqualE(t, res.Nfraction, o.Nfraction) assertEqualE(t, res.Bo, o.Bo) assertDeepEqualE(t, res.Bi, o.Bi) assertTrueE(t, res.Date.Equal(o.Date)) assertEqualE(t, res.Time.Hour(), o.Time.Hour()) assertEqualE(t, res.Time.Minute(), o.Time.Minute()) assertEqualE(t, res.Time.Second(), o.Time.Second()) assertTrueE(t, res.Ltz.Equal(o.Ltz)) assertTrueE(t, res.Tz.Equal(o.Tz)) assertTrueE(t, res.Ntz.Equal(o.Ntz)) assertDeepEqualE(t, res.So, o.So) assertDeepEqualE(t, res.SArr, o.SArr) assertDeepEqualE(t, res.F64Arr, o.F64Arr) assertDeepEqualE(t, res.SomeMap, o.SomeMap) }) } func TestBindingObjectWithNullableFieldsWithSchemaSimpleWrite(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.forceJSON() dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, tz TIMESTAMPTZ, ntz TIMESTAMPNTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN)))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.mustExec("ALTER SESSION SET TIMESTAMP_OUTPUT_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM'") t.Run("not null", func(t *testing.T) { o := &objectWithAllTypesNullableSimpleScan{ S: sql.NullString{String: "some string", Valid: true}, B: sql.NullByte{Byte: 1, Valid: true}, I16: sql.NullInt16{Int16: 2, Valid: true}, I32: sql.NullInt32{Int32: 3, Valid: true}, I64: sql.NullInt64{Int64: 4, Valid: true}, F64: sql.NullFloat64{Float64: 2.2, Valid: true}, Bo: sql.NullBool{Bool: true, Valid: true}, Bi: []byte{'a', 'b', 'c'}, Date: sql.NullTime{Time: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC), Valid: true}, Time: sql.NullTime{Time: time.Date(1, 1, 1, 11, 22, 33, 0, time.UTC), Valid: true}, Ltz: sql.NullTime{Time: time.Date(2025, time.May, 24, 11, 22, 33, 44, warsawTz), Valid: true}, Ntz: sql.NullTime{Time: time.Date(2026, time.May, 24, 11, 22, 33, 0, time.UTC), Valid: true}, Tz: sql.NullTime{Time: time.Date(2027, time.May, 24, 11, 22, 33, 44, warsawTz), Valid: true}, So: &simpleObject{s: "another string", i: 123}, SArr: []string{"a", "b"}, F64Arr: []float64{1.1, 2.2}, SomeMap: map[string]bool{"a": true, "b": false}, } dbt.mustExecT(t, "INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE obj = ?", o) defer rows.Close() assertTrueE(t, rows.Next()) var res objectWithAllTypesNullableSimpleScan err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.S, o.S) assertEqualE(t, res.B, o.B) assertEqualE(t, res.I16, o.I16) assertEqualE(t, res.I32, o.I32) assertEqualE(t, res.I64, o.I64) assertEqualE(t, res.F64, o.F64) assertEqualE(t, res.Bo, o.Bo) assertDeepEqualE(t, res.Bi, o.Bi) assertTrueE(t, res.Date.Time.Equal(o.Date.Time)) assertEqualE(t, res.Time.Time.Hour(), o.Time.Time.Hour()) assertEqualE(t, res.Time.Time.Minute(), o.Time.Time.Minute()) assertEqualE(t, res.Time.Time.Second(), o.Time.Time.Second()) assertTrueE(t, res.Ltz.Time.Equal(o.Ltz.Time)) assertTrueE(t, res.Tz.Time.Equal(o.Tz.Time)) assertTrueE(t, res.Ntz.Time.Equal(o.Ntz.Time)) assertDeepEqualE(t, res.So, o.So) assertDeepEqualE(t, res.SArr, o.SArr) assertDeepEqualE(t, res.F64Arr, o.F64Arr) assertDeepEqualE(t, res.SomeMap, o.SomeMap) }) t.Run("null", func(t *testing.T) { o := &objectWithAllTypesNullableSimpleScan{ S: sql.NullString{}, B: sql.NullByte{}, I16: sql.NullInt16{}, I32: sql.NullInt32{}, I64: sql.NullInt64{}, F64: sql.NullFloat64{}, Bo: sql.NullBool{}, Bi: nil, Date: sql.NullTime{}, Time: sql.NullTime{}, Ltz: sql.NullTime{}, Ntz: sql.NullTime{}, Tz: sql.NullTime{}, So: nil, SArr: nil, F64Arr: nil, SomeMap: nil, } dbt.mustExecT(t, "INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE obj = ?", o) defer rows.Close() assertTrueE(t, rows.Next()) var res objectWithAllTypesNullableSimpleScan err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.S, o.S) assertEqualE(t, res.B, o.B) assertEqualE(t, res.I16, o.I16) assertEqualE(t, res.I32, o.I32) assertEqualE(t, res.I64, o.I64) assertEqualE(t, res.F64, o.F64) assertEqualE(t, res.Bo, o.Bo) assertDeepEqualE(t, res.Bi, o.Bi) assertTrueE(t, res.Date.Time.Equal(o.Date.Time)) assertEqualE(t, res.Time.Time.Hour(), o.Time.Time.Hour()) assertEqualE(t, res.Time.Time.Minute(), o.Time.Time.Minute()) assertEqualE(t, res.Time.Time.Second(), o.Time.Time.Second()) assertTrueE(t, res.Ltz.Time.Equal(o.Ltz.Time)) assertTrueE(t, res.Tz.Time.Equal(o.Tz.Time)) assertTrueE(t, res.Ntz.Time.Equal(o.Ntz.Time)) assertDeepEqualE(t, res.So, o.So) assertDeepEqualE(t, res.SArr, o.SArr) assertDeepEqualE(t, res.F64Arr, o.F64Arr) assertDeepEqualE(t, res.SomeMap, o.SomeMap) }) }) } type objectWithAllTypesWrapper struct { o *objectWithAllTypes } func (o *objectWithAllTypesWrapper) Scan(val any) error { st := val.(StructuredObject) var owat *objectWithAllTypes _, err := st.GetStruct("o", owat) if err == nil { return err } o.o = owat return err } func (o *objectWithAllTypesWrapper) Write(sowc StructuredObjectWriterContext) error { return sowc.WriteNullableStruct("o", o.o, reflect.TypeFor[objectWithAllTypes]()) } func TestBindingObjectWithAllTypesNullable(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.forceJSON() dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (o OBJECT(o OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, tz TIMESTAMPTZ, ntz TIMESTAMPNTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.mustExec("ALTER SESSION SET TIMESTAMP_OUTPUT_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM'") o := &objectWithAllTypesWrapper{} dbt.mustExec("INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE o = ?", o) defer rows.Close() assertTrueE(t, rows.Next()) var res objectWithAllTypesWrapper err := rows.Scan(&res) assertNilF(t, err) assertDeepEqualE(t, o, &res) }) } func TestBindingObjectWithSchemaWithCustomNameAndIgnoredField(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(anotherName VARCHAR))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() o := &objectWithCustomNameAndIgnoredField{ SomeString: "some string", IgnoreMe: "ignore me", } dbt.mustExec("INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContext(ctx, "SELECT * FROM test_object_binding WHERE obj = ?", o) defer rows.Close() assertTrueE(t, rows.Next()) var res objectWithCustomNameAndIgnoredField err := rows.Scan(&res) assertNilF(t, err) assertEqualE(t, res.SomeString, "some string") assertEqualE(t, res.IgnoreMe, "") }) } func TestBindingNullStructuredObjects(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, i INTEGER))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() dbt.mustExec("INSERT INTO test_object_binding SELECT (?)", DataTypeNilObject, reflect.TypeFor[simpleObject]()) rows := dbt.mustQueryContext(ctx, "SELECT * FROM test_object_binding") defer rows.Close() assertTrueE(t, rows.Next()) var res *simpleObject err := rows.Scan(&res) assertNilF(t, err) assertNilE(t, res) }) } func TestBindingArrayWithSchema(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() testcases := []struct { name string arrayType string values []any expected any }{ { name: "byte - empty", arrayType: "TINYINT", values: []any{[]byte{}}, expected: []int64{}, }, { name: "byte - not empty", arrayType: "TINYINT", values: []any{[]byte{1, 2, 3}}, expected: []int64{1, 2, 3}, }, { name: "int16", arrayType: "SMALLINT", values: []any{[]int16{1, 2, 3}}, expected: []int64{1, 2, 3}, }, { name: "int16 - empty", arrayType: "SMALLINT", values: []any{[]int16{}}, expected: []int64{}, }, { name: "int32", arrayType: "INTEGER", values: []any{[]int32{1, 2, 3}}, expected: []int64{1, 2, 3}, }, { name: "int64", arrayType: "BIGINT", values: []any{[]int64{1, 2, 3}}, expected: []int64{1, 2, 3}, }, { name: "float32", arrayType: "FLOAT", values: []any{[]float32{1.2, 3.4}}, expected: []float64{1.2, 3.4}, }, { name: "float64", arrayType: "FLOAT", values: []any{[]float64{1.2, 3.4}}, expected: []float64{1.2, 3.4}, }, { name: "bool", arrayType: "BOOLEAN", values: []any{[]bool{true, false}}, expected: []bool{true, false}, }, { name: "binary", arrayType: "BINARY", values: []any{DataTypeBinary, [][]byte{{'a', 'b'}, {'c', 'd'}}}, expected: [][]byte{{'a', 'b'}, {'c', 'd'}}, }, { name: "binary - empty", arrayType: "BINARY", values: []any{DataTypeBinary, [][]byte{}}, expected: [][]byte{}, }, { name: "date", arrayType: "DATE", values: []any{DataTypeDate, []time.Time{time.Date(2024, time.June, 4, 0, 0, 0, 0, time.UTC)}}, expected: []time.Time{time.Date(2024, time.June, 4, 0, 0, 0, 0, time.UTC)}, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { dbt.mustExecT(t, fmt.Sprintf("CREATE OR REPLACE TABLE test_array_binding (arr ARRAY(%s))", tc.arrayType)) defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_array_binding") }() dbt.mustExecT(t, "INSERT INTO test_array_binding SELECT (?)", tc.values...) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_array_binding") defer rows.Close() assertTrueE(t, rows.Next()) var res any err := rows.Scan(&res) assertNilF(t, err) assertDeepEqualE(t, res, tc.expected) }) } }) } func TestBindingArrayOfObjects(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_array_binding (arr ARRAY(OBJECT(s VARCHAR, i INTEGER)))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_array_binding") }() arr := []*simpleObject{{s: "some string", i: 123}} dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", arr) rows := dbt.mustQueryContext(ctx, "SELECT * FROM test_array_binding WHERE arr = ?", arr) defer rows.Close() assertTrueE(t, rows.Next()) var res []*simpleObject err := rows.Scan(ScanArrayOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, arr) }) } func TestBindingEmptyArrayOfObjects(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_array_binding (arr ARRAY(OBJECT(s VARCHAR, i INTEGER)))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_array_binding") }() arr := []*simpleObject{} dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", arr) rows := dbt.mustQueryContext(ctx, "SELECT * FROM test_array_binding WHERE arr = ?", arr) defer rows.Close() assertTrueF(t, rows.Next()) var res []*simpleObject err := rows.Scan(ScanArrayOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, arr) }) } func TestBindingNilArrayOfObjects(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_array_binding (arr ARRAY(OBJECT(s VARCHAR, i INTEGER)))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_array_binding") }() var arr []*simpleObject dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", DataTypeNilArray, reflect.TypeFor[simpleObject]()) rows := dbt.mustQueryContext(ctx, "SELECT * FROM test_array_binding") defer rows.Close() assertTrueF(t, rows.Next()) var res []*simpleObject err := rows.Scan(ScanArrayOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, arr) }) } func TestBindingNilArrayOfInts(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_array_binding (arr ARRAY(INTEGER))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_array_binding") }() var arr *[]int64 dbt.mustExec("INSERT INTO test_array_binding SELECT (?)", DataTypeNilArray, reflect.TypeFor[int]()) rows := dbt.mustQueryContext(ctx, "SELECT * FROM test_array_binding") defer rows.Close() assertTrueF(t, rows.Next()) var res *[]int64 err := rows.Scan(&res) assertNilF(t, err) assertDeepEqualE(t, res, arr) }) } func TestBindingMap(t *testing.T) { warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) testcases := []struct { tableDefinition string values []any expected any isTimeOnly bool }{ { tableDefinition: "VARCHAR, VARCHAR", values: []any{map[string]string{ "a": "b", "c": "d", }}, expected: map[string]string{ "a": "b", "c": "d", }, }, { tableDefinition: "INTEGER, VARCHAR", values: []any{map[int64]string{ 1: "b", 2: "d", }}, expected: map[int64]string{ 1: "b", 2: "d", }, }, { tableDefinition: "VARCHAR, BOOLEAN", values: []any{map[string]bool{ "a": true, "c": false, }}, expected: map[string]bool{ "a": true, "c": false, }, }, { tableDefinition: "VARCHAR, INTEGER", values: []any{map[string]int64{ "a": 1, "b": 2, }}, expected: map[string]int64{ "a": 1, "b": 2, }, }, { tableDefinition: "VARCHAR, DOUBLE", values: []any{map[string]float64{ "a": 1.1, "b": 2.2, }}, expected: map[string]float64{ "a": 1.1, "b": 2.2, }, }, { tableDefinition: "INTEGER, BINARY", values: []any{DataTypeBinary, map[int64][]byte{ 1: {'a', 'b'}, 2: {'c', 'd'}, }}, expected: map[int64][]byte{ 1: {'a', 'b'}, 2: {'c', 'd'}, }, }, { tableDefinition: "VARCHAR, BINARY", values: []any{DataTypeBinary, map[string][]byte{ "a": {'a', 'b'}, "b": {'c', 'd'}, }}, expected: map[string][]byte{ "a": {'a', 'b'}, "b": {'c', 'd'}, }, }, { tableDefinition: "VARCHAR, DATE", values: []any{DataTypeDate, map[string]time.Time{ "a": time.Date(2024, time.June, 25, 0, 0, 0, 0, time.UTC), "b": time.Date(2024, time.June, 26, 0, 0, 0, 0, time.UTC), }}, expected: map[string]time.Time{ "a": time.Date(2024, time.June, 25, 0, 0, 0, 0, time.UTC), "b": time.Date(2024, time.June, 26, 0, 0, 0, 0, time.UTC), }, }, { tableDefinition: "VARCHAR, TIME", values: []any{DataTypeTime, map[string]time.Time{ "a": time.Date(1, time.January, 1, 11, 22, 33, 0, time.UTC), "b": time.Date(2, time.January, 1, 22, 11, 44, 0, time.UTC), }}, expected: map[string]time.Time{ "a": time.Date(1, time.January, 1, 11, 22, 33, 0, time.UTC), "b": time.Date(2, time.January, 1, 22, 11, 44, 0, time.UTC), }, isTimeOnly: true, }, { tableDefinition: "VARCHAR, TIMESTAMPNTZ", values: []any{DataTypeTimestampNtz, map[string]time.Time{ "a": time.Date(2024, time.June, 25, 11, 22, 33, 0, time.UTC), "b": time.Date(2024, time.June, 26, 11, 22, 33, 0, time.UTC), }}, expected: map[string]time.Time{ "a": time.Date(2024, time.June, 25, 11, 22, 33, 0, time.UTC), "b": time.Date(2024, time.June, 26, 11, 22, 33, 0, time.UTC), }, }, { tableDefinition: "VARCHAR, TIMESTAMPTZ", values: []any{DataTypeTimestampTz, map[string]time.Time{ "a": time.Date(2024, time.June, 25, 11, 22, 33, 0, warsawTz), "b": time.Date(2024, time.June, 26, 11, 22, 33, 0, warsawTz), }}, expected: map[string]time.Time{ "a": time.Date(2024, time.June, 25, 11, 22, 33, 0, warsawTz), "b": time.Date(2024, time.June, 26, 11, 22, 33, 0, warsawTz), }, }, { tableDefinition: "VARCHAR, TIMESTAMPLTZ", values: []any{DataTypeTimestampLtz, map[string]time.Time{ "a": time.Date(2024, time.June, 25, 11, 22, 33, 0, warsawTz), "b": time.Date(2024, time.June, 26, 11, 22, 33, 0, warsawTz), }}, expected: map[string]time.Time{ "a": time.Date(2024, time.June, 25, 11, 22, 33, 0, warsawTz), "b": time.Date(2024, time.June, 26, 11, 22, 33, 0, warsawTz), }, }, } runDBTest(t, func(dbt *DBTest) { dbt.mustExecT(t, "ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.enableStructuredTypesBinding() for _, tc := range testcases { t.Run(tc.tableDefinition, func(t *testing.T) { dbt.mustExecT(t, fmt.Sprintf("CREATE OR REPLACE TABLE test_map_binding (m MAP(%v))", tc.tableDefinition)) defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT (?)", tc.values...) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_map_binding WHERE m = ?", tc.values...) defer rows.Close() assertTrueE(t, rows.Next()) var res any err := rows.Scan(&res) assertNilF(t, err) if m, ok := tc.expected.(map[string]time.Time); ok { resTimes := res.(map[string]time.Time) for k, v := range m { if tc.isTimeOnly { assertEqualE(t, resTimes[k].Hour(), v.Hour()) assertEqualE(t, resTimes[k].Minute(), v.Minute()) assertEqualE(t, resTimes[k].Second(), v.Second()) } else { assertTrueE(t, resTimes[k].Equal(v)) } } } else { assertDeepEqualE(t, res, tc.expected) } }) } }) } func TestBindingMapOfStructs(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_map_binding (m MAP(VARCHAR, OBJECT(s VARCHAR, i INTEGER)))") defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() m := map[string]*simpleObject{ "a": {"abc", 1}, "b": nil, "c": {"def", 2}, } dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT ?", m) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_map_binding WHERE m = ?", m) defer rows.Close() rows.Next() var res map[string]*simpleObject err := rows.Scan(ScanMapOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, m) }) } func TestBindingMapOfWithAllValuesNil(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_map_binding (m MAP(VARCHAR, OBJECT(s VARCHAR, i INTEGER)))") defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() m := map[string]*simpleObject{ "a": nil, } dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT ?", m) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_map_binding WHERE m = ?", m) defer rows.Close() rows.Next() var res map[string]*simpleObject err := rows.Scan(ScanMapOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, m) }) } func TestBindingEmptyMapOfStructs(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_map_binding (m MAP(VARCHAR, OBJECT(s VARCHAR, i INTEGER)))") defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() m := map[string]*simpleObject{} dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT ?", m) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_map_binding WHERE m = ?", m) defer rows.Close() assertTrueF(t, rows.Next()) var res map[string]*simpleObject err := rows.Scan(ScanMapOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, m) }) } func TestBindingEmptyMapOfInts(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_map_binding (m MAP(VARCHAR, INTEGER))") defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() m := map[string]int64{} dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT ?", m) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_map_binding WHERE m = ?", m) defer rows.Close() assertTrueF(t, rows.Next()) var res map[string]int64 err := rows.Scan(&res) assertNilF(t, err) assertDeepEqualE(t, res, m) }) } func TestBindingNilMapOfStructs(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_map_binding (m MAP(VARCHAR, OBJECT(s VARCHAR, i INTEGER)))") defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() var m map[string]*simpleObject dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT ?", DataTypeNilMap, NilMapTypes{Key: reflect.TypeFor[string](), Value: reflect.TypeFor[*simpleObject]()}) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_map_binding", DataTypeNilMap, NilMapTypes{Key: reflect.TypeFor[string](), Value: reflect.TypeFor[*simpleObject]()}) defer rows.Close() assertTrueF(t, rows.Next()) var res map[string]*simpleObject err := rows.Scan(ScanMapOfScanners(&res)) assertNilF(t, err) assertDeepEqualE(t, res, m) }) } func TestBindingNilMapOfInts(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_map_binding (m MAP(VARCHAR, INTEGER))") defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() var m *map[string]int64 dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT ?", DataTypeNilMap, NilMapTypes{Key: reflect.TypeFor[string](), Value: reflect.TypeFor[int]()}) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_map_binding", DataTypeNilMap, NilMapTypes{Key: reflect.TypeFor[string](), Value: reflect.TypeFor[int]()}) defer rows.Close() assertTrueF(t, rows.Next()) var res *map[string]int64 err := rows.Scan(&res) assertNilF(t, err) assertDeepEqualE(t, res, m) }) } func TestBindingMapOfArrays(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.enableStructuredTypesBinding() dbt.mustExec("CREATE OR REPLACE TABLE test_map_binding (m MAP(VARCHAR, ARRAY(INTEGER)))") defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() m := map[string][]int64{ "a": {1, 2}, "b": nil, } dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT ?", m) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_map_binding", m) defer rows.Close() assertTrueF(t, rows.Next()) var res map[string][]int64 err := rows.Scan(&res) assertNilF(t, err) assertDeepEqualE(t, res, m) }) } func TestBindingMapWithNillableValues(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) var testcases = []struct { tableDefinition string values []any expected any isTimeOnly bool }{ { tableDefinition: "VARCHAR, VARCHAR", values: []any{map[string]sql.NullString{ "a": {String: "b", Valid: true}, "c": {}, }}, expected: map[string]sql.NullString{ "a": {String: "b", Valid: true}, "c": {}, }, }, { tableDefinition: "INTEGER, VARCHAR", values: []any{map[int64]sql.NullString{ 1: {String: "b", Valid: true}, 2: {}, }}, expected: map[int64]sql.NullString{ 1: {String: "b", Valid: true}, 2: {}, }, }, { tableDefinition: "VARCHAR, BOOLEAN", values: []any{map[string]sql.NullBool{ "a": {Bool: true, Valid: true}, "c": {}, }}, expected: map[string]sql.NullBool{ "a": {Bool: true, Valid: true}, "c": {}, }, }, { tableDefinition: "VARCHAR, INTEGER", values: []any{map[string]sql.NullInt64{ "a": {Int64: 1, Valid: true}, "b": {}, }}, expected: map[string]sql.NullInt64{ "a": {Int64: 1, Valid: true}, "b": {}, }, }, { tableDefinition: "VARCHAR, DOUBLE", values: []any{map[string]sql.NullFloat64{ "a": {Float64: 1.1, Valid: true}, "b": {}, }}, expected: map[string]sql.NullFloat64{ "a": {Float64: 1.1, Valid: true}, "b": {}, }, }, { tableDefinition: "INTEGER, BINARY", values: []any{DataTypeBinary, map[int64][]byte{ 1: {'a', 'b'}, 2: nil, }}, expected: map[int64][]byte{ 1: {'a', 'b'}, 2: nil, }, }, { tableDefinition: "VARCHAR, BINARY", values: []any{DataTypeBinary, map[string][]byte{ "a": {'a', 'b'}, "b": nil, }}, expected: map[string][]byte{ "a": {'a', 'b'}, "b": nil, }, }, { tableDefinition: "VARCHAR, DATE", values: []any{DataTypeDate, map[string]sql.NullTime{ "a": {Time: time.Date(2024, time.June, 25, 0, 0, 0, 0, time.UTC), Valid: true}, "b": {}, }}, expected: map[string]sql.NullTime{ "a": {Time: time.Date(2024, time.June, 25, 0, 0, 0, 0, time.UTC), Valid: true}, "b": {}, }, }, { tableDefinition: "VARCHAR, TIME", values: []any{DataTypeTime, map[string]sql.NullTime{ "a": {Time: time.Date(1, time.January, 1, 11, 22, 33, 0, time.UTC), Valid: true}, "b": {}, }}, expected: map[string]sql.NullTime{ "a": {Time: time.Date(1, time.January, 1, 11, 22, 33, 0, time.UTC), Valid: true}, "b": {}, }, isTimeOnly: true, }, { tableDefinition: "VARCHAR, TIMESTAMPNTZ", values: []any{DataTypeTimestampNtz, map[string]sql.NullTime{ "a": {Time: time.Date(2024, time.June, 25, 11, 22, 33, 0, time.UTC), Valid: true}, "b": {}, }}, expected: map[string]sql.NullTime{ "a": {Time: time.Date(2024, time.June, 25, 11, 22, 33, 0, time.UTC), Valid: true}, "b": {}, }, }, { tableDefinition: "VARCHAR, TIMESTAMPTZ", values: []any{DataTypeTimestampTz, map[string]sql.NullTime{ "a": {Time: time.Date(2024, time.June, 25, 11, 22, 33, 0, warsawTz), Valid: true}, "b": {}, }}, expected: map[string]sql.NullTime{ "a": {Time: time.Date(2024, time.June, 25, 11, 22, 33, 0, warsawTz), Valid: true}, "b": {}, }, }, { tableDefinition: "VARCHAR, TIMESTAMPLTZ", values: []any{DataTypeTimestampLtz, map[string]sql.NullTime{ "a": {Time: time.Date(2024, time.June, 25, 11, 22, 33, 0, warsawTz), Valid: true}, "b": {}, }}, expected: map[string]sql.NullTime{ "a": {Time: time.Date(2024, time.June, 25, 11, 22, 33, 0, warsawTz), Valid: true}, "b": {}, }, }, } runDBTest(t, func(dbt *DBTest) { dbt.mustExecT(t, "ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.enableStructuredTypesBinding() for _, tc := range testcases { t.Run(tc.tableDefinition, func(t *testing.T) { dbt.mustExecT(t, fmt.Sprintf("CREATE OR REPLACE TABLE test_map_binding (m MAP(%v))", tc.tableDefinition)) defer func() { dbt.mustExecT(t, "DROP TABLE IF EXISTS test_map_binding") }() dbt.mustExecT(t, "INSERT INTO test_map_binding SELECT (?)", tc.values...) rows := dbt.mustQueryContextT(WithEmbeddedValuesNullable(ctx), t, "SELECT * FROM test_map_binding WHERE m = ?", tc.values...) defer rows.Close() assertTrueE(t, rows.Next()) var res any err := rows.Scan(&res) assertNilF(t, err) if m, ok := tc.expected.(map[string]sql.NullTime); ok { resTimes := res.(map[string]sql.NullTime) for k, v := range m { if tc.isTimeOnly { assertEqualE(t, resTimes[k].Valid, v.Valid) assertEqualE(t, resTimes[k].Time.Hour(), v.Time.Hour()) assertEqualE(t, resTimes[k].Time.Minute(), v.Time.Minute()) assertEqualE(t, resTimes[k].Time.Second(), v.Time.Second()) } else { assertEqualE(t, resTimes[k].Valid, v.Valid) if v.Valid { assertTrueE(t, resTimes[k].Time.Equal(v.Time)) } } } } else { assertDeepEqualE(t, res, tc.expected) } }) } }) } ================================================ FILE: telemetry.go ================================================ package gosnowflake import ( "context" "encoding/json" "fmt" "net/http" "strings" "sync" "time" ) const ( telemetryPath = "/telemetry/send" defaultTelemetryTimeout = 10 * time.Second defaultFlushSize = 100 ) const ( typeKey = "type" sourceKey = "source" queryIDKey = "QueryID" driverTypeKey = "DriverType" driverVersionKey = "DriverVersion" golangVersionKey = "GolangVersion" sqlStateKey = "SQLState" reasonKey = "reason" errorNumberKey = "ErrorNumber" stacktraceKey = "Stacktrace" ) const ( telemetrySource = "golang_driver" sqlException = "client_sql_exception" connectionParameters = "client_connection_parameters" ) type telemetryData struct { Timestamp int64 `json:"timestamp,omitempty"` Message map[string]string `json:"message,omitempty"` } type snowflakeTelemetry struct { logs []*telemetryData flushSize int sr *snowflakeRestful mutex *sync.Mutex enabled bool } func (st *snowflakeTelemetry) addLog(data *telemetryData) error { if !st.enabled { logger.Debug("telemetry disabled; not adding log") return nil } st.mutex.Lock() st.logs = append(st.logs, data) shouldFlush := len(st.logs) >= st.flushSize st.mutex.Unlock() if shouldFlush { if err := st.sendBatch(); err != nil { return err } } return nil } func (st *snowflakeTelemetry) sendBatch() error { if !st.enabled { logger.Debug("telemetry disabled; not sending log") return nil } type telemetry struct { Logs []*telemetryData `json:"logs"` } st.mutex.Lock() logsToSend := st.logs minicoreLoadLogs.mu.Lock() if mcLogs := minicoreLoadLogs.logs; len(mcLogs) > 0 { logsToSend = append(logsToSend, &telemetryData{ Timestamp: time.Now().UnixMilli(), Message: map[string]string{ "minicoreLogs": strings.Join(mcLogs, "; "), }, }) minicoreLoadLogs.logs = make([]string, 0) } minicoreLoadLogs.mu.Unlock() st.logs = make([]*telemetryData, 0) st.mutex.Unlock() if len(logsToSend) == 0 { logger.Debug("nothing to send to telemetry") return nil } s := &telemetry{logsToSend} body, err := json.Marshal(s) if err != nil { return err } logger.Debugf("sending %v logs to telemetry.", len(logsToSend)) logger.Debugf("telemetry payload being sent: %v", string(body)) headers := getHeaders() if token, _, _ := st.sr.TokenAccessor.GetTokens(); token != "" { headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) } fullURL := st.sr.getFullURL(telemetryPath, nil) resp, err := st.sr.FuncPost(context.Background(), st.sr, fullURL, headers, body, defaultTelemetryTimeout, defaultTimeProvider, nil) if err != nil { logger.Errorf("failed to upload metrics to telemetry. err: %v", err) return err } defer func() { if err = resp.Body.Close(); err != nil { logger.Errorf("failed to close response body for %v. err: %v", fullURL, err) } }() if resp.StatusCode != http.StatusOK { err = fmt.Errorf("non-successful response from telemetry server: %v. "+ "disabling telemetry", resp.StatusCode) logger.Error(err.Error()) st.enabled = false return err } var respd telemetryResponse if err = json.NewDecoder(resp.Body).Decode(&respd); err != nil { logger.Errorf("cannot decode telemetry response body: %v", err) st.enabled = false return err } if !respd.Success { err = fmt.Errorf("telemetry send failed with error code: %v, message: %v", respd.Code, respd.Message) logger.Error(err.Error()) st.enabled = false return err } logger.Debug("successfully uploaded metrics to telemetry") return nil } ================================================ FILE: telemetry_test.go ================================================ package gosnowflake import ( "context" "errors" "math/rand" "net/http" "net/url" "sync" "testing" "time" ) func TestTelemetryAddLog(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { st := &snowflakeTelemetry{ sr: sct.sc.rest, mutex: &sync.Mutex{}, enabled: true, flushSize: defaultFlushSize, } r := rand.New(rand.NewSource(time.Now().UnixNano())) randNum := r.Int() % 10000 for range randNum { if err := st.addLog(&telemetryData{ Message: map[string]string{ typeKey: "client_telemetry_type", queryIDKey: "123", }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), }); err != nil { t.Fatal(err) } } if len(st.logs) != randNum%defaultFlushSize { t.Errorf("length of remaining logs does not match. expected: %v, got: %v", randNum%defaultFlushSize, len(st.logs)) } if err := st.sendBatch(); err != nil { t.Fatal(err) } }) } func TestTelemetrySQLException(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sct.sc.telemetry = &snowflakeTelemetry{ sr: sct.sc.rest, mutex: &sync.Mutex{}, enabled: true, flushSize: defaultFlushSize, } sfa := &snowflakeFileTransferAgent{ ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), data: &execResponseData{ SrcLocations: make([]string, 0), }, } if err := sfa.initFileMetadata(); err == nil { t.Fatal("this should have thrown an error") } if len(sct.sc.telemetry.logs) != 1 { t.Errorf("there should be 1 telemetry data in log. found: %v", len(sct.sc.telemetry.logs)) } if sendErr := sct.sc.telemetry.sendBatch(); sendErr != nil { t.Fatal(sendErr) } if len(sct.sc.telemetry.logs) != 0 { t.Errorf("there should be no telemetry data in log. found: %v", len(sct.sc.telemetry.logs)) } }) } func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { return nil, errors.New("failed to upload metrics to telemetry") } func TestTelemetryError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { st := &snowflakeTelemetry{ sr: &snowflakeRestful{ FuncPost: funcPostTelemetryRespFail, TokenAccessor: getSimpleTokenAccessor(), }, mutex: &sync.Mutex{}, enabled: true, flushSize: defaultFlushSize, } if err := st.addLog(&telemetryData{ Message: map[string]string{ typeKey: "client_telemetry_type", queryIDKey: "123", }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), }); err != nil { t.Fatal(err) } err := st.sendBatch() if err == nil { t.Fatal("should have failed") } }) } func TestTelemetryDisabledOnBadResponse(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { st := &snowflakeTelemetry{ sr: &snowflakeRestful{ FuncPost: postTestAppBadGatewayError, TokenAccessor: getSimpleTokenAccessor(), }, mutex: &sync.Mutex{}, enabled: true, flushSize: defaultFlushSize, } if err := st.addLog(&telemetryData{ Message: map[string]string{ typeKey: "client_telemetry_type", queryIDKey: "123", }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), }); err != nil { t.Fatal(err) } err := st.sendBatch() if err == nil { t.Fatal("should have failed") } if st.enabled == true { t.Fatal("telemetry should be disabled") } st.enabled = true st.sr.FuncPost = postTestQueryNotExecuting if err = st.addLog(&telemetryData{ Message: map[string]string{ typeKey: "client_telemetry_type", queryIDKey: "123", }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), }); err != nil { t.Fatal(err) } err = st.sendBatch() if err == nil { t.Fatal("should have failed") } if st.enabled == true { t.Fatal("telemetry should be disabled") } st.enabled = true st.sr.FuncPost = postTestSuccessButInvalidJSON if err = st.addLog(&telemetryData{ Message: map[string]string{ typeKey: "client_telemetry_type", queryIDKey: "123", }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), }); err != nil { t.Fatal(err) } err = st.sendBatch() if err == nil { t.Fatal("should have failed") } if st.enabled == true { t.Fatal("telemetry should be disabled") } }) } func TestTelemetryDisabled(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { st := &snowflakeTelemetry{ sr: &snowflakeRestful{ FuncPost: postTestAppBadGatewayError, TokenAccessor: getSimpleTokenAccessor(), }, mutex: &sync.Mutex{}, enabled: false, // disable flushSize: defaultFlushSize, } if err := st.addLog(&telemetryData{ Message: map[string]string{ typeKey: "client_telemetry_type", queryIDKey: "123", }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), }); err != nil { t.Fatalf("calling addLog should not return an error just because telemetry is disabled, but did: %v", err) } st.enabled = true if err := st.addLog(&telemetryData{ Message: map[string]string{ typeKey: "client_telemetry_type", queryIDKey: "123", }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), }); err != nil { t.Fatal(err) } st.enabled = false err := st.sendBatch() if err != nil { t.Fatalf("calling sendBatch should not return an error just because telemetry is disabled, but did: %v", err) } }) } func TestAddLogError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { st := &snowflakeTelemetry{ sr: &snowflakeRestful{ FuncPost: funcPostTelemetryRespFail, TokenAccessor: getSimpleTokenAccessor(), }, mutex: &sync.Mutex{}, enabled: true, flushSize: 1, } if err := st.addLog(&telemetryData{ Message: map[string]string{ typeKey: "client_telemetry_type", queryIDKey: "123", }, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), }); err == nil { t.Fatal("should have failed") } }) } ================================================ FILE: test_data/.gitignore ================================================ writeonly.csv ================================================ FILE: test_data/connections.toml ================================================ [default] account = 'snowdriverswarsaw.us-west-2.aws' user = 'test_default_user' password = 'test_default_pass' warehouse = 'testw_default' database = 'test_default_db' schema = 'test_default_go' protocol = 'https' port = '300' [aws-oauth] account = 'snowdriverswarsaw.us-west-2.aws' user = 'test_oauth_user' password = 'test_oauth_pass' warehouse = 'testw_oauth' database = 'test_oauth_db' schema = 'test_oauth_go' protocol = 'https' port = '443' authenticator = 'oauth' testNot = 'problematicParameter' token = 'token_value' disableOCSPChecks = true [aws-oauth-file] account = 'snowdriverswarsaw.us-west-2.aws' user = 'test_user' password = 'test_pass' warehouse = 'testw' database = 'test_db' schema = 'test_go' protocol = 'https' port = '443' authenticator = 'oauth' testNot = 'problematicParameter' token_file_path = '/Users/test/.snowflake/token' [read-token] account = 'snowdriverswarsaw.us-west-2.aws' user = 'test_default_user' password = 'test_default_pass' warehouse = 'testw_default' database = 'test_default_db' schema = 'test_default_go' protocol = 'https' authenticator = 'oauth' token_file_path = './test_data/snowflake/session/token' disable_ocsp_checks = true [snake-case] account = 'snowdriverswarsaw.us-west-2.aws' user = 'test_default_user' password = 'test_default_pass' warehouse = 'testw_default' database = 'test_default_db' schema = 'test_default_go' protocol = 'https' port = '300' ocsp_fail_open=true ================================================ FILE: test_data/multistatements.sql ================================================ CREATE OR REPLACE TABLE jj_1(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_2(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_3(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_4(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_5(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_6(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_7(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_8(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_9(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_10(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_11(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_12(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_13(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_14(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_15(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_16(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_17(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_18(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_19(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_20(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_21(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_22(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_23(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_24(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_25(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_26(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_27(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_28(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_29(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_30(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_31(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_32(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_33(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_34(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_35(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_36(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_37(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_38(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_39(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_40(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_41(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_42(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_43(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_44(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_45(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_46(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_47(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_48(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_49(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_50(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_51(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_52(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_53(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_54(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_55(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_56(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_57(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_58(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_59(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_60(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_61(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_62(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_63(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_64(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_65(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_66(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_67(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_68(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_69(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_70(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_71(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_72(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_73(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_74(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_75(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_76(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_77(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_78(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_79(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_80(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_81(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_82(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_83(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_84(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_85(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_86(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_87(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_88(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_89(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_90(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_91(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_92(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_93(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_94(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_95(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_96(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_97(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_98(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_99(i int, v varchar(10)); CREATE OR REPLACE TABLE jj_100(i int, v varchar(10)); ================================================ FILE: test_data/multistatements_drop.sql ================================================ drop table if exists jj_1; drop table if exists jj_2; drop table if exists jj_3; drop table if exists jj_4; drop table if exists jj_5; drop table if exists jj_6; drop table if exists jj_7; drop table if exists jj_8; drop table if exists jj_9; drop table if exists jj_10; drop table if exists jj_11; drop table if exists jj_12; drop table if exists jj_13; drop table if exists jj_14; drop table if exists jj_15; drop table if exists jj_16; drop table if exists jj_17; drop table if exists jj_18; drop table if exists jj_19; drop table if exists jj_20; drop table if exists jj_21; drop table if exists jj_22; drop table if exists jj_23; drop table if exists jj_24; drop table if exists jj_25; drop table if exists jj_26; drop table if exists jj_27; drop table if exists jj_28; drop table if exists jj_29; drop table if exists jj_30; drop table if exists jj_31; drop table if exists jj_32; drop table if exists jj_33; drop table if exists jj_34; drop table if exists jj_35; drop table if exists jj_36; drop table if exists jj_37; drop table if exists jj_38; drop table if exists jj_39; drop table if exists jj_40; drop table if exists jj_41; drop table if exists jj_42; drop table if exists jj_43; drop table if exists jj_44; drop table if exists jj_45; drop table if exists jj_46; drop table if exists jj_47; drop table if exists jj_48; drop table if exists jj_49; drop table if exists jj_50; drop table if exists jj_51; drop table if exists jj_52; drop table if exists jj_53; drop table if exists jj_54; drop table if exists jj_55; drop table if exists jj_56; drop table if exists jj_57; drop table if exists jj_58; drop table if exists jj_59; drop table if exists jj_60; drop table if exists jj_61; drop table if exists jj_62; drop table if exists jj_63; drop table if exists jj_64; drop table if exists jj_65; drop table if exists jj_66; drop table if exists jj_67; drop table if exists jj_68; drop table if exists jj_69; drop table if exists jj_70; drop table if exists jj_71; drop table if exists jj_72; drop table if exists jj_73; drop table if exists jj_74; drop table if exists jj_75; drop table if exists jj_76; drop table if exists jj_77; drop table if exists jj_78; drop table if exists jj_79; drop table if exists jj_80; drop table if exists jj_81; drop table if exists jj_82; drop table if exists jj_83; drop table if exists jj_84; drop table if exists jj_85; drop table if exists jj_86; drop table if exists jj_87; drop table if exists jj_88; drop table if exists jj_89; drop table if exists jj_90; drop table if exists jj_91; drop table if exists jj_92; drop table if exists jj_93; drop table if exists jj_94; drop table if exists jj_95; drop table if exists jj_96; drop table if exists jj_97; drop table if exists jj_98; drop table if exists jj_99; drop table if exists jj_100; ================================================ FILE: test_data/orders_100.csv ================================================ 1|36901|O|173665.47|1996-01-02|5-LOW|Clerk#000000951|0|nstructions sleep furiously among | 2|78002|O|46929.18|1996-12-01|1-URGENT|Clerk#000000880|0| foxes. pending accounts at the pending, silent asymptot| 3|123314|F|193846.25|1993-10-14|5-LOW|Clerk#000000955|0|sly final accounts boost. carefully regular ideas cajole carefully. depos| 4|136777|O|32151.78|1995-10-11|5-LOW|Clerk#000000124|0|sits. slyly regular warthogs cajole. regular, regular theodolites acro| 5|44485|F|144659.20|1994-07-30|5-LOW|Clerk#000000925|0|quickly. bold deposits sleep slyly. packages use slyly| 6|55624|F|58749.59|1992-02-21|4-NOT SPECIFIED|Clerk#000000058|0|ggle. special, final requests are against the furiously specia| 7|39136|O|252004.18|1996-01-10|2-HIGH|Clerk#000000470|0|ly special requests | 32|130057|O|208660.75|1995-07-16|2-HIGH|Clerk#000000616|0|ise blithely bold, regular requests. quickly unusual dep| 33|66958|F|163243.98|1993-10-27|3-MEDIUM|Clerk#000000409|0|uriously. furiously final request| 34|61001|O|58949.67|1998-07-21|3-MEDIUM|Clerk#000000223|0|ly final packages. fluffily final deposits wake blithely ideas. spe| 35|127588|O|253724.56|1995-10-23|4-NOT SPECIFIED|Clerk#000000259|0|zzle. carefully enticing deposits nag furio| 36|115252|O|68289.96|1995-11-03|1-URGENT|Clerk#000000358|0| quick packages are blithely. slyly silent accounts wake qu| 37|86116|F|206680.66|1992-06-03|3-MEDIUM|Clerk#000000456|0|kly regular pinto beans. carefully unusual waters cajole never| 38|124828|O|82500.05|1996-08-21|4-NOT SPECIFIED|Clerk#000000604|0|haggle blithely. furiously express ideas haggle blithely furiously regular re| 39|81763|O|341734.47|1996-09-20|3-MEDIUM|Clerk#000000659|0|ole express, ironic requests: ir| 64|32113|F|39414.99|1994-07-16|3-MEDIUM|Clerk#000000661|0|wake fluffily. sometimes ironic pinto beans about the dolphin| 65|16252|P|110643.60|1995-03-18|1-URGENT|Clerk#000000632|0|ular requests are blithely pending orbits-- even requests against the deposit| 66|129200|F|103740.67|1994-01-20|5-LOW|Clerk#000000743|0|y pending requests integrate| 67|56614|O|169405.01|1996-12-19|4-NOT SPECIFIED|Clerk#000000547|0|symptotes haggle slyly around the furiously iron| 68|28547|O|330793.52|1998-04-18|3-MEDIUM|Clerk#000000440|0| pinto beans sleep carefully. blithely ironic deposits haggle furiously acro| 69|84487|F|197689.49|1994-06-04|4-NOT SPECIFIED|Clerk#000000330|0| depths atop the slyly thin deposits detect among the furiously silent accou| 70|64340|F|113534.42|1993-12-18|5-LOW|Clerk#000000322|0| carefully ironic request| 71|3373|O|276992.74|1998-01-24|4-NOT SPECIFIED|Clerk#000000271|0| express deposits along the blithely regul| 96|107779|F|68989.90|1994-04-17|2-HIGH|Clerk#000000395|0|oost furiously. pinto| 97|21061|F|110512.84|1993-01-29|3-MEDIUM|Clerk#000000547|0|hang blithely along the regular accounts. furiously even ideas after the| 98|104480|F|69168.33|1994-09-25|1-URGENT|Clerk#000000448|0|c asymptotes. quickly regular packages should have to nag re| 99|88910|F|112126.95|1994-03-13|4-NOT SPECIFIED|Clerk#000000973|0|e carefully ironic packages. pending| 100|147004|O|187782.63|1998-02-28|4-NOT SPECIFIED|Clerk#000000577|0|heodolites detect slyly alongside of the ent| ================================================ FILE: test_data/orders_101.csv ================================================ 353|1777|F|249710.43|1993-12-31|5-LOW|Clerk#000000449|0| quiet ideas sleep. even instructions cajole slyly. silently spe| 354|138268|O|217160.72|1996-03-14|2-HIGH|Clerk#000000511|0|ly regular ideas wake across the slyly silent ideas. final deposits eat b| 355|70007|F|99516.75|1994-06-14|5-LOW|Clerk#000000532|0|s. sometimes regular requests cajole. regular, pending accounts a| 356|146809|F|209439.04|1994-06-30|4-NOT SPECIFIED|Clerk#000000944|0|as wake along the bold accounts. even, | 357|60395|O|157411.61|1996-10-09|2-HIGH|Clerk#000000301|0|e blithely about the express, final accounts. quickl| 358|2290|F|354132.39|1993-09-20|2-HIGH|Clerk#000000392|0|l, silent instructions are slyly. silently even de| 359|77600|F|239998.53|1994-12-19|3-MEDIUM|Clerk#000000934|0|n dolphins. special courts above the carefully ironic requests use| 384|113009|F|166753.71|1992-03-03|5-LOW|Clerk#000000206|0|, even accounts use furiously packages. slyly ironic pla| 385|32947|O|54948.26|1996-03-22|5-LOW|Clerk#000000600|0|hless accounts unwind bold pain| 386|60110|F|110216.57|1995-01-25|2-HIGH|Clerk#000000648|0| haggle quickly. stealthily bold asymptotes haggle among the furiously even re| 387|3296|O|204546.39|1997-01-26|4-NOT SPECIFIED|Clerk#000000768|0| are carefully among the quickly even deposits. furiously silent req| 388|44668|F|198800.71|1992-12-16|4-NOT SPECIFIED|Clerk#000000356|0|ar foxes above the furiously ironic deposits nag slyly final reque| 389|126973|F|2519.40|1994-02-17|2-HIGH|Clerk#000000062|0|ing to the regular asymptotes. final, pending foxes about the blithely sil| 390|102563|O|269761.09|1998-04-07|5-LOW|Clerk#000000404|0|xpress asymptotes use among the regular, final pinto b| 391|110278|F|20890.17|1994-11-17|2-HIGH|Clerk#000000256|0|orges thrash fluffil| 416|40130|F|105675.20|1993-09-27|5-LOW|Clerk#000000294|0| the accounts. fluffily bold depo| 417|54583|F|125155.22|1994-02-06|3-MEDIUM|Clerk#000000468|0|ironic, even packages. thinly unusual accounts sleep along the slyly unusual | 418|94834|P|53328.48|1995-04-13|4-NOT SPECIFIED|Clerk#000000643|0|. furiously ironic instruc| 419|116261|O|165454.42|1996-10-01|3-MEDIUM|Clerk#000000376|0|osits. blithely pending theodolites boost carefully| 420|90145|O|343254.06|1995-10-31|4-NOT SPECIFIED|Clerk#000000756|0|leep carefully final excuses. fluffily pending requests unwind carefully above| 421|39149|F|1156.67|1992-02-22|5-LOW|Clerk#000000405|0|egular, even packages according to the final, un| 422|73075|O|188124.81|1997-05-31|4-NOT SPECIFIED|Clerk#000000049|0|aggle carefully across the accounts. regular accounts eat fluffi| 423|103396|O|50240.88|1996-06-01|1-URGENT|Clerk#000000674|0|quests. deposits cajole quickly. furiously bold accounts haggle q| 448|149641|O|165954.35|1995-08-21|3-MEDIUM|Clerk#000000597|0| regular, express foxes use blithely. quic| 449|95767|O|71120.82|1995-07-20|2-HIGH|Clerk#000000841|0|. furiously regular theodolites affix blithely | 450|47380|P|228518.02|1995-03-05|4-NOT SPECIFIED|Clerk#000000293|0|d theodolites. boldly bold foxes since the pack| 451|98758|O|141490.92|1998-05-25|5-LOW|Clerk#000000048|0|nic pinto beans. theodolites poach carefully; | 452|59560|O|3270.20|1997-10-14|1-URGENT|Clerk#000000498|0|t, unusual instructions above the blithely bold pint| 453|44030|O|329149.33|1997-05-26|5-LOW|Clerk#000000504|0|ss foxes. furiously regular ideas sleep according to t| 454|48776|O|36743.83|1995-12-27|5-LOW|Clerk#000000890|0|dolites sleep carefully blithely regular deposits. quickly regul| 455|12098|O|183606.42|1996-12-04|1-URGENT|Clerk#000000796|0| about the final platelets. dependen| 480|71383|F|23699.64|1993-05-08|5-LOW|Clerk#000000004|0|ealthy pinto beans. fluffily regular requests along the special sheaves wake | 481|30352|F|201254.08|1992-10-08|2-HIGH|Clerk#000000230|0|ly final ideas. packages haggle fluffily| 482|125059|O|182312.78|1996-03-26|1-URGENT|Clerk#000000295|0|ts. deposits wake: final acco| 483|34820|O|70146.28|1995-07-11|2-HIGH|Clerk#000000025|0|cross the carefully final e| 484|54244|O|327889.57|1997-01-03|3-MEDIUM|Clerk#000000545|0|grouches use. furiously bold accounts maintain. bold, regular deposits| 485|100561|O|192867.30|1997-03-26|2-HIGH|Clerk#000000105|0| regular ideas nag thinly furiously s| 486|50861|O|284644.07|1996-03-11|4-NOT SPECIFIED|Clerk#000000803|0|riously dolphins. fluffily ironic requ| 487|107825|F|90657.45|1992-08-18|1-URGENT|Clerk#000000086|0|ithely unusual courts eat accordi| 512|63022|P|194834.40|1995-05-20|5-LOW|Clerk#000000814|0|ding requests. carefully express theodolites was quickly. furious| 513|60569|O|105559.70|1995-05-01|2-HIGH|Clerk#000000522|0|regular packages. pinto beans cajole carefully against the even| 514|74872|O|154735.68|1996-04-04|2-HIGH|Clerk#000000094|0| cajole furiously. slyly final excuses cajole. slyly special instructions | 515|141829|F|244660.33|1993-08-29|4-NOT SPECIFIED|Clerk#000000700|0|eposits are furiously furiously silent pinto beans. pending pack| 516|43903|O|21920.56|1998-04-21|2-HIGH|Clerk#000000305|0|lar, unusual platelets are carefully. even courts sleep bold, final pinto bea| 517|9220|O|121396.01|1997-04-07|5-LOW|Clerk#000000359|0|slyly pending deposits cajole quickly packages. furiou| ================================================ FILE: test_data/put_get_1.txt ================================================ 1,2014-01-02,2014-01-02 11:30:21,2014-01-02 11:30:22,2014-01-02 11:30:23,2014-01-02T11:30:24-07:00,8.765,9.876 2,2014-02-02,2014-02-02 11:30:21,2014-02-02 11:30:22,2014-02-02 11:30:23,2014-02-02T11:30:24+02:00,8.764,9.875 3,2014-03-02,2014-03-02 11:30:21,2014-03-02 11:30:22,2014-03-02 11:30:23,2014-03-02T11:30:24Z,8.763,9.874 ================================================ FILE: test_data/snowflake/session/token ================================================ mock_token123456 ================================================ FILE: test_data/wiremock/mappings/auth/external_browser/parallel_login_first_fails_then_successful_flow.json ================================================ { "mappings": [ { "scenarioName": "External browser parallel login first fails then successful flow", "requiredScenarioState": "Started", "newScenarioState": "First request failed", "request": { "urlPathPattern": "/session/authenticator-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.TOKEN", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "code": null, "message": "auth failed", "success": false }, "fixedDelayMilliseconds": 2000 } }, { "scenarioName": "External browser parallel login first fails then successful flow", "requiredScenarioState": "First request failed", "newScenarioState": "Second request successful", "request": { "urlPathPattern": "/session/authenticator-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.TOKEN", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "ssoUrl": "http://localhost:{{ jsonPath request.body '$.data.BROWSER_MODE_REDIRECT_PORT' }}?token=test-saml-token", "proofKey": "test-proof-key" }, "code": null, "message": null, "success": true }, "transformers": ["response-template"], "fixedDelayMilliseconds": 2000 } }, { "scenarioName": "External browser parallel login first fails then successful flow", "requiredScenarioState": "Second request successful", "newScenarioState": "Login request with ID token required", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "TOKEN": "test-saml-token" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": "test-id-token", "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } }, { "scenarioName": "External browser parallel login first fails then successful flow", "requiredScenarioState": "Login request with ID token required", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "TOKEN": "test-id-token" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/external_browser/parallel_login_successful_flow.json ================================================ { "mappings": [ { "scenarioName": "External browser parallel login successful flow", "requiredScenarioState": "Started", "newScenarioState": "Login request with SAML token required", "request": { "urlPathPattern": "/session/authenticator-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.TOKEN", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "ssoUrl": "http://localhost:{{ jsonPath request.body '$.data.BROWSER_MODE_REDIRECT_PORT' }}?token=test-saml-token", "proofKey": "test-proof-key" }, "code": null, "message": null, "success": true }, "transformers": ["response-template"], "fixedDelayMilliseconds": 2000 } }, { "scenarioName": "External browser parallel login successful flow", "requiredScenarioState": "Login request with SAML token required", "newScenarioState": "Login request with ID token required", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "TOKEN": "test-saml-token" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": "test-id-token", "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } }, { "scenarioName": "External browser parallel login successful flow", "requiredScenarioState": "Login request with ID token required", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "TOKEN": "test-id-token" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/external_browser/successful_flow.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/authenticator-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.TOKEN", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "ssoUrl": "http://localhost:{{ jsonPath request.body '$.data.BROWSER_MODE_REDIRECT_PORT' }}?token=test-token", "proofKey": "test-proof-key" }, "code": null, "message": null, "success": true }, "transformers": ["response-template"], "fixedDelayMilliseconds": 2000 } }, { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "TOKEN": "test-token" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": "test-id-token", "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } }, { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "TOKEN": "test-id-token" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/mfa/parallel_login_first_fails_then_successful_flow.json ================================================ { "mappings": [ { "scenarioName": "MFA Authentication Flow", "requiredScenarioState": "Started", "newScenarioState": "MFA first attempt failed", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.TOKEN", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "code": "394508", "data": { "authnMethod": "USERNAME_PASSWORD", "loginName": "testUser", "nextAction": "RETRY_LOGIN", "requestId": "8239b728-24d5-4d1b-5af6-593402a1cea2", "signInOptions": {} }, "headers": null, "message": "Failed to authenticate: MFA with TOTP is required. To authenticate, provide both your password and a current TOTP passcode.", "success": false }, "fixedDelayMilliseconds": 2000 } }, { "scenarioName": "MFA Authentication Flow", "requiredScenarioState": "MFA first attempt failed", "newScenarioState": "MFA token required", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.TOKEN", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": "mfa-token", "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true }, "fixedDelayMilliseconds": 2000 } }, { "scenarioName": "MFA Authentication Flow", "requiredScenarioState": "MFA token required", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword", "TOKEN": "mfa-token" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/mfa/parallel_login_successful_flow.json ================================================ { "mappings": [ { "scenarioName": "MFA Authentication Flow", "requiredScenarioState": "Started", "newScenarioState": "MFA token required", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.TOKEN", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": "mfa-token", "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true }, "fixedDelayMilliseconds": 2000 } }, { "scenarioName": "MFA Authentication Flow", "requiredScenarioState": "MFA token required", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword", "TOKEN": "mfa-token" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/authorization_code/error_from_idp.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/authorize", "queryParameters": { "response_type": { "equalTo": "code" }, "scope": { "equalTo": "session:role:ANALYST" }, "code_challenge_method": { "equalTo": "S256" }, "redirect_uri": { "equalTo": "http://localhost:1234/snowflake/oauth-redirect" }, "code_challenge": { "matches": ".+" }, "state": { "matches": "testState|invalidState" }, "client_id": { "equalTo": "testClientId" } }, "method": "GET" }, "response": { "status": 302, "headers": { "Location": "http://localhost:1234/snowflake/oauth-redirect?error=some+error&error_description=some+error+desc" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/authorization_code/invalid_code.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/authorize", "queryParameters": { "response_type": { "equalTo": "code" }, "scope": { "equalTo": "session:role:ANALYST" }, "code_challenge_method": { "equalTo": "S256" }, "redirect_uri": { "equalTo": "http://localhost:1234/snowflake/oauth-redirect" }, "code_challenge": { "matches": ".+" }, "state": { "matches": "testState" }, "client_id": { "equalTo": "testClientId" } }, "method": "GET" }, "response": { "status": 302, "headers": { "Location": "http://localhost:1234/snowflake/oauth-redirect?code=testCode&state=testState" } } }, { "scenarioName": "Successful token exchange", "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" }, "Authorization": { "equalTo": "Basic dGVzdENsaWVudElkOnRlc3RDbGllbnRTZWNyZXQ=" } }, "formParameters": { "grant_type": { "equalTo": "authorization_code" }, "code_verifier": { "matches": "[a-zA-Z0-9\\-_]+" }, "code": { "equalTo": "testCode" }, "redirect_uri": { "equalTo": "http://localhost:1234/snowflake/oauth-redirect" } } }, "response": { "status": 400, "jsonBody": { "error" : "invalid_grant", "error_description" : "The authorization code is invalid or has expired." } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/authorization_code/successful_flow.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/authorize", "queryParameters": { "response_type": { "equalTo": "code" }, "scope": { "equalTo": "session:role:ANALYST" }, "code_challenge_method": { "equalTo": "S256" }, "redirect_uri": { "matches": "http:.+" }, "code_challenge": { "matches": "JZpN_-zfNduuWm-zUo-D-m7vMw_pgUGv8wGDGqBR8PM" }, "state": { "matches": "testState|invalidState" }, "client_id": { "equalTo": "testClientId" } }, "method": "GET" }, "response": { "status": 302, "headers": { "Location": "{{ request.query.redirect_uri }}?code=testCode&state=testState" }, "transformers": ["response-template"] } }, { "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" }, "Authorization": { "equalTo": "Basic dGVzdENsaWVudElkOnRlc3RDbGllbnRTZWNyZXQ=" } }, "formParameters": { "grant_type": { "equalTo": "authorization_code" }, "code_verifier": { "matches": "testCodeVerifier" }, "code": { "equalTo": "testCode" }, "redirect_uri": { "matches": "http://(127.0.0.1|localhost):[0-9]+.*" } } }, "response": { "status": 200, "jsonBody": { "access_token": "access-token-123", "token_type": "Bearer", "username": "test-user", "scope": "refresh_token session:role:ANALYST", "expires_in": 600, "refresh_token_expires_in": 86399, "idpInitiated": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/authorization_code/successful_flow_with_offline_access.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/authorize", "queryParameters": { "response_type": { "equalTo": "code" }, "scope": { "equalTo": "session:role:ANALYST offline_access" }, "code_challenge_method": { "equalTo": "S256" }, "redirect_uri": { "equalTo": "http://localhost:1234/snowflake/oauth-redirect" }, "code_challenge": { "matches": "JZpN_-zfNduuWm-zUo-D-m7vMw_pgUGv8wGDGqBR8PM" }, "state": { "matches": "testState|invalidState" }, "client_id": { "equalTo": "testClientId" } }, "method": "GET" }, "response": { "status": 302, "headers": { "Location": "http://localhost:1234/snowflake/oauth-redirect?code=testCode&state=testState" } } }, { "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" }, "Authorization": { "equalTo": "Basic dGVzdENsaWVudElkOnRlc3RDbGllbnRTZWNyZXQ=" } }, "formParameters": { "grant_type": { "equalTo": "authorization_code" }, "code_verifier": { "matches": "testCodeVerifier" }, "code": { "equalTo": "testCode" }, "redirect_uri": { "equalTo": "http://localhost:1234/snowflake/oauth-redirect" } } }, "response": { "status": 200, "jsonBody": { "access_token": "access-token-123", "refresh_token": "refresh-token-123", "token_type": "Bearer", "username": "test-user", "scope": "refresh_token session:role:ANALYST", "expires_in": 600, "refresh_token_expires_in": 86399, "idpInitiated": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/authorization_code/successful_flow_with_single_use_refresh_token.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/authorize", "queryParameters": { "response_type": { "equalTo": "code" }, "scope": { "equalTo": "session:role:ANALYST" }, "code_challenge_method": { "equalTo": "S256" }, "redirect_uri": { "equalTo": "http://localhost:1234/snowflake/oauth-redirect" }, "code_challenge": { "matches": "JZpN_-zfNduuWm-zUo-D-m7vMw_pgUGv8wGDGqBR8PM" }, "state": { "matches": "testState|invalidState" }, "client_id": { "equalTo": "testClientId" } }, "method": "GET" }, "response": { "status": 302, "headers": { "Location": "http://localhost:1234/snowflake/oauth-redirect?code=testCode&state=testState" } } }, { "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" }, "Authorization": { "equalTo": "Basic dGVzdENsaWVudElkOnRlc3RDbGllbnRTZWNyZXQ=" } }, "formParameters": { "grant_type": { "equalTo": "authorization_code" }, "code_verifier": { "matches": "testCodeVerifier" }, "code": { "equalTo": "testCode" }, "redirect_uri": { "equalTo": "http://localhost:1234/snowflake/oauth-redirect" }, "enable_single_use_refresh_tokens": { "equalTo": "true" } } }, "response": { "status": 200, "jsonBody": { "access_token": "access-token-123", "token_type": "Bearer", "username": "test-user", "scope": "refresh_token session:role:ANALYST", "expires_in": 600, "refresh_token_expires_in": 86399, "idpInitiated": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/client_credentials/invalid_client.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" } } }, "response": { "status": 401, "jsonBody": { "error": "invalid_client", "error_description": "The client secret supplied for a confidential client is invalid." } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/client_credentials/successful_flow.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" }, "Authorization": { "equalTo": "Basic dGVzdENsaWVudElkOnRlc3RDbGllbnRTZWNyZXQ=" } }, "formParameters": { "grant_type": { "equalTo": "client_credentials" }, "scope": { "equalTo": "session:role:ANALYST" } } }, "response": { "status": 200, "jsonBody": { "access_token": "access-token-123", "refresh_token": "123", "token_type": "Bearer", "username": "user", "scope": "refresh_token session:role:ANALYST", "expires_in": 600, "refresh_token_expires_in": 86399, "idpInitiated": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/login_request.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson": { "data": { "TOKEN": "access-token-123" } }, "ignoreExtraElements": true } ] }, "response": { "status": 200, "headers": { "Content-Type": "application/json" }, "jsonBody": { "code": null, "data": { "token": "session token" }, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/login_request_with_expired_access_token.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson": { "data": { "TOKEN": "expired-token" } }, "ignoreExtraElements": true } ] }, "response": { "status": 200, "headers": { "Content-Type": "application/json" }, "jsonBody": { "code": "390303", "data": { "authnMethod": "OAUTH", "nextAction": "RETRY_LOGIN", "requestId": "89c7289e-b984-4038-565b-dda3d96dcef3", "signInOptions": {} }, "headers": null, "message": "Invalid OAuth access token. ", "success": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/refresh_token/invalid_refresh_token.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" }, "Authorization": { "equalTo": "Basic dGVzdENsaWVudElkOnRlc3RDbGllbnRTZWNyZXQ=" } }, "formParameters": { "scope": { "equalTo": "session:role:ANALYST offline_access" }, "grant_type": { "equalTo": "refresh_token" }, "refresh_token": { "equalTo": "expired-refresh-token" } } }, "response": { "status": 400, "jsonBody": { "error" : "invalid_grant", "error_description" : "The authorization code is invalid or has expired." } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/refresh_token/successful_flow.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" }, "Authorization": { "equalTo": "Basic dGVzdENsaWVudElkOnRlc3RDbGllbnRTZWNyZXQ=" } }, "formParameters": { "scope": { "equalTo": "session:role:ANALYST offline_access" }, "grant_type": { "equalTo": "refresh_token" }, "refresh_token": { "equalTo": "refresh-token-123" } } }, "response": { "status": 200, "jsonBody": { "access_token": "access-token-123", "refresh_token": "refresh-token-123a", "token_type": "Bearer", "username": "test-user", "scope": "session:role:ANALYST offline_access", "expires_in": 600, "refresh_token_expires_in": 86399, "idpInitiated": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/oauth2/refresh_token/successful_flow_without_new_refresh_token.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/oauth/token", "method": "POST", "headers": { "Content-Type": { "contains": "application/x-www-form-urlencoded" }, "Authorization": { "equalTo": "Basic dGVzdENsaWVudElkOnRlc3RDbGllbnRTZWNyZXQ=" } }, "formParameters": { "scope": { "equalTo": "session:role:ANALYST offline_access" }, "grant_type": { "equalTo": "refresh_token" }, "refresh_token": { "equalTo": "refresh-token-123" } } }, "response": { "status": 200, "jsonBody": { "access_token": "access-token-123", "token_type": "Bearer", "username": "test-user", "scope": "session:role:ANALYST offline_access", "expires_in": 600, "refresh_token_expires_in": 86399, "idpInitiated": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/password/invalid_host.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST" }, "response": { "status": 403, "jsonBody": { "data": null, "code": "390144", "message": "Invalid account name or host.", "success": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/password/invalid_password.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "PASSWORD": "INVALID_PASSWORD" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": null, "code": "390100", "message": "Incorrect username or password was specified.", "success": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/password/invalid_user.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "bogus" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": null, "code": "390422", "message": "Incorrect username or password was specified.", "success": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/password/successful_flow.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/password/successful_flow_with_telemetry.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 }, { "name": "CLIENT_TELEMETRY_ENABLED", "value": %CLIENT_TELEMETRY_ENABLED% } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/pat/invalid_token.json ================================================ { "mappings": [ { "scenarioName": "Successful PAT authentication flow", "requiredScenarioState": "Started", "newScenarioState": "Authenticated", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "AUTHENTICATOR": "PROGRAMMATIC_ACCESS_TOKEN", "TOKEN": "some PAT" } }, "ignoreExtraElements" : true } ] }, "response": { "status": 200, "jsonBody": { "data": { "nextAction": "RETRY_LOGIN", "authnMethod": "PAT", "signInOptions": {} }, "code": "394400", "message": "Programmatic access token is invalid.", "success": false, "headers": null } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/pat/reading_fresh_token.json ================================================ { "mappings": [ { "scenarioName": "Successful PAT authentication flow", "requiredScenarioState": "Started", "newScenarioState": "Second authentication", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "AUTHENTICATOR": "PROGRAMMATIC_ACCESS_TOKEN", "TOKEN": "some PAT" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.PASSWORD", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "OAUTH_TEST_AUTH_CODE", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DHEYMAN", "schemaName": "TEST_JDBC", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } }, { "scenarioName": "Successful PAT authentication flow", "requiredScenarioState": "Second authentication", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "AUTHENTICATOR": "PROGRAMMATIC_ACCESS_TOKEN", "TOKEN": "some PAT 2" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.PASSWORD", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "OAUTH_TEST_AUTH_CODE", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DHEYMAN", "schemaName": "TEST_JDBC", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/pat/successful_flow.json ================================================ { "mappings": [ { "scenarioName": "Successful PAT authentication flow", "requiredScenarioState": "Started", "newScenarioState": "Authenticated", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "AUTHENTICATOR": "PROGRAMMATIC_ACCESS_TOKEN", "TOKEN": "some PAT" } }, "ignoreExtraElements" : true }, { "matchesJsonPath": { "expression": "$.data.PASSWORD", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "OAUTH_TEST_AUTH_CODE", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DHEYMAN", "schemaName": "TEST_JDBC", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/http_error.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/oauth2/token.*", "queryParameters": { "api-version": { "equalTo": "2018-02-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } }, "method": "GET", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 400 } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/missing_issuer_claim.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/oauth2/token.*", "queryParameters": { "api-version": { "equalTo": "2018-02-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } }, "method": "GET", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 200, "jsonBody": { "access_token": "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk0ZGI4N2NiMjdmNjdjZDA1Zjk5OTlkZjMwNjg1NmQ4In0.eyJhdWQiOiJhcGkxIiwic3ViIjoiNzcyMTNFMzAtRThDQi00NTk1LUIxQjYtNUYwNTBFODMwOEZEIiwiZXhwIjoxNzQ0NzE2MDUxLCJpYXQiOjE3NDQ3MTI0NTEsImp0aSI6Ijg3MTMzNzcwMDk0MTZmYmFhNDM0MmFkMjMxZGUwMDBkIn0.xv_rY9IUnnoC0SeBsoXbF2UZo5wmeYNuumLJuTa7cwq0P6OHa2R5DkrHVMu4Zgz3eipQ_O9wln66BQPr_VG1iQ" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/missing_sub_claim.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/oauth2/token.*", "queryParameters": { "api-version": { "equalTo": "2018-02-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } }, "method": "GET", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 200, "jsonBody": { "access_token": "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk0ZGI4N2NiMjdmNjdjZDA1Zjk5OTlkZjMwNjg1NmQ4In0.eyJhdWQiOiJhcGkxIiwiaXNzIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvZmExNWQ2OTItZTljNy00NDYwLWE3NDMtMjlmMjk1MjIyMjkvIiwiZXhwIjoxNzQ0NzE2MDUxLCJpYXQiOjE3NDQ3MTI0NTEsImp0aSI6Ijg3MTMzNzcwMDk0MTZmYmFhNDM0MmFkMjMxZGUwMDBkIn0.KfVQlyouRS2EoGZTvzTN77pTviXdyPl27WrC9rPsr9AiTwnsXnOxIj-CDahyeFksWGNuhRcyzN_nI_ewBS7fVw" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/non_json_response.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/oauth2/token.*", "queryParameters": { "api-version": { "equalTo": "2018-02-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } }, "method": "GET", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 200, "body": "not a JSON format" } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/successful_flow_azure_functions.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/endpoint/from/env.*", "queryParameters": { "api-version": { "equalTo": "2019-08-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" }, "client_id": { "equalTo": "managed-client-id-from-env" } }, "method": "GET", "headers": { "X-IDENTITY-HEADER": { "equalTo": "some-identity-header-from-env" } } }, "response": { "status": 200, "jsonBody": { "access_token": "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk0ZGI4N2NiMjdmNjdjZDA1Zjk5OTlkZjMwNjg1NmQ4In0.eyJhdWQiOiJhcGkxIiwiaXNzIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvZmExNWQ2OTItZTljNy00NDYwLWE3NDMtMjlmMjk1MjIyMjkvIiwic3ViIjoiNzcyMTNFMzAtRThDQi00NTk1LUIxQjYtNUYwNTBFODMwOEZEIiwiZXhwIjoxNzQ0NzE2MDUxLCJpYXQiOjE3NDQ3MTI0NTEsImp0aSI6Ijg3MTMzNzcwMDk0MTZmYmFhNDM0MmFkMjMxZGUwMDBkIn0.C5jTYoybRs5YF5GvPgoDq4WK5U9-gDzh_N3IPaqEBI0IifdYSWpKQ72v3UISnVpp7Fc46C-ZC8kijUGe3IU9zA" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/successful_flow_azure_functions_custom_entra_resource.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/endpoint/from/env.*", "queryParameters": { "api-version": { "equalTo": "2019-08-01" }, "resource": { "equalTo": "api://1111111-2222-3333-44444-55555555" }, "client_id": { "equalTo": "managed-client-id-from-env" } }, "method": "GET", "headers": { "X-IDENTITY-HEADER": { "equalTo": "some-identity-header-from-env" } } }, "response": { "status": 200, "jsonBody": { "access_token": "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk0ZGI4N2NiMjdmNjdjZDA1Zjk5OTlkZjMwNjg1NmQ4In0.eyJhdWQiOiJhcGkxIiwiaXNzIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvZmExNWQ2OTItZTljNy00NDYwLWE3NDMtMjlmMjk1MjIyMjkvIiwic3ViIjoiNzcyMTNFMzAtRThDQi00NTk1LUIxQjYtNUYwNTBFODMwOEZEIiwiZXhwIjoxNzQ0NzE2MDUxLCJpYXQiOjE3NDQ3MTI0NTEsImp0aSI6Ijg3MTMzNzcwMDk0MTZmYmFhNDM0MmFkMjMxZGUwMDBkIn0.C5jTYoybRs5YF5GvPgoDq4WK5U9-gDzh_N3IPaqEBI0IifdYSWpKQ72v3UISnVpp7Fc46C-ZC8kijUGe3IU9zA" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/successful_flow_azure_functions_no_client_id.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/endpoint/from/env.*", "queryParameters": { "api-version": { "equalTo": "2019-08-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } }, "method": "GET", "headers": { "X-IDENTITY-HEADER": { "equalTo": "some-identity-header-from-env" } } }, "response": { "status": 200, "jsonBody": { "access_token": "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk0ZGI4N2NiMjdmNjdjZDA1Zjk5OTlkZjMwNjg1NmQ4In0.eyJhdWQiOiJhcGkxIiwiaXNzIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvZmExNWQ2OTItZTljNy00NDYwLWE3NDMtMjlmMjk1MjIyMjkvIiwic3ViIjoiNzcyMTNFMzAtRThDQi00NTk1LUIxQjYtNUYwNTBFODMwOEZEIiwiZXhwIjoxNzQ0NzE2MDUxLCJpYXQiOjE3NDQ3MTI0NTEsImp0aSI6Ijg3MTMzNzcwMDk0MTZmYmFhNDM0MmFkMjMxZGUwMDBkIn0.C5jTYoybRs5YF5GvPgoDq4WK5U9-gDzh_N3IPaqEBI0IifdYSWpKQ72v3UISnVpp7Fc46C-ZC8kijUGe3IU9zA" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/successful_flow_azure_functions_v2_issuer.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/endpoint/from/env.*", "queryParameters": { "api-version": { "equalTo": "2019-08-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" }, "client_id": { "equalTo": "managed-client-id-from-env" } }, "method": "GET", "headers": { "X-IDENTITY-HEADER": { "equalTo": "some-identity-header-from-env" } } }, "response": { "status": 200, "jsonBody": { "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJhdWQiOiJhcGk6Ly9mZDNmNzUzYi1lZWQzLTQ2MmMtYjZhNy1hNGI1YmI2NTBhYWQiLCJleHAiOjE3NDQ3MTYwNTEsImlhdCI6MTc0NDcxMjQ1MSwiaXNzIjoiaHR0cHM6Ly9sb2dpbi5taWNyb3NvZnRvbmxpbmUuY29tL2ZhMTVkNjkyLWU5YzctNDQ2MC1hNzQzLTI5ZjI5NTIyMjI5LyIsImp0aSI6Ijg3MTMzNzcwMDk0MTZmYmFhNDM0MmFkMjMxZGUwMDBkIiwic3ViIjoiNzcyMTNFMzAtRThDQi00NTk1LUIxQjYtNUYwNTBFODMwOEZEIn0.5mAlEPkzHLR7YbllpKgk-8ZEd88XfzA15DUK8u1rLWs" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/successful_flow_basic.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/oauth2/token.*", "queryParameters": { "api-version": { "equalTo": "2018-02-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } }, "method": "GET", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 200, "jsonBody": { "access_token": "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk0ZGI4N2NiMjdmNjdjZDA1Zjk5OTlkZjMwNjg1NmQ4In0.eyJhdWQiOiJhcGkxIiwiaXNzIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvZmExNWQ2OTItZTljNy00NDYwLWE3NDMtMjlmMjk1MjIyMjkvIiwic3ViIjoiNzcyMTNFMzAtRThDQi00NTk1LUIxQjYtNUYwNTBFODMwOEZEIiwiZXhwIjoxNzQ0NzE2MDUxLCJpYXQiOjE3NDQ3MTI0NTEsImp0aSI6Ijg3MTMzNzcwMDk0MTZmYmFhNDM0MmFkMjMxZGUwMDBkIn0.C5jTYoybRs5YF5GvPgoDq4WK5U9-gDzh_N3IPaqEBI0IifdYSWpKQ72v3UISnVpp7Fc46C-ZC8kijUGe3IU9zA" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/successful_flow_v2_issuer.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/oauth2/token.*", "queryParameters": { "api-version": { "equalTo": "2018-02-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } }, "method": "GET", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 200, "jsonBody": { "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJhdWQiOiJhcGk6Ly9mZDNmNzUzYi1lZWQzLTQ2MmMtYjZhNy1hNGI1YmI2NTBhYWQiLCJleHAiOjE3NDQ3MTYwNTEsImlhdCI6MTc0NDcxMjQ1MSwiaXNzIjoiaHR0cHM6Ly9sb2dpbi5taWNyb3NvZnRvbmxpbmUuY29tL2ZhMTVkNjkyLWU5YzctNDQ2MC1hNzQzLTI5ZjI5NTIyMjI5LyIsImp0aSI6Ijg3MTMzNzcwMDk0MTZmYmFhNDM0MmFkMjMxZGUwMDBkIiwic3ViIjoiNzcyMTNFMzAtRThDQi00NTk1LUIxQjYtNUYwNTBFODMwOEZEIn0.5mAlEPkzHLR7YbllpKgk-8ZEd88XfzA15DUK8u1rLWs" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/azure/unparsable_token.json ================================================ { "mappings": [ { "request": { "urlPattern": "/metadata/identity/oauth2/token.*", "queryParameters": { "api-version": { "equalTo": "2018-02-01" }, "resource": { "equalTo": "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" } }, "method": "GET", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 200, "jsonBody": { "access_token": "unparsable.token" } } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/gcp/http_error.json ================================================ { "mappings": [ { "request": { "urlPattern": "/computeMetadata/v1/instance/service-accounts/default/identity.*", "queryParameters": { "audience": { "equalTo": "snowflakecomputing.com" } }, "method": "GET", "headers": { "Metadata-Flavor": { "equalTo": "Google" } } }, "response": { "status": 400 } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/gcp/missing_issuer_claim.json ================================================ { "mappings": [ { "request": { "urlPattern": "/computeMetadata/v1/instance/service-accounts/default/identity.*", "queryParameters": { "audience": { "equalTo": "snowflakecomputing.com" } }, "method": "GET", "headers": { "Metadata-Flavor": { "equalTo": "Google" } } }, "response": { "status": 200, "body": "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImU2M2I5NzA1OTRiY2NmZTAxMDlkOTg4OWM2MDk3OWEwIn0.eyJzdWIiOiJzb21lLXN1YmplY3QiLCJpYXQiOjE3NDM3NjEyMTMsImV4cCI6MTc0Mzc2NDgxMywiYXVkIjoid3d3LmV4YW1wbGUuY29tIn0.H6sN6kjA82EuijFcv-yCJTqau5qvVTCsk0ZQ4gvFQMkB7c71XPs4lkwTa7ZlNNlx9e6TpN1CVGnpCIRDDAZaDw" } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/gcp/missing_sub_claim.json ================================================ { "mappings": [ { "request": { "urlPattern": "/computeMetadata/v1/instance/service-accounts/default/identity.*", "queryParameters": { "audience": { "equalTo": "snowflakecomputing.com" } }, "method": "GET", "headers": { "Metadata-Flavor": { "equalTo": "Google" } } }, "response": { "status": 200, "body": "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImU2M2I5NzA1OTRiY2NmZTAxMDlkOTg4OWM2MDk3OWEwIn0.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJpYXQiOjE3NDM3NjEyMTMsImV4cCI6MTc0Mzc2NDgxMywiYXVkIjoid3d3LmV4YW1wbGUuY29tIn0.w0njdpfWFETVK8Ktq9GdvuKRQJjvhOplcSyvQ_zHHwBUSMapqO1bjEWBx5VhGkdECZIGS1VY7db_IOqT45yOMA" } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/gcp/successful_flow.json ================================================ { "mappings": [ { "request": { "urlPattern": "/computeMetadata/v1/instance/service-accounts/default/identity.*", "queryParameters": { "audience": { "equalTo": "snowflakecomputing.com" } }, "method": "GET", "headers": { "Metadata-Flavor": { "equalTo": "Google" } } }, "response": { "status": 200, "body": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJpYXQiOjE3NDM2OTIwMTcsImV4cCI6MTc3NTIyODAxNCwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoic29tZS1zdWJqZWN0In0.k7018udXQjw-sgVY8sTLTnNrnJoGwVpjE6HozZN-h0w" } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/gcp/successful_impersionation_flow.json ================================================ { "mappings": [ { "request": { "urlPattern": "/computeMetadata/v1/instance/service-accounts/default/token", "method": "GET", "headers": { "Metadata-Flavor": { "equalTo": "Google" } } }, "response": { "status": 200, "jsonBody": {"access_token":"randomToken123","expires_in":3599,"token_type":"Bearer"} } }, { "request": { "urlPattern": "/v1/projects/-/serviceAccounts/targetServiceAccount:generateIdToken", "method": "POST", "bodyPatterns": [ { "matchesJsonPath": { "expression": "$.delegates", "equalToJson": "[\"projects/-/serviceAccounts/delegate1\", \"projects/-/serviceAccounts/delegate2\"]" } }, { "matchesJsonPath": { "expression": "$.audience", "equalTo": "snowflakecomputing.com" } } ] }, "response": { "status": 200, "jsonBody": {"token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJpYXQiOjE3NDM2OTIwMTcsImV4cCI6MTc3NTIyODAxNCwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoic29tZS1pbXBlcnNvbmF0ZWQtc3ViamVjdCJ9.5KC0hjxwAheysO-hWCgjBGPUe143-xjytC72epRG8Ks"} } } ] } ================================================ FILE: test_data/wiremock/mappings/auth/wif/gcp/unparsable_token.json ================================================ { "mappings": [ { "request": { "urlPattern": "/computeMetadata/v1/instance/service-accounts/default/identity.*", "queryParameters": { "audience": { "equalTo": "snowflakecomputing.com" } }, "method": "GET", "headers": { "Metadata-Flavor": { "equalTo": "Google" } } }, "response": { "status": 200, "body": "unparsable.token" } } ] } ================================================ FILE: test_data/wiremock/mappings/close_session.json ================================================ { "mappings": [ { "scenarioName": "Successful close session", "request": { "urlPathPattern": "/session", "method": "POST", "queryParameters": { "delete": { "equalTo": "true" } } }, "response": { "status": 200, "jsonBody": { "code": null, "data": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/hang.json ================================================ { "mappings": [ { "request": { "url": "/hang" }, "response": { "status": 200, "fixedDelayMilliseconds": 2000 } } ] } ================================================ FILE: test_data/wiremock/mappings/minicore/auth/disabled_flow.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "matchesJsonPath": "$.data.CLIENT_ENVIRONMENT[?(@.CORE_LOAD_ERROR =~ /.*disabled at compile time.*/)]" }, { "matchesJsonPath": { "expression": "$.data.CLIENT_ENVIRONMENT.CORE_VERSION", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/minicore/auth/successful_flow.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword", "CLIENT_ENVIRONMENT": { "CORE_VERSION": "0.0.1", "CGO_ENABLED": true, "LINKING_MODE": "unknown" } } }, "ignoreExtraElements" : true }, { "matchesJsonPath": "$.data.CLIENT_ENVIRONMENT[?(@.CORE_FILE_NAME =~ /.+/)]" }, { "matchesJsonPath": { "expression": "$.data.CLIENT_ENVIRONMENT.CORE_LOAD_ERROR", "absent": "(absent)" } } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/minicore/auth/successful_flow_linux.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson" : { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword", "CLIENT_ENVIRONMENT": { "CORE_VERSION": "0.0.1", "CGO_ENABLED": true, "LINKING_MODE": "dynamic" } } }, "ignoreExtraElements" : true }, { "matchesJsonPath": "$.data.CLIENT_ENVIRONMENT[?(@.CORE_FILE_NAME =~ /.+/)]" }, { "matchesJsonPath": { "expression": "$.data.CLIENT_ENVIRONMENT.CORE_LOAD_ERROR", "absent": "(absent)" } }, { "matchesJsonPath": "$.data.CLIENT_ENVIRONMENT[?(@.LIBC_FAMILY =~ /^(glibc|musl)$/)]" }, { "matchesJsonPath": "$.data.CLIENT_ENVIRONMENT[?(@.LIBC_VERSION =~ /\\d+\\.\\d+.*/)]" } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": null, "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/ocsp/auth_failure.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST" }, "response": { "status": 401, "jsonBody": { "data": null, "code": "390100", "message": "Authentication failed for OCSP test", "success": false } } } ] } ================================================ FILE: test_data/wiremock/mappings/ocsp/malformed.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/" }, "response": { "status": 200, "base64Body": "AQID" } } ] } ================================================ FILE: test_data/wiremock/mappings/ocsp/unauthorized.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/", "method": "POST" }, "response": { "status": 200, "base64Body": "MAMKAQY=" } } ] } ================================================ FILE: test_data/wiremock/mappings/platform_detection/aws_ec2_instance_success.json ================================================ { "mappings": [ { "request": { "method": "PUT", "urlPath": "/latest/api/token" }, "response": { "status": 200, "body": "AQAEAEV4aW1hbGVUb2tlbg==", "headers": { "Content-Type": "text/plain" } } }, { "request": { "method": "GET", "urlPath": "/latest/meta-data/iam/security-credentials/" }, "response": { "status": 200, "body": "test-role", "headers": { "Content-Type": "text/plain" } } }, { "request": { "method": "GET", "urlPath": "/latest/meta-data/iam/security-credentials/test-role" }, "response": { "status": 200, "jsonBody": { "Code": "Success", "LastUpdated": "2023-01-01T00:00:00Z", "Type": "AWS-HMAC", "AccessKeyId": "AKIAIOSFODNN7EXAMPLE", "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "Token": "AQoDYXdzEJr...", "Expiration": "2030-01-01T06:00:00Z" }, "headers": { "Content-Type": "application/json" } } }, { "request": { "method": "GET", "urlPath": "/latest/dynamic/instance-identity/document" }, "response": { "status": 200, "jsonBody": { "instanceId": "i-1234567890abcdef0", "imageId": "ami-12345678", "availabilityZone": "us-east-1a", "instanceType": "t2.micro", "accountId": "123456789012", "architecture": "x86_64", "kernelId": null, "ramdiskId": null, "region": "us-east-1", "version": "2017-09-30", "privateIp": "10.0.0.1", "billingProducts": null, "marketplaceProductCodes": null, "pendingTime": "2023-01-01T00:00:00Z", "devpayProductCodes": null }, "headers": { "Content-Type": "application/json" } } } ] } ================================================ FILE: test_data/wiremock/mappings/platform_detection/aws_identity_success.json ================================================ { "mappings": [ { "request": { "method": "PUT", "urlPath": "/latest/api/token" }, "response": { "status": 200, "body": "AQAEAEV4aW1hbGVUb2tlbg==", "headers": { "Content-Type": "text/plain" } } }, { "request": { "method": "GET", "urlPath": "/latest/meta-data/iam/security-credentials/" }, "response": { "status": 200, "body": "test-role", "headers": { "Content-Type": "text/plain" } } }, { "request": { "method": "GET", "urlPath": "/latest/meta-data/iam/security-credentials/test-role" }, "response": { "status": 200, "jsonBody": { "Code": "Success", "LastUpdated": "2023-01-01T00:00:00Z", "Type": "AWS-HMAC", "AccessKeyId": "AKIAIOSFODNN7EXAMPLE", "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "Token": "AQoDYXdzEJr...", "Expiration": "2030-01-01T06:00:00Z" }, "headers": { "Content-Type": "application/json" } } }, { "request": { "method": "POST", "urlPath": "/", "bodyPatterns": [ { "contains": "Action=GetCallerIdentity" } ] }, "response": { "status": 200, "body": "\n\n \n arn:aws:iam::123456789012:user/test-user\n AIDACKCEVSQ6C2EXAMPLE\n 123456789012\n \n \n 01234567-89ab-cdef-0123-456789abcdef\n \n", "headers": { "Content-Type": "text/xml" } } } ] } ================================================ FILE: test_data/wiremock/mappings/platform_detection/azure_managed_identity_success.json ================================================ { "mappings": [ { "request": { "method": "GET", "urlPattern": "/metadata/identity/oauth2/token\\?.*", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 200, "headers": { "Content-Type": "application/json" }, "jsonBody": { "access_token": "test-token", "token_type": "Bearer", "expires_in": 3600 } } } ] } ================================================ FILE: test_data/wiremock/mappings/platform_detection/azure_vm_success.json ================================================ { "mappings": [ { "request": { "method": "GET", "url": "/metadata/instance?api-version=2019-03-11", "headers": { "Metadata": { "equalTo": "true" } } }, "response": { "status": 200, "headers": { "Content-Type": "application/json" }, "jsonBody": { "compute": { "vmId": "test-vm-id", "name": "test-vm" } } } } ] } ================================================ FILE: test_data/wiremock/mappings/platform_detection/gce_identity_success.json ================================================ { "mappings": [ { "request": { "method": "GET", "url": "/computeMetadata/v1/instance/service-accounts/default/email", "headers": { "Metadata-Flavor": { "equalTo": "Google" } } }, "response": { "status": 200, "headers": { "Content-Type": "text/plain" }, "body": "test-service-account@test-project.iam.gserviceaccount.com" } } ] } ================================================ FILE: test_data/wiremock/mappings/platform_detection/gce_vm_success.json ================================================ { "mappings": [ { "request": { "method": "GET", "url": "/" }, "response": { "status": 200, "headers": { "Metadata-Flavor": "Google", "Content-Type": "text/plain" }, "body": "v1/" } } ] } ================================================ FILE: test_data/wiremock/mappings/platform_detection/timeout_response.json ================================================ { "mappings": [ { "request": { "urlPattern": ".*" }, "response": { "status": 200, "fixedDelayMilliseconds": 1000, "body": "timeout" } } ] } ================================================ FILE: test_data/wiremock/mappings/query/long_running_query.json ================================================ { "mappings": [ { "request": { "urlPathPattern": "/queries/v1/query-request.*", "method": "POST", "headers": { "Authorization": { "matches": ".*" } } }, "response": { "status": 200, "jsonBody": { "code": "333334", "data": { "getResultUrl": "/queries/01bfd516-0009-ae23-0000-4c390101d1aa/result", "progressDesc": null, "queryAbortsAfterSecs": 300, "queryId": "01bfd516-0009-ae23-0000-4c390101d1aa" }, "message": "Asynchronous execution in progress. Use provided query id to perform query monitoring and management.", "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/query/query_by_id_timeout.json ================================================ { "mappings": [ { "scenarioName": "Query status monitoring - RUNNING", "request": { "urlPathPattern": "/queries.*", "method": "GET", "headers": { "Authorization": { "matches": ".*" } } }, "response": { "status": 200, "fixedDelayMilliseconds": 3000 } } ] } ================================================ FILE: test_data/wiremock/mappings/query/query_execution.json ================================================ { "mappings": [ { "scenarioName": "SQL Query execution for fetchResultByQueryID", "request": { "urlPathPattern": "/queries/v1/query-request.*", "method": "POST", "headers": { "Authorization": { "matches": ".*" } } }, "response": { "status": 200, "jsonBody": { "success": true, "data": { "queryId": "mock-query-id-12345", "resultSetMetaData": { "columnCount": 2, "columns": [ {"name": "MS", "type": "number"}, {"name": "SUM(C1)", "type": "number"} ] }, "rowType": [ {"name": "MS", "type": "FIXED", "length": 10, "precision": 38, "scale": 0}, {"name": "SUM(C1)", "type": "FIXED", "length": 10, "precision": 38, "scale": 0} ], "rowset": [["1", "5050"], ["2", "5100"]], "total": 2, "queryResultFormat": "json" } } } }, { "scenarioName": "Query result fetching", "request": { "urlPathPattern": "/queries/.*/result.*", "method": "GET", "headers": { "Authorization": { "matches": ".*" } } }, "response": { "status": 200, "jsonBody": { "success": true, "data": { "queryId": "mock-query-id-12345", "resultSetMetaData": { "columnCount": 2, "columns": [ {"name": "MS", "type": "number"}, {"name": "SUM(C1)", "type": "number"} ] }, "rowType": [ {"name": "MS", "type": "FIXED", "length": 10, "precision": 38, "scale": 0}, {"name": "SUM(C1)", "type": "FIXED", "length": 10, "precision": 38, "scale": 0} ], "rowset": [["1", "5050"], ["2", "5100"]], "total": 2, "queryResultFormat": "json" } } } } ] } ================================================ FILE: test_data/wiremock/mappings/query/query_monitoring.json ================================================ { "mappings": [ { "scenarioName": "Query status monitoring - SUCCESS", "request": { "urlPathPattern": "/monitoring/queries.*", "method": "GET", "headers": { "Authorization": { "matches": ".*" } } }, "response": { "status": 200, "jsonBody": { "success": true, "data": { "queries": [ { "id": "mock-query-id-12345", "status": "SUCCESS", "errorCode": "", "errorMessage": "" } ] } } } } ] } ================================================ FILE: test_data/wiremock/mappings/query/query_monitoring_error.json ================================================ { "mappings": [ { "scenarioName": "Query status monitoring - FAILED_WITH_ERROR", "request": { "urlPathPattern": "/monitoring/queries.*", "method": "GET", "headers": { "Authorization": { "matches": ".*" } } }, "response": { "status": 200, "jsonBody": { "success": true, "data": { "queries": [ { "id": "mock-query-id-12345", "status": "FAILED_WITH_ERROR", "errorCode": "", "errorMessage": "" } ] }, "code": null, "message": null } } } ] } ================================================ FILE: test_data/wiremock/mappings/query/query_monitoring_malformed.json ================================================ { "mappings": [ { "scenarioName": "Query status monitoring - Malformed JSON", "request": { "urlPathPattern": "/monitoring/queries.*", "method": "GET", "headers": { "Authorization": { "matches": ".*" } } }, "response": { "status": 200, "body": "{\"malformedJson\"}", "headers": { "Content-Type": "application/json" } } } ] } ================================================ FILE: test_data/wiremock/mappings/query/query_monitoring_running.json ================================================ { "mappings": [ { "scenarioName": "Query status monitoring - RUNNING", "request": { "urlPathPattern": "/monitoring/queries.*", "method": "GET", "headers": { "Authorization": { "matches": ".*" } } }, "response": { "status": 200, "jsonBody": { "success": true, "data": { "queries": [ { "id": "mock-query-id-12345", "status": "RUNNING", "state": "FILE_SET_INITIALIZATION", "errorCode": "", "errorMessage": null } ] }, "code": null, "message": null } } } ] } ================================================ FILE: test_data/wiremock/mappings/retry/redirection_retry_workflow.json ================================================ { "mappings": [ { "scenarioName": "wiremock retry strategy", "requiredScenarioState": "Started", "newScenarioState": "Successful login", "request": { "urlPathPattern": "/session/v1/login-request.*", "method": "POST", "bodyPatterns": [ { "equalToJson": { "data": { "LOGIN_NAME": "testUser", "PASSWORD": "testPassword" } }, "ignoreExtraElements": true } ] }, "response": { "status": 200, "jsonBody": { "data": { "masterToken": "master token", "token": "session token", "validityInSeconds": 3600, "masterValidityInSeconds": 14400, "displayUserName": "TEST_USER", "serverVersion": "8.48.0 b2024121104444034239f05", "firstLogin": false, "remMeToken": null, "remMeValidityInSeconds": 0, "healthCheckInterval": 45, "newClientForUpgrade": "3.12.3", "sessionId": 1172562260498, "parameters": [ { "name": "CLIENT_PREFETCH_THREADS", "value": 4 } ], "sessionInfo": { "databaseName": "TEST_DB", "schemaName": "TEST_GO", "warehouseName": "TEST_XSMALL", "roleName": "ANALYST" }, "idToken": null, "idTokenValidityInSeconds": 0, "responseData": null, "mfaToken": "mfa-token", "mfaTokenValidityInSeconds": 0 }, "code": null, "message": null, "success": true }, "fixedDelayMilliseconds": 2000 } }, { "scenarioName": "wiremock retry strategy", "requiredScenarioState": "Successful login", "newScenarioState": "Query attempt with HTTP 3xx response", "request": { "urlPathPattern": "/queries/v1/query-request.*", "method": "POST" }, "response": { "status": 307, "headers": { "Location": "/temp-redirect-1" } } }, { "scenarioName": "wiremock retry strategy", "requiredScenarioState": "Query attempt with HTTP 3xx response", "newScenarioState": "3xx redirect followed and times out", "request": { "urlPathPattern": "/temp-redirect-1", "method": "POST" }, "response": { "fixedDelayMilliseconds": 5000 } }, { "scenarioName": "wiremock retry strategy", "requiredScenarioState": "3xx redirect followed and times out", "newScenarioState": "Retry attempt successful", "request": { "urlPathPattern": "/queries/v1/query-request.*", "method": "POST" }, "response": { "status": 200, "headers": { "date": "Fri, 31 Oct 2025 06:26:51 GMT", "cache-control": "no-cache, no-store", "content-type": "application/json", "vary": "Accept-Encoding, User-Agent", "server": "SF-LB", "x-envoy-upstream-service-time": "72", "x-content-type-options": "nosniff", "x-xss-protection": "1; mode=block", "expect-ct": "enforce, max-age=3600", "strict-transport-security": "max-age=31536000", "x-snowflake-fe-instance": "-", "x-snowflake-fe-config": "v20251022.0.0-4d0dc170.1761148450.prod1.1761891997993", "x-frame-options": "deny", "x-envoy-attempt-count": "1", "transfer-encoding": "chunked" }, "jsonBody": { "data": { "parameters": [ { "name": "TIMESTAMP_OUTPUT_FORMAT", "value": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM" }, { "name": "CLIENT_PREFETCH_THREADS", "value": 4 }, { "name": "JS_TREAT_INTEGER_AS_BIGINT", "value": false }, { "name": "TIME_OUTPUT_FORMAT", "value": "HH24:MI:SS" }, { "name": "CLIENT_RESULT_CHUNK_SIZE", "value": 160 }, { "name": "TIMESTAMP_TZ_OUTPUT_FORMAT", "value": "" }, { "name": "CLIENT_SESSION_KEEP_ALIVE", "value": false }, { "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", "value": false }, { "name": "CLIENT_METADATA_USE_SESSION_DATABASE", "value": false }, { "name": "QUERY_CONTEXT_CACHE_SIZE", "value": 5 }, { "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", "value": true }, { "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", "value": "YYYY-MM-DD HH24:MI:SS.FF3" }, { "name": "CLIENT_RESULT_PREFETCH_THREADS", "value": 1 }, { "name": "CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX", "value": false }, { "name": "CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ", "value": true }, { "name": "CLIENT_MEMORY_LIMIT", "value": 1536 }, { "name": "CLIENT_TIMESTAMP_TYPE_MAPPING", "value": "TIMESTAMP_NTZ" }, { "name": "TIMEZONE", "value": "America/Los_Angeles" }, { "name": "CLIENT_RESULT_PREFETCH_SLOTS", "value": 2 }, { "name": "CLIENT_TELEMETRY_ENABLED", "value": true }, { "name": "CLIENT_DISABLE_INCIDENTS", "value": true }, { "name": "CLIENT_USE_V1_QUERY_API", "value": true }, { "name": "CLIENT_RESULT_COLUMN_CASE_INSENSITIVE", "value": false }, { "name": "BINARY_OUTPUT_FORMAT", "value": "HEX" }, { "name": "CSV_TIMESTAMP_FORMAT", "value": "" }, { "name": "CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS", "value": false }, { "name": "CLIENT_TELEMETRY_SESSIONLESS_ENABLED", "value": true }, { "name": "JS_DRIVER_DISABLE_OCSP_FOR_NON_SF_ENDPOINTS", "value": false }, { "name": "DATE_OUTPUT_FORMAT", "value": "YYYY-MM-DD" }, { "name": "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD", "value": 65280 }, { "name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", "value": 3600 }, { "name": "AUTOCOMMIT", "value": true }, { "name": "CLIENT_SESSION_CLONE", "value": false }, { "name": "TIMESTAMP_LTZ_OUTPUT_FORMAT", "value": "" } ], "rowtype": [ { "name": "1", "database": "", "schema": "", "table": "", "scale": 0, "nullable": false, "byteLength": null, "precision": 1, "length": null, "type": "fixed", "collation": null } ], "rowset": [ [ "1" ] ], "total": 1, "returned": 1, "queryId": "01c01270-0e12-4b04-0000-53b10b9c95be", "databaseProvider": null, "finalDatabaseName": "WIREMOCKTESTDB", "finalSchemaName": "TESTSCHEMA", "finalWarehouseName": "WIREMOCK_WH", "finalRoleName": "SYSADMIN", "numberOfBinds": 0, "arrayBindSupported": false, "statementTypeId": 4096, "version": 1, "sendResultTime": 1761890916147, "queryResultFormat": "json", "queryContext": { "entries": [ { "id": 0, "timestamp": 1761890916132138, "priority": 0, "context": "CJLYpAI=" } ] } }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/select1.json ================================================ { "mappings": [ { "scenarioName": "Successful SELECT 1 flow", "request": { "urlPathPattern": "/queries/v1/query-request.*", "method": "POST", "headers": { "Authorization": { "equalTo": "Snowflake Token=\"session token\"" } } }, "response": { "status": 200, "jsonBody": { "data": { "parameters": [ { "name": "TIMESTAMP_OUTPUT_FORMAT", "value": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM" }, { "name": "CLIENT_PREFETCH_THREADS", "value": 4 }, { "name": "TIME_OUTPUT_FORMAT", "value": "HH24:MI:SS" }, { "name": "CLIENT_RESULT_CHUNK_SIZE", "value": 16 }, { "name": "TIMESTAMP_TZ_OUTPUT_FORMAT", "value": "" }, { "name": "CLIENT_SESSION_KEEP_ALIVE", "value": false }, { "name": "QUERY_CONTEXT_CACHE_SIZE", "value": 5 }, { "name": "CLIENT_METADATA_USE_SESSION_DATABASE", "value": false }, { "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", "value": false }, { "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", "value": true }, { "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", "value": "YYYY-MM-DD HH24:MI:SS.FF3" }, { "name": "CLIENT_RESULT_PREFETCH_THREADS", "value": 1 }, { "name": "CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX", "value": false }, { "name": "CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ", "value": true }, { "name": "CLIENT_MEMORY_LIMIT", "value": 1536 }, { "name": "CLIENT_TIMESTAMP_TYPE_MAPPING", "value": "TIMESTAMP_LTZ" }, { "name": "TIMEZONE", "value": "America/Los_Angeles" }, { "name": "SERVICE_NAME", "value": "" }, { "name": "CLIENT_RESULT_PREFETCH_SLOTS", "value": 2 }, { "name": "CLIENT_TELEMETRY_ENABLED", "value": true }, { "name": "CLIENT_DISABLE_INCIDENTS", "value": true }, { "name": "CLIENT_USE_V1_QUERY_API", "value": true }, { "name": "CLIENT_RESULT_COLUMN_CASE_INSENSITIVE", "value": false }, { "name": "CSV_TIMESTAMP_FORMAT", "value": "" }, { "name": "BINARY_OUTPUT_FORMAT", "value": "HEX" }, { "name": "CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS", "value": false }, { "name": "CLIENT_TELEMETRY_SESSIONLESS_ENABLED", "value": true }, { "name": "DATE_OUTPUT_FORMAT", "value": "YYYY-MM-DD" }, { "name": "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD", "value": 65280 }, { "name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", "value": 3600 }, { "name": "CLIENT_SESSION_CLONE", "value": false }, { "name": "AUTOCOMMIT", "value": true }, { "name": "TIMESTAMP_LTZ_OUTPUT_FORMAT", "value": "" } ], "rowtype": [ { "name": "1", "database": "", "schema": "", "table": "", "nullable": false, "length": null, "type": "fixed", "scale": 0, "precision": 1, "byteLength": null, "collation": null } ], "rowset": [ [ "1" ] ], "total": 1, "returned": 1, "queryId": "01ba13b4-0104-e9fd-0000-0111029ca00e", "databaseProvider": null, "finalDatabaseName": null, "finalSchemaName": null, "finalWarehouseName": "TEST_XSMALL", "numberOfBinds": 0, "arrayBindSupported": false, "statementTypeId": 4096, "version": 1, "sendResultTime": 1738317395581, "queryResultFormat": "json", "queryContext": { "entries": [ { "id": 0, "timestamp": 1738317395574564, "priority": 0, "context": "CPbPTg==" } ] } }, "code": null, "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/telemetry/custom_telemetry.json ================================================ { "mappings": [ { "scenarioName": "Successful telemetry flow", "request": { "urlPathPattern": "/telemetry/send", "method": "POST", "bodyPatterns": [ { "equalToJson": { "logs": { "message": { "test_key": "test_value" } } }, "ignoreExtraElements": true } ] }, "response": { "status": 200, "jsonBody": { "code": null, "data": "Log Received", "message": null, "success": true } } } ] } ================================================ FILE: test_data/wiremock/mappings/telemetry/telemetry.json ================================================ { "mappings": [ { "scenarioName": "Successful telemetry flow", "request": { "urlPathPattern": "/telemetry/send", "method": "POST" }, "response": { "status": 200, "jsonBody": { "code": null, "data": "Log Received", "message": null, "success": true } } } ] } ================================================ FILE: test_utils_test.go ================================================ package gosnowflake import ( "net/http" "os" "runtime" "strings" "sync" "testing" "time" ) type countingRoundTripper struct { delegate http.RoundTripper getReqCount map[string]int postReqCount map[string]int mu sync.Mutex } func newCountingRoundTripper(delegate http.RoundTripper) *countingRoundTripper { return &countingRoundTripper{ delegate: delegate, getReqCount: make(map[string]int), postReqCount: make(map[string]int), } } func (crt *countingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { crt.mu.Lock() switch req.Method { case http.MethodGet: crt.getReqCount[req.URL.String()]++ case http.MethodPost: crt.postReqCount[req.URL.String()]++ } crt.mu.Unlock() return crt.delegate.RoundTrip(req) } func (crt *countingRoundTripper) reset() { crt.getReqCount = make(map[string]int) crt.postReqCount = make(map[string]int) } func (crt *countingRoundTripper) totalRequestsByPath(urlPath string) int { total := 0 for url, reqs := range crt.getReqCount { if strings.Contains(url, urlPath) { total += reqs } } for url, reqs := range crt.postReqCount { if strings.Contains(url, urlPath) { total += reqs } } return total } func (crt *countingRoundTripper) totalRequests() int { total := 0 for _, reqs := range crt.getReqCount { total += reqs } for _, reqs := range crt.postReqCount { total += reqs } return total } type blockingRoundTripper struct { delegate http.RoundTripper defaultBlockTime time.Duration pathBlockTime map[string]time.Duration } func newBlockingRoundTripper(delegate http.RoundTripper, defaultBlockTime time.Duration) *blockingRoundTripper { return &blockingRoundTripper{ delegate: delegate, defaultBlockTime: defaultBlockTime, pathBlockTime: make(map[string]time.Duration), } } func (brt *blockingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { if blockTime, exists := brt.pathBlockTime[req.URL.Path]; exists { time.Sleep(blockTime) } else if brt.defaultBlockTime != 0 { time.Sleep(brt.defaultBlockTime) } return brt.delegate.RoundTrip(req) } func (brt *blockingRoundTripper) setPathBlockTime(path string, blockTime time.Duration) { brt.pathBlockTime[path] = blockTime } func (brt *blockingRoundTripper) reset() { brt.pathBlockTime = make(map[string]time.Duration) } func skipOnMissingHome(t *testing.T) { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && os.Getenv("HOME") == "" { t.Skip("skipping on missing HOME environment variable") } } ================================================ FILE: tls_config.go ================================================ package gosnowflake import ( "crypto/tls" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" ) // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tlsConfigName=value. func RegisterTLSConfig(key string, cfg *tls.Config) error { return sfconfig.RegisterTLSConfig(key, cfg) } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) error { return sfconfig.DeregisterTLSConfig(key) } ================================================ FILE: tls_config_test.go ================================================ package gosnowflake import ( "context" "database/sql" "testing" ) // TODO move this test to config package when we have wiremock support in an internal package func TestShouldSetUpTlsConfig(t *testing.T) { tlsConfig := wiremockHTTPS.tlsConfig(t) err := RegisterTLSConfig("wiremock", tlsConfig) assertNilF(t, err) wiremockHTTPS.registerMappings(t, newWiremockMapping("auth/password/successful_flow.json")) for _, dbFunc := range []func() *sql.DB{ func() *sql.DB { cfg := wiremockHTTPS.connectionConfig(t) cfg.TLSConfigName = "wiremock" cfg.Transporter = nil return sql.OpenDB(NewConnector(SnowflakeDriver{}, *cfg)) }, func() *sql.DB { cfg := wiremockHTTPS.connectionConfig(t) cfg.TLSConfigName = "wiremock" cfg.Transporter = nil dsn, err := DSN(cfg) assertNilF(t, err) db, err := sql.Open("snowflake", dsn) assertNilF(t, err) return db }, } { t.Run("", func(t *testing.T) { db := dbFunc() defer db.Close() // mock connection, no need to close _, err := db.Conn(context.Background()) assertNilF(t, err) }) } } ================================================ FILE: transaction.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "errors" ) type snowflakeTx struct { sc *snowflakeConn ctx context.Context } type txCommand int const ( commit txCommand = iota rollback ) func (cmd txCommand) string() (string, error) { switch cmd { case commit: return "COMMIT", nil case rollback: return "ROLLBACK", nil } return "", errors.New("unsupported transaction command") } func (tx *snowflakeTx) Commit() error { return tx.execTxCommand(commit) } func (tx *snowflakeTx) Rollback() error { return tx.execTxCommand(rollback) } func (tx *snowflakeTx) execTxCommand(command txCommand) (err error) { txStr, err := command.string() if err != nil { return } if tx.sc == nil || tx.sc.rest == nil { return driver.ErrBadConn } isInternal := isInternal(tx.ctx) _, err = tx.sc.exec(tx.ctx, txStr, false /* noResult */, isInternal, false /* describeOnly */, nil) if err != nil { return } tx.sc = nil return } ================================================ FILE: transaction_test.go ================================================ package gosnowflake import ( "context" "database/sql" "errors" "fmt" errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors" "testing" "time" ) func TestTransactionOptions(t *testing.T) { var tx *sql.Tx var err error runDBTest(t, func(dbt *DBTest) { tx, err = dbt.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { t.Fatal("failed to start transaction.") } if err = tx.Rollback(); err != nil { t.Fatal("failed to rollback") } if _, err = dbt.conn.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}); err == nil { t.Fatal("should have failed.") } if driverErr, ok := err.(*SnowflakeError); !ok || driverErr.Number != ErrNoReadOnlyTransaction { t.Fatalf("should have returned Snowflake Error: %v", errors2.ErrMsgNoReadOnlyTransaction) } if _, err = dbt.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: 100}); err == nil { t.Fatal("should have failed.") } if driverErr, ok := err.(*SnowflakeError); !ok || driverErr.Number != ErrNoDefaultTransactionIsolationLevel { t.Fatalf("should have returned Snowflake Error: %v", errors2.ErrMsgNoDefaultTransactionIsolationLevel) } }) } // SNOW-823072: Test that transaction uses the context object supplied by BeginTx(), not from the parent connection func TestTransactionContext(t *testing.T) { var tx *sql.Tx var err error ctx := context.Background() runDBTest(t, func(dbt *DBTest) { pingWithRetry := withRetry(PingFunc, 5, 3*time.Second) err = pingWithRetry(context.Background(), dbt.conn) if err != nil { t.Fatal(err) } tx, err = dbt.conn.BeginTx(ctx, nil) if err != nil { t.Fatal(err) } _, err = tx.ExecContext(ctx, "SELECT SYSTEM$WAIT(10, 'SECONDS')") if err != nil { t.Fatal(err) } err = tx.Commit() if err != nil { t.Fatal(err) } }) } func PingFunc(ctx context.Context, conn *sql.Conn) error { return conn.PingContext(ctx) } // Helper function for SNOW-823072 repro func withRetry(fn func(context.Context, *sql.Conn) error, numAttempts int, timeout time.Duration) func(context.Context, *sql.Conn) error { return func(ctx context.Context, db *sql.Conn) error { for currAttempt := 1; currAttempt <= numAttempts; currAttempt++ { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() err := fn(ctx, db) if err != nil { if errors.Is(err, context.DeadlineExceeded) { continue } return err } return nil } return fmt.Errorf("context deadline exceeded, failed after [%d] attempts", numAttempts) } } func TestTransactionError(t *testing.T) { sr := &snowflakeRestful{ FuncPostQuery: postQueryFail, } tx := snowflakeTx{ sc: &snowflakeConn{ cfg: &Config{}, rest: sr, }, ctx: context.Background(), } // test for post query error when executing the txCommand err := tx.execTxCommand(rollback) assertNotNilF(t, err, "") assertEqualE(t, err.Error(), "failed to get query response") // test for invalid txCommand err = tx.execTxCommand(2) assertNotNilF(t, err, "") assertEqualE(t, err.Error(), "unsupported transaction command") // test for bad connection error when snowflakeConn is nil tx.sc = nil err = tx.execTxCommand(rollback) assertNotNilF(t, err, "") assertEqualE(t, err.Error(), "driver: bad connection") } ================================================ FILE: transport.go ================================================ package gosnowflake import ( "cmp" "crypto/tls" "crypto/x509" "errors" "fmt" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" "net" "net/http" "net/url" "strconv" "time" "golang.org/x/net/http/httpproxy" ) type transportConfigs interface { forTransportType(transportType transportType) *transportConfig } type transportType int const ( transportTypeOAuth transportType = iota transportTypeCloudProvider transportTypeOCSP transportTypeCRL transportTypeSnowflake transportTypeWIF ) var defaultTransportConfigs transportConfigs = newDefaultTransportConfigs() // transportConfig holds the configuration for creating HTTP transports type transportConfig struct { MaxIdleConns int IdleConnTimeout time.Duration DialTimeout time.Duration KeepAlive time.Duration DisableProxy bool } // TransportFactory handles creation of HTTP transports with different validation modes type transportFactory struct { config *Config telemetry *snowflakeTelemetry } func (tf *transportConfig) String() string { return fmt.Sprintf("{MaxIdleConns: %d, IdleConnTimeout: %s, DialTimeout: %s, KeepAlive: %s}", tf.MaxIdleConns, tf.IdleConnTimeout, tf.DialTimeout, tf.KeepAlive) } // NewTransportFactory creates a new transport factory func newTransportFactory(config *Config, telemetry *snowflakeTelemetry) *transportFactory { return &transportFactory{config: config, telemetry: telemetry} } func (tf *transportFactory) createProxy(transportConfig *transportConfig) func(*http.Request) (*url.URL, error) { if transportConfig.DisableProxy { return nil } logger.Debug("Initializing proxy configuration") if tf.config == nil || tf.config.ProxyHost == "" { logger.Debug("Config is empty or ProxyHost is not set. Using proxy settings from environment variables.") return http.ProxyFromEnvironment } connectionProxy := &url.URL{ Scheme: tf.config.ProxyProtocol, Host: fmt.Sprintf("%s:%d", tf.config.ProxyHost, tf.config.ProxyPort), } if tf.config.ProxyUser != "" && tf.config.ProxyPassword != "" { connectionProxy.User = url.UserPassword(tf.config.ProxyUser, tf.config.ProxyPassword) logger.Infof("Connection Proxy is configured: Connection proxy %v: ****@%v NoProxy:%v", tf.config.ProxyUser, connectionProxy.Host, tf.config.NoProxy) } else { logger.Infof("Connection Proxy is configured: Connection proxy: %v NoProxy: %v", connectionProxy.Host, tf.config.NoProxy) } cfg := httpproxy.Config{ HTTPSProxy: connectionProxy.String(), HTTPProxy: connectionProxy.String(), NoProxy: tf.config.NoProxy, } proxyURLFunc := cfg.ProxyFunc() return func(req *http.Request) (*url.URL, error) { return proxyURLFunc(req.URL) } } // createBaseTransport creates a base HTTP transport with the given configuration func (tf *transportFactory) createBaseTransport(transportConfig *transportConfig, tlsConfig *tls.Config) *http.Transport { logger.Debugf("Create a new Base Transport with transportConfig %v", transportConfig.String()) dialer := &net.Dialer{ Timeout: transportConfig.DialTimeout, KeepAlive: transportConfig.KeepAlive, } defaultTransport := http.DefaultTransport.(*http.Transport) return &http.Transport{ TLSClientConfig: tlsConfig, MaxIdleConns: cmp.Or(transportConfig.MaxIdleConns, defaultTransport.MaxIdleConns), MaxIdleConnsPerHost: cmp.Or(transportConfig.MaxIdleConns, defaultTransport.MaxIdleConns), IdleConnTimeout: cmp.Or(transportConfig.IdleConnTimeout, defaultTransport.IdleConnTimeout), Proxy: tf.createProxy(transportConfig), DialContext: dialer.DialContext, } } // createOCSPTransport creates a transport with OCSP validation func (tf *transportFactory) createOCSPTransport(transportConfig *transportConfig) (*http.Transport, error) { // Chain OCSP verification with custom TLS config ov := newOcspValidator(tf.config) tlsConfig, ok := sfconfig.GetTLSConfig(tf.config.TLSConfigName) if ok && tlsConfig != nil { tlsConfig.VerifyPeerCertificate = tf.chainVerificationCallbacks(tlsConfig.VerifyPeerCertificate, ov.verifyPeerCertificateSerial) } else { tlsConfig = &tls.Config{ VerifyPeerCertificate: ov.verifyPeerCertificateSerial, } } return tf.createBaseTransport(transportConfig, tlsConfig), nil } // createNoRevocationTransport creates a transport without certificate revocation checking func (tf *transportFactory) createNoRevocationTransport(transportConfig *transportConfig) http.RoundTripper { if tf.config != nil && tf.config.Transporter != nil { return tf.config.Transporter } return tf.createBaseTransport(transportConfig, nil) } // createCRLValidator creates a CRL validator func (tf *transportFactory) createCRLValidator() (*crlValidator, error) { allowCertificatesWithoutCrlURL := tf.config.CrlAllowCertificatesWithoutCrlURL == ConfigBoolTrue client := &http.Client{ Timeout: cmp.Or(tf.config.CrlHTTPClientTimeout, defaultCrlHTTPClientTimeout), Transport: tf.createNoRevocationTransport(transportConfigFor(transportTypeCRL)), } return newCrlValidator( tf.config.CertRevocationCheckMode, allowCertificatesWithoutCrlURL, tf.config.CrlInMemoryCacheDisabled, tf.config.CrlOnDiskCacheDisabled, cmp.Or(tf.config.CrlDownloadMaxSize, defaultCrlDownloadMaxSize), client, tf.telemetry, ) } // createTransport is the main entry point for creating transports func (tf *transportFactory) createTransport(transportConfig *transportConfig) (http.RoundTripper, error) { if tf.config == nil { // should never happen in production, only in tests logger.Warn("createTransport: got nil Config, using default one") return tf.createNoRevocationTransport(transportConfig), nil } // if user configured a custom Transporter, prioritize that if tf.config.Transporter != nil { logger.Debug("createTransport: using Transporter configured by the user") return tf.config.Transporter, nil } // Validate configuration if err := tf.validateRevocationConfig(); err != nil { return nil, err } // Handle CRL validation path if tf.config.CertRevocationCheckMode != CertRevocationCheckDisabled { logger.Debug("createTransport: will perform CRL validation") crlValidator, err := tf.createCRLValidator() if err != nil { return nil, err } crlCacheCleaner.startPeriodicCacheCleanup() // Chain CRL verification with custom TLS config tlsConfig, ok := sfconfig.GetTLSConfig(tf.config.TLSConfigName) if ok && tlsConfig != nil { crlVerify := crlValidator.verifyPeerCertificates tlsConfig.VerifyPeerCertificate = tf.chainVerificationCallbacks(tlsConfig.VerifyPeerCertificate, crlVerify) } else { tlsConfig = &tls.Config{ VerifyPeerCertificate: crlValidator.verifyPeerCertificates, } } return tf.createBaseTransport(transportConfig, tlsConfig), nil } // Handle no revocation checking path if tf.config.DisableOCSPChecks { logger.Debug("createTransport: skipping OCSP validation") return tf.createNoRevocationTransport(transportConfig), nil } logger.Debug("createTransport: will perform OCSP validation") return tf.createOCSPTransport(transportConfig) } // validateRevocationConfig checks for conflicting revocation settings func (tf *transportFactory) validateRevocationConfig() error { if !tf.config.DisableOCSPChecks && tf.config.CertRevocationCheckMode != CertRevocationCheckDisabled { return errors.New("both OCSP and CRL cannot be enabled at the same time, please disable one of them") } return nil } // chainVerificationCallbacks chains a user's custom verification with the provided verification function func (tf *transportFactory) chainVerificationCallbacks(orignalVerificationFunc func([][]byte, [][]*x509.Certificate) error, verificationFunc func([][]byte, [][]*x509.Certificate) error) func([][]byte, [][]*x509.Certificate) error { if orignalVerificationFunc == nil { return verificationFunc } // Chain the existing verification with the new one newVerify := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { // Run the user's custom verification first if err := orignalVerificationFunc(rawCerts, verifiedChains); err != nil { return err } // Then run the provided verification return verificationFunc(rawCerts, verifiedChains) } return newVerify } type defaultTransportConfigsType struct { oauthTransportConfig *transportConfig cloudProviderTransportConfig *transportConfig ocspTransportConfig *transportConfig crlTransportConfig *transportConfig snowflakeTransportConfig *transportConfig wifTransportConfig *transportConfig } func newDefaultTransportConfigs() *defaultTransportConfigsType { return &defaultTransportConfigsType{ oauthTransportConfig: &transportConfig{ MaxIdleConns: 1, IdleConnTimeout: 30 * time.Second, DialTimeout: 30 * time.Second, }, cloudProviderTransportConfig: &transportConfig{ MaxIdleConns: 15, IdleConnTimeout: 30 * time.Second, DialTimeout: 30 * time.Second, }, ocspTransportConfig: &transportConfig{ MaxIdleConns: 1, IdleConnTimeout: 5 * time.Second, DialTimeout: 5 * time.Second, KeepAlive: -1, }, crlTransportConfig: &transportConfig{ MaxIdleConns: 1, IdleConnTimeout: 5 * time.Second, DialTimeout: 5 * time.Second, KeepAlive: -1, }, snowflakeTransportConfig: &transportConfig{ MaxIdleConns: 3, IdleConnTimeout: 30 * time.Minute, DialTimeout: 30 * time.Second, }, wifTransportConfig: &transportConfig{ MaxIdleConns: 1, IdleConnTimeout: 30 * time.Second, DialTimeout: 30 * time.Second, DisableProxy: true, }, } } func (dtc *defaultTransportConfigsType) forTransportType(transportType transportType) *transportConfig { switch transportType { case transportTypeOAuth: return dtc.oauthTransportConfig case transportTypeCloudProvider: return dtc.cloudProviderTransportConfig case transportTypeOCSP: return dtc.ocspTransportConfig case transportTypeCRL: return dtc.crlTransportConfig case transportTypeSnowflake: return dtc.snowflakeTransportConfig case transportTypeWIF: return dtc.wifTransportConfig } panic("unknown transport type: " + strconv.Itoa(int(transportType))) } ================================================ FILE: transport_test.go ================================================ package gosnowflake import ( "crypto/tls" "net/http" "testing" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" ) func TestTransportFactoryErrorHandling(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} assertNilF(t, RegisterTLSConfig("TestTransportFactoryErrorHandlingTlsConfig", tlsConfig)) // Test CreateCustomTLSTransport with conflicting OCSP and CRL settings conflictingConfig := &Config{ DisableOCSPChecks: false, CertRevocationCheckMode: CertRevocationCheckEnabled, TLSConfigName: "TestTransportFactoryErrorHandlingTlsConfig", } factory := newTransportFactory(conflictingConfig, nil) transport, err := factory.createTransport(transportConfigFor(transportTypeSnowflake)) assertNotNilF(t, err, "Expected error for conflicting OCSP and CRL configuration") assertNilF(t, transport, "Expected nil transport when error occurs") expectedError := "both OCSP and CRL cannot be enabled at the same time, please disable one of them" assertEqualF(t, err.Error(), expectedError, "Expected specific error message") } func TestCreateStandardTransportErrorHandling(t *testing.T) { // Test CreateStandardTransport with conflicting settings conflictingConfig := &Config{ DisableOCSPChecks: false, CertRevocationCheckMode: CertRevocationCheckEnabled, } factory := newTransportFactory(conflictingConfig, nil) transport, err := factory.createTransport(transportConfigFor(transportTypeSnowflake)) assertNotNilF(t, err, "Expected error for conflicting OCSP and CRL configuration") assertNilF(t, transport, "Expected nil transport when error occurs") } func TestCreateCustomTLSTransportSuccess(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} assertNilF(t, RegisterTLSConfig("TestCreateCustomTLSTransportSuccessTlsConfig", tlsConfig)) // Test successful creation with valid config validConfig := &Config{ DisableOCSPChecks: true, CertRevocationCheckMode: CertRevocationCheckDisabled, TLSConfigName: "TestCreateCustomTLSTransportSuccessTlsConfig", } factory := newTransportFactory(validConfig, nil) transport, err := factory.createTransport(transportConfigFor(transportTypeSnowflake)) assertNilF(t, err, "Unexpected error") assertNotNilF(t, transport, "Expected non-nil transport for valid configuration") } func TestCreateStandardTransportSuccess(t *testing.T) { // Test successful creation with valid config validConfig := &Config{ DisableOCSPChecks: true, CertRevocationCheckMode: CertRevocationCheckDisabled, } factory := newTransportFactory(validConfig, nil) transport, err := factory.createTransport(transportConfigFor(transportTypeSnowflake)) assertNilF(t, err, "Unexpected error") assertNotNilF(t, transport, "Expected non-nil transport for valid configuration") } func TestDirectTLSConfigUsage(t *testing.T) { // Test the new direct TLS config approach customTLS := &tls.Config{ InsecureSkipVerify: true, ServerName: "custom.example.com", } assertNilF(t, RegisterTLSConfig("TestDirectTLSConfigUsageTlsConfig", customTLS)) config := &Config{ DisableOCSPChecks: true, CertRevocationCheckMode: CertRevocationCheckDisabled, TLSConfigName: "TestDirectTLSConfigUsageTlsConfig", } factory := newTransportFactory(config, nil) transport, err := factory.createTransport(transportConfigFor(transportTypeSnowflake)) assertNilF(t, err, "Unexpected error") assertNotNilF(t, transport, "Expected non-nil transport") } func TestRegisteredTLSConfigUsage(t *testing.T) { // Test registered TLS config approach through DSN parsing // Clean up any existing registry sfconfig.ResetTLSConfigRegistry() // Register a custom TLS config customTLS := &tls.Config{ InsecureSkipVerify: true, ServerName: "registered.example.com", } err := RegisterTLSConfig("test-direct", customTLS) assertNilF(t, err, "Failed to register TLS config") defer func() { err := DeregisterTLSConfig("test-direct") assertNilF(t, err, "Failed to deregister test TLS config") }() // Parse DSN that references the registered config dsn := "user:pass@account/db?tls=test-direct&ocspFailOpen=false&disableOCSPChecks=true" config, err2 := ParseDSN(dsn) assertNilF(t, err2, "Failed to parse DSN") config.CertRevocationCheckMode = CertRevocationCheckDisabled factory := newTransportFactory(config, nil) transport, err := factory.createTransport(transportConfigFor(transportTypeSnowflake)) assertNilF(t, err, "Unexpected error") assertNotNilF(t, transport, "Expected non-nil transport") } func TestDirectTLSConfigOnly(t *testing.T) { // Test that direct TLS config works without any registration // Create a direct TLS config directTLS := &tls.Config{ InsecureSkipVerify: true, ServerName: "direct.example.com", } assertNilF(t, RegisterTLSConfig("TestDirectTLSConfigOnlyTlsConfig", directTLS)) config := &Config{ DisableOCSPChecks: true, CertRevocationCheckMode: CertRevocationCheckDisabled, TLSConfigName: "TestDirectTLSConfigOnlyTlsConfig", } factory := newTransportFactory(config, nil) transport, err := factory.createTransport(transportConfigFor(transportTypeSnowflake)) assertNilF(t, err, "Unexpected error") assertNotNilF(t, transport, "Expected non-nil transport") } func TestProxyTransportCreation(t *testing.T) { proxyTests := []struct { config *Config proxyURL string disableProxy bool }{ { config: &Config{ ProxyProtocol: "http", ProxyHost: "proxy.connection.com", ProxyPort: 1234, }, disableProxy: true, proxyURL: "", }, { config: &Config{ ProxyProtocol: "https", ProxyHost: "proxy.connection.com", ProxyPort: 1234, }, disableProxy: true, proxyURL: "", }, { config: &Config{ ProxyProtocol: "http", ProxyHost: "proxy.connection.com", ProxyPort: 1234, }, proxyURL: "http://proxy.connection.com:1234", }, { config: &Config{ ProxyProtocol: "http", ProxyHost: "proxy.connection.com", ProxyPort: 1234, }, proxyURL: "http://proxy.connection.com:1234", }, { config: &Config{ ProxyProtocol: "https", ProxyHost: "proxy.connection.com", ProxyPort: 1234, }, proxyURL: "https://proxy.connection.com:1234", }, { config: &Config{ ProxyProtocol: "http", ProxyHost: "proxy.connection.com", ProxyPort: 1234, NoProxy: "*.snowflakecomputing.com,ocsp.testing.com", }, proxyURL: "", }, } for _, test := range proxyTests { t.Run(test.proxyURL, func(t *testing.T) { factory := newTransportFactory(test.config, nil) proxyFunc := factory.createProxy(&transportConfig{DisableProxy: test.disableProxy}) if test.disableProxy { assertNilF(t, proxyFunc, "Expected nil proxy function when proxy is disabled") return } req, _ := http.NewRequest("GET", "https://testing.snowflakecomputing.com", nil) proxyURL, _ := proxyFunc(req) if test.proxyURL == "" { assertNilF(t, proxyURL, "Expected nil proxy for https request") } else { assertEqualF(t, proxyURL.String(), test.proxyURL) } req, _ = http.NewRequest("GET", "http://ocsp.testing.com", nil) proxyURL, _ = proxyFunc(req) if test.proxyURL == "" { assertNilF(t, proxyURL, "Expected nil proxy for https request") } else { assertEqualF(t, proxyURL.String(), test.proxyURL) } }) } } func createTestNoRevocationTransport() http.RoundTripper { return newTransportFactory(&Config{}, nil).createNoRevocationTransport(defaultTransportConfigs.forTransportType(transportTypeSnowflake)) } ================================================ FILE: url_util.go ================================================ package gosnowflake import ( "net/url" "regexp" ) var ( matcher, _ = regexp.Compile(`^http(s?)\:\/\/[0-9a-zA-Z]([-.\w]*[0-9a-zA-Z@:])*(:(0-9)*)*(\/?)([a-zA-Z0-9\-\.\?\,\&\(\)\/\\\+&%\$#_=@]*)?$`) ) func isValidURL(targetURL string) bool { if !matcher.MatchString(targetURL) { logger.Infof(" The provided URL is not a valid URL - " + targetURL) return false } return true } func urlEncode(targetString string) string { // We use QueryEscape instead of PathEscape here // for consistency across Drivers. For example: // QueryEscape escapes space as "+" whereas PE // it as %20F. PE also does not escape @ or & // either but QE does. // The behavior of QE in Golang is more in sync // with URL encoders in Python and Java hence the choice return url.QueryEscape(targetString) } ================================================ FILE: util.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "fmt" "io" "iter" "maps" "math/rand" "os" "strings" "sync" "time" "github.com/apache/arrow-go/v18/arrow/memory" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config" ) // ContextKey is a type for context keys used in gosnowflake. Using a custom type helps avoid collisions with other context keys. type ContextKey string const ( multiStatementCount ContextKey = "MULTI_STATEMENT_COUNT" asyncMode ContextKey = "ASYNC_MODE_QUERY" queryIDChannel ContextKey = "QUERY_ID_CHANNEL" snowflakeRequestIDKey ContextKey = "SNOWFLAKE_REQUEST_ID" fetchResultByID ContextKey = "SF_FETCH_RESULT_BY_ID" filePutStream ContextKey = "STREAMING_PUT_FILE" fileGetStream ContextKey = "STREAMING_GET_FILE" fileTransferOptions ContextKey = "FILE_TRANSFER_OPTIONS" enableDecfloat ContextKey = "ENABLE_DECFLOAT" arrowAlloc ContextKey = "ARROW_ALLOC" queryTag ContextKey = "QUERY_TAG" enableStructuredTypes ContextKey = "ENABLE_STRUCTURED_TYPES" embeddedValuesNullable ContextKey = "EMBEDDED_VALUES_NULLABLE" describeOnly ContextKey = "DESCRIBE_ONLY" internalQuery ContextKey = "INTERNAL_QUERY" cancelRetry ContextKey = "CANCEL_RETRY" logQueryText ContextKey = "LOG_QUERY_TEXT" logQueryParameters ContextKey = "LOG_QUERY_PARAMETERS" ) var ( defaultTimeProvider = &unixTimeProvider{} ) // WithMultiStatement returns a context that allows the user to execute the desired number of sql queries in one query func WithMultiStatement(ctx context.Context, num int) context.Context { return context.WithValue(ctx, multiStatementCount, num) } // WithAsyncMode returns a context that allows execution of query in async mode func WithAsyncMode(ctx context.Context) context.Context { return context.WithValue(ctx, asyncMode, true) } // WithQueryIDChan returns a context that contains the channel to receive the query ID func WithQueryIDChan(ctx context.Context, c chan<- string) context.Context { return context.WithValue(ctx, queryIDChannel, c) } // WithRequestID returns a new context with the specified snowflake request id func WithRequestID(ctx context.Context, requestID UUID) context.Context { return context.WithValue(ctx, snowflakeRequestIDKey, requestID) } // WithFetchResultByID returns a context that allows retrieving the result by query ID func WithFetchResultByID(ctx context.Context, queryID string) context.Context { return context.WithValue(ctx, fetchResultByID, queryID) } // WithFilePutStream returns a context that contains the address of the file stream to be PUT func WithFilePutStream(ctx context.Context, reader io.Reader) context.Context { return context.WithValue(ctx, filePutStream, reader) } // WithFileGetStream returns a context that contains the address of the file stream to be GET func WithFileGetStream(ctx context.Context, writer io.Writer) context.Context { return context.WithValue(ctx, fileGetStream, writer) } // WithFileTransferOptions returns a context that contains the address of file transfer options func WithFileTransferOptions(ctx context.Context, options *SnowflakeFileTransferOptions) context.Context { return context.WithValue(ctx, fileTransferOptions, options) } // WithDescribeOnly returns a context that enables a describe only query func WithDescribeOnly(ctx context.Context) context.Context { return context.WithValue(ctx, describeOnly, true) } // WithHigherPrecision returns a context that enables higher precision by // returning a *big.Int or *big.Float variable when querying rows for column // types with numbers that don't fit into its native Golang counterpart // When used in combination with arrowbatches.WithBatches, original BigDecimal in arrow batches will be preserved. func WithHigherPrecision(ctx context.Context) context.Context { return ia.WithHigherPrecision(ctx) } // WithDecfloatMappingEnabled returns a context that enables native support for DECFLOAT. // Without this context, DECFLOAT columns are returned as strings. // With this context enabled, DECFLOAT columns are returned as *big.Float or float64 (depending on HigherPrecision setting). // Keep in mind that both float64 and *big.Float are not able to precisely represent some DECFLOAT values. // If precision is important, you have to use string representation and use your own library to parse it. func WithDecfloatMappingEnabled(ctx context.Context) context.Context { return context.WithValue(ctx, enableDecfloat, true) } // WithArrowAllocator returns a context embedding the provided allocator // which will be utilized by chunk downloaders when constructing Arrow // objects. func WithArrowAllocator(ctx context.Context, pool memory.Allocator) context.Context { return context.WithValue(ctx, arrowAlloc, pool) } // WithQueryTag returns a context that will set the given tag as the QUERY_TAG // parameter on any queries that are run func WithQueryTag(ctx context.Context, tag string) context.Context { return context.WithValue(ctx, queryTag, tag) } // WithStructuredTypesEnabled changes how structured types are returned. // Without this context structured types are returned as strings. // With this context enabled, structured types are returned as native Go types. func WithStructuredTypesEnabled(ctx context.Context) context.Context { return context.WithValue(ctx, enableStructuredTypes, true) } // WithEmbeddedValuesNullable changes how complex structures are returned. // Instead of simple values (like string) sql.NullXXX wrappers (like sql.NullString) are used. // It applies to map values and arrays. func WithEmbeddedValuesNullable(ctx context.Context) context.Context { return context.WithValue(ctx, embeddedValuesNullable, true) } // WithInternal sets the internal query flag. func WithInternal(ctx context.Context) context.Context { return context.WithValue(ctx, internalQuery, true) } // WithLogQueryText enables logging of the query text. func WithLogQueryText(ctx context.Context) context.Context { return context.WithValue(ctx, logQueryText, true) } // WithLogQueryParameters enables logging of the query parameters. func WithLogQueryParameters(ctx context.Context) context.Context { return context.WithValue(ctx, logQueryParameters, true) } // Get the request ID from the context if specified, otherwise generate one func getOrGenerateRequestIDFromContext(ctx context.Context) UUID { requestID, ok := ctx.Value(snowflakeRequestIDKey).(UUID) if ok && requestID != nilUUID { return requestID } return NewUUID() } // integer min func intMin(a, b int) int { if a < b { return a } return b } // integer max func intMax(a, b int) int { if a > b { return a } return b } func int64Max(a, b int64) int64 { if a > b { return a } return b } func getMin(arr []int) int { if len(arr) == 0 { return -1 } min := arr[0] for _, v := range arr { if v <= min { min = v } } return min } // time.Duration max func durationMax(d1, d2 time.Duration) time.Duration { if d1-d2 > 0 { return d1 } return d2 } // time.Duration min func durationMin(d1, d2 time.Duration) time.Duration { if d1-d2 < 0 { return d1 } return d2 } // toNamedValues converts a slice of driver.Value to a slice of driver.NamedValue for Go 1.8 SQL package func toNamedValues(values []driver.Value) []driver.NamedValue { namedValues := make([]driver.NamedValue, len(values)) for idx, value := range values { namedValues[idx] = driver.NamedValue{Name: "", Ordinal: idx + 1, Value: value} } return namedValues } // TokenAccessor manages the session token and master token type TokenAccessor = sfconfig.TokenAccessor type simpleTokenAccessor struct { token string masterToken string sessionID int64 accessorLock sync.Mutex // Used to implement accessor's Lock and Unlock tokenLock sync.RWMutex // Used to synchronize SetTokens and GetTokens } func getSimpleTokenAccessor() TokenAccessor { return &simpleTokenAccessor{sessionID: -1} } func (sta *simpleTokenAccessor) Lock() error { sta.accessorLock.Lock() return nil } func (sta *simpleTokenAccessor) Unlock() { sta.accessorLock.Unlock() } func (sta *simpleTokenAccessor) GetTokens() (token string, masterToken string, sessionID int64) { sta.tokenLock.RLock() defer sta.tokenLock.RUnlock() return sta.token, sta.masterToken, sta.sessionID } func (sta *simpleTokenAccessor) SetTokens(token string, masterToken string, sessionID int64) { sta.tokenLock.Lock() defer sta.tokenLock.Unlock() sta.token = token sta.masterToken = masterToken sta.sessionID = sessionID } func safeGetTokens(sr *snowflakeRestful) (token string, masterToken string, sessionID int64) { if sr == nil || sr.TokenAccessor == nil { logger.Error("safeGetTokens: could not get tokens as TokenAccessor was nil") return "", "", 0 } return sr.TokenAccessor.GetTokens() } func escapeForCSV(value string) string { if value == "" { return "\"\"" } if strings.Contains(value, "\"") || strings.Contains(value, "\n") || strings.Contains(value, ",") || strings.Contains(value, "\\") { return "\"" + strings.ReplaceAll(value, "\"", "\"\"") + "\"" } return value } // GetFromEnv is used to get the value of an environment variable from the system func GetFromEnv(name string, failOnMissing bool) (string, error) { if value := os.Getenv(name); value != "" { return value, nil } if failOnMissing { return "", fmt.Errorf("%v environment variable is not set", name) } return "", nil } type currentTimeProvider interface { currentTime() int64 } type unixTimeProvider struct { } func (utp *unixTimeProvider) currentTime() int64 { return time.Now().UnixMilli() } type syncParams struct { mu sync.Mutex params map[string]*string } func newSyncParams(params map[string]*string) syncParams { copied := make(map[string]*string) if params != nil { maps.Copy(copied, params) } return syncParams{params: copied} } func (sp *syncParams) get(key string) (*string, bool) { sp.mu.Lock() defer sp.mu.Unlock() if sp.params == nil { return nil, false } v, ok := sp.params[key] return v, ok } func (sp *syncParams) set(key string, value *string) { sp.mu.Lock() defer sp.mu.Unlock() if sp.params == nil { sp.params = make(map[string]*string) } sp.params[key] = value } // All returns an iterator over all params, holding the lock for the // duration of iteration. Callers use: for k, v := range sp.All() { ... } func (sp *syncParams) All() iter.Seq2[string, string] { return func(yield func(string, string) bool) { sp.mu.Lock() defer sp.mu.Unlock() for k, v := range sp.params { if !yield(k, *v) { return } } } } func chooseRandomFromRange(min float64, max float64) float64 { return rand.Float64()*(max-min) + min } func withLowerKeys[T any](in map[string]T) map[string]T { out := make(map[string]T) for k, v := range in { out[strings.ToLower(k)] = v } return out } func findByPrefix(in []string, prefix string) int { for i, v := range in { if strings.HasPrefix(v, prefix) { return i } } return -1 } ================================================ FILE: util_test.go ================================================ package gosnowflake import ( "context" "database/sql/driver" "fmt" "maps" "math/rand" "os" "runtime" "strconv" "sync" "testing" "time" ) type tcIntMinMax struct { v1 int v2 int out int } type tcUUID struct { uuid string } type constTypeProvider struct { constTime int64 } type tcSafeGetTokens struct { name string sr *snowflakeRestful expectedSessionID int64 } func (ctp *constTypeProvider) currentTime() int64 { return ctp.constTime } func constTimeProvider(constTime int64) *constTypeProvider { return &constTypeProvider{constTime: constTime} } func TestSimpleTokenAccessor(t *testing.T) { accessor := getSimpleTokenAccessor() token, masterToken, sessionID := accessor.GetTokens() if token != "" { t.Errorf("unexpected token %v", token) } if masterToken != "" { t.Errorf("unexpected master token %v", masterToken) } if sessionID != -1 { t.Errorf("unexpected session id %v", sessionID) } expectedToken, expectedMasterToken, expectedSessionID := "token123", "master123", int64(123) accessor.SetTokens(expectedToken, expectedMasterToken, expectedSessionID) token, masterToken, sessionID = accessor.GetTokens() if token != expectedToken { t.Errorf("unexpected token %v", token) } if masterToken != expectedMasterToken { t.Errorf("unexpected master token %v", masterToken) } if sessionID != expectedSessionID { t.Errorf("unexpected session id %v", sessionID) } } func TestSimpleTokenAccessorGetTokensSynchronization(t *testing.T) { accessor := getSimpleTokenAccessor() var wg sync.WaitGroup failed := false for range 1000 { wg.Add(1) go func() { // set a random session and token session := rand.Int63() sessionStr := strconv.FormatInt(session, 10) accessor.SetTokens("t"+sessionStr, "m"+sessionStr, session) // read back session and token and verify that invariant still holds token, masterToken, session := accessor.GetTokens() sessionStr = strconv.FormatInt(session, 10) if "t"+sessionStr != token || "m"+sessionStr != masterToken { failed = true } wg.Done() }() } // wait for all competing goroutines to finish setting and getting tokens wg.Wait() if failed { t.Fail() } } func TestSafeGetTokens(t *testing.T) { testcases := []tcSafeGetTokens{ { name: "with simple token accessor", sr: &snowflakeRestful{ FuncPostQuery: postQueryTest, TokenAccessor: getSimpleTokenAccessor(), }, expectedSessionID: -1, }, { name: "without token accessor", sr: &snowflakeRestful{ FuncPostQuery: postQueryTest, }, expectedSessionID: 0, }, } for _, test := range testcases { t.Run(fmt.Sprintf("%v", test.name), func(t *testing.T) { _, _, sessionID := safeGetTokens(test.sr) assertEqualE(t, sessionID, test.expectedSessionID, "expected sessionId to be %v, was %v", fmt.Sprintf("%d", test.expectedSessionID), fmt.Sprintf("%d", sessionID)) }) } } func TestGetRequestIDFromContext(t *testing.T) { expectedRequestID := NewUUID() ctx := WithRequestID(context.Background(), expectedRequestID) requestID := getOrGenerateRequestIDFromContext(ctx) if requestID != expectedRequestID { t.Errorf("unexpected request id: %v, expected: %v", requestID, expectedRequestID) } ctx = WithRequestID(context.Background(), nilUUID) requestID = getOrGenerateRequestIDFromContext(ctx) if requestID == nilUUID { t.Errorf("unexpected request id, should not be nil") } } func TestGenerateRequestID(t *testing.T) { firstRequestID := getOrGenerateRequestIDFromContext(context.Background()) otherRequestID := getOrGenerateRequestIDFromContext(context.Background()) if firstRequestID == otherRequestID { t.Errorf("request id should not be the same") } } func TestIntMin(t *testing.T) { testcases := []tcIntMinMax{ {1, 3, 1}, {5, 100, 5}, {321, 3, 3}, {123, 123, 123}, } for _, test := range testcases { t.Run(fmt.Sprintf("%v_%v_%v", test.v1, test.v2, test.out), func(t *testing.T) { a := intMin(test.v1, test.v2) if test.out != a { t.Errorf("failed int min. v1: %v, v2: %v, expected: %v, got: %v", test.v1, test.v2, test.out, a) } }) } } func TestIntMax(t *testing.T) { testcases := []tcIntMinMax{ {1, 3, 3}, {5, 100, 100}, {321, 3, 321}, {123, 123, 123}, } for _, test := range testcases { t.Run(fmt.Sprintf("%v_%v_%v", test.v1, test.v2, test.out), func(t *testing.T) { a := intMax(test.v1, test.v2) if test.out != a { t.Errorf("failed int max. v1: %v, v2: %v, expected: %v, got: %v", test.v1, test.v2, test.out, a) } }) } } type tcDurationMinMax struct { v1 time.Duration v2 time.Duration out time.Duration } func TestDurationMin(t *testing.T) { testcases := []tcDurationMinMax{ {1 * time.Second, 3 * time.Second, 1 * time.Second}, {5 * time.Second, 100 * time.Second, 5 * time.Second}, {321 * time.Second, 3 * time.Second, 3 * time.Second}, {123 * time.Second, 123 * time.Second, 123 * time.Second}, } for _, test := range testcases { t.Run(fmt.Sprintf("%v_%v_%v", test.v1, test.v2, test.out), func(t *testing.T) { a := durationMin(test.v1, test.v2) if test.out != a { t.Errorf("failed duratoin max. v1: %v, v2: %v, expected: %v, got: %v", test.v1, test.v2, test.out, a) } }) } } func TestDurationMax(t *testing.T) { testcases := []tcDurationMinMax{ {1 * time.Second, 3 * time.Second, 3 * time.Second}, {5 * time.Second, 100 * time.Second, 100 * time.Second}, {321 * time.Second, 3 * time.Second, 321 * time.Second}, {123 * time.Second, 123 * time.Second, 123 * time.Second}, } for _, test := range testcases { t.Run(fmt.Sprintf("%v_%v_%v", test.v1, test.v2, test.out), func(t *testing.T) { a := durationMax(test.v1, test.v2) if test.out != a { t.Errorf("failed duratoin max. v1: %v, v2: %v, expected: %v, got: %v", test.v1, test.v2, test.out, a) } }) } } type tcNamedValues struct { values []driver.Value out []driver.NamedValue } func compareNamedValues(v1 []driver.NamedValue, v2 []driver.NamedValue) bool { if v1 == nil && v2 == nil { return true } if v1 == nil || v2 == nil { return false } if len(v1) != len(v2) { return false } for i := range v1 { if v1[i] != v2[i] { return false } } return true } func TestToNamedValues(t *testing.T) { testcases := []tcNamedValues{ { values: []driver.Value{}, out: []driver.NamedValue{}, }, { values: []driver.Value{1}, out: []driver.NamedValue{{Name: "", Ordinal: 1, Value: 1}}, }, { values: []driver.Value{1, "test1", 9.876, nil}, out: []driver.NamedValue{ {Name: "", Ordinal: 1, Value: 1}, {Name: "", Ordinal: 2, Value: "test1"}, {Name: "", Ordinal: 3, Value: 9.876}, {Name: "", Ordinal: 4, Value: nil}}, }, } for _, test := range testcases { t.Run("", func(t *testing.T) { a := toNamedValues(test.values) if !compareNamedValues(test.out, a) { t.Errorf("failed int max. v1: %v, v2: %v, expected: %v, got: %v", test.values, test.out, test.out, a) } }) } } type tcIntArrayMin struct { in []int out int } func TestGetMin(t *testing.T) { testcases := []tcIntArrayMin{ {[]int{1, 2, 3, 4, 5}, 1}, {[]int{10, 25, 15, 5, 20}, 5}, {[]int{15, 12, 9, 6, 3}, 3}, {[]int{123, 123, 123, 123, 123}, 123}, {[]int{}, -1}, } for _, test := range testcases { t.Run(fmt.Sprintf("%v", test.out), func(t *testing.T) { a := getMin(test.in) if test.out != a { t.Errorf("failed get min. in: %v, expected: %v, got: %v", test.in, test.out, a) } }) } } type tcURLList struct { in string out bool } func TestValidURL(t *testing.T) { testcases := []tcURLList{ {"https://ssoTestURL.okta.com", true}, {"https://ssoTestURL.okta.com:8080", true}, {"https://ssoTestURL.okta.com/testpathvalue", true}, {"-a calculator", false}, {"This is a random test", false}, {"file://TestForFile", false}, } for _, test := range testcases { t.Run(test.in, func(t *testing.T) { result := isValidURL(test.in) if test.out != result { t.Errorf("Failed to validate URL, input :%v, expected: %v, got: %v", test.in, test.out, result) } }) } } type tcEncodeList struct { in string out string } func TestEncodeURL(t *testing.T) { testcases := []tcEncodeList{ {"Hello @World", "Hello+%40World"}, {"Test//String", "Test%2F%2FString"}, } for _, test := range testcases { t.Run(test.in, func(t *testing.T) { result := urlEncode(test.in) if test.out != result { t.Errorf("Failed to encode string, input %v, expected: %v, got: %v", test.in, test.out, result) } }) } } func TestParseUUID(t *testing.T) { testcases := []tcUUID{ {"6ba7b812-9dad-11d1-80b4-00c04fd430c8"}, {"00302010-0504-0706-0809-0a0b0c0d0e0f"}, } for _, test := range testcases { t.Run(test.uuid, func(t *testing.T) { requestID := ParseUUID(test.uuid) if requestID.String() != test.uuid { t.Fatalf("failed to parse uuid") } }) } } type tcEscapeCsv struct { in string out string } func TestEscapeForCSV(t *testing.T) { testcases := []tcEscapeCsv{ {"", "\"\""}, {"\n", "\"\n\""}, {"test\\", "\"test\\\""}, } for _, test := range testcases { t.Run(test.out, func(t *testing.T) { result := escapeForCSV(test.in) if test.out != result { t.Errorf("Failed to escape string, input %v, expected: %v, got: %v", test.in, test.out, result) } }) } } func TestGetFromEnv(t *testing.T) { os.Setenv("SF_TEST", "test") defer os.Unsetenv("SF_TEST") result, err := GetFromEnv("SF_TEST", true) if err != nil { t.Error("failed to read SF_TEST environment variable") } if result != "test" { t.Errorf("incorrect value read for SF_TEST. Expected: test, read %v", result) } } func TestGetFromEnvFailOnMissing(t *testing.T) { _, err := GetFromEnv("SF_TEST_MISSING", true) if err == nil { t.Error("should report error when there is missing env parameter") } } func skipOnJenkins(t *testing.T, message string) { if os.Getenv("JENKINS_HOME") != "" { t.Skip("Skipping test on Jenkins: " + message) } } func skipAuthTests(t *testing.T, message string) { if os.Getenv("RUN_AUTH_TESTS") != "true" { t.Skip("Setup 'RUN_AUTH_TESTS' flag to perform this test" + message) } } func skipOnMac(t *testing.T, reason string) { if runtime.GOOS == "darwin" && runningOnGithubAction() { t.Skip("skipped on Mac: " + reason) } } func skipOnWindows(t *testing.T, reason string) { if runtime.GOOS == "windows" { t.Skip("skipped on Windows: " + reason) } } func randomString(n int) string { r := rand.New(rand.NewSource(time.Now().UnixNano())) alpha := []rune("abcdefghijklmnopqrstuvwxyz") b := make([]rune, n) for i := range b { b[i] = alpha[r.Intn(len(alpha))] } return string(b) } func TestWithLowerKeys(t *testing.T) { m := make(map[string]string) m["abc"] = "def" m["GHI"] = "KLM" lowerM := withLowerKeys(m) assertEqualE(t, lowerM["abc"], "def") assertEqualE(t, lowerM["ghi"], "KLM") } func TestFindByPrefix(t *testing.T) { nonEmpty := []string{"aaa", "bbb", "ccc"} assertEqualE(t, findByPrefix(nonEmpty, "a"), 0) assertEqualE(t, findByPrefix(nonEmpty, "aa"), 0) assertEqualE(t, findByPrefix(nonEmpty, "aaa"), 0) assertEqualE(t, findByPrefix(nonEmpty, "bb"), 1) assertEqualE(t, findByPrefix(nonEmpty, "ccc"), 2) assertEqualE(t, findByPrefix(nonEmpty, "dd"), -1) assertEqualE(t, findByPrefix([]string{}, "dd"), -1) } func TestInternal(t *testing.T) { ctx := context.Background() assertFalseE(t, isInternal(ctx)) ctx = WithInternal(ctx) assertTrueE(t, isInternal(ctx)) } type envOverride struct { envName string oldValue string } func (e *envOverride) rollback() { if e.oldValue != "" { os.Setenv(e.envName, e.oldValue) } else { os.Unsetenv(e.envName) } } func overrideEnv(env string, value string) envOverride { oldValue := os.Getenv(env) os.Setenv(env, value) return envOverride{env, oldValue} } func TestSyncParamsAll(t *testing.T) { t.Run("nil map constructor", func(t *testing.T) { assertEqualE(t, len(syncParams{}.params), 0) }) t.Run("original map is left intact", func(t *testing.T) { m := make(map[string]*string) a := "a" m["a"] = &a sp := newSyncParams(m) b := "b" sp.set("a", &b) assertEqualE(t, *m["a"], "a") }) t.Run("nil map yields nothing", func(t *testing.T) { var sp syncParams count := 0 for range sp.All() { count++ } assertEqualE(t, count, 0) }) t.Run("empty map yields nothing", func(t *testing.T) { sp := newSyncParams(map[string]*string{}) count := 0 for range sp.All() { count++ } assertEqualE(t, count, 0) }) t.Run("iterates all entries", func(t *testing.T) { a, b := "1", "2" sp := newSyncParams(map[string]*string{"a": &a, "b": &b}) got := maps.Collect(sp.All()) assertEqualE(t, len(got), 2) assertEqualE(t, got["a"], "1") assertEqualE(t, got["b"], "2") }) t.Run("break stops early", func(t *testing.T) { a, b, c := "1", "2", "3" sp := newSyncParams(map[string]*string{"a": &a, "b": &b, "c": &c}) count := 0 for range sp.All() { count++ break } assertEqualE(t, count, 1) }) // This test verifies there's no data race — All() holds the mutex during iteration // while set() also acquires it, so they must serialize correctly. Running this test // under -race would catch it if the locking were missing or broken. t.Run("concurrent iteration and mutation", func(t *testing.T) { sp := newSyncParams(map[string]*string{}) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() for i := range 100 { v := strconv.Itoa(i) sp.set(v, &v) } }() go func() { defer wg.Done() for range 100 { for range sp.All() { } } }() wg.Wait() }) } ================================================ FILE: uuid.go ================================================ package gosnowflake import ( "crypto/rand" "fmt" "strconv" ) const rfc4122 = 0x40 // UUID is a RFC4122 compliant uuid type type UUID [16]byte var nilUUID UUID // NewUUID creates a new snowflake UUID func NewUUID() UUID { var u UUID _, err := rand.Read(u[:]) if err != nil { logger.Warnf("error while reading random bytes to UUID. %v", err) } u[8] = (u[8] | rfc4122) & 0x7F var version byte = 4 u[6] = (u[6] & 0xF) | (version << 4) return u } func getChar(str string) byte { i, _ := strconv.ParseUint(str, 16, 8) return byte(i) } // ParseUUID parses a string of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx into its UUID form func ParseUUID(str string) UUID { return UUID{ getChar(str[0:2]), getChar(str[2:4]), getChar(str[4:6]), getChar(str[6:8]), getChar(str[9:11]), getChar(str[11:13]), getChar(str[14:16]), getChar(str[16:18]), getChar(str[19:21]), getChar(str[21:23]), getChar(str[24:26]), getChar(str[26:28]), getChar(str[28:30]), getChar(str[30:32]), getChar(str[32:34]), getChar(str[34:36]), } } func (u UUID) String() string { return fmt.Sprintf("%x-%x-%x-%x-%x", u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) } ================================================ FILE: value_awaiter.go ================================================ package gosnowflake import ( "bytes" "runtime" "strconv" "sync" ) type valueAwaiterType struct { lockKey lockKeyType working bool cond *sync.Cond mu sync.Mutex h *valueAwaitHolderType } func newValueAwaiter(lockKey lockKeyType, h *valueAwaitHolderType) *valueAwaiterType { ret := &valueAwaiterType{ lockKey: lockKey, h: h, } ret.cond = sync.NewCond(&ret.mu) return ret } func awaitValue[T any](valueAwaiter *valueAwaiterType, runFunc func() (T, error), acceptFunc func(t T, err error) bool, defaultFactoryFunc func() T) (T, error) { logger.Tracef("awaitValue[%v] entered awaitValue for %s", goroutineID(), valueAwaiter.lockKey.lockID()) valueAwaiter.mu.Lock() value, err := runFunc() // check if the value is already ready if acceptFunc(value, err) { logger.Tracef("awaitValue[%v] value was ready", goroutineID()) valueAwaiter.mu.Unlock() return value, err } // value is not ready, check if no other thread is working if !valueAwaiter.working { logger.Tracef("awaitValue[%v] start working", goroutineID()) valueAwaiter.working = true valueAwaiter.mu.Unlock() // continue working only in this thread return defaultFactoryFunc(), nil } // Check again if the value is ready after each wakeup. // If one thread is woken up and the value is still not ready, it should return default and continue working on this. // If the value is ready, all threads should be woken up and return the value. ret, err := runFunc() for !acceptFunc(ret, err) { logger.Tracef("awaitValue[%v] waiting for value", goroutineID()) valueAwaiter.cond.Wait() logger.Tracef("awaitValue[%v] woke up", goroutineID()) ret, err = runFunc() if !acceptFunc(ret, err) && !valueAwaiter.working { logger.Tracef("awaitValue[%v] start working after wait", goroutineID()) valueAwaiter.working = true valueAwaiter.mu.Unlock() return defaultFactoryFunc(), nil } } // Value is ready - all threads should return the value. logger.Tracef("awaitValue[%v] value was ready after wait", goroutineID()) valueAwaiter.mu.Unlock() return ret, err } func (v *valueAwaiterType) done() { logger.Tracef("valueAwaiter[%v] done working for %s, resuming all threads", goroutineID(), v.lockKey.lockID()) v.mu.Lock() defer v.mu.Unlock() v.working = false v.cond.Broadcast() v.h.remove(v) } func (v *valueAwaiterType) resumeOne() { logger.Tracef("valueAwaiter[%v] done working for %s, resuming one thread", goroutineID(), v.lockKey.lockID()) v.mu.Lock() defer v.mu.Unlock() v.working = false v.cond.Signal() } type valueAwaitHolderType struct { mu sync.Mutex holders map[string]*valueAwaiterType } var valueAwaitHolder = newValueAwaitHolder() func newValueAwaitHolder() *valueAwaitHolderType { return &valueAwaitHolderType{ holders: make(map[string]*valueAwaiterType), } } func (h *valueAwaitHolderType) get(lockKey lockKeyType) *valueAwaiterType { lockID := lockKey.lockID() h.mu.Lock() defer h.mu.Unlock() holder, ok := h.holders[lockID] if !ok { holder = newValueAwaiter(lockKey, h) h.holders[lockID] = holder } return holder } func (h *valueAwaitHolderType) remove(v *valueAwaiterType) { h.mu.Lock() defer h.mu.Unlock() delete(h.holders, v.lockKey.lockID()) } func goroutineID() int { buf := make([]byte, 32) n := runtime.Stack(buf, false) buf = buf[:n] // goroutine 1 [running]: ... buf, ok := bytes.CutPrefix(buf, []byte("goroutine ")) if !ok { return -1 } before, _, ok := bytes.Cut(buf, []byte{' '}) if !ok { return -2 } goid, err := strconv.Atoi(string(before)) if err != nil { logger.Tracef("goroutineID err: %v", err) return -3 } return goid } ================================================ FILE: version.go ================================================ package gosnowflake // SnowflakeGoDriverVersion is the version of Go Snowflake Driver. const SnowflakeGoDriverVersion = "2.0.1" ================================================ FILE: wiremock_test.go ================================================ package gosnowflake import ( "crypto/tls" "crypto/x509" "database/sql" "fmt" "io" "net/http" "os" "strconv" "strings" "testing" "time" ) var wiremock = newWiremock() var wiremockHTTPS = newWiremockHTTPS() type wiremockClient struct { protocol string host string port int adminPort int client http.Client } type wiremockClientHTTPS struct { wiremockClient } func newWiremock() *wiremockClient { wmHost := os.Getenv("WIREMOCK_HOST") if wmHost == "" { wmHost = "127.0.0.1" } wmPortStr := os.Getenv("WIREMOCK_PORT") if wmPortStr == "" { wmPortStr = "14355" } wmPort, err := strconv.Atoi(wmPortStr) if err != nil { panic(fmt.Sprintf("WIREMOCK_PORT is not a number: %v", wmPortStr)) } return &wiremockClient{ protocol: "http", host: wmHost, port: wmPort, adminPort: wmPort, } } func newWiremockHTTPS() *wiremockClientHTTPS { wmHost := os.Getenv("WIREMOCK_HOST_HTTPS") if wmHost == "" { wmHost = "127.0.0.1" } wmPortStr := os.Getenv("WIREMOCK_PORT_HTTPS") if wmPortStr == "" { wmPortStr = "13567" } wmPort, err := strconv.Atoi(wmPortStr) if err != nil { panic(fmt.Sprintf("WIREMOCK_PORT is not a number: %v", wmPortStr)) } wmAdminPortStr := os.Getenv("WIREMOCK_PORT") if wmAdminPortStr == "" { wmAdminPortStr = "14355" } wmAdminPort, err := strconv.Atoi(wmAdminPortStr) if err != nil { panic(fmt.Sprintf("WIREMOCK_PORT is not a number: %v", wmPortStr)) } return &wiremockClientHTTPS{ wiremockClient: wiremockClient{ protocol: "https", host: wmHost, port: wmPort, adminPort: wmAdminPort, }, } } func (wm *wiremockClient) openDb(t *testing.T) *sql.DB { cfg := wm.connectionConfig() connector := NewConnector(SnowflakeDriver{}, *cfg) return sql.OpenDB(connector) } func (wm *wiremockClient) connectionConfig() *Config { cfg := &Config{ Account: "testAccount", User: "testUser", Password: "testPassword", Host: wm.host, Port: wm.port, Protocol: wm.protocol, LoginTimeout: time.Duration(30) * time.Second, RequestTimeout: time.Duration(30) * time.Second, MaxRetryCount: 3, OauthClientID: "testClientId", OauthClientSecret: "testClientSecret", OauthAuthorizationURL: wm.baseURL() + "/oauth/authorize", OauthTokenRequestURL: wm.baseURL() + "/oauth/token", } return cfg } func (wm *wiremockClientHTTPS) connectionConfig(t *testing.T) *Config { cfg := wm.wiremockClient.connectionConfig() cfg.Transporter = &http.Transport{ TLSClientConfig: wm.tlsConfig(t), } return cfg } func (wm *wiremockClientHTTPS) certPool(t *testing.T) *x509.CertPool { testCertPool := x509.NewCertPool() caBytes, err := os.ReadFile("ci/scripts/ca.der") assertNilF(t, err) certificate, err := x509.ParseCertificate(caBytes) assertNilF(t, err) testCertPool.AddCert(certificate) return testCertPool } func (wm *wiremockClientHTTPS) ocspTransporter(t *testing.T, delegate http.RoundTripper) http.RoundTripper { if delegate == nil { delegate = http.DefaultTransport } cfg := wm.connectionConfig(t) cfg.Transporter = delegate ov := newOcspValidator(cfg) return &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: wiremockHTTPS.certPool(t), VerifyPeerCertificate: ov.verifyPeerCertificateSerial, }, DisableKeepAlives: true, } } func (wm *wiremockClientHTTPS) tlsConfig(t *testing.T) *tls.Config { return &tls.Config{ RootCAs: wm.certPool(t), } } type wiremockMapping struct { filePath string params map[string]string } func newWiremockMapping(filePath string) wiremockMapping { return wiremockMapping{filePath: filePath} } type disableEnrichingWithTelemetry struct{} func (wm *wiremockClient) registerMappings(t *testing.T, args ...any) { skipOnJenkins(t, "wiremock does not work on Jenkins") enrichWithTelemetry := true var mappings []wiremockMapping for _, arg := range args { switch v := arg.(type) { case wiremockMapping: mappings = append(mappings, v) case []wiremockMapping: mappings = append(mappings, v...) case disableEnrichingWithTelemetry: enrichWithTelemetry = false default: t.Fatalf("unsupported argument type: %T", v) } } allMappings := mappings if enrichWithTelemetry { allMappings = append(allMappings, newWiremockMapping("telemetry/telemetry.json")) } for _, mapping := range allMappings { f, err := os.Open("test_data/wiremock/mappings/" + mapping.filePath) assertNilF(t, err) defer f.Close() mappingBodyBytes, err := io.ReadAll(f) assertNilF(t, err) mappingBody := string(mappingBodyBytes) for key, val := range mapping.params { mappingBody = strings.Replace(mappingBody, key, val, 1) } resp, err := wm.client.Post(fmt.Sprintf("%v/import", wm.mappingsURL()), "application/json", strings.NewReader(mappingBody)) assertNilF(t, err) if resp.StatusCode != http.StatusOK { respBody, err := io.ReadAll(resp.Body) assertNilF(t, err) t.Fatalf("cannot create mapping. status=%v body=\n%v", resp.StatusCode, string(respBody)) } } t.Cleanup(func() { req, err := http.NewRequest("DELETE", wm.mappingsURL(), nil) assertNilF(t, err) _, err = wm.client.Do(req) assertNilE(t, err) req, err = http.NewRequest("POST", fmt.Sprintf("%v/reset", wm.scenariosURL()), nil) assertNilF(t, err) _, err = wm.client.Do(req) assertNilE(t, err) }) } func (wm *wiremockClient) mappingsURL() string { return fmt.Sprintf("http://%v:%v/__admin/mappings", wm.host, wm.adminPort) } func (wm *wiremockClient) scenariosURL() string { return fmt.Sprintf("http://%v:%v/__admin/scenarios", wm.host, wm.adminPort) } func (wm *wiremockClient) baseURL() string { return fmt.Sprintf("%v://%v:%v", wm.protocol, wm.host, wm.port) } func TestQueryViaHttps(t *testing.T) { wiremockHTTPS.registerMappings(t, wiremockMapping{filePath: "auth/password/successful_flow.json"}, wiremockMapping{filePath: "select1.json", params: map[string]string{ "%AUTHORIZATION_HEADER%": "session token", }}, ) cfg := wiremockHTTPS.connectionConfig(t) testCertPool := x509.NewCertPool() caBytes, err := os.ReadFile("ci/scripts/ca.der") assertNilF(t, err) certificate, err := x509.ParseCertificate(caBytes) assertNilF(t, err) testCertPool.AddCert(certificate) cfg.Transporter = &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: testCertPool, }, } connector := NewConnector(SnowflakeDriver{}, *cfg) db := sql.OpenDB(connector) rows, err := db.Query("SELECT 1") assertNilF(t, err) defer rows.Close() var v int assertTrueF(t, rows.Next()) assertNilF(t, rows.Scan(&v)) assertEqualE(t, v, 1) }