Repository: snowflakedb/gosnowflake
Branch: master
Commit: a0b59e44724f
Files: 394
Total size: 2.4 MB
Directory structure:
gitextract_osgz45y5/
├── .cursor/
│ └── rules/
│ ├── overall-guidelines.mdc
│ └── testing.mdc
├── .github/
│ ├── CODEOWNERS
│ ├── ISSUE_TEMPLATE/
│ │ ├── BUG_REPORT.md
│ │ └── FEATURE_REQUEST.md
│ ├── ISSUE_TEMPLATE.md
│ ├── PULL_REQUEST_TEMPLATE.md
│ ├── repo_meta.yaml
│ ├── secret_scanning.yml
│ └── workflows/
│ ├── build-test.yml
│ ├── changelog.yml
│ ├── cla_bot.yml
│ ├── jira_close.yml
│ ├── jira_comment.yml
│ ├── jira_issue.yml
│ ├── parameters/
│ │ └── public/
│ │ ├── rsa_key_golang_aws.p8.gpg
│ │ ├── rsa_key_golang_azure.p8.gpg
│ │ └── rsa_key_golang_gcp.p8.gpg
│ ├── parameters_aws_auth_tests.json.gpg
│ ├── parameters_aws_golang.json.gpg
│ ├── parameters_azure_golang.json.gpg
│ ├── parameters_gcp_golang.json.gpg
│ ├── rsa-2048-private-key.p8.gpg
│ ├── rsa_keys/
│ │ ├── rsa_key.p8.gpg
│ │ └── rsa_key_invalid.p8.gpg
│ └── semgrep.yml
├── .gitignore
├── .golangci.yml
├── .pre-commit-config.yaml
├── .windsurf/
│ └── rules/
│ └── go.md
├── CHANGELOG.md
├── CONTRIBUTING.md
├── Jenkinsfile
├── LICENSE
├── Makefile
├── README.md
├── SECURITY.md
├── aaa_test.go
├── arrow_chunk.go
├── arrow_stream.go
├── arrow_test.go
├── arrowbatches/
│ ├── batches.go
│ ├── batches_test.go
│ ├── context.go
│ ├── converter.go
│ ├── converter_test.go
│ └── schema.go
├── assert_test.go
├── async.go
├── async_test.go
├── auth.go
├── auth_generic_test_methods_test.go
├── auth_oauth.go
├── auth_oauth_test.go
├── auth_test.go
├── auth_wif.go
├── auth_wif_test.go
├── auth_with_external_browser_test.go
├── auth_with_keypair_test.go
├── auth_with_mfa_test.go
├── auth_with_oauth_okta_authorization_code_test.go
├── auth_with_oauth_okta_client_credentials_test.go
├── auth_with_oauth_snowflake_authorization_code_test.go
├── auth_with_oauth_snowflake_authorization_code_wildcards_test.go
├── auth_with_oauth_test.go
├── auth_with_okta_test.go
├── auth_with_pat_test.go
├── authexternalbrowser.go
├── authexternalbrowser_test.go
├── authokta.go
├── authokta_test.go
├── azure_storage_client.go
├── azure_storage_client_test.go
├── bind_uploader.go
├── bindings_test.go
├── chunk.go
├── chunk_downloader.go
├── chunk_downloader_test.go
├── chunk_test.go
├── ci/
│ ├── _init.sh
│ ├── build.bat
│ ├── build.sh
│ ├── container/
│ │ ├── test_authentication.sh
│ │ └── test_component.sh
│ ├── docker/
│ │ └── rockylinux9/
│ │ └── Dockerfile
│ ├── gofix.sh
│ ├── image/
│ │ ├── Dockerfile
│ │ ├── build.sh
│ │ ├── scripts/
│ │ │ └── entrypoint.sh
│ │ └── update.sh
│ ├── scripts/
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── ca.crt
│ │ ├── ca.der
│ │ ├── ca.key
│ │ ├── ca.srl
│ │ ├── execute_tests.sh
│ │ ├── hang_webserver.py
│ │ ├── login_internal_docker.sh
│ │ ├── run_wiremock.sh
│ │ ├── setup_connection_parameters.sh
│ │ ├── setup_gpg.sh
│ │ ├── wiremock-ecdsa-pub.key
│ │ ├── wiremock-ecdsa.crt
│ │ ├── wiremock-ecdsa.csr
│ │ ├── wiremock-ecdsa.key
│ │ ├── wiremock-ecdsa.p12
│ │ ├── wiremock.crt
│ │ ├── wiremock.csr
│ │ ├── wiremock.key
│ │ ├── wiremock.p12
│ │ └── wiremock.v3.ext
│ ├── test.bat
│ ├── test.sh
│ ├── test_authentication.sh
│ ├── test_revocation.sh
│ ├── test_rockylinux9.sh
│ ├── test_rockylinux9_docker.sh
│ ├── test_wif.sh
│ └── wif/
│ └── parameters/
│ ├── parameters_wif.json.gpg
│ ├── rsa_wif_aws_azure.gpg
│ └── rsa_wif_gcp.gpg
├── client.go
├── client_configuration.go
├── client_configuration_test.go
├── client_test.go
├── cmd/
│ ├── arrow/
│ │ ├── .gitignore
│ │ ├── Makefile
│ │ └── transform_batches_to_rows/
│ │ ├── Makefile
│ │ └── transform_batches_to_rows.go
│ ├── logger/
│ │ ├── Makefile
│ │ └── logger.go
│ ├── mfa/
│ │ ├── Makefile
│ │ └── mfa.go
│ ├── programmatic_access_token/
│ │ ├── .gitignore
│ │ ├── Makefile
│ │ └── pat.go
│ ├── tomlfileconnection/
│ │ ├── .gitignore
│ │ └── Makefile
│ └── variant/
│ ├── Makefile
│ └── insertvariantobject.go
├── codecov.yml
├── connection.go
├── connection_configuration_test.go
├── connection_test.go
├── connection_util.go
├── connectivity_diagnosis.go
├── connectivity_diagnosis_test.go
├── connector.go
├── connector_test.go
├── converter.go
├── converter_test.go
├── crl.go
├── crl_test.go
├── ctx_test.go
├── datatype.go
├── datatype_test.go
├── datetime.go
├── datetime_test.go
├── doc.go
├── driver.go
├── driver_ocsp_test.go
├── driver_test.go
├── dsn.go
├── easy_logging.go
├── easy_logging_test.go
├── encrypt_util.go
├── encrypt_util_test.go
├── errors.go
├── errors_test.go
├── file_compression_type.go
├── file_transfer_agent.go
├── file_transfer_agent_test.go
├── file_util.go
├── file_util_test.go
├── function_wrapper_test.go
├── function_wrappers.go
├── gcs_storage_client.go
├── gcs_storage_client_test.go
├── go.mod
├── go.sum
├── gosnowflake.mak
├── heartbeat.go
├── heartbeat_test.go
├── htap.go
├── htap_test.go
├── internal/
│ ├── arrow/
│ │ └── arrow.go
│ ├── compilation/
│ │ ├── cgo_disabled.go
│ │ ├── cgo_enabled.go
│ │ ├── linking_mode.go
│ │ ├── minicore_disabled.go
│ │ └── minicore_enabled.go
│ ├── config/
│ │ ├── assert_test.go
│ │ ├── auth_type.go
│ │ ├── config.go
│ │ ├── config_bool.go
│ │ ├── connection_configuration.go
│ │ ├── connection_configuration_test.go
│ │ ├── crl_mode.go
│ │ ├── dsn.go
│ │ ├── dsn_test.go
│ │ ├── ocsp_mode.go
│ │ ├── priv_key.go
│ │ ├── tls_config.go
│ │ ├── tls_config_test.go
│ │ └── token_accessor.go
│ ├── errors/
│ │ └── errors.go
│ ├── logger/
│ │ ├── accessor.go
│ │ ├── accessor_test.go
│ │ ├── context.go
│ │ ├── easy_logging_support.go
│ │ ├── interfaces.go
│ │ ├── level_filtering.go
│ │ ├── optional_interfaces.go
│ │ ├── proxy.go
│ │ ├── secret_detector.go
│ │ ├── secret_detector_test.go
│ │ ├── secret_masking.go
│ │ ├── secret_masking_test.go
│ │ ├── slog_handler.go
│ │ ├── slog_logger.go
│ │ └── source_location_test.go
│ ├── os/
│ │ ├── libc_info.go
│ │ ├── libc_info_linux.go
│ │ ├── libc_info_notlinux.go
│ │ ├── libc_info_test.go
│ │ ├── os_details.go
│ │ ├── os_details_linux.go
│ │ ├── os_details_notlinux.go
│ │ ├── os_details_test.go
│ │ └── test_data/
│ │ └── sample_os_release
│ ├── query/
│ │ ├── response_types.go
│ │ └── transform.go
│ └── types/
│ └── types.go
├── local_storage_client.go
├── local_storage_client_test.go
├── location.go
├── location_test.go
├── locker.go
├── log.go
├── log_client_test.go
├── log_test.go
├── minicore.go
├── minicore_disabled_test.go
├── minicore_posix.go
├── minicore_provider_darwin_amd64.go
├── minicore_provider_darwin_arm64.go
├── minicore_provider_linux_amd64.go
├── minicore_provider_linux_arm64.go
├── minicore_provider_windows_amd64.go
├── minicore_provider_windows_arm64.go
├── minicore_test.go
├── minicore_windows.go
├── monitoring.go
├── multistatement.go
├── multistatement_test.go
├── ocsp.go
├── ocsp_test.go
├── old_driver_test.go
├── os_specific_posix.go
├── os_specific_windows.go
├── parameters.json.local
├── parameters.json.tmpl
├── permissions_test.go
├── platform_detection.go
├── platform_detection_test.go
├── prepared_statement_test.go
├── priv_key_test.go
├── put_get_test.go
├── put_get_user_stage_test.go
├── put_get_with_aws_test.go
├── query.go
├── restful.go
├── restful_test.go
├── result.go
├── retry.go
├── retry_test.go
├── rows.go
├── rows_test.go
├── s3_storage_client.go
├── s3_storage_client_test.go
├── secret_detector.go
├── secret_detector_test.go
├── secure_storage_manager.go
├── secure_storage_manager_linux.go
├── secure_storage_manager_notlinux.go
├── secure_storage_manager_test.go
├── sflog/
│ ├── interface.go
│ ├── levels.go
│ └── slog.go
├── sqlstate.go
├── statement.go
├── statement_test.go
├── storage_client.go
├── storage_client_test.go
├── storage_file_util_test.go
├── structured_type.go
├── structured_type_arrow_batches_test.go
├── structured_type_read_test.go
├── structured_type_write_test.go
├── telemetry.go
├── telemetry_test.go
├── test_data/
│ ├── .gitignore
│ ├── connections.toml
│ ├── multistatements.sql
│ ├── multistatements_drop.sql
│ ├── orders_100.csv
│ ├── orders_101.csv
│ ├── put_get_1.txt
│ ├── snowflake/
│ │ └── session/
│ │ └── token
│ ├── userdata1.parquet
│ ├── userdata1_orc
│ └── wiremock/
│ └── mappings/
│ ├── auth/
│ │ ├── external_browser/
│ │ │ ├── parallel_login_first_fails_then_successful_flow.json
│ │ │ ├── parallel_login_successful_flow.json
│ │ │ └── successful_flow.json
│ │ ├── mfa/
│ │ │ ├── parallel_login_first_fails_then_successful_flow.json
│ │ │ └── parallel_login_successful_flow.json
│ │ ├── oauth2/
│ │ │ ├── authorization_code/
│ │ │ │ ├── error_from_idp.json
│ │ │ │ ├── invalid_code.json
│ │ │ │ ├── successful_flow.json
│ │ │ │ ├── successful_flow_with_offline_access.json
│ │ │ │ └── successful_flow_with_single_use_refresh_token.json
│ │ │ ├── client_credentials/
│ │ │ │ ├── invalid_client.json
│ │ │ │ └── successful_flow.json
│ │ │ ├── login_request.json
│ │ │ ├── login_request_with_expired_access_token.json
│ │ │ └── refresh_token/
│ │ │ ├── invalid_refresh_token.json
│ │ │ ├── successful_flow.json
│ │ │ └── successful_flow_without_new_refresh_token.json
│ │ ├── password/
│ │ │ ├── invalid_host.json
│ │ │ ├── invalid_password.json
│ │ │ ├── invalid_user.json
│ │ │ ├── successful_flow.json
│ │ │ └── successful_flow_with_telemetry.json
│ │ ├── pat/
│ │ │ ├── invalid_token.json
│ │ │ ├── reading_fresh_token.json
│ │ │ └── successful_flow.json
│ │ └── wif/
│ │ ├── azure/
│ │ │ ├── http_error.json
│ │ │ ├── missing_issuer_claim.json
│ │ │ ├── missing_sub_claim.json
│ │ │ ├── non_json_response.json
│ │ │ ├── successful_flow_azure_functions.json
│ │ │ ├── successful_flow_azure_functions_custom_entra_resource.json
│ │ │ ├── successful_flow_azure_functions_no_client_id.json
│ │ │ ├── successful_flow_azure_functions_v2_issuer.json
│ │ │ ├── successful_flow_basic.json
│ │ │ ├── successful_flow_v2_issuer.json
│ │ │ └── unparsable_token.json
│ │ └── gcp/
│ │ ├── http_error.json
│ │ ├── missing_issuer_claim.json
│ │ ├── missing_sub_claim.json
│ │ ├── successful_flow.json
│ │ ├── successful_impersionation_flow.json
│ │ └── unparsable_token.json
│ ├── close_session.json
│ ├── hang.json
│ ├── minicore/
│ │ └── auth/
│ │ ├── disabled_flow.json
│ │ ├── successful_flow.json
│ │ └── successful_flow_linux.json
│ ├── ocsp/
│ │ ├── auth_failure.json
│ │ ├── malformed.json
│ │ └── unauthorized.json
│ ├── platform_detection/
│ │ ├── aws_ec2_instance_success.json
│ │ ├── aws_identity_success.json
│ │ ├── azure_managed_identity_success.json
│ │ ├── azure_vm_success.json
│ │ ├── gce_identity_success.json
│ │ ├── gce_vm_success.json
│ │ └── timeout_response.json
│ ├── query/
│ │ ├── long_running_query.json
│ │ ├── query_by_id_timeout.json
│ │ ├── query_execution.json
│ │ ├── query_monitoring.json
│ │ ├── query_monitoring_error.json
│ │ ├── query_monitoring_malformed.json
│ │ └── query_monitoring_running.json
│ ├── retry/
│ │ └── redirection_retry_workflow.json
│ ├── select1.json
│ └── telemetry/
│ ├── custom_telemetry.json
│ └── telemetry.json
├── test_utils_test.go
├── tls_config.go
├── tls_config_test.go
├── transaction.go
├── transaction_test.go
├── transport.go
├── transport_test.go
├── url_util.go
├── util.go
├── util_test.go
├── uuid.go
├── value_awaiter.go
├── version.go
└── wiremock_test.go
================================================
FILE CONTENTS
================================================
================================================
FILE: .cursor/rules/overall-guidelines.mdc
================================================
---
alwaysApply: true
---
# Cursor Rules for Go Snowflake Driver
## General Development Standards
### Code Quality
- Follow Go formatting standards (use `gofmt`)
- Use meaningful variable and function names
- Include error handling for all operations that can fail
- Write comprehensive documentation for public APIs
### Project Structure
- Place test files in the same package as the code being tested
- Use `test_data/` directory for test fixtures and sample data
- Group related functionality in logical packages
### Testing
- Test files should be named `*_test.go`
- **For test-specific rules, see `testing.mdc`**
- Write both positive and negative test cases
- Use table-driven tests for testing multiple scenarios
### Code Review Guidelines
- Ensure code follows Go best practices
- Verify comprehensive test coverage
- Check that error messages are descriptive and helpful for debugging
- Validate that public APIs are properly documented
================================================
FILE: .cursor/rules/testing.mdc
================================================
---
alwaysApply: true
---
# Cursor Rules for Go Test Files
This file automatically applies when working on `*_test.go` files.
## Testing Standards
### Assertion Helper Usage
- **ALWAYS** Attempt to use assertion helpers from `assert_test.go` instead of direct `t.Fatal`, `t.Fatalf`, `t.Error`, or `t.Errorf` calls. Where it makes sense, add new assertion helpers.
- **NEVER** write manual if-then-fatal patterns in test functions when a suitable assertion helper exists.
#### Common Assertion Patterns:
**Error Checking:**
```go
// ❌ WRONG
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// ✅ CORRECT
assertNilF(t, err, "Unexpected error")
```
**Nil Checking:**
```go
// ❌ WRONG
if obj == nil {
t.Fatal("Expected non-nil object")
}
// ✅ CORRECT
assertNotNilF(t, obj, "Expected non-nil object")
```
**Equality Checking:**
```go
// ❌ WRONG
if actual != expected {
t.Fatalf("Expected %v, got %v", expected, actual)
}
// ✅ CORRECT
assertEqualF(t, actual, expected, "Values should match")
```
**Error Message Validation:**
```go
// ❌ WRONG
if err.Error() != expectedMsg {
t.Fatalf("Expected error: %s, got: %s", expectedMsg, err.Error())
}
// ✅ CORRECT
assertEqualF(t, err.Error(), expectedMsg, "Error message should match")
```
**Boolean Assertions:**
```go
// ❌ WRONG
if !condition {
t.Fatal("Condition should be true")
}
// ✅ CORRECT
assertTrueF(t, condition, "Condition should be true")
```
#### Helper Function Reference:
Always examine `assertion_helpers.go` for the latest set of helpers. Consider these existing examples below.
- `assertNilF/E(t, value, description)` - Assert value is nil
- `assertNotNilF/E(t, value, description)` - Assert value is not nil
- `assertEqualF/E(t, actual, expected, description)` - Assert equality
- `assertNotEqualF/E(t, actual, expected, description)` - Assert inequality
- `assertTrueF/E(t, value, description)` - Assert boolean is true
- `assertFalseF/E(t, value, description)` - Assert boolean is false
- `assertStringContainsF/E(t, str, substring, description)` - Assert string contains substring
- `assertErrIsF/E(t, actual, expected, description)` - Assert error matches expected error
#### When to Use F vs E:
- Use `F` suffix (Fatal) for critical failures that should stop the test immediately as well as for preconditions
- Use `E` suffix (Error) for non-critical failures that allow the test to continue
## Code Review Guidelines:
- Flag any direct use of `t.Fatal*` or `t.Error*` in new code
- Ensure all test functions use appropriate assertion helpers
- Verify that error messages are descriptive and helpful for debugging
- Check that tests are comprehensive and cover edge cases# Cursor Rules for Go Test Files
================================================
FILE: .github/CODEOWNERS
================================================
* @snowflakedb/Client
/transport.go @snowflakedb/pki-oversight @snowflakedb/Client
/crl.go @snowflakedb/pki-oversight @snowflakedb/Client
/ocsp.go @snowflakedb/pki-oversight @snowflakedb/Client
# GitHub Advanced Security Secret Scanning config
/.github/secret_scanning.yml @snowflakedb/prodsec-security-manager-write
================================================
FILE: .github/ISSUE_TEMPLATE/BUG_REPORT.md
================================================
---
name: Bug Report 🐞
about: Something isn't working as expected? Here is the right place to report.
labels: bug
---
:exclamation: If you need **urgent assistance** then [file a case with Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge).
Otherwise continue here.
Please answer these questions before submitting your issue.
In order to accurately debug the issue this information is required. Thanks!
1. What version of GO driver are you using?
2. What operating system and processor architecture are you using?
3. What version of GO are you using?
run `go version` in your console
4.Server version:* E.g. 1.90.1
You may get the server version by running a query:
```
SELECT CURRENT_VERSION();
```
5. What did you do?
If possible, provide a recipe for reproducing the error.
A complete runnable program is good.
6. What did you expect to see?
What should have happened and what happened instead?
7. Can you set logging to DEBUG and collect the logs?
https://community.snowflake.com/s/article/How-to-generate-log-file-on-Snowflake-connectors
Before sharing any information, please be sure to review the log and remove any sensitive
information.
================================================
FILE: .github/ISSUE_TEMPLATE/FEATURE_REQUEST.md
================================================
---
name: Feature Request 💡
about: Suggest a new idea for the project.
labels: feature
---
## What is the current behavior?
## What is the desired behavior?
## How would this improve `gosnowflake`?
## References, Other Background
================================================
FILE: .github/ISSUE_TEMPLATE.md
================================================
### Issue description
Tell us what should happen and what happens instead
### Example code
```go
If possible, please enter some example code here to reproduce the issue.
```
### Error log
```
If you have an error log, please paste it here.
```
Add ``glog` option to your application to collect log files.
### Configuration
*Driver version (or git SHA):*
*Go version:* run `go version` in your console
*Server version:* E.g. 1.90.1
You may get the server version by running a query:
```
SELECT CURRENT_VERSION();
```
*Client OS:* E.g. Debian 8.1 (Jessie), Windows 10
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
### Description
SNOW-XXX Please explain the changes you made here.
### Checklist
- [ ] Added proper logging (if possible)
- [ ] Created tests which fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary
================================================
FILE: .github/repo_meta.yaml
================================================
point_of_contact: @snowflakedb/client
production: true
code_owners_file_present: false
jira_area: Developer Platform
================================================
FILE: .github/secret_scanning.yml
================================================
paths-ignore:
- "**/test_data/**"
================================================
FILE: .github/workflows/build-test.yml
================================================
name: Build and Test
permissions:
contents: read
on:
push:
branches:
- master
tags:
- v*
pull_request:
schedule:
- cron: '7 3 * * *'
workflow_dispatch:
inputs:
goTestParams:
default:
description: 'Parameters passed to go test'
sequentialTests:
type: boolean
default: false
description: 'Run tests sequentially (no buffering, slower)'
concurrency:
# older builds for the same pull request numer or branch should be cancelled
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
lint:
runs-on: ubuntu-latest
name: Check linter
steps:
- uses: actions/checkout@v4
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: '1.26'
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.11
- name: Format, Lint
shell: bash
run: ./ci/build.sh
- name: Run go fix across all platforms and tags
shell: bash
run: ./ci/gofix.sh
build-test-linux:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
cloud: [ 'AWS', 'AZURE', 'GCP' ]
go: [ '1.24', '1.25', '1.26' ]
name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Ubuntu
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: ${{ matrix.cloud }}
GORACE: history_size=7
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ./ci/test.sh
- name: Upload test results to Codecov
if: ${{!cancelled()}}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
build-test-linux-no-home:
runs-on: ubuntu-latest
name: Ubuntu - no HOME
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: '1.25'
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: AWS
GORACE: history_size=7
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
HOME_EMPTY: "yes"
run: ./ci/test.sh
build-test-mac:
runs-on: macos-latest
strategy:
fail-fast: false
matrix:
cloud: [ 'AWS', 'AZURE', 'GCP' ]
go: [ '1.24', '1.25', '1.26' ]
name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Mac
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: ${{ matrix.cloud }}
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ./ci/test.sh
- name: Upload test results to Codecov
if: ${{!cancelled()}}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
build-test-mac-no-home:
runs-on: macos-latest
name: Mac - no HOME
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: '1.25'
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: AWS
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
HOME_EMPTY: "yes"
run: ./ci/test.sh
build-test-windows:
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
cloud: [ 'AWS', 'AZURE', 'GCP' ]
go: [ '1.24', '1.25', '1.26' ]
name: ${{ matrix.cloud }} Go ${{ matrix.go }} on Windows
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- uses: actions/setup-python@v5
with:
python-version: '3.x'
architecture: 'x64'
- name: Test
shell: cmd
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: ${{ matrix.cloud }}
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ci\\test.bat
- name: Upload test results to Codecov
if: ${{!cancelled()}}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
fipsOnly:
runs-on: ubuntu-latest
strategy:
fail-fast: false
name: FIPS only mode
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: ${{ matrix.cloud }}
GORACE: history_size=7
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
TEST_GODEBUG: fips140=only
SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ./ci/test.sh
- name: Upload test results to Codecov
if: ${{!cancelled()}}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
build-test-linux-minicore-disabled:
runs-on: ubuntu-latest
name: Ubuntu - minicore disabled
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: '1.25'
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: AWS
GORACE: history_size=7
GO_TEST_PARAMS: ${{ inputs.goTestParams }} -tags=minicore_disabled
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ./ci/test.sh
build-test-mac-minicore-disabled:
runs-on: macos-latest
name: Mac - minicore disabled
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: '1.25'
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: AWS
GO_TEST_PARAMS: ${{ inputs.goTestParams }} -tags=minicore_disabled
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ./ci/test.sh
build-test-windows-minicore-disabled:
runs-on: windows-latest
name: Windows - minicore disabled
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: '1.25'
- uses: actions/setup-python@v5
with:
python-version: '3.x'
architecture: 'x64'
- name: Test
shell: cmd
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: AWS
GO_TEST_PARAMS: ${{ inputs.goTestParams }} -tags=minicore_disabled
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ci\\test.bat
ecc:
runs-on: ubuntu-latest
strategy:
fail-fast: false
name: Elliptic curves check
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: '1.25'
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: AWS
GORACE: history_size=7
GO_TEST_PARAMS: ${{ inputs.goTestParams }} -run TestQueryViaHttps
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
WIREMOCK_ENABLE_ECDSA: true
run: ./ci/test.sh
build-test-rockylinux9:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
cloud_go:
- cloud: 'AWS'
go: '1.24.2'
- cloud: 'AZURE'
go: '1.25.0'
- cloud: 'GCP'
go: '1.26.0'
name: ${{ matrix.cloud_go.cloud }} Go ${{ matrix.cloud_go.go }} on Rocky Linux 9
steps:
- uses: actions/checkout@v4
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: ${{ matrix.cloud_go.cloud }}
GORACE: history_size=7
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
SEQUENTIAL_TESTS: ${{ inputs.sequentialTests }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ./ci/test_rockylinux9_docker.sh ${{ matrix.cloud_go.go }}
build-test-ubuntu-arm:
runs-on: ubuntu-24.04-arm
strategy:
fail-fast: false
matrix:
cloud_go:
- cloud: 'AWS'
go: '1.24'
- cloud: 'AZURE'
go: '1.25'
- cloud: 'GCP'
go: '1.26'
name: ${{ matrix.cloud_go.cloud }} Go ${{ matrix.cloud_go.go }} on Ubuntu ARM
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 17
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.cloud_go.go }}
- name: Test
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: ${{ matrix.cloud_go.cloud }}
GORACE: history_size=7
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ./ci/test.sh
- name: Upload test results to Codecov
if: ${{!cancelled()}}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
build-test-windows-arm:
runs-on: windows-11-arm
strategy:
fail-fast: false
matrix:
cloud_go:
- cloud: 'AWS'
go: '1.24'
- cloud: 'AZURE'
go: '1.25'
- cloud: 'GCP'
go: '1.26'
name: ${{ matrix.cloud_go.cloud }} Go ${{ matrix.cloud_go.cloud }} on Windows ARM
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4 # for wiremock
with:
java-version: 21
distribution: 'temurin'
- name: Setup go
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.cloud_go.go }}
- uses: actions/setup-python@v5
with:
python-version: '3.x'
architecture: 'x64'
- name: Test
shell: cmd
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
GOLANG_PRIVATE_KEY_SECRET: ${{ secrets.GOLANG_PRIVATE_KEY_SECRET }}
CLOUD_PROVIDER: ${{ matrix.cloud_go.cloud }}
GO_TEST_PARAMS: ${{ inputs.goTestParams }}
WIREMOCK_PORT: 14335
WIREMOCK_HTTPS_PORT: 13567
run: ci\\test.bat
- name: Upload test results to Codecov
if: ${{!cancelled()}}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODE_COV_UPLOAD_TOKEN }}
================================================
FILE: .github/workflows/changelog.yml
================================================
name: Changelog Check
on:
pull_request:
types: [opened, synchronize, labeled, unlabeled]
jobs:
check_change_log:
runs-on: ubuntu-latest
if: ${{!contains(github.event.pull_request.labels.*.name, 'NO-CHANGELOG-UPDATES')}}
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Ensure CHANGELOG.md is updated
run: git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -wq "CHANGELOG.md"
================================================
FILE: .github/workflows/cla_bot.yml
================================================
name: "CLA Assistant"
on:
issue_comment:
types: [created]
pull_request_target:
types: [opened,closed,synchronize]
jobs:
CLAAssistant:
runs-on: ubuntu-latest
permissions:
actions: write
contents: write
pull-requests: write
statuses: write
steps:
- name: "CLA Assistant"
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
uses: contributor-assistant/github-action/@master
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PERSONAL_ACCESS_TOKEN : ${{ secrets.CLA_BOT_TOKEN }}
with:
path-to-signatures: 'signatures/version1.json'
path-to-document: 'https://github.com/snowflakedb/CLA/blob/main/README.md'
branch: 'main'
allowlist: 'dependabot[bot],github-actions,Jenkins User,_jenkins,sfc-gh-snyk-sca-sa,snyk-bot'
remote-organization-name: 'snowflake-eng'
remote-repository-name: 'cla-db'
================================================
FILE: .github/workflows/jira_close.yml
================================================
name: Jira closure
on:
issues:
types: [closed, deleted]
jobs:
close-issue:
runs-on: ubuntu-latest
steps:
- name: Extract issue from title
id: extract
env:
TITLE: "${{ github.event.issue.title }}"
run: |
jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://')
echo ::set-output name=jira::$jira
- name: Close Jira Issue
if: startsWith(steps.extract.outputs.jira, 'SNOW-')
env:
ISSUE_KEY: ${{ steps.extract.outputs.jira }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
run: |
JIRA_API_URL="${JIRA_BASE_URL}/rest/api/2/issue/${ISSUE_KEY}/transitions"
curl -X POST \
--url "$JIRA_API_URL" \
--user "${JIRA_USER_EMAIL}:${JIRA_API_TOKEN}" \
--header "Content-Type: application/json" \
--data "{
\"update\": {
\"comment\": [
{ \"add\": { \"body\": \"Closed on GitHub\" } }
]
},
\"fields\": {
\"customfield_12860\": { \"id\": \"11506\" },
\"customfield_10800\": { \"id\": \"-1\" },
\"customfield_12500\": { \"id\": \"11302\" },
\"customfield_12400\": { \"id\": \"-1\" },
\"resolution\": { \"name\": \"Done\" }
},
\"transition\": { \"id\": \"71\" }
}"
================================================
FILE: .github/workflows/jira_comment.yml
================================================
name: Jira comment
on:
issue_comment:
types: [created]
jobs:
comment-issue:
runs-on: ubuntu-latest
steps:
- name: Jira login
uses: atlassian/gajira-login@master
env:
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
- name: Extract issue from title
id: extract
env:
TITLE: "${{ github.event.issue.title }}"
run: |
jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://')
echo ::set-output name=jira::$jira
- name: Comment on issue
uses: atlassian/gajira-comment@master
if: startsWith(steps.extract.outputs.jira, 'SNOW-') && github.event.comment.user.login != 'codecov[bot]'
with:
issue: "${{ steps.extract.outputs.jira }}"
comment: "${{ github.event.comment.user.login }} commented:\n\n${{ github.event.comment.body }}\n\n${{ github.event.comment.html_url }}"
================================================
FILE: .github/workflows/jira_issue.yml
================================================
name: Jira creation
on:
issues:
types: [opened]
issue_comment:
types: [created]
jobs:
create-issue:
runs-on: ubuntu-latest
permissions:
issues: write
if: ((github.event_name == 'issue_comment' && github.event.comment.body == 'recreate jira' && github.event.comment.user.login == 'sfc-gh-mkeller') || (github.event_name == 'issues' && github.event.pull_request.user.login != 'whitesource-for-github-com[bot]'))
steps:
- name: Create JIRA Ticket
id: create
env:
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
ISSUE_TITLE: ${{ github.event.issue.title }}
ISSUE_BODY: ${{ github.event.issue.body }}
ISSUE_URL: ${{ github.event.issue.html_url }}
run: |
# debug
#set -x
TMP_BODY=$(mktemp)
trap "rm -f $TMP_BODY" EXIT
# Escape special characters in title and body
TITLE=$(echo "${ISSUE_TITLE//`/\\`}" | sed 's/"/\\"/g' | sed "s/'/\\\'/g")
echo "${ISSUE_BODY//`/\\`}" | sed 's/"/\\"/g' | sed "s/'/\\\'/g" > $TMP_BODY
echo -e "\n\n_Created from GitHub Action_ for $ISSUE_URL" >> $TMP_BODY
BODY=$(cat "$TMP_BODY")
PAYLOAD=$(jq -n \
--arg issuetitle "$TITLE" \
--arg issuebody "$BODY" \
'{
fields: {
project: { key: "SNOW" },
issuetype: { name: "Bug" },
summary: $issuetitle,
description: $issuebody,
customfield_11401: { id: "14723" },
assignee: { id: "712020:e527ae71-55cc-4e02-9217-1ca4ca8028a2" },
components: [{ id: "19286" }],
labels: ["oss"],
priority: { id: "10001" }
}
}')
# Create JIRA issue using REST API
RESPONSE=$(curl -s -X POST \
-H "Content-Type: application/json" \
-H "Accept: application/json" \
-u "$JIRA_USER_EMAIL:$JIRA_API_TOKEN" \
"$JIRA_BASE_URL/rest/api/2/issue" \
-d "$PAYLOAD")
# Extract JIRA issue key from response
JIRA_KEY=$(echo "$RESPONSE" | jq -r '.key')
if [ "$JIRA_KEY" = "null" ] || [ -z "$JIRA_KEY" ]; then
echo "Failed to create JIRA issue"
echo "Response: $RESPONSE"
echo "Request payload: $PAYLOAD"
exit 1
fi
echo "Created JIRA issue: $JIRA_KEY"
echo "jira_key=$JIRA_KEY" >> $GITHUB_OUTPUT
- name: Update GitHub Issue
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
REPOSITORY: ${{ github.repository }}
ISSUE_NUMBER: ${{ github.event.issue.number }}
JIRA_KEY: ${{ steps.create.outputs.jira_key }}
ISSUE_TITLE: ${{ github.event.issue.title }}
run: |
TITLE=$(echo "${ISSUE_TITLE//`/\\`}" | sed 's/"/\\"/g' | sed "s/'/\\\'/g")
PAYLOAD=$(jq -n \
--arg issuetitle "$TITLE" \
--arg jirakey "$JIRA_KEY" \
'{
title: ($jirakey + ": " + $issuetitle)
}')
# Update Github issue title with jira id
curl -s \
-X PATCH \
-H "Authorization: Bearer $GITHUB_TOKEN" \
-H "Accept: application/vnd.github+json" \
-H "X-GitHub-Api-Version: 2022-11-28" \
"https://api.github.com/repos/$REPOSITORY/issues/$ISSUE_NUMBER" \
-d "$PAYLOAD"
if [ "$?" != 0 ]; then
echo "Failed to update GH issue. Payload was:"
echo "$PAYLOAD"
exit 1
fi
================================================
FILE: .github/workflows/semgrep.yml
================================================
name: Run semgrep checks
on:
pull_request:
branches: [main, master]
permissions:
contents: read
jobs:
run-semgrep-reusable-workflow:
uses: snowflakedb/reusable-workflows/.github/workflows/semgrep-v2.yml@main
secrets:
token: ${{ secrets.SEMGREP_APP_TOKEN }}
================================================
FILE: .gitignore
================================================
*.DS_Store
.idea/
.vscode/
parameters*.json
parameters*.bat
*.p8
coverage.txt
fuzz-*/
/select1
/selectmany
/verifycert
wss-golang-agent.config
wss-unified-agent.jar
whitesource/
*.swp
cp.out
__debug_bin*
test-output.txt
test-report.junit.xml
# exclude vendor
vendor
# SSH private key for WIF tests
ci/wif/parameters/rsa_wif_aws_azure
ci/wif/parameters/rsa_wif_gcp
================================================
FILE: .golangci.yml
================================================
version: "2"
run:
tests: true
linters:
exclusions:
rules:
- path: "_test.go"
linters:
- errcheck
- path: "cmd/"
linters:
- errcheck
- path: "_test.go"
linters:
- staticcheck
text: "implement StmtQueryContext"
- path: "_test.go"
linters:
- staticcheck
text: "implement StmtExecContext"
- linters:
- staticcheck
text: "QF1001"
- linters:
- staticcheck
text: "SA1019: .+\\.(LoginTimeout|RequestTimeout|JWTExpireTimeout|ClientTimeout|JWTClientTimeout|ExternalBrowserTimeout|CloudStorageTimeout|Tracing) is deprecated"
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: git@github.com:snowflakedb/casec_precommit.git # SSH
# - repo: https://github.com/snowflakedb/casec_precommit.git # HTTPS
rev: v1.5
hooks:
- id: snapps-secret-scanner
================================================
FILE: .windsurf/rules/go.md
================================================
---
trigger: glob
description:
globs: **/*.go
---
# Go files rules
## General
1. Unless it's necessary or told otherwise, try reusing existing files, both for implementation and tests.
2. If possible, try running relevant tests.
## Tests
1. Create a test file with the name same as prod code file by default.
2. For assertions use our test helpers defined in assert_test.go.
## Logging
1. Add reasonable logging - don't repeat logs, but add them when it's meaningful.
2. Always consider log levels.
================================================
FILE: CHANGELOG.md
================================================
# Changelog
## Upcoming release
Bug fixes:
- Fixed empty `Account` when connecting with programmatic `Config` and `database/sql.Connector` by deriving `Account` from the first DNS label of `Host` in `FillMissingConfigParameters` when `Host` matches the Snowflake hostname pattern (snowflakedb/gosnowflake#1772).
## 2.0.1
Bug fixes:
- Fixed default `CrlDownloadMaxSize` to be 20MB instead of 200MB, as the previous value was set too high and could cause out-of-memory issues (snowflakedb/gosnowflake#1735).
- Replaced global `paramsMutex` with per-connection `syncParams` to encapsulate parameter synchronization and avoid cross-connection contention (snowflakedb/gosnoflake#1747).
- `Config.Params` map is not modified anymore, to avoid changing parameter values across connections of the same connection pool (snowflakedb/gosnowflake#1747).
- Set `BlobContentMD5` on Azure uploads so that multi-part uploads have the blob content-MD5 property populated (snowflakedb/gosnowflake#1757).
- Fixed 403 errors from Google/GCP/GCS PUT queries on versioned stages (snowflakedb/gosnowflake#1760).
- Fixed not updating query context cache for failed queries (snowflakedb/gosnowflake#1763).
Internal changes:
- Moved configuration to a dedicated internal package (snowflakedb/gosnowflake#1720).
- Modernized Go syntax idioms throughout the codebase.
- Added libc family, version and dynamic linking marker to client environment telemetry (snowflakedb/gosnowflake#1750).
- Bumped a few libraries to fix vulnerabilities (snowflakedb/gosnowflake#1751, snowflakedb/gosnowflake#1756).
- Depointerised query context cache in `snowflakeConn` (snowflakedb/gosnowflake#1763).
## 2.0.0
Breaking changes:
- Removed `RaisePutGetError` from `SnowflakeFileTransferOptions` - current behaviour is aligned to always raise errors for PUT/GET operations (snowflakedb/gosnowflake#1690).
- Removed `GetFileToStream` from `SnowflakeFileTransferOptions` - using `WithFileGetStream` automatically enables file streaming for GETs (snowflakedb/gosnowflake#1690).
- Renamed `WithFileStream` to `WithFilePutStream` for consistency (snowflakedb/gosnowflake#1690).
- `Array` function now returns error for unsupported types (snowflakedb/gosnowflake#1693).
- `WithMultiStatement` does not return error anymore (snowflakedb/gosnowflake#1693).
- `WithOriginalTimestamp` is removed, use `WithArrowBatchesTimestampOption(UseOriginalTimestamp)` instead (snowflakedb/gosnowflake#1693).
- `WithMapValuesNullable` and `WithArrayValuesNullable` combined into one option `WithEmbeddedValuesNullable` (snowflakedb/gosnowflake#1693).
- Hid streaming chunk downloader. It will be removed completely in the future (snowflakedb/gosnowflake#1696).
- Maximum number of chunk download goroutines is now configured with `CLIENT_PREFETCH_THREADS` session parameter (snowflakedb/gosnowflake#1696)
and default to 4.
- Fixed typo in `GOSNOWFLAKE_SKIP_REGISTRATION` env variable (snowflakedb/gosnowflake#1696).
- Removed `ClientIP` field from `Config` struct. This field was never used and is not needed for any functionality (snowflakedb/gosnowflake#1692).
- Unexported MfaToken and IdToken (snowflakedb/gosnowflake#1692).
- Removed `InsecureMode` field from `Config` struct. Use `DisableOCSPChecks` instead (snowflakedb/gosnowflake#1692).
- Renamed `KeepSessionAlive` field in `Config` struct to `ServerSessionKeepAlive` to adjust with the remaining drivers (snowflakedb/gosnowflake#1692).
- Removed `DisableTelemetry` field from `Config` struct. Use `CLIENT_TELEMETRY_ENABLED` session parameter instead (snowflakedb/gosnowflake#1692).
- Removed stream chunk downloader. Use a regular, default downloader instead. (snowflakedb/gosnowflake#1702).
- Removed `SnowflakeTransport`. Use `Config.Transporter` or simply register your own TLS config with `RegisterTLSConfig` if you just need a custom root certificates set (snowflakedb/gosnowflake#1703).
- Arrow batches changes (snowflakedb/gosnowflake#1706):
- Arrow batches have been extracted to a separate package. It should significantly drop the compilation size for those who don't need arrow batches (~34MB -> ~18MB).
- Removed `GetArrowBatches` from `SnowflakeRows` and `SnowflakeResult`. Use `arrowbatches.GetArrowBatches(rows.(SnowflakeRows))` instead.
- Migrated functions:
- `sf.WithArrowBatchesTimestampOption` -> `arrowbatches.WithTimstampOption`
- `sf.WithArrowBatchesUtf8Validation` -> `arrowbatches.WithUtf8Validation`
- `sf.ArrowSnowflakeTimestampToTime` -> `arrowbatches.ArrowSnowflakeTimestampToTime`
- Logging changes (snowflakedb/gosnowflake#1710):
- Removed Logrus logger and migrated to slog.
- Simplified `SFLogger` interface.
- Added `SFSlogLogger` interface for setting custom slog handler.
Bug fixes:
- The query `context.Context` is now propagated to cloud storage operations for PUT and GET queries, allowing for better cancellation handling (snowflakedb/gosnowflake#1690).
New features:
- Added support for Go 1.26, dropped support for Go 1.23 (snowflakedb/gosnowflake#1707).
- Added support for FIPS-only mode (snowflakedb/gosnowflake#1496).
Bug fixes:
- Added panic recovery block for stage file uploads and downloads operation (snowflakedb/gosnowflake#1687).
- Fixed WIF metadata request from Azure container, manifested with HTTP 400 error (snowflakedb/gosnowflake#1701).
- Fixed SAML authentication port validation bypass in `isPrefixEqual` where the second URL's port was never checked (snowflakedb/gosnowflake#1712).
- Fixed a race condition in OCSP cache clearer (snowflakedb/gosnowflake#1704).
- The query `context.Context` is now propagated to cloud storage operations for PUT and GET queries, allowing for better cancellation handling (snowflakedb/gosnowflake#1690).
- Fixed `tokenFilePath` DSN parameter triggering false validation error claiming both `token` and `tokenFilePath` were specified when only `tokenFilePath` was provided in the DSN string (snowflakedb/gosnowflake#1715).
- Fixed minicore crash (SIGFPE) on fully statically linked Linux binaries by detecting static linking via ELF PT_INTERP inspection and skipping `dlopen` gracefully (snowflakedb/gosnowflake#1721).
Internal changes:
- Moved configuration to a dedicated internal package (snowflakedb/gosnowflake#1720).
## 1.19.0
New features:
- Added ability to disable minicore loading at compile time (snowflakedb/gosnowflake#1679).
- Exposed `tokenFilePath` in `Config` (snowflakedb/gosnowflake#1666).
- `tokenFilePath` is now read for every new connection (snowflakedb/gosnowflake#1666).
- Added support for identity impersonation when using workload identity federation (snowflakedb/gosnowflake#1652, snowflakedb/gosnowflake#1660).
Bug fixes:
- Fixed getting file from an unencrypted stage (snowflakedb/gosnowflake#1672).
- Fixed minicore file name gathering in client environment (snowflakedb/gosnowflake#1661).
- Fixed file descriptor leaks in cloud storage calls (snowflakedb/gosnowflake#1682)
- Fixed path escaping for GCS urls (snowflakedb/gosnowflake#1678).
Internal changes:
- Improved Linux telemetry gathering (snowflakedb/gosnowflake#1677).
- Improved some logs returned from cloud storage clients (snowflakedb/gosnowflake#1665).
## 1.18.1
Bug fixes:
- Handle HTTP307 & 308 in drivers to achieve better resiliency to backend errors (snowflakedb/gosnowflake#1616).
- Create temp directory only if needed during file transfer (snowflakedb/gosnowflake#1647)
- Fix unnecessary user expansion for file paths (snowflakedb/gosnowflake#1646).
Internal changes:
- Remove spammy "telemetry disabled" log messages (snowflakedb/gosnowflake#1638).
- Introduced shared library ([source code](https://github.com/snowflakedb/universal-driver/tree/main/sf_mini_core)) for extended telemetry to identify and prepare testing platform for native rust extensions (snowflakedb/gosnowflake#1629)
## 1.18.0
New features:
- Added validation of CRL `NextUpdate` for freshly downloaded CRLs (snowflakedb/gosnowflake#1617)
- Exposed function to send arbitrary telemetry data (snowflakedb/gosnowflake#1627)
- Added logging of query text and parameters (snowflakedb/gosnowflake#1625)
Bug fixes:
- Fixed a data race error in tests caused by platform_detection init() function (snowflakedb/gosnowflake#1618)
- Make secrets detector initialization thread safe and more maintainable (snowflakedb/gosnowflake#1621)
Internal changes:
- Added ISA to login request telemetry (snowflakedb/gosnowflake#1620)
## 1.17.1
- Fix unsafe reflection of nil pointer on DECFLOAT func in bind uploader (snowflakedb/gosnowflake#1604).
- Added temporary download files cleanup (snowflakedb/gosnowflake#1577)
- Marked fields as deprecated (snowflakedb/gosnowflake#1556)
- Exposed `QueryStatus` from `SnowflakeResult` and `SnowflakeRows` in `GetStatus()` function (snowflakedb/gosnowflake#1556)
- Split timeout settings into separate groups based on target service types (snowflakedb/gosnowflake#1531)
- Added small clarification in oauth.go example on token escaping (snowflakedb/gosnowflake#1574)
- Ensured proper permissions for CRL cache directory (snowflakedb/gosnowflake#1588)
- Added `CrlDownloadMaxSize` to limit the size of CRL downloads (snowflakedb/gosnowflake#1588)
- Added platform telemetry to login requests. Can be disabled with `SNOWFLAKE_DISABLE_PLATFORM_DETECTION` environment variable (snowflakedb/gosnowflake#1601)
- Bypassed proxy settings for WIF metadata requests (snowflakedb/gosnowflake#1593)
- Fixed a bug where GCP PUT/GET operations would fail when the connection context was cancelled (snowflakedb/gosnowflake#1584)
- Fixed nil pointer dereference while calling long-running queries (snowflakedb/gosnowflake#1592) (snowflakedb/gosnowflake#1596)
- Moved keyring-based secure storage manager into separate file to avoid the need to initialize keyring on Linux. (snowflakedb/gosnowflake#1595)
- Enabling official support for RHEL9 by testing and enabling CI/CD checks for Rocky Linux in CICD, (snowflakedb/gosnowflake#1597)
- Improve logging (snowflakedb/gosnowflake#1570)
## 1.17.0
- Added ability to configure OCSP per connection (snowflakedb/gosnowflake#1528)
- Added `DECFLOAT` support, see details in `doc.go` (snowflakedb/gosnowflake#1504, snowflakedb/gosnowflake#1506)
- Added support for Go 1.25, dropped support for Go 1.22 (snowflakedb/gosnowflake#1544)
- Added proxy options to connection parameters (snowflakedb/gosnowflake#1511)
- Added `client_session_keep_alive_heartbeat_frequency` connection param (snowflakedb/gosnowflake#1576)
- Added support for multi-part downloads for S3, Azure and GCP (snowflakedb/gosnowflake#1549)
- Added `singleAuthenticationPrompt` to control whether only one authentication should be performed at the same time for authentications that need human interactions (like MFA or OAuth authorization code). Default is true. (snowflakedb/gosnowflake#1561)
- Fixed missing `DisableTelemetry` option in connection parameters (snowflakedb/gosnowflake#1520)
- Fixed multistatements in large result sets (snowflakedb/gosnowflake#1539, snowflakedb/gosnowflake#1543, snowflakedb/gosnowflake#1547)
- Fixed unnecessary retries when context is cancelled (snowflakedb/gosnowflake#1540)
- Fixed regression in TOML connection file (snowflakedb/gosnowflake#1530)
## Prior Releases
Release notes available at https://docs.snowflake.com/en/release-notes/clients-drivers/golang
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing Guidelines
## Reporting Issues
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/snowflakedb/gosnowflake/issues?state=open) or was [recently closed](https://github.com/snowflakedb/gosnowflake/issues?direction=desc&page=1&sort=updated&state=closed).
## Contributing Code
By contributing to this project, you share your code under the Apache License 2, as specified in the LICENSE file.
### Code Review
Everyone is invited to review and comment on pull requests.
If it looks fine to you, comment with "LGTM" (Looks good to me).
If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes.
Before merging the Pull Request, at least one Snowflake team member must have commented with "LGTM".
================================================
FILE: Jenkinsfile
================================================
@Library('pipeline-utils')
import com.snowflake.DevEnvUtils
import groovy.json.JsonOutput
timestamps {
node('high-memory-node') {
stage('checkout') {
scmInfo = checkout scm
println("${scmInfo}")
env.GIT_BRANCH = scmInfo.GIT_BRANCH
env.GIT_COMMIT = scmInfo.GIT_COMMIT
}
params = [
string(name: 'svn_revision', value: 'temptest-deployed'),
string(name: 'branch', value: 'main'),
string(name: 'client_git_commit', value: scmInfo.GIT_COMMIT),
string(name: 'client_git_branch', value: scmInfo.GIT_BRANCH),
string(name: 'TARGET_DOCKER_TEST_IMAGE', value: 'go-chainguard-go1_24'),
string(name: 'parent_job', value: env.JOB_NAME),
string(name: 'parent_build_number', value: env.BUILD_NUMBER)
]
stage('Authenticate Artifactory') {
script {
new DevEnvUtils().withSfCli {
sh "sf artifact oci auth"
}
}
}
parallel(
'Test': {
stage('Test') {
build job: 'RT-LanguageGo-PC', parameters: params
}
},
'Test Authentication': {
stage('Test Authentication') {
withCredentials([
string(credentialsId: 'sfctest0-parameters-secret', variable: 'PARAMETERS_SECRET')
]) {
sh '''\
|#!/bin/bash -e
|$WORKSPACE/ci/test_authentication.sh
'''.stripMargin()
}
}
},
'Test WIF Auth': {
stage('Test WIF Auth') {
withCredentials([
string(credentialsId: 'sfctest0-parameters-secret', variable: 'PARAMETERS_SECRET'),
]) {
sh '''\
|#!/bin/bash -e
|$WORKSPACE/ci/test_wif.sh
'''.stripMargin()
}
}
},
'Test Revocation Validation': {
stage('Test Revocation Validation') {
withCredentials([
usernamePassword(credentialsId: 'jenkins-snowflakedb-github-app',
usernameVariable: 'GITHUB_USER',
passwordVariable: 'GITHUB_TOKEN')
]) {
try {
sh '''\
|#!/bin/bash -e
|chmod +x $WORKSPACE/ci/test_revocation.sh
|$WORKSPACE/ci/test_revocation.sh
'''.stripMargin()
} finally {
archiveArtifacts artifacts: 'revocation-results.json,revocation-report.html', allowEmptyArchive: true
publishHTML(target: [
allowMissing: true,
alwaysLinkToLastBuild: true,
keepAll: true,
reportDir: '.',
reportFiles: 'revocation-report.html',
reportName: 'Revocation Validation Report'
])
}
}
}
}
)
}
}
pipeline {
agent { label 'high-memory-node' }
options { timestamps() }
environment {
COMMIT_SHA_LONG = sh(returnStdout: true, script: "echo \$(git rev-parse " + "HEAD)").trim()
// environment variables for semgrep_agent (for findings / analytics page)
// remove .git at the end
// remove SCM URL + .git at the end
BASELINE_BRANCH = "${env.CHANGE_TARGET}"
}
stages {
stage('Checkout') {
steps {
checkout scm
}
}
}
}
def wgetUpdateGithub(String state, String folder, String targetUrl, String seconds) {
def ghURL = "https://api.github.com/repos/snowflakedb/gosnowflake/statuses/$COMMIT_SHA_LONG"
def data = JsonOutput.toJson([state: "${state}", context: "jenkins/${folder}",target_url: "${targetUrl}"])
sh "wget ${ghURL} --spider -q --header='Authorization: token $GIT_PASSWORD' --post-data='${data}'"
}
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: Makefile
================================================
NAME:=gosnowflake
VERSION:=$(shell git describe --tags --abbrev=0)
REVISION:=$(shell git rev-parse --short HEAD)
COVFLAGS:=
## Run fmt, lint and test
all: fmt lint cov
include gosnowflake.mak
## Run tests
test_setup: test_teardown
python3 ci/scripts/hang_webserver.py 12345 &
test_teardown:
pkill -9 hang_webserver || true
test: deps test_setup
./ci/scripts/execute_tests.sh
## Run Coverage tests
cov:
make test COVFLAGS="-coverprofile=coverage.txt -covermode=atomic"
## Lint
lint: clint
## Format source codes
fmt: cfmt
@for c in $$(ls cmd); do \
(cd cmd/$$c; make fmt); \
done
## Install sample programs
install:
for c in $$(ls cmd); do \
(cd cmd/$$c; GOBIN=$$GOPATH/bin go install $$c.go); \
done
## Build fuzz tests
fuzz-build:
for c in $$(ls | grep -E "fuzz-*"); do \
(cd $$c; make fuzz-build); \
done
## Run fuzz-dsn
fuzz-dsn:
(cd fuzz-dsn; go-fuzz -bin=./dsn-fuzz.zip -workdir=.)
.PHONY: setup deps update test lint help fuzz-dsn
================================================
FILE: README.md
================================================
## Migrating to v2
**Version 2.0.0 of the Go Snowflake Driver was released on March 3rd, 2026.** This major version includes breaking changes that require code updates when migrating from v1.x.
### Key Changes and Migration Steps
#### 1. Update Import Paths
Update your `go.mod` to use v2:
```sh
go get -u github.com/snowflakedb/gosnowflake/v2
```
Update imports in your code:
```go
// Old (v1)
import "github.com/snowflakedb/gosnowflake"
// New (v2)
import "github.com/snowflakedb/gosnowflake/v2"
```
#### 2. Arrow Batches Moved to Separate Package
The public Arrow batches API now lives in `github.com/snowflakedb/gosnowflake/v2/arrowbatches`.
Importing that sub-package pulls in the additional Arrow compute dependency only for applications
that use Arrow batches directly.
**Migration:**
```go
import (
"context"
"database/sql/driver"
sf "github.com/snowflakedb/gosnowflake/v2"
"github.com/snowflakedb/gosnowflake/v2/arrowbatches"
)
ctx := arrowbatches.WithArrowBatches(context.Background())
var rows driver.Rows
err := conn.Raw(func(x any) error {
rows, err = x.(driver.QueryerContext).QueryContext(ctx, query, nil)
return err
})
if err != nil {
// handle error
}
batches, err := arrowbatches.GetArrowBatches(rows.(sf.SnowflakeRows))
if err != nil {
// handle error
}
```
**Optional helper mapping:**
- `sf.WithArrowBatchesTimestampOption` → `arrowbatches.WithTimestampOption`
- `sf.WithArrowBatchesUtf8Validation` → `arrowbatches.WithUtf8Validation`
- `sf.ArrowSnowflakeTimestampToTime` → `arrowbatches.ArrowSnowflakeTimestampToTime`
- `sf.WithOriginalTimestamp` → `arrowbatches.WithTimestampOption(ctx, arrowbatches.UseOriginalTimestamp)`
#### 3. Configuration Struct Changes
**Renamed fields:**
```go
// Old (v1)
config := &gosnowflake.Config{
KeepSessionAlive: true,
InsecureMode: true,
DisableTelemetry: true,
}
// New (v2)
config := &gosnowflake.Config{
ServerSessionKeepAlive: true, // Renamed for consistency with other drivers
DisableOCSPChecks: true, // Replaces InsecureMode
// DisableTelemetry removed - use CLIENT_TELEMETRY_ENABLED session parameter
}
```
**Removed fields:**
- `ClientIP` - No longer used
- `MfaToken` and `IdToken` - Now unexported
- `DisableTelemetry` - Use `CLIENT_TELEMETRY_ENABLED` session parameter instead
#### 4. Logger Changes
The built-in logger is now based on Go's standard `log/slog`:
```go
logger := gosnowflake.GetLogger()
_ = logger.SetLogLevel("debug")
```
For custom logging, continue implementing `SFLogger`.
If you want to customize the built-in slog handler, type-assert `GetLogger()` to `SFSlogLogger`
and call `SetHandler`.
#### 5. File Transfer Changes
**Configuration options:**
```go
// Old (v1)
options := &gosnowflake.SnowflakeFileTransferOptions{
RaisePutGetError: true,
GetFileToStream: true,
}
ctx = gosnowflake.WithFileStream(ctx, stream)
// New (v2)
// RaisePutGetError removed - errors always raised
// GetFileToStream removed - use WithFileGetStream instead
ctx = gosnowflake.WithFilePutStream(ctx, stream) // Renamed from WithFileStream
ctx = gosnowflake.WithFileGetStream(ctx, stream) // For GET operations
```
#### 6. Context and Function Changes
```go
// Old (v1)
ctx, err := gosnowflake.WithMultiStatement(ctx, 0)
if err != nil {
// handle error
}
// New (v2)
ctx = gosnowflake.WithMultiStatement(ctx, 0) // No error returned
```
```go
// Old (v1)
values := gosnowflake.Array(data)
// New (v2)
values, err := gosnowflake.Array(data) // Now returns error for unsupported types
if err != nil {
// handle error
}
```
#### 7. Nullable Options Combined
```go
// Old (v1)
ctx = gosnowflake.WithMapValuesNullable(ctx)
ctx = gosnowflake.WithArrayValuesNullable(ctx)
// New (v2)
ctx = gosnowflake.WithEmbeddedValuesNullable(ctx) // Handles both maps and arrays
```
#### 8. Session Parameter Changes
**Chunk download workers:**
```go
// Old (v1)
gosnowflake.MaxChunkDownloadWorkers = 10 // Global variable
// New (v2)
// Configure via CLIENT_PREFETCH_THREADS session parameter.
// NOTE: The default is 4.
db.Exec("ALTER SESSION SET CLIENT_PREFETCH_THREADS = 10")
```
#### 9. Transport Configuration
```go
import "crypto/tls"
// Old (v1)
gosnowflake.SnowflakeTransport = yourTransport
// New (v2)
config := &gosnowflake.Config{
Transporter: yourCustomTransport,
}
// Or, if you only need custom TLS settings/certificates:
tlsConfig := &tls.Config{
// ...
}
_ = gosnowflake.RegisterTLSConfig("custom", tlsConfig)
config.TLSConfigName = "custom"
```
#### 10. Environment Variable Fix
If you use the skip registration environment variable:
```sh
# Old (v1)
GOSNOWFLAKE_SKIP_REGISTERATION=true # Note the typo
# New (v2)
GOSNOWFLAKE_SKIP_REGISTRATION=true # Typo fixed
```
### Additional Resources
- Full list of changes: See [CHANGELOG.md](./CHANGELOG.md)
- Questions or issues: [GitHub Issues](https://github.com/snowflakedb/gosnowflake/issues)
## Support
For official support and urgent, production-impacting issues, please [contact Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge).
# Go Snowflake Driver
This topic provides instructions for installing, running, and modifying the Go Snowflake Driver. The driver supports Go's [database/sql](https://golang.org/pkg/database/sql/) package.
# Prerequisites
The following software packages are required to use the Go Snowflake Driver.
## Go
The latest driver requires the [Go language](https://golang.org/) 1.24 or higher. The supported operating systems are 64-bits Linux, Mac OS, and Windows, but you may run the driver on other platforms if the Go language works correctly on those platforms.
# Installation
If you don't have a project initialized, set it up.
```sh
go mod init example.com/snowflake
```
Get Gosnowflake source code, if not installed.
```sh
go get -u github.com/snowflakedb/gosnowflake/v2
```
# Docs
For detailed documentation and basic usage examples, please see the documentation at
[godoc.org](https://godoc.org/github.com/snowflakedb/gosnowflake/v2).
## Notes
This driver currently does not support GCP regional endpoints. Please ensure that any workloads using through this driver do not require support for regional endpoints on GCP. If you have questions about this, please contact Snowflake Support.
The driver uses Rust library called sf_mini_core, you can find its source code [here](https://github.com/snowflakedb/universal-driver/tree/main/sf_mini_core)
# Sample Programs
Snowflake provides a set of sample programs to test with. Set the environment variable ``$GOPATH`` to the top directory of your workspace, e.g., ``~/go`` and make certain to
include ``$GOPATH/bin`` in the environment variable ``$PATH``. Run the ``make`` command to build all sample programs.
```sh
make install
```
In the following example, the program ``select1.go`` is built and installed in ``$GOPATH/bin`` and can be run from the command line:
```sh
SNOWFLAKE_TEST_ACCOUNT= \
SNOWFLAKE_TEST_USER= \
SNOWFLAKE_TEST_PASSWORD= \
select1
Congrats! You have successfully run SELECT 1 with Snowflake DB!
```
# Development
The developer notes are hosted with the source code on [GitHub](https://github.com/snowflakedb/gosnowflake/v2).
## Testing Code
Set the Snowflake connection info in ``parameters.json``:
```json
{
"testconnection": {
"SNOWFLAKE_TEST_USER": "",
"SNOWFLAKE_TEST_PASSWORD": "",
"SNOWFLAKE_TEST_ACCOUNT": "",
"SNOWFLAKE_TEST_WAREHOUSE": "",
"SNOWFLAKE_TEST_DATABASE": "",
"SNOWFLAKE_TEST_SCHEMA": "",
"SNOWFLAKE_TEST_ROLE": "",
"SNOWFLAKE_TEST_DEBUG": "false"
}
}
```
Install [jq](https://stedolan.github.io/jq) so that the parameters can get parsed correctly, and run ``make test`` in your Go development environment:
```sh
make test
```
### Setting debug mode during tests
This is for debugging Large SQL statements (greater than 300 characters). If you want to enable debug mode, set `SNOWFLAKE_TEST_DEBUG` to `true` in `parameters.json`, or export it in your shell instance.
## customizing Logging Tags
If you would like to ensure that certain tags are always present in the logs, `RegisterClientLogContextHook` can be used in your init function. See example below.
```go
import "github.com/snowflakedb/gosnowflake/v2"
func init() {
// each time the logger is used, the logs will contain a REQUEST_ID field with requestID the value extracted
// from the context
gosnowflake.RegisterClientLogContextHook("REQUEST_ID", func(ctx context.Context) interface{} {
return requestIdFromContext(ctx)
})
}
```
## Setting Log Level
If you want to change the log level, `SetLogLevel` can be used in your init function like this:
```go
import "github.com/snowflakedb/gosnowflake/v2"
func init() {
// The following line changes the log level to debug
_ = gosnowflake.GetLogger().SetLogLevel("debug")
}
```
The following is a list of options you can pass in to set the level from least to most verbose:
- `"OFF"`
- `"fatal"`
- `"error"`
- `"warn"`
- `"info"`
- `"debug"`
- `"trace"`
## Capturing Code Coverage
Configure your testing environment as described above and run ``make cov``. The coverage percentage will be printed on the console when the testing completes.
```sh
make cov
```
For more detailed analysis, results are printed to ``coverage.txt`` in the project directory.
To read the coverage report, run:
```sh
go tool cover -html=coverage.txt
```
## Submitting Pull Requests
You may use your preferred editor to edit the driver code. Make certain to run ``make fmt lint`` before submitting any pull request to Snowflake. This command formats your source code according to the standard Go style and detects any coding style issues.
================================================
FILE: SECURITY.md
================================================
# Security Policy
Please refer to the Snowflake [HackerOne program](https://hackerone.com/snowflake?type=team) for our security policies and for reporting any security vulnerabilities.
For other security related questions and concerns, please contact the Snowflake security team at security@snowflake.com
================================================
FILE: aaa_test.go
================================================
package gosnowflake
import (
"testing"
)
func TestShowServerVersion(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQuery("SELECT CURRENT_VERSION()")
defer func() {
assertNilF(t, rows.Close())
}()
var version string
rows.Next()
assertNilF(t, rows.Scan(&version))
println(version)
})
}
================================================
FILE: arrow_chunk.go
================================================
package gosnowflake
import (
"bytes"
"context"
"encoding/base64"
"github.com/snowflakedb/gosnowflake/v2/internal/query"
"time"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/ipc"
"github.com/apache/arrow-go/v18/arrow/memory"
)
type arrowResultChunk struct {
reader *ipc.Reader
rowCount int
loc *time.Location
allocator memory.Allocator
}
func (arc *arrowResultChunk) decodeArrowChunk(ctx context.Context, rowType []query.ExecResponseRowType, highPrec bool, params *syncParams) ([]chunkRowType, error) {
defer arc.reader.Release()
logger.Debug("Arrow Decoder")
var chunkRows []chunkRowType
for arc.reader.Next() {
record := arc.reader.Record()
start := len(chunkRows)
numRows := int(record.NumRows())
logger.Debugf("rows in current record: %v", numRows)
columns := record.Columns()
chunkRows = append(chunkRows, make([]chunkRowType, numRows)...)
for i := start; i < start+numRows; i++ {
chunkRows[i].ArrowRow = make([]snowflakeValue, len(columns))
}
for colIdx, col := range columns {
values := make([]snowflakeValue, numRows)
if err := arrowToValues(ctx, values, rowType[colIdx], col, arc.loc, highPrec, params); err != nil {
return nil, err
}
for i := range values {
chunkRows[start+i].ArrowRow[colIdx] = values[i]
}
}
arc.rowCount += numRows
}
logger.Debugf("The number of chunk rows: %v", len(chunkRows))
return chunkRows, arc.reader.Err()
}
// decodeArrowBatchRaw reads raw (untransformed) arrow records from the IPC reader.
// The records are not transformed with arrow-compute; the arrowbatches sub-package
// handles transformation when the user calls ArrowBatch.Fetch().
func (arc *arrowResultChunk) decodeArrowBatchRaw() (*[]arrow.Record, error) {
var records []arrow.Record
defer arc.reader.Release()
for arc.reader.Next() {
record := arc.reader.Record()
record.Retain()
records = append(records, record)
}
return &records, arc.reader.Err()
}
// Build arrow chunk based on RowSet of base64
func buildFirstArrowChunk(rowsetBase64 string, loc *time.Location, alloc memory.Allocator) (arrowResultChunk, error) {
rowSetBytes, err := base64.StdEncoding.DecodeString(rowsetBase64)
if err != nil {
return arrowResultChunk{}, err
}
rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes), ipc.WithAllocator(alloc))
if err != nil {
return arrowResultChunk{}, err
}
return arrowResultChunk{rr, 0, loc, alloc}, nil
}
================================================
FILE: arrow_stream.go
================================================
package gosnowflake
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"fmt"
"io"
"maps"
"net/http"
"strconv"
"time"
"github.com/apache/arrow-go/v18/arrow/ipc"
"github.com/snowflakedb/gosnowflake/v2/internal/query"
)
// ArrowStreamLoader is a convenience interface for downloading
// Snowflake results via multiple Arrow Record Batch streams.
//
// Some queries from Snowflake do not return Arrow data regardless
// of the settings, such as "SHOW WAREHOUSES". In these cases,
// you'll find TotalRows() > 0 but GetBatches returns no batches
// and no errors. In this case, the data is accessible via JSONData
// with the actual types matching up to the metadata in RowTypes.
type ArrowStreamLoader interface {
GetBatches() ([]ArrowStreamBatch, error)
NextResultSet(ctx context.Context) error
TotalRows() int64
RowTypes() []query.ExecResponseRowType
Location() *time.Location
JSONData() [][]*string
}
// ArrowStreamBatch is a type describing a potentially yet-to-be-downloaded
// Arrow IPC stream. Call GetStream to download and retrieve an io.Reader
// that can be used with ipc.NewReader to get record batch results.
type ArrowStreamBatch struct {
idx int
numrows int64
scd *snowflakeArrowStreamChunkDownloader
Loc *time.Location
rr io.ReadCloser
}
// NumRows returns the total number of rows that the metadata stated should
// be in this stream of record batches.
func (asb *ArrowStreamBatch) NumRows() int64 { return asb.numrows }
// GetStream returns a stream of bytes consisting of an Arrow IPC Record
// batch stream. Close should be called on the returned stream when done
// to ensure no leaked memory.
func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) {
if asb.rr == nil {
if err := asb.downloadChunkStreamHelper(ctx); err != nil {
return nil, err
}
}
return asb.rr, nil
}
// streamWrapReader wraps an io.Reader so that Close closes the underlying body.
type streamWrapReader struct {
io.Reader
wrapped io.ReadCloser
}
func (w *streamWrapReader) Close() error {
if cl, ok := w.Reader.(io.ReadCloser); ok {
if err := cl.Close(); err != nil {
return err
}
}
return w.wrapped.Close()
}
func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error {
headers := make(map[string]string)
if len(asb.scd.ChunkHeader) > 0 {
maps.Copy(headers, asb.scd.ChunkHeader)
} else {
headers[headerSseCAlgorithm] = headerSseCAes
headers[headerSseCKey] = asb.scd.Qrmk
}
resp, err := asb.scd.FuncGet(ctx, asb.scd.sc, asb.scd.ChunkMetas[asb.idx].URL, headers, asb.scd.sc.rest.RequestTimeout)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
defer func() {
_ = resp.Body.Close()
}()
b, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
_ = b
return &SnowflakeError{
Number: ErrFailedToGetChunk,
SQLState: SQLStateConnectionFailure,
Message: fmt.Sprintf("failed to get chunk. idx: %v", asb.idx),
MessageArgs: []any{asb.idx},
}
}
defer func() {
if asb.rr == nil {
_ = resp.Body.Close()
}
}()
bufStream := bufio.NewReader(resp.Body)
gzipMagic, err := bufStream.Peek(2)
if err != nil {
return err
}
if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b {
bufStream0, err := gzip.NewReader(bufStream)
if err != nil {
return err
}
asb.rr = &streamWrapReader{Reader: bufStream0, wrapped: resp.Body}
} else {
asb.rr = &streamWrapReader{Reader: bufStream, wrapped: resp.Body}
}
return nil
}
type snowflakeArrowStreamChunkDownloader struct {
sc *snowflakeConn
ChunkMetas []query.ExecResponseChunk
Total int64
Qrmk string
ChunkHeader map[string]string
FuncGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error)
RowSet rowSetType
resultIDs []string
}
func (scd *snowflakeArrowStreamChunkDownloader) Location() *time.Location {
if scd.sc != nil {
return getCurrentLocation(&scd.sc.syncParams)
}
return nil
}
func (scd *snowflakeArrowStreamChunkDownloader) TotalRows() int64 { return scd.Total }
func (scd *snowflakeArrowStreamChunkDownloader) RowTypes() []query.ExecResponseRowType {
return scd.RowSet.RowType
}
func (scd *snowflakeArrowStreamChunkDownloader) JSONData() [][]*string {
return scd.RowSet.JSON
}
func (scd *snowflakeArrowStreamChunkDownloader) maybeFirstBatch() ([]byte, error) {
if scd.RowSet.RowSetBase64 == "" {
return nil, nil
}
rowSetBytes, err := base64.StdEncoding.DecodeString(scd.RowSet.RowSetBase64)
if err != nil {
logger.Warnf("skipping first batch as it is not a valid base64 response. %v", err)
return nil, err
}
rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes))
if err != nil {
logger.Warnf("skipping first batch as it is not a valid IPC stream. %v", err)
return nil, err
}
rr.Release()
return rowSetBytes, nil
}
func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamBatch, err error) {
chunkMetaLen := len(scd.ChunkMetas)
loc := scd.Location()
out = make([]ArrowStreamBatch, chunkMetaLen, chunkMetaLen+1)
toFill := out
rowSetBytes, err := scd.maybeFirstBatch()
if err != nil {
return nil, err
}
if len(rowSetBytes) > 0 {
out = out[:chunkMetaLen+1]
out[0] = ArrowStreamBatch{
scd: scd,
Loc: loc,
rr: io.NopCloser(bytes.NewReader(rowSetBytes)),
}
toFill = out[1:]
}
var totalCounted int64
for i := range toFill {
toFill[i] = ArrowStreamBatch{
idx: i,
numrows: int64(scd.ChunkMetas[i].RowCount),
Loc: loc,
scd: scd,
}
totalCounted += int64(scd.ChunkMetas[i].RowCount)
}
if len(rowSetBytes) > 0 {
out[0].numrows = scd.Total - totalCounted
}
return
}
func (scd *snowflakeArrowStreamChunkDownloader) NextResultSet(ctx context.Context) error {
if !scd.hasNextResultSet() {
return io.EOF
}
resultID := scd.resultIDs[0]
scd.resultIDs = scd.resultIDs[1:]
resultPath := fmt.Sprintf(urlQueriesResultFmt, resultID)
resp, err := scd.sc.getQueryResultResp(ctx, resultPath)
if err != nil {
return err
}
if !resp.Success {
code, err := strconv.Atoi(resp.Code)
if err != nil {
logger.WithContext(ctx).Errorf("error while parsing code: %v", err)
}
return exceptionTelemetry(&SnowflakeError{
Number: code,
SQLState: resp.Data.SQLState,
Message: resp.Message,
QueryID: resp.Data.QueryID,
}, scd.sc)
}
scd.ChunkMetas = resp.Data.Chunks
scd.Total = resp.Data.Total
scd.Qrmk = resp.Data.Qrmk
scd.ChunkHeader = resp.Data.ChunkHeaders
scd.RowSet = rowSetType{
RowType: resp.Data.RowType,
JSON: resp.Data.RowSet,
RowSetBase64: resp.Data.RowSetBase64,
}
return nil
}
func (scd *snowflakeArrowStreamChunkDownloader) hasNextResultSet() bool {
return len(scd.resultIDs) > 0
}
================================================
FILE: arrow_test.go
================================================
package gosnowflake
import (
"bytes"
"context"
"fmt"
"math/big"
"reflect"
"strings"
"testing"
"time"
"github.com/apache/arrow-go/v18/arrow/memory"
ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow"
"database/sql/driver"
)
func TestArrowBatchDataProvider(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
ctx := ia.EnableArrowBatches(context.Background())
query := "select '0.1':: DECIMAL(38, 19) as c"
var rows driver.Rows
var err error
err = dbt.conn.Raw(func(x any) error {
queryer, implementsQueryContext := x.(driver.QueryerContext)
assertTrueF(t, implementsQueryContext, "snowflake connection driver does not implement queryerContext")
rows, err = queryer.QueryContext(ctx, query, nil)
return err
})
assertNilF(t, err, "error running select query")
sfRows, isSfRows := rows.(SnowflakeRows)
assertTrueF(t, isSfRows, "rows should be snowflakeRows")
provider, isProvider := sfRows.(ia.BatchDataProvider)
assertTrueF(t, isProvider, "rows should implement BatchDataProvider")
info, err := provider.GetArrowBatches()
assertNilF(t, err, "error getting arrow batch data")
assertNotEqualF(t, len(info.Batches), 0, "should have at least one batch")
// Verify raw records are available for the first batch
batch := info.Batches[0]
assertNotNilF(t, batch.Records, "first batch should have pre-decoded records")
records := *batch.Records
assertNotEqualF(t, len(records), 0, "should have at least one record")
// Verify column 0 has data (raw decimal value)
strVal := records[0].Column(0).ValueStr(0)
assertTrueF(t, len(strVal) > 0, fmt.Sprintf("column should have a value, got: %s", strVal))
})
}
func TestArrowBigInt(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
testcases := []struct {
num string
prec int
sc int
}{
{"10000000000000000000000000000000000000", 38, 0},
{"-10000000000000000000000000000000000000", 38, 0},
{"12345678901234567890123456789012345678", 38, 0}, // #pragma: allowlist secret
{"-12345678901234567890123456789012345678", 38, 0},
{"99999999999999999999999999999999999999", 38, 0},
{"-99999999999999999999999999999999999999", 38, 0},
}
for _, tc := range testcases {
rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()),
fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
if !rows.Next() {
dbt.Error("failed to query")
}
defer rows.Close()
var v *big.Int
if err := rows.Scan(&v); err != nil {
dbt.Errorf("failed to scan. %#v", err)
}
b, ok := new(big.Int).SetString(tc.num, 10)
if !ok {
dbt.Errorf("failed to convert %v big.Int.", tc.num)
}
if v.Cmp(b) != 0 {
dbt.Errorf("big.Int value mismatch: expected %v, got %v", b, v)
}
}
})
}
func TestArrowBigFloat(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
testcases := []struct {
num string
prec int
sc int
}{
{"1.23", 30, 2},
{"1.0000000000000000000000000000000000000", 38, 37},
{"-1.0000000000000000000000000000000000000", 38, 37},
{"1.2345678901234567890123456789012345678", 38, 37},
{"-1.2345678901234567890123456789012345678", 38, 37},
{"9.9999999999999999999999999999999999999", 38, 37},
{"-9.9999999999999999999999999999999999999", 38, 37},
}
for _, tc := range testcases {
rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()),
fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
if !rows.Next() {
dbt.Error("failed to query")
}
defer rows.Close()
var v *big.Float
if err := rows.Scan(&v); err != nil {
dbt.Errorf("failed to scan. %#v", err)
}
prec := v.Prec()
b, ok := new(big.Float).SetPrec(prec).SetString(tc.num)
if !ok {
dbt.Errorf("failed to convert %v to big.Float.", tc.num)
}
if v.Cmp(b) != 0 {
dbt.Errorf("big.Float value mismatch: expected %v, got %v", b, v)
}
}
})
}
func TestArrowIntPrecision(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec(forceJSON)
intTestcases := []struct {
num string
prec int
sc int
}{
{"10000000000000000000000000000000000000", 38, 0},
{"-10000000000000000000000000000000000000", 38, 0},
{"12345678901234567890123456789012345678", 38, 0}, // pragma: allowlist secret
{"-12345678901234567890123456789012345678", 38, 0},
{"99999999999999999999999999999999999999", 38, 0},
{"-99999999999999999999999999999999999999", 38, 0},
}
t.Run("arrow_disabled_scan_int64", func(t *testing.T) {
for _, tc := range intTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v int64
if err := rows.Scan(&v); err == nil {
t.Error("should fail to scan")
}
}
})
t.Run("arrow_disabled_scan_string", func(t *testing.T) {
for _, tc := range intTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v string
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
if v != tc.num {
t.Errorf("string value mismatch: expected %v, got %v", tc.num, v)
}
}
})
dbt.mustExec(forceARROW)
t.Run("arrow_enabled_scan_big_int", func(t *testing.T) {
for _, tc := range intTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v string
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
if !strings.EqualFold(v, tc.num) {
t.Errorf("int value mismatch: expected %v, got %v", tc.num, v)
}
}
})
t.Run("arrow_high_precision_enabled_scan_big_int", func(t *testing.T) {
for _, tc := range intTestcases {
rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v *big.Int
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
b, ok := new(big.Int).SetString(tc.num, 10)
if !ok {
t.Errorf("failed to convert %v big.Int.", tc.num)
}
if v.Cmp(b) != 0 {
t.Errorf("big.Int value mismatch: expected %v, got %v", b, v)
}
}
})
})
}
// TestArrowFloatPrecision tests the different variable types allowed in the
// rows.Scan() method. Note that for lower precision types we do not attempt
// to check the value as precision could be lost.
func TestArrowFloatPrecision(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec(forceJSON)
fltTestcases := []struct {
num string
prec int
sc int
}{
{"1.23", 30, 2},
{"1.0000000000000000000000000000000000000", 38, 37},
{"-1.0000000000000000000000000000000000000", 38, 37},
{"1.2345678901234567890123456789012345678", 38, 37},
{"-1.2345678901234567890123456789012345678", 38, 37},
{"9.9999999999999999999999999999999999999", 38, 37},
{"-9.9999999999999999999999999999999999999", 38, 37},
}
t.Run("arrow_disabled_scan_float64", func(t *testing.T) {
for _, tc := range fltTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v float64
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
}
})
t.Run("arrow_disabled_scan_float32", func(t *testing.T) {
for _, tc := range fltTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v float32
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
}
})
t.Run("arrow_disabled_scan_string", func(t *testing.T) {
for _, tc := range fltTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v string
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
if !strings.EqualFold(v, tc.num) {
t.Errorf("int value mismatch: expected %v, got %v", tc.num, v)
}
}
})
dbt.mustExec(forceARROW)
t.Run("arrow_enabled_scan_float64", func(t *testing.T) {
for _, tc := range fltTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v float64
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
}
})
t.Run("arrow_enabled_scan_float32", func(t *testing.T) {
for _, tc := range fltTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v float32
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
}
})
t.Run("arrow_enabled_scan_string", func(t *testing.T) {
for _, tc := range fltTestcases {
rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v string
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
if v != tc.num {
t.Errorf("string value mismatch: expected %v, got %v", tc.num, v)
}
}
})
t.Run("arrow_high_precision_enabled_scan_big_float", func(t *testing.T) {
for _, tc := range fltTestcases {
rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
defer rows.Close()
if !rows.Next() {
t.Error("failed to query")
}
var v *big.Float
if err := rows.Scan(&v); err != nil {
t.Errorf("failed to scan. %#v", err)
}
prec := v.Prec()
b, ok := new(big.Float).SetPrec(prec).SetString(tc.num)
if !ok {
t.Errorf("failed to convert %v to big.Float.", tc.num)
}
if v.Cmp(b) != 0 {
t.Errorf("big.Float value mismatch: expected %v, got %v", b, v)
}
}
})
})
}
func TestArrowTimePrecision(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE t (col5 TIME(5), col6 TIME(6), col7 TIME(7), col8 TIME(8));")
defer dbt.mustExec("DROP TABLE IF EXISTS t")
dbt.mustExec("INSERT INTO t VALUES ('23:59:59.99999', '23:59:59.999999', '23:59:59.9999999', '23:59:59.99999999');")
rows := dbt.mustQuery("select * from t")
defer rows.Close()
var c5, c6, c7, c8 time.Time
for rows.Next() {
if err := rows.Scan(&c5, &c6, &c7, &c8); err != nil {
t.Errorf("values were not scanned: %v", err)
}
}
nano := 999999990
expected := time.Time{}.Add(23*time.Hour + 59*time.Minute + 59*time.Second + 99*time.Millisecond)
if c8.Unix() != expected.Unix() || c8.Nanosecond() != nano {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c8)
}
if c7.Unix() != expected.Unix() || c7.Nanosecond() != nano-(nano%1e2) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c7)
}
if c6.Unix() != expected.Unix() || c6.Nanosecond() != nano-(nano%1e3) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c6)
}
if c5.Unix() != expected.Unix() || c5.Nanosecond() != nano-(nano%1e4) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c5)
}
dbt.mustExec(`CREATE TABLE t_ntz (
col1 TIMESTAMP_NTZ(1),
col2 TIMESTAMP_NTZ(2),
col3 TIMESTAMP_NTZ(3),
col4 TIMESTAMP_NTZ(4),
col5 TIMESTAMP_NTZ(5),
col6 TIMESTAMP_NTZ(6),
col7 TIMESTAMP_NTZ(7),
col8 TIMESTAMP_NTZ(8)
);`)
defer dbt.mustExec("DROP TABLE IF EXISTS t_ntz")
dbt.mustExec(`INSERT INTO t_ntz VALUES (
'9999-12-31T23:59:59.9',
'9999-12-31T23:59:59.99',
'9999-12-31T23:59:59.999',
'9999-12-31T23:59:59.9999',
'9999-12-31T23:59:59.99999',
'9999-12-31T23:59:59.999999',
'9999-12-31T23:59:59.9999999',
'9999-12-31T23:59:59.99999999'
);`)
rows2 := dbt.mustQuery("select * from t_ntz")
defer rows2.Close()
var c1, c2, c3, c4 time.Time
for rows2.Next() {
if err := rows2.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8); err != nil {
t.Errorf("values were not scanned: %v", err)
}
}
expected = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
if c8.Unix() != expected.Unix() || c8.Nanosecond() != nano {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c8)
}
if c7.Unix() != expected.Unix() || c7.Nanosecond() != nano-(nano%1e2) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c7)
}
if c6.Unix() != expected.Unix() || c6.Nanosecond() != nano-(nano%1e3) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c6)
}
if c5.Unix() != expected.Unix() || c5.Nanosecond() != nano-(nano%1e4) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c5)
}
if c4.Unix() != expected.Unix() || c4.Nanosecond() != nano-(nano%1e5) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c4)
}
if c3.Unix() != expected.Unix() || c3.Nanosecond() != nano-(nano%1e6) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c3)
}
if c2.Unix() != expected.Unix() || c2.Nanosecond() != nano-(nano%1e7) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c2)
}
if c1.Unix() != expected.Unix() || c1.Nanosecond() != nano-(nano%1e8) {
t.Errorf("the value did not match. expected: %v, got: %v", expected, c1)
}
})
}
func TestArrowVariousTypes(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQueryContext(
WithHigherPrecision(context.Background()), selectVariousTypes)
defer rows.Close()
if !rows.Next() {
dbt.Error("failed to query")
}
cc, err := rows.Columns()
if err != nil {
dbt.Errorf("columns: %v", cc)
}
ct, err := rows.ColumnTypes()
if err != nil {
dbt.Errorf("column types: %v", ct)
}
var v1 *big.Float
var v2, v2a int
var v3 string
var v4 float64
var v5 []byte
var v6 bool
if err = rows.Scan(&v1, &v2, &v2a, &v3, &v4, &v5, &v6); err != nil {
dbt.Errorf("failed to scan: %#v", err)
}
if v1.Cmp(big.NewFloat(1.0)) != 0 {
dbt.Errorf("failed to scan. %#v", *v1)
}
if ct[0].Name() != "C1" || ct[1].Name() != "C2" || ct[2].Name() != "C2A" || ct[3].Name() != "C3" || ct[4].Name() != "C4" || ct[5].Name() != "C5" || ct[6].Name() != "C6" {
dbt.Errorf("failed to get column names: %#v", ct)
}
if ct[0].ScanType() != reflect.TypeFor[*big.Float]() {
dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeFor[float64](), ct[0].ScanType())
}
if ct[1].ScanType() != reflect.TypeFor[int64]() {
dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeFor[int64](), ct[1].ScanType())
}
if ct[2].ScanType() != reflect.TypeFor[*big.Int]() {
dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeFor[*big.Int](), ct[2].ScanType())
}
var pr, sc int64
var cLen int64
pr, sc = dbt.mustDecimalSize(ct[0])
if pr != 30 || sc != 2 {
dbt.Errorf("failed to get precision and scale. %#v", ct[0])
}
dbt.mustFailLength(ct[0])
if canNull := dbt.mustNullable(ct[0]); canNull {
dbt.Errorf("failed to get nullable. %#v", ct[0])
}
if cLen != 0 {
dbt.Errorf("failed to get length. %#v", ct[0])
}
if v2 != 2 {
dbt.Errorf("failed to scan. %#v", v2)
}
pr, sc = dbt.mustDecimalSize(ct[1])
if pr != 18 || sc != 0 {
dbt.Errorf("failed to get precision and scale. %#v", ct[1])
}
dbt.mustFailLength(ct[1])
if canNull := dbt.mustNullable(ct[1]); canNull {
dbt.Errorf("failed to get nullable. %#v", ct[1])
}
if v2a != 22 {
dbt.Errorf("failed to scan. %#v", v2a)
}
dbt.mustFailLength(ct[2])
if canNull := dbt.mustNullable(ct[2]); canNull {
dbt.Errorf("failed to get nullable. %#v", ct[2])
}
if v3 != "t3" {
dbt.Errorf("failed to scan. %#v", v3)
}
dbt.mustFailDecimalSize(ct[3])
if cLen = dbt.mustLength(ct[3]); cLen != 2 {
dbt.Errorf("failed to get length. %#v", ct[3])
}
if canNull := dbt.mustNullable(ct[3]); canNull {
dbt.Errorf("failed to get nullable. %#v", ct[3])
}
if v4 != 4.2 {
dbt.Errorf("failed to scan. %#v", v4)
}
dbt.mustFailDecimalSize(ct[4])
dbt.mustFailLength(ct[4])
if canNull := dbt.mustNullable(ct[4]); canNull {
dbt.Errorf("failed to get nullable. %#v", ct[4])
}
if !bytes.Equal(v5, []byte{0xab, 0xcd}) {
dbt.Errorf("failed to scan. %#v", v5)
}
dbt.mustFailDecimalSize(ct[5])
if cLen = dbt.mustLength(ct[5]); cLen != 8388608 { // BINARY
dbt.Errorf("failed to get length. %#v", ct[5])
}
if canNull := dbt.mustNullable(ct[5]); canNull {
dbt.Errorf("failed to get nullable. %#v", ct[5])
}
if !v6 {
dbt.Errorf("failed to scan. %#v", v6)
}
dbt.mustFailDecimalSize(ct[6])
dbt.mustFailLength(ct[6])
})
}
func TestArrowMemoryCleanedUp(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)
runDBTest(t, func(dbt *DBTest) {
ctx := WithArrowAllocator(
context.Background(),
mem,
)
rows := dbt.mustQueryContext(ctx, "select 1 UNION select 2 ORDER BY 1")
defer rows.Close()
var v int
assertTrueF(t, rows.Next())
assertNilF(t, rows.Scan(&v))
assertEqualE(t, v, 1)
assertTrueF(t, rows.Next())
assertNilF(t, rows.Scan(&v))
assertEqualE(t, v, 2)
assertFalseE(t, rows.Next())
})
}
================================================
FILE: arrowbatches/batches.go
================================================
package arrowbatches
import (
"cmp"
"context"
"github.com/snowflakedb/gosnowflake/v2/internal/query"
"github.com/snowflakedb/gosnowflake/v2/internal/types"
"time"
sf "github.com/snowflakedb/gosnowflake/v2"
ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/memory"
)
// ArrowBatch represents a chunk of data retrievable in arrow.Record format.
type ArrowBatch struct {
raw ia.BatchRaw
rowTypes []query.ExecResponseRowType
allocator memory.Allocator
ctx context.Context
}
// WithContext sets the context for subsequent Fetch calls on this batch.
func (rb *ArrowBatch) WithContext(ctx context.Context) *ArrowBatch {
rb.ctx = ctx
return rb
}
// Fetch returns an array of arrow.Record representing this batch's data.
// Records are transformed from Snowflake's internal format to standard Arrow types.
func (rb *ArrowBatch) Fetch() (*[]arrow.Record, error) {
var rawRecords *[]arrow.Record
ctx := cmp.Or(rb.ctx, context.Background())
if rb.raw.Records != nil {
rawRecords = rb.raw.Records
} else if rb.raw.Download != nil {
recs, rowCount, err := rb.raw.Download(ctx)
if err != nil {
return nil, err
}
rawRecords = recs
rb.raw.Records = recs
rb.raw.RowCount = rowCount
}
if rawRecords == nil || len(*rawRecords) == 0 {
empty := make([]arrow.Record, 0)
return &empty, nil
}
var transformed []arrow.Record
for i, rec := range *rawRecords {
newRec, err := arrowToRecord(ctx, rec, rb.allocator, rb.rowTypes, rb.raw.Location)
if err != nil {
for _, t := range transformed {
t.Release()
}
for _, r := range (*rawRecords)[i:] {
r.Release()
}
rb.raw.Records = nil
return nil, err
}
transformed = append(transformed, newRec)
rec.Release()
}
rb.raw.Records = nil
rb.raw.RowCount = countArrowBatchRows(&transformed)
return &transformed, nil
}
// GetRowCount returns the number of rows in this batch.
func (rb *ArrowBatch) GetRowCount() int {
return rb.raw.RowCount
}
// GetLocation returns the timezone location for this batch.
func (rb *ArrowBatch) GetLocation() *time.Location {
return rb.raw.Location
}
// GetRowTypes returns the column metadata for this batch.
func (rb *ArrowBatch) GetRowTypes() []query.ExecResponseRowType {
return rb.rowTypes
}
// ArrowSnowflakeTimestampToTime converts an original Snowflake timestamp to time.Time.
func (rb *ArrowBatch) ArrowSnowflakeTimestampToTime(rec arrow.Record, colIdx int, recIdx int) *time.Time {
scale := int(rb.rowTypes[colIdx].Scale)
dbType := rb.rowTypes[colIdx].Type
return ArrowSnowflakeTimestampToTime(rec.Column(colIdx), types.GetSnowflakeType(dbType), scale, recIdx, rb.raw.Location)
}
// GetArrowBatches retrieves arrow batches from SnowflakeRows.
// The rows must have been queried with arrowbatches.WithArrowBatches(ctx).
func GetArrowBatches(rows sf.SnowflakeRows) ([]*ArrowBatch, error) {
provider, ok := rows.(ia.BatchDataProvider)
if !ok {
return nil, &sf.SnowflakeError{
Number: sf.ErrNotImplemented,
Message: "rows do not support arrow batch data",
}
}
info, err := provider.GetArrowBatches()
if err != nil {
return nil, err
}
batches := make([]*ArrowBatch, len(info.Batches))
for i, raw := range info.Batches {
batches[i] = &ArrowBatch{
raw: raw,
rowTypes: info.RowTypes,
allocator: info.Allocator,
ctx: info.Ctx,
}
}
return batches, nil
}
func countArrowBatchRows(recs *[]arrow.Record) (cnt int) {
for _, r := range *recs {
cnt += int(r.NumRows())
}
return
}
// GetAllocator returns the memory allocator for this batch.
func (rb *ArrowBatch) GetAllocator() memory.Allocator {
return rb.allocator
}
================================================
FILE: arrowbatches/batches_test.go
================================================
package arrowbatches
import (
"context"
"crypto/rsa"
"crypto/x509"
"database/sql"
"database/sql/driver"
"encoding/pem"
"errors"
"fmt"
"math"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/memory"
sf "github.com/snowflakedb/gosnowflake/v2"
ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow"
)
// testConn holds a reusable database connection for running multiple queries.
type testConn struct {
db *sql.DB
conn *sql.Conn
}
// repoRoot walks up from the current working directory to find the directory
// containing go.mod, which is the repository root.
func repoRoot(t *testing.T) string {
t.Helper()
dir, err := os.Getwd()
if err != nil {
t.Fatalf("failed to get working directory: %v", err)
}
for {
if _, err = os.Stat(filepath.Join(dir, "go.mod")); err == nil {
return dir
}
if !os.IsNotExist(err) {
t.Fatalf("failed to stat go.mod in %q: %v", dir, err)
}
parent := filepath.Dir(dir)
if parent == dir {
t.Fatal("could not find repository root (no go.mod found)")
}
dir = parent
}
}
// readPrivateKey reads an RSA private key from a PEM file. If the path is
// relative it is resolved against the repository root so that tests in
// sub-packages work with repo-root-relative paths.
func readPrivateKey(t *testing.T, path string) *rsa.PrivateKey {
t.Helper()
if !filepath.IsAbs(path) {
path = filepath.Join(repoRoot(t), path)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("failed to read private key file %q: %v", path, err)
}
block, _ := pem.Decode(data)
if block == nil {
t.Fatalf("failed to decode PEM block from %q", path)
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
t.Fatalf("failed to parse private key from %q: %v", path, err)
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
t.Fatalf("private key in %q is not RSA (got %T)", path, key)
}
return rsaKey
}
func testConfig(t *testing.T) *sf.Config {
t.Helper()
configParams := []*sf.ConfigParam{
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
{Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true},
{Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false},
{Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false},
{Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false},
{Name: "Warehouse", EnvName: "SNOWFLAKE_TEST_WAREHOUSE", FailOnMissing: false},
}
isJWT := os.Getenv("SNOWFLAKE_TEST_AUTHENTICATOR") == "SNOWFLAKE_JWT"
if !isJWT {
configParams = append(configParams,
&sf.ConfigParam{Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true},
)
}
cfg, err := sf.GetConfigFromEnv(configParams)
if err != nil {
t.Fatalf("failed to get config from environment: %v", err)
}
if isJWT {
privKeyPath := os.Getenv("SNOWFLAKE_TEST_PRIVATE_KEY")
if privKeyPath == "" {
t.Fatal("SNOWFLAKE_TEST_PRIVATE_KEY must be set for JWT authentication")
}
cfg.PrivateKey = readPrivateKey(t, privKeyPath)
cfg.Authenticator = sf.AuthTypeJwt
}
tz := "UTC"
if cfg.Params == nil {
cfg.Params = make(map[string]*string)
}
cfg.Params["timezone"] = &tz
return cfg
}
func openTestConn(ctx context.Context, t *testing.T) *testConn {
t.Helper()
cfg := testConfig(t)
dsn, err := sf.DSN(cfg)
if err != nil {
t.Fatalf("failed to create DSN: %v", err)
}
db, err := sql.Open("snowflake", dsn)
if err != nil {
t.Fatalf("failed to open db: %v", err)
}
conn, err := db.Conn(ctx)
if err != nil {
db.Close()
t.Fatalf("failed to get connection: %v", err)
}
return &testConn{db: db, conn: conn}
}
func (tc *testConn) close() {
tc.conn.Close()
tc.db.Close()
}
// queryRows executes a query on the existing connection and returns
// SnowflakeRows plus a function to close just the rows.
func (tc *testConn) queryRows(ctx context.Context, t *testing.T, query string) (sf.SnowflakeRows, func()) {
t.Helper()
var rows driver.Rows
var err error
err = tc.conn.Raw(func(x any) error {
queryer, ok := x.(driver.QueryerContext)
if !ok {
return fmt.Errorf("connection does not implement QueryerContext")
}
rows, err = queryer.QueryContext(ctx, query, nil)
return err
})
if err != nil {
t.Fatalf("failed to execute query: %v", err)
}
sfRows, ok := rows.(sf.SnowflakeRows)
if !ok {
rows.Close()
t.Fatalf("rows do not implement SnowflakeRows")
}
return sfRows, func() { rows.Close() }
}
// queryRawRows is a convenience wrapper that opens a new connection,
// runs a single query, and returns SnowflakeRows with a full cleanup.
func queryRawRows(ctx context.Context, t *testing.T, query string) (sf.SnowflakeRows, func()) {
t.Helper()
tc := openTestConn(ctx, t)
sfRows, closeRows := tc.queryRows(ctx, t, query)
return sfRows, func() {
closeRows()
tc.close()
}
}
func TestGetArrowBatches(t *testing.T) {
ctx := WithArrowBatches(context.Background())
sfRows, cleanup := queryRawRows(ctx, t, "SELECT 1 AS num, 'hello' AS str")
defer cleanup()
batches, err := GetArrowBatches(sfRows)
if err != nil {
t.Fatalf("GetArrowBatches failed: %v", err)
}
if len(batches) == 0 {
t.Fatal("expected at least one batch")
}
records, err := batches[0].Fetch()
if err != nil {
t.Fatalf("Fetch failed: %v", err)
}
if records == nil || len(*records) == 0 {
t.Fatal("expected at least one record")
}
rec := (*records)[0]
defer rec.Release()
if rec.NumCols() != 2 {
t.Fatalf("expected 2 columns, got %d", rec.NumCols())
}
if rec.NumRows() != 1 {
t.Fatalf("expected 1 row, got %d", rec.NumRows())
}
}
func TestGetArrowBatchesHighPrecision(t *testing.T) {
ctx := sf.WithHigherPrecision(WithArrowBatches(context.Background()))
sfRows, cleanup := queryRawRows(ctx, t, "SELECT '0.1'::DECIMAL(38, 19) AS c")
defer cleanup()
batches, err := GetArrowBatches(sfRows)
if err != nil {
t.Fatalf("GetArrowBatches failed: %v", err)
}
if len(batches) == 0 {
t.Fatal("expected at least one batch")
}
records, err := batches[0].Fetch()
if err != nil {
t.Fatalf("Fetch failed: %v", err)
}
if records == nil || len(*records) == 0 {
t.Fatal("expected at least one record")
}
rec := (*records)[0]
defer rec.Release()
strVal := rec.Column(0).ValueStr(0)
expected := "1000000000000000000"
if strVal != expected {
t.Fatalf("expected %q, got %q", expected, strVal)
}
}
func TestGetArrowBatchesLargeResultSet(t *testing.T) {
numrows := 3000
pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
ctx := sf.WithArrowAllocator(WithArrowBatches(context.Background()), pool)
query := fmt.Sprintf("SELECT SEQ8(), RANDSTR(1000, RANDOM()) FROM TABLE(GENERATOR(ROWCOUNT=>%v))", numrows)
sfRows, cleanup := queryRawRows(ctx, t, query)
defer cleanup()
batches, err := GetArrowBatches(sfRows)
if err != nil {
t.Fatalf("GetArrowBatches failed: %v", err)
}
if len(batches) == 0 {
t.Fatal("expected at least one batch")
}
maxWorkers := 10
type count struct {
mu sync.Mutex
val int
}
cnt := &count{}
var wg sync.WaitGroup
work := make(chan int, len(batches))
for range maxWorkers {
wg.Add(1)
go func() {
defer wg.Done()
for i := range work {
recs, fetchErr := batches[i].Fetch()
if fetchErr != nil {
t.Errorf("Fetch failed for batch %d: %v", i, fetchErr)
return
}
for _, r := range *recs {
cnt.mu.Lock()
cnt.val += int(r.NumRows())
cnt.mu.Unlock()
r.Release()
}
}
}()
}
for i := range batches {
work <- i
}
close(work)
wg.Wait()
if cnt.val != numrows {
t.Fatalf("row count mismatch: expected %d, got %d", numrows, cnt.val)
}
}
func TestGetArrowBatchesWithTimestampOption(t *testing.T) {
ctx := WithTimestampOption(
WithArrowBatches(context.Background()),
UseOriginalTimestamp,
)
sfRows, cleanup := queryRawRows(ctx, t, "SELECT TO_TIMESTAMP_NTZ('2024-01-15 13:45:30.123456789') AS ts")
defer cleanup()
batches, err := GetArrowBatches(sfRows)
if err != nil {
t.Fatalf("GetArrowBatches failed: %v", err)
}
if len(batches) == 0 {
t.Fatal("expected at least one batch")
}
records, err := batches[0].Fetch()
if err != nil {
t.Fatalf("Fetch failed: %v", err)
}
if records == nil || len(*records) == 0 {
t.Fatal("expected at least one record")
}
rec := (*records)[0]
defer rec.Release()
if rec.NumRows() != 1 {
t.Fatalf("expected 1 row, got %d", rec.NumRows())
}
if rec.NumCols() != 1 {
t.Fatalf("expected 1 column, got %d", rec.NumCols())
}
}
func TestGetArrowBatchesJSONResponseError(t *testing.T) {
ctx := WithArrowBatches(context.Background())
cfg := testConfig(t)
dsn, err := sf.DSN(cfg)
if err != nil {
t.Fatalf("failed to create DSN: %v", err)
}
db, err := sql.Open("snowflake", dsn)
if err != nil {
t.Fatalf("failed to open db: %v", err)
}
defer db.Close()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatalf("failed to get connection: %v", err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, "ALTER SESSION SET GO_QUERY_RESULT_FORMAT = json")
if err != nil {
t.Fatalf("failed to set JSON format: %v", err)
}
var rows driver.Rows
err = conn.Raw(func(x any) error {
queryer, ok := x.(driver.QueryerContext)
if !ok {
return fmt.Errorf("connection does not implement QueryerContext")
}
rows, err = queryer.QueryContext(ctx, "SELECT 'hello'", nil)
return err
})
if err != nil {
t.Fatalf("failed to execute query: %v", err)
}
defer rows.Close()
sfRows, ok := rows.(sf.SnowflakeRows)
if !ok {
t.Fatal("rows do not implement SnowflakeRows")
}
_, err = GetArrowBatches(sfRows)
if err == nil {
t.Fatal("expected error when using arrow batches with JSON response")
}
var se *sf.SnowflakeError
if !errors.As(err, &se) {
t.Fatalf("expected SnowflakeError, got %T: %v", err, err)
}
if se.Number != sf.ErrNonArrowResponseInArrowBatches {
t.Fatalf("expected error code %d, got %d", sf.ErrNonArrowResponseInArrowBatches, se.Number)
}
}
// TestTimestampConversionDistantDates tests all 10 timestamp scales (0-9)
// because each scale exercises a mathematically distinct code path in
// extractEpoch/extractFraction (converter.go). Past bugs have been
// scale-specific: SNOW-526255 (time scale for arrow) and SNOW-2091309
// (precision loss at scale 0). Do not reduce the scale range.
func TestTimestampConversionDistantDates(t *testing.T) {
timestamps := [2]string{
"9999-12-12 23:59:59.999999999",
"0001-01-01 00:00:00.000000000",
}
tsTypes := [3]string{"TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ"}
precisions := []struct {
name string
option ia.TimestampOption
unit arrow.TimeUnit
expectError bool
}{
{"second", UseSecondTimestamp, arrow.Second, false},
{"millisecond", UseMillisecondTimestamp, arrow.Millisecond, false},
{"microsecond", UseMicrosecondTimestamp, arrow.Microsecond, false},
{"nanosecond", UseNanosecondTimestamp, arrow.Nanosecond, true},
}
for _, prec := range precisions {
t.Run(prec.name, func(t *testing.T) {
t.Parallel()
pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
ctx := sf.WithArrowAllocator(
WithTimestampOption(WithArrowBatches(context.Background()), prec.option),
pool,
)
tc := openTestConn(ctx, t)
defer tc.close()
for _, tsStr := range timestamps {
for _, tp := range tsTypes {
for scale := 0; scale <= 9; scale++ {
t.Run(tp+"("+strconv.Itoa(scale)+")_"+tsStr, func(t *testing.T) {
query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale)
sfRows, closeRows := tc.queryRows(ctx, t, query)
defer closeRows()
batches, err := GetArrowBatches(sfRows)
if err != nil {
t.Fatalf("GetArrowBatches failed: %v", err)
}
if len(batches) == 0 {
t.Fatal("expected at least one batch")
}
records, err := batches[0].Fetch()
if prec.expectError {
expectedError := "Cannot convert timestamp"
if err == nil {
t.Fatalf("no error, expected: %v", expectedError)
}
if !strings.Contains(err.Error(), expectedError) {
t.Fatalf("improper error, expected: %v, got: %v", expectedError, err.Error())
}
return
}
if err != nil {
t.Fatalf("Fetch failed: %v", err)
}
if records == nil || len(*records) == 0 {
t.Fatal("expected at least one record")
}
rec := (*records)[0]
defer rec.Release()
actual := rec.Column(0).(*array.Timestamp).TimestampValues()[0]
actualYear := actual.ToTime(prec.unit).Year()
ts, err := time.Parse("2006-01-02 15:04:05", tsStr)
if err != nil {
t.Fatalf("failed to parse time: %v", err)
}
exp := ts.Truncate(time.Duration(math.Pow10(9 - scale)))
if actualYear != exp.Year() {
t.Fatalf("unexpected year, expected: %v, got: %v", exp.Year(), actualYear)
}
})
}
}
}
})
}
}
// TestTimestampConversionWithOriginalTimestamp tests all 10 timestamp scales
// (0-9) because each scale exercises a mathematically distinct code path in
// extractEpoch/extractFraction. See TestTimestampConversionDistantDates for
// rationale on why the full scale range is required.
func TestTimestampConversionWithOriginalTimestamp(t *testing.T) {
timestamps := [3]string{
"2000-10-10 10:10:10.123456789",
"9999-12-12 23:59:59.999999999",
"0001-01-01 00:00:00.000000000",
}
tsTypes := [3]string{"TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP_TZ"}
pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
ctx := sf.WithArrowAllocator(
WithTimestampOption(WithArrowBatches(context.Background()), UseOriginalTimestamp),
pool,
)
tc := openTestConn(ctx, t)
defer tc.close()
for _, tsStr := range timestamps {
ts, err := time.Parse("2006-01-02 15:04:05", tsStr)
if err != nil {
t.Fatalf("failed to parse time: %v", err)
}
for _, tp := range tsTypes {
t.Run(tp+"_"+tsStr, func(t *testing.T) {
// Batch all 10 scales into a single multi-column query to reduce round trips.
var cols []string
for scale := 0; scale <= 9; scale++ {
cols = append(cols, fmt.Sprintf("'%s'::%s(%v)", tsStr, tp, scale))
}
query := "SELECT " + strings.Join(cols, ", ")
sfRows, closeRows := tc.queryRows(ctx, t, query)
defer closeRows()
batches, err := GetArrowBatches(sfRows)
if err != nil {
t.Fatalf("GetArrowBatches failed: %v", err)
}
if len(batches) != 1 {
t.Fatalf("expected 1 batch, got %d", len(batches))
}
records, err := batches[0].Fetch()
if err != nil {
t.Fatalf("Fetch failed: %v", err)
}
if records == nil || len(*records) == 0 {
t.Fatal("expected at least one record")
}
for scale := 0; scale <= 9; scale++ {
exp := ts.Truncate(time.Duration(math.Pow10(9 - scale)))
for _, r := range *records {
defer r.Release()
act := batches[0].ArrowSnowflakeTimestampToTime(r, scale, 0)
if act == nil {
t.Fatalf("scale %d: unexpected nil, expected: %v", scale, exp)
} else if !exp.Equal(*act) {
t.Fatalf("scale %d: unexpected result, expected: %v, got: %v", scale, exp, *act)
}
}
}
})
}
}
}
================================================
FILE: arrowbatches/context.go
================================================
package arrowbatches
import (
"context"
ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow"
)
// Timestamp option constants.
const (
UseNanosecondTimestamp = ia.UseNanosecondTimestamp
UseMicrosecondTimestamp = ia.UseMicrosecondTimestamp
UseMillisecondTimestamp = ia.UseMillisecondTimestamp
UseSecondTimestamp = ia.UseSecondTimestamp
UseOriginalTimestamp = ia.UseOriginalTimestamp
)
// WithArrowBatches returns a context that enables arrow batch mode for queries.
func WithArrowBatches(ctx context.Context) context.Context {
return ia.EnableArrowBatches(ctx)
}
// WithTimestampOption returns a context that sets the timestamp conversion option
// for arrow batches.
func WithTimestampOption(ctx context.Context, option ia.TimestampOption) context.Context {
return ia.WithTimestampOption(ctx, option)
}
// WithUtf8Validation returns a context that enables UTF-8 validation for
// string columns in arrow batches.
func WithUtf8Validation(ctx context.Context) context.Context {
return ia.EnableUtf8Validation(ctx)
}
================================================
FILE: arrowbatches/converter.go
================================================
package arrowbatches
import (
"context"
"fmt"
"github.com/snowflakedb/gosnowflake/v2/internal/query"
"github.com/snowflakedb/gosnowflake/v2/internal/types"
"math"
"math/big"
"strings"
"time"
"unicode/utf8"
sf "github.com/snowflakedb/gosnowflake/v2"
ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/memory"
)
// arrowToRecord transforms a raw arrow.Record from Snowflake into a record
// with standard Arrow types (e.g., converting struct-based timestamps to
// arrow.Timestamp, decimal128 to int64/float64, etc.)
func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocator, rowType []query.ExecResponseRowType, loc *time.Location) (arrow.Record, error) {
timestampOption := ia.GetTimestampOption(ctx)
higherPrecision := ia.HigherPrecisionEnabled(ctx)
s, err := recordToSchema(record.Schema(), rowType, loc, timestampOption, higherPrecision)
if err != nil {
return nil, err
}
var cols []arrow.Array
numRows := record.NumRows()
ctxAlloc := compute.WithAllocator(ctx, pool)
for i, col := range record.Columns() {
fieldMetadata := rowType[i].ToFieldMetadata()
newCol, err := arrowToRecordSingleColumn(ctxAlloc, s.Field(i), col, fieldMetadata, higherPrecision, timestampOption, pool, loc, numRows)
if err != nil {
return nil, err
}
cols = append(cols, newCol)
defer newCol.Release()
}
newRecord := array.NewRecord(s, cols, numRows)
return newRecord, nil
}
func arrowToRecordSingleColumn(ctx context.Context, field arrow.Field, col arrow.Array, fieldMetadata query.FieldMetadata, higherPrecisionEnabled bool, timestampOption ia.TimestampOption, pool memory.Allocator, loc *time.Location, numRows int64) (arrow.Array, error) {
var err error
newCol := col
snowflakeType := types.GetSnowflakeType(fieldMetadata.Type)
switch snowflakeType {
case types.FixedType:
if higherPrecisionEnabled {
col.Retain()
} else if col.DataType().ID() == arrow.DECIMAL || col.DataType().ID() == arrow.DECIMAL256 {
var toType arrow.DataType
if fieldMetadata.Scale == 0 {
toType = arrow.PrimitiveTypes.Int64
} else {
toType = arrow.PrimitiveTypes.Float64
}
newCol, err = compute.CastArray(ctx, col, compute.UnsafeCastOptions(toType))
if err != nil {
return nil, err
}
} else if fieldMetadata.Scale != 0 && col.DataType().ID() != arrow.INT64 {
result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true},
&compute.ArrayDatum{Value: newCol.Data()},
compute.NewDatum(math.Pow10(int(fieldMetadata.Scale))))
if err != nil {
return nil, err
}
defer result.Release()
newCol = result.(*compute.ArrayDatum).MakeArray()
} else if fieldMetadata.Scale != 0 && col.DataType().ID() == arrow.INT64 {
values := col.(*array.Int64).Int64Values()
floatValues := make([]float64, len(values))
for i, val := range values {
floatValues[i], _ = intToBigFloat(val, int64(fieldMetadata.Scale)).Float64()
}
builder := array.NewFloat64Builder(pool)
builder.AppendValues(floatValues, nil)
newCol = builder.NewArray()
builder.Release()
} else {
col.Retain()
}
case types.TimeType:
newCol, err = compute.CastArray(ctx, col, compute.SafeCastOptions(arrow.FixedWidthTypes.Time64ns))
if err != nil {
return nil, err
}
case types.TimestampNtzType, types.TimestampLtzType, types.TimestampTzType:
if timestampOption == ia.UseOriginalTimestamp {
col.Retain()
} else {
var unit arrow.TimeUnit
switch timestampOption {
case ia.UseMicrosecondTimestamp:
unit = arrow.Microsecond
case ia.UseMillisecondTimestamp:
unit = arrow.Millisecond
case ia.UseSecondTimestamp:
unit = arrow.Second
case ia.UseNanosecondTimestamp:
unit = arrow.Nanosecond
}
var tb *array.TimestampBuilder
if snowflakeType == types.TimestampLtzType {
tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit, TimeZone: loc.String()})
} else {
tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit})
}
defer tb.Release()
for i := 0; i < int(numRows); i++ {
ts := ArrowSnowflakeTimestampToTime(col, snowflakeType, int(fieldMetadata.Scale), i, loc)
if ts != nil {
var ar arrow.Timestamp
switch timestampOption {
case ia.UseMicrosecondTimestamp:
ar = arrow.Timestamp(ts.UnixMicro())
case ia.UseMillisecondTimestamp:
ar = arrow.Timestamp(ts.UnixMilli())
case ia.UseSecondTimestamp:
ar = arrow.Timestamp(ts.Unix())
case ia.UseNanosecondTimestamp:
ar = arrow.Timestamp(ts.UnixNano())
if ts.UTC().Year() != ar.ToTime(arrow.Nanosecond).Year() {
return nil, &sf.SnowflakeError{
Number: sf.ErrTooHighTimestampPrecision,
SQLState: sf.SQLStateInvalidDataTimeFormat,
Message: fmt.Sprintf("Cannot convert timestamp %v in column %v to Arrow.Timestamp data type due to too high precision. Please use context with WithOriginalTimestamp.", ts.UTC(), fieldMetadata.Name),
}
}
}
tb.Append(ar)
} else {
tb.AppendNull()
}
}
newCol = tb.NewArray()
}
case types.TextType:
if stringCol, ok := col.(*array.String); ok {
newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows)
}
case types.ObjectType:
if structCol, ok := col.(*array.Struct); ok {
var internalCols []arrow.Array
for i := 0; i < structCol.NumField(); i++ {
internalCol := structCol.Field(i)
newInternalCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.StructType).Field(i), internalCol, fieldMetadata.Fields[i], higherPrecisionEnabled, timestampOption, pool, loc, numRows)
if err != nil {
return nil, err
}
internalCols = append(internalCols, newInternalCol)
defer newInternalCol.Release()
}
var fieldNames []string
for _, f := range field.Type.(*arrow.StructType).Fields() {
fieldNames = append(fieldNames, f.Name)
}
nullBitmap := memory.NewBufferBytes(structCol.NullBitmapBytes())
numberOfNulls := structCol.NullN()
return array.NewStructArrayWithNulls(internalCols, fieldNames, nullBitmap, numberOfNulls, 0)
} else if stringCol, ok := col.(*array.String); ok {
newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows)
}
case types.ArrayType:
if listCol, ok := col.(*array.List); ok {
newCol, err = arrowToRecordSingleColumn(ctx, field.Type.(*arrow.ListType).ElemField(), listCol.ListValues(), fieldMetadata.Fields[0], higherPrecisionEnabled, timestampOption, pool, loc, numRows)
if err != nil {
return nil, err
}
defer newCol.Release()
newData := array.NewData(arrow.ListOf(newCol.DataType()), listCol.Len(), listCol.Data().Buffers(), []arrow.ArrayData{newCol.Data()}, listCol.NullN(), 0)
defer newData.Release()
return array.NewListData(newData), nil
} else if stringCol, ok := col.(*array.String); ok {
newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows)
}
case types.MapType:
if mapCol, ok := col.(*array.Map); ok {
keyCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.MapType).KeyField(), mapCol.Keys(), fieldMetadata.Fields[0], higherPrecisionEnabled, timestampOption, pool, loc, numRows)
if err != nil {
return nil, err
}
defer keyCol.Release()
valueCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.MapType).ItemField(), mapCol.Items(), fieldMetadata.Fields[1], higherPrecisionEnabled, timestampOption, pool, loc, numRows)
if err != nil {
return nil, err
}
defer valueCol.Release()
structArr, err := array.NewStructArray([]arrow.Array{keyCol, valueCol}, []string{"k", "v"})
if err != nil {
return nil, err
}
defer structArr.Release()
newData := array.NewData(arrow.MapOf(keyCol.DataType(), valueCol.DataType()), mapCol.Len(), mapCol.Data().Buffers(), []arrow.ArrayData{structArr.Data()}, mapCol.NullN(), 0)
defer newData.Release()
return array.NewMapData(newData), nil
} else if stringCol, ok := col.(*array.String); ok {
newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows)
}
default:
col.Retain()
}
return newCol, nil
}
func arrowStringRecordToColumn(
ctx context.Context,
stringCol *array.String,
mem memory.Allocator,
numRows int64,
) arrow.Array {
if ia.Utf8ValidationEnabled(ctx) && stringCol.DataType().ID() == arrow.STRING {
tb := array.NewStringBuilder(mem)
defer tb.Release()
for i := 0; i < int(numRows); i++ {
if stringCol.IsValid(i) {
stringValue := stringCol.Value(i)
if !utf8.ValidString(stringValue) {
stringValue = strings.ToValidUTF8(stringValue, "�")
}
tb.Append(stringValue)
} else {
tb.AppendNull()
}
}
arr := tb.NewArray()
return arr
}
stringCol.Retain()
return stringCol
}
func intToBigFloat(val int64, scale int64) *big.Float {
f := new(big.Float).SetInt64(val)
s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil))
return new(big.Float).Quo(f, s)
}
// ArrowSnowflakeTimestampToTime converts original timestamp returned by Snowflake to time.Time.
func ArrowSnowflakeTimestampToTime(
column arrow.Array,
sfType types.SnowflakeType,
scale int,
recIdx int,
loc *time.Location) *time.Time {
if column.IsNull(recIdx) {
return nil
}
var ret time.Time
switch sfType {
case types.TimestampNtzType:
if column.DataType().ID() == arrow.STRUCT {
structData := column.(*array.Struct)
epoch := structData.Field(0).(*array.Int64).Int64Values()
fraction := structData.Field(1).(*array.Int32).Int32Values()
ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).UTC()
} else {
intData := column.(*array.Int64)
value := intData.Value(recIdx)
epoch := extractEpoch(value, scale)
fraction := extractFraction(value, scale)
ret = time.Unix(epoch, fraction).UTC()
}
case types.TimestampLtzType:
if column.DataType().ID() == arrow.STRUCT {
structData := column.(*array.Struct)
epoch := structData.Field(0).(*array.Int64).Int64Values()
fraction := structData.Field(1).(*array.Int32).Int32Values()
ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(loc)
} else {
intData := column.(*array.Int64)
value := intData.Value(recIdx)
epoch := extractEpoch(value, scale)
fraction := extractFraction(value, scale)
ret = time.Unix(epoch, fraction).In(loc)
}
case types.TimestampTzType:
structData := column.(*array.Struct)
if structData.NumField() == 2 {
value := structData.Field(0).(*array.Int64).Int64Values()
timezone := structData.Field(1).(*array.Int32).Int32Values()
epoch := extractEpoch(value[recIdx], scale)
fraction := extractFraction(value[recIdx], scale)
locTz := sf.Location(int(timezone[recIdx]) - 1440)
ret = time.Unix(epoch, fraction).In(locTz)
} else {
epoch := structData.Field(0).(*array.Int64).Int64Values()
fraction := structData.Field(1).(*array.Int32).Int32Values()
timezone := structData.Field(2).(*array.Int32).Int32Values()
locTz := sf.Location(int(timezone[recIdx]) - 1440)
ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(locTz)
}
}
return &ret
}
func extractEpoch(value int64, scale int) int64 {
return value / int64(math.Pow10(scale))
}
func extractFraction(value int64, scale int) int64 {
return (value % int64(math.Pow10(scale))) * int64(math.Pow10(9-scale))
}
================================================
FILE: arrowbatches/converter_test.go
================================================
package arrowbatches
import (
"context"
"fmt"
"github.com/snowflakedb/gosnowflake/v2/internal/query"
"github.com/snowflakedb/gosnowflake/v2/internal/types"
"math/big"
"strings"
"testing"
"time"
ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/memory"
)
var decimalShift = new(big.Int).Exp(big.NewInt(2), big.NewInt(64), nil)
func stringIntToDecimal(src string) (decimal128.Num, bool) {
b, ok := new(big.Int).SetString(src, 10)
if !ok {
return decimal128.Num{}, ok
}
var high, low big.Int
high.QuoRem(b, decimalShift, &low)
return decimal128.New(high.Int64(), low.Uint64()), true
}
func decimalToBigInt(num decimal128.Num) *big.Int {
high := new(big.Int).SetInt64(num.HighBits())
low := new(big.Int).SetUint64(num.LowBits())
return new(big.Int).Add(new(big.Int).Mul(high, decimalShift), low)
}
func TestArrowToRecord(t *testing.T) {
pool := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer pool.AssertSize(t, 0)
var valids []bool
localTime := time.Date(2019, 1, 1, 1, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600))
localTimeFarIntoFuture := time.Date(9000, 2, 6, 14, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600))
epochField := arrow.Field{Name: "epoch", Type: &arrow.Int64Type{}}
timezoneField := arrow.Field{Name: "timezone", Type: &arrow.Int32Type{}}
fractionField := arrow.Field{Name: "fraction", Type: &arrow.Int32Type{}}
timestampTzStructWithoutFraction := arrow.StructOf(epochField, timezoneField)
timestampTzStructWithFraction := arrow.StructOf(epochField, fractionField, timezoneField)
timestampNtzStruct := arrow.StructOf(epochField, fractionField)
timestampLtzStruct := arrow.StructOf(epochField, fractionField)
type testObj struct {
field1 int
field2 string
}
for _, tc := range []struct {
logical string
physical string
sc *arrow.Schema
rowType query.ExecResponseRowType
values any
expected any
error string
arrowBatchesTimestampOption ia.TimestampOption
enableArrowBatchesUtf8Validation bool
withHigherPrecision bool
nrows int
builder array.Builder
append func(b array.Builder, vs any)
compare func(src any, expected any, rec arrow.Record) int
}{
{
logical: "fixed",
physical: "number",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
values: []int64{1, 2},
nrows: 2,
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) },
},
{
logical: "fixed",
physical: "int64",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}}}, nil),
values: []string{"10000000000000000000000000000000000000", "-12345678901234567890123456789012345678"},
nrows: 2,
builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 0}),
append: func(b array.Builder, vs any) {
for _, s := range vs.([]string) {
num, ok := stringIntToDecimal(s)
if !ok {
t.Fatalf("failed to convert to Int64")
}
b.(*array.Decimal128Builder).Append(num)
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]string)
for i, dec := range convertedRec.Column(0).(*array.Int64).Int64Values() {
num, ok := stringIntToDecimal(srcvs[i])
if !ok {
return i
}
srcDec := decimalToBigInt(num).Int64()
if srcDec != dec {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "number(38,0)",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}}}, nil),
values: []string{"10000000000000000000000000000000000000", "-12345678901234567890123456789012345678"},
withHigherPrecision: true,
nrows: 2,
builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 0}),
append: func(b array.Builder, vs any) {
for _, s := range vs.([]string) {
num, ok := stringIntToDecimal(s)
if !ok {
t.Fatalf("failed to convert to Int64")
}
b.(*array.Decimal128Builder).Append(num)
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]string)
for i, dec := range convertedRec.Column(0).(*array.Decimal128).Values() {
srcDec, ok := stringIntToDecimal(srcvs[i])
if !ok {
return i
}
if srcDec != dec {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "float64",
rowType: query.ExecResponseRowType{Scale: 37},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 37}}}, nil),
values: []string{"1.2345678901234567890123456789012345678", "-9.999999999999999"},
nrows: 2,
builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 37}),
append: func(b array.Builder, vs any) {
for _, s := range vs.([]string) {
num, err := decimal128.FromString(s, 38, 37)
if err != nil {
t.Fatalf("failed to convert to decimal: %s", err)
}
b.(*array.Decimal128Builder).Append(num)
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]string)
for i, dec := range convertedRec.Column(0).(*array.Float64).Float64Values() {
num, err := decimal128.FromString(srcvs[i], 38, 37)
if err != nil {
return i
}
srcDec := num.ToFloat64(37)
if srcDec != dec {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "number(38,37)",
rowType: query.ExecResponseRowType{Scale: 37},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 37}}}, nil),
values: []string{"1.2345678901234567890123456789012345678", "-9.999999999999999"},
withHigherPrecision: true,
nrows: 2,
builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 37}),
append: func(b array.Builder, vs any) {
for _, s := range vs.([]string) {
num, err := decimal128.FromString(s, 38, 37)
if err != nil {
t.Fatalf("failed to convert to decimal: %s", err)
}
b.(*array.Decimal128Builder).Append(num)
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]string)
for i, dec := range convertedRec.Column(0).(*array.Decimal128).Values() {
srcDec, err := decimal128.FromString(srcvs[i], 38, 37)
if err != nil {
return i
}
if srcDec != dec {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "int8",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int8Type{}}}, nil),
values: []int8{1, 2},
nrows: 2,
builder: array.NewInt8Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int8Builder).AppendValues(vs.([]int8), valids) },
},
{
logical: "fixed",
physical: "int16",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int16Type{}}}, nil),
values: []int16{1, 2},
nrows: 2,
builder: array.NewInt16Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int16Builder).AppendValues(vs.([]int16), valids) },
},
{
logical: "fixed",
physical: "int32",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int32Type{}}}, nil),
values: []int32{1, 2},
nrows: 2,
builder: array.NewInt32Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int32Builder).AppendValues(vs.([]int32), valids) },
},
{
logical: "fixed",
physical: "int64",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
values: []int64{1, 2},
nrows: 2,
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) },
},
{
logical: "fixed",
physical: "float8",
rowType: query.ExecResponseRowType{Scale: 1},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int8Type{}}}, nil),
values: []int8{10, 16},
nrows: 2,
builder: array.NewInt8Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int8Builder).AppendValues(vs.([]int8), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]int8)
for i, f := range convertedRec.Column(0).(*array.Float64).Float64Values() {
rawFloat, _ := intToBigFloat(int64(srcvs[i]), 1).Float64()
if rawFloat != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "int8",
rowType: query.ExecResponseRowType{Scale: 1},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int8Type{}}}, nil),
values: []int8{10, 16},
withHigherPrecision: true,
nrows: 2,
builder: array.NewInt8Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int8Builder).AppendValues(vs.([]int8), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]int8)
for i, f := range convertedRec.Column(0).(*array.Int8).Int8Values() {
if srcvs[i] != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "float16",
rowType: query.ExecResponseRowType{Scale: 1},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int16Type{}}}, nil),
values: []int16{20, 26},
nrows: 2,
builder: array.NewInt16Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int16Builder).AppendValues(vs.([]int16), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]int16)
for i, f := range convertedRec.Column(0).(*array.Float64).Float64Values() {
rawFloat, _ := intToBigFloat(int64(srcvs[i]), 1).Float64()
if rawFloat != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "int16",
rowType: query.ExecResponseRowType{Scale: 1},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int16Type{}}}, nil),
values: []int16{20, 26},
withHigherPrecision: true,
nrows: 2,
builder: array.NewInt16Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int16Builder).AppendValues(vs.([]int16), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]int16)
for i, f := range convertedRec.Column(0).(*array.Int16).Int16Values() {
if srcvs[i] != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "float32",
rowType: query.ExecResponseRowType{Scale: 2},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int32Type{}}}, nil),
values: []int32{200, 265},
nrows: 2,
builder: array.NewInt32Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int32Builder).AppendValues(vs.([]int32), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]int32)
for i, f := range convertedRec.Column(0).(*array.Float64).Float64Values() {
rawFloat, _ := intToBigFloat(int64(srcvs[i]), 2).Float64()
if rawFloat != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "int32",
rowType: query.ExecResponseRowType{Scale: 2},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int32Type{}}}, nil),
values: []int32{200, 265},
withHigherPrecision: true,
nrows: 2,
builder: array.NewInt32Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int32Builder).AppendValues(vs.([]int32), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]int32)
for i, f := range convertedRec.Column(0).(*array.Int32).Int32Values() {
if srcvs[i] != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "float64",
rowType: query.ExecResponseRowType{Scale: 5},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
values: []int64{12345, 234567},
nrows: 2,
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]int64)
for i, f := range convertedRec.Column(0).(*array.Float64).Float64Values() {
rawFloat, _ := intToBigFloat(srcvs[i], 5).Float64()
if rawFloat != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "int64",
rowType: query.ExecResponseRowType{Scale: 5},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
values: []int64{12345, 234567},
withHigherPrecision: true,
nrows: 2,
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]int64)
for i, f := range convertedRec.Column(0).(*array.Int64).Int64Values() {
if srcvs[i] != f {
return i
}
}
return -1
},
},
{
logical: "boolean",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.BooleanType{}}}, nil),
values: []bool{true, false},
nrows: 2,
builder: array.NewBooleanBuilder(pool),
append: func(b array.Builder, vs any) { b.(*array.BooleanBuilder).AppendValues(vs.([]bool), valids) },
},
{
logical: "real",
physical: "float",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Float64Type{}}}, nil),
values: []float64{1, 2},
nrows: 2,
builder: array.NewFloat64Builder(pool),
append: func(b array.Builder, vs any) { b.(*array.Float64Builder).AppendValues(vs.([]float64), valids) },
},
{
logical: "text",
physical: "string",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}}}, nil),
values: []string{"foo", "bar"},
nrows: 2,
builder: array.NewStringBuilder(pool),
append: func(b array.Builder, vs any) { b.(*array.StringBuilder).AppendValues(vs.([]string), valids) },
},
{
logical: "text",
physical: "string with invalid utf8",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}}}, nil),
rowType: query.ExecResponseRowType{Type: "TEXT"},
values: []string{"\xFF", "bar", "baz\xFF\xFF"},
expected: []string{"�", "bar", "baz��"},
enableArrowBatchesUtf8Validation: true,
nrows: 2,
builder: array.NewStringBuilder(pool),
append: func(b array.Builder, vs any) { b.(*array.StringBuilder).AppendValues(vs.([]string), valids) },
compare: func(src any, expected any, convertedRec arrow.Record) int {
arr := convertedRec.Column(0).(*array.String)
for i := 0; i < arr.Len(); i++ {
if expected.([]string)[i] != arr.Value(i) {
return i
}
}
return -1
},
},
{
logical: "binary",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.BinaryType{}}}, nil),
values: [][]byte{[]byte("foo"), []byte("bar")},
nrows: 2,
builder: array.NewBinaryBuilder(pool, arrow.BinaryTypes.Binary),
append: func(b array.Builder, vs any) { b.(*array.BinaryBuilder).AppendValues(vs.([][]byte), valids) },
},
{
logical: "date",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Date32Type{}}}, nil),
values: []time.Time{time.Now(), localTime},
nrows: 2,
builder: array.NewDate32Builder(pool),
append: func(b array.Builder, vs any) {
for _, d := range vs.([]time.Time) {
b.(*array.Date32Builder).Append(arrow.Date32(d.Unix()))
}
},
},
{
logical: "time",
sc: arrow.NewSchema([]arrow.Field{{Type: arrow.FixedWidthTypes.Time64ns}}, nil),
values: []time.Time{time.Now(), time.Now()},
nrows: 2,
builder: array.NewTime64Builder(pool, arrow.FixedWidthTypes.Time64ns.(*arrow.Time64Type)),
append: func(b array.Builder, vs any) {
for _, t := range vs.([]time.Time) {
b.(*array.Time64Builder).Append(arrow.Time64(t.UnixNano()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
arr := convertedRec.Column(0).(*array.Time64)
for i := 0; i < arr.Len(); i++ {
if srcvs[i].UnixNano() != int64(arr.Value(i)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ntz",
physical: "int64",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)},
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 3},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) {
for _, t := range vs.([]time.Time) {
b.(*array.Int64Builder).Append(t.UnixMilli())
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ntz",
physical: "struct",
values: []time.Time{time.Now(), localTime},
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ntz",
physical: "struct",
values: []time.Time{time.Now().Truncate(time.Microsecond), localTime.Truncate(time.Microsecond)},
arrowBatchesTimestampOption: ia.UseMicrosecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Microsecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ntz",
physical: "struct",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)},
arrowBatchesTimestampOption: ia.UseMillisecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Millisecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ntz",
physical: "struct",
values: []time.Time{time.Now().Truncate(time.Second), localTime.Truncate(time.Second)},
arrowBatchesTimestampOption: ia.UseSecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Second)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ntz",
physical: "error",
values: []time.Time{localTimeFarIntoFuture},
error: "Cannot convert timestamp",
nrows: 1,
rowType: query.ExecResponseRowType{Scale: 3},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) {
for _, t := range vs.([]time.Time) {
b.(*array.Int64Builder).Append(t.UnixMilli())
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int { return 0 },
},
{
logical: "timestamp_ntz",
physical: "int64 with original timestamp",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)},
arrowBatchesTimestampOption: ia.UseOriginalTimestamp,
nrows: 3,
rowType: query.ExecResponseRowType{Scale: 3},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) {
for _, t := range vs.([]time.Time) {
b.(*array.Int64Builder).Append(t.UnixMilli())
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i := 0; i < convertedRec.Column(0).Len(); i++ {
ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_ntz"), 3, i, nil)
if !srcvs[i].Equal(*ts) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ntz",
physical: "struct with original timestamp",
values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture},
arrowBatchesTimestampOption: ia.UseOriginalTimestamp,
nrows: 3,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i := 0; i < convertedRec.Column(0).Len(); i++ {
ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_ntz"), 9, i, nil)
if !srcvs[i].Equal(*ts) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ltz",
physical: "int64",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)},
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 3},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) {
for _, t := range vs.([]time.Time) {
b.(*array.Int64Builder).Append(t.UnixMilli())
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ltz",
physical: "struct",
values: []time.Time{time.Now(), localTime},
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ltz",
physical: "struct",
values: []time.Time{time.Now().Truncate(time.Microsecond), localTime.Truncate(time.Microsecond)},
arrowBatchesTimestampOption: ia.UseMicrosecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Microsecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ltz",
physical: "struct",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)},
arrowBatchesTimestampOption: ia.UseMillisecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Millisecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ltz",
physical: "struct",
values: []time.Time{time.Now().Truncate(time.Second), localTime.Truncate(time.Second)},
arrowBatchesTimestampOption: ia.UseSecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampNtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampNtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Second)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ltz",
physical: "error",
values: []time.Time{localTimeFarIntoFuture},
error: "Cannot convert timestamp",
nrows: 1,
rowType: query.ExecResponseRowType{Scale: 3},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) {
for _, t := range vs.([]time.Time) {
b.(*array.Int64Builder).Append(t.UnixMilli())
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int { return 0 },
},
{
logical: "timestamp_ltz",
physical: "int64 with original timestamp",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)},
arrowBatchesTimestampOption: ia.UseOriginalTimestamp,
nrows: 3,
rowType: query.ExecResponseRowType{Scale: 3},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs any) {
for _, t := range vs.([]time.Time) {
b.(*array.Int64Builder).Append(t.UnixMilli())
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i := 0; i < convertedRec.Column(0).Len(); i++ {
ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_ltz"), 3, i, localTime.Location())
if !srcvs[i].Equal(*ts) {
return i
}
}
return -1
},
},
{
logical: "timestamp_ltz",
physical: "struct with original timestamp",
values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture},
arrowBatchesTimestampOption: ia.UseOriginalTimestamp,
nrows: 3,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampLtzStruct}}, nil),
builder: array.NewStructBuilder(pool, timestampLtzStruct),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i := 0; i < convertedRec.Column(0).Len(); i++ {
ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_ltz"), 9, i, localTime.Location())
if !srcvs[i].Equal(*ts) {
return i
}
}
return -1
},
},
{
logical: "timestamp_tz",
physical: "struct2",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)},
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 3},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithoutFraction}}, nil),
builder: array.NewStructBuilder(pool, timestampTzStructWithoutFraction),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.UnixMilli())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(0))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_tz",
physical: "struct3",
values: []time.Time{time.Now(), localTime},
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil),
builder: array.NewStructBuilder(pool, timestampTzStructWithFraction),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Nanosecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_tz",
physical: "struct3",
values: []time.Time{time.Now().Truncate(time.Microsecond), localTime.Truncate(time.Microsecond)},
arrowBatchesTimestampOption: ia.UseMicrosecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil),
builder: array.NewStructBuilder(pool, timestampTzStructWithFraction),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Microsecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_tz",
physical: "struct3",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond)},
arrowBatchesTimestampOption: ia.UseMillisecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil),
builder: array.NewStructBuilder(pool, timestampTzStructWithFraction),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Millisecond)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_tz",
physical: "struct3",
values: []time.Time{time.Now().Truncate(time.Second), localTime.Truncate(time.Second)},
arrowBatchesTimestampOption: ia.UseSecondTimestamp,
nrows: 2,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil),
builder: array.NewStructBuilder(pool, timestampTzStructWithFraction),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i, t := range convertedRec.Column(0).(*array.Timestamp).TimestampValues() {
if !srcvs[i].Equal(t.ToTime(arrow.Second)) {
return i
}
}
return -1
},
},
{
logical: "timestamp_tz",
physical: "struct2 with original timestamp",
values: []time.Time{time.Now().Truncate(time.Millisecond), localTime.Truncate(time.Millisecond), localTimeFarIntoFuture.Truncate(time.Millisecond)},
arrowBatchesTimestampOption: ia.UseOriginalTimestamp,
nrows: 3,
rowType: query.ExecResponseRowType{Scale: 3},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithoutFraction}}, nil),
builder: array.NewStructBuilder(pool, timestampTzStructWithoutFraction),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.UnixMilli())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(0))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i := 0; i < convertedRec.Column(0).Len(); i++ {
ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_tz"), 3, i, nil)
if !srcvs[i].Equal(*ts) {
return i
}
}
return -1
},
},
{
logical: "timestamp_tz",
physical: "struct3 with original timestamp",
values: []time.Time{time.Now(), localTime, localTimeFarIntoFuture},
arrowBatchesTimestampOption: ia.UseOriginalTimestamp,
nrows: 3,
rowType: query.ExecResponseRowType{Scale: 9},
sc: arrow.NewSchema([]arrow.Field{{Type: timestampTzStructWithFraction}}, nil),
builder: array.NewStructBuilder(pool, timestampTzStructWithFraction),
append: func(b array.Builder, vs any) {
sb := b.(*array.StructBuilder)
valids = []bool{true, true, true}
sb.AppendValues(valids)
for _, t := range vs.([]time.Time) {
sb.FieldBuilder(0).(*array.Int64Builder).Append(t.Unix())
sb.FieldBuilder(1).(*array.Int32Builder).Append(int32(t.Nanosecond()))
sb.FieldBuilder(2).(*array.Int32Builder).Append(int32(0))
}
},
compare: func(src any, expected any, convertedRec arrow.Record) int {
srcvs := src.([]time.Time)
for i := 0; i < convertedRec.Column(0).Len(); i++ {
ts := ArrowSnowflakeTimestampToTime(convertedRec.Column(0), types.GetSnowflakeType("timestamp_tz"), 9, i, nil)
if !srcvs[i].Equal(*ts) {
return i
}
}
return -1
},
},
{
logical: "array",
values: [][]string{{"foo", "bar"}, {"baz", "quz", "quux"}},
nrows: 2,
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}}}, nil),
builder: array.NewStringBuilder(pool),
append: func(b array.Builder, vs any) {
for _, a := range vs.([][]string) {
b.(*array.StringBuilder).Append(fmt.Sprint(a))
}
},
},
{
logical: "object",
values: []testObj{{0, "foo"}, {1, "bar"}},
nrows: 2,
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}}}, nil),
builder: array.NewStringBuilder(pool),
append: func(b array.Builder, vs any) {
for _, o := range vs.([]testObj) {
b.(*array.StringBuilder).Append(fmt.Sprint(o))
}
},
},
} {
testName := tc.logical
if tc.physical != "" {
testName += " " + tc.physical
}
t.Run(testName, func(t *testing.T) {
scope := memory.NewCheckedAllocatorScope(pool)
defer scope.CheckSize(t)
b := tc.builder
defer b.Release()
tc.append(b, tc.values)
arr := b.NewArray()
defer arr.Release()
rawRec := array.NewRecord(tc.sc, []arrow.Array{arr}, int64(tc.nrows))
defer rawRec.Release()
meta := tc.rowType
meta.Type = tc.logical
ctx := context.Background()
switch tc.arrowBatchesTimestampOption {
case ia.UseOriginalTimestamp:
ctx = ia.WithTimestampOption(ctx, ia.UseOriginalTimestamp)
case ia.UseSecondTimestamp:
ctx = ia.WithTimestampOption(ctx, ia.UseSecondTimestamp)
case ia.UseMillisecondTimestamp:
ctx = ia.WithTimestampOption(ctx, ia.UseMillisecondTimestamp)
case ia.UseMicrosecondTimestamp:
ctx = ia.WithTimestampOption(ctx, ia.UseMicrosecondTimestamp)
default:
ctx = ia.WithTimestampOption(ctx, ia.UseNanosecondTimestamp)
}
if tc.enableArrowBatchesUtf8Validation {
ctx = ia.EnableUtf8Validation(ctx)
}
if tc.withHigherPrecision {
ctx = ia.WithHigherPrecision(ctx)
}
transformedRec, err := arrowToRecord(ctx, rawRec, pool, []query.ExecResponseRowType{meta}, localTime.Location())
if err != nil {
if tc.error == "" || !strings.Contains(err.Error(), tc.error) {
t.Fatalf("error: %s", err)
}
} else {
defer transformedRec.Release()
if tc.error != "" {
t.Fatalf("expected error: %s", tc.error)
}
if tc.compare != nil {
idx := tc.compare(tc.values, tc.expected, transformedRec)
if idx != -1 {
t.Fatalf("error: column array value mismatch at index %v", idx)
}
} else {
for i, c := range transformedRec.Columns() {
rawCol := rawRec.Column(i)
if rawCol != c {
t.Fatalf("error: expected column %s, got column %s", rawCol, c)
}
}
}
}
})
}
}
================================================
FILE: arrowbatches/schema.go
================================================
package arrowbatches
import (
"github.com/snowflakedb/gosnowflake/v2/internal/query"
"github.com/snowflakedb/gosnowflake/v2/internal/types"
"time"
ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow"
"github.com/apache/arrow-go/v18/arrow"
)
func recordToSchema(sc *arrow.Schema, rowType []query.ExecResponseRowType, loc *time.Location, timestampOption ia.TimestampOption, withHigherPrecision bool) (*arrow.Schema, error) {
fields := recordToSchemaRecursive(sc.Fields(), rowType, loc, timestampOption, withHigherPrecision)
meta := sc.Metadata()
return arrow.NewSchema(fields, &meta), nil
}
func recordToSchemaRecursive(inFields []arrow.Field, rowType []query.ExecResponseRowType, loc *time.Location, timestampOption ia.TimestampOption, withHigherPrecision bool) []arrow.Field {
var outFields []arrow.Field
for i, f := range inFields {
fieldMetadata := rowType[i].ToFieldMetadata()
converted, t := recordToSchemaSingleField(fieldMetadata, f, withHigherPrecision, timestampOption, loc)
newField := f
if converted {
newField = arrow.Field{
Name: f.Name,
Type: t,
Nullable: f.Nullable,
Metadata: f.Metadata,
}
}
outFields = append(outFields, newField)
}
return outFields
}
func recordToSchemaSingleField(fieldMetadata query.FieldMetadata, f arrow.Field, withHigherPrecision bool, timestampOption ia.TimestampOption, loc *time.Location) (bool, arrow.DataType) {
t := f.Type
converted := true
switch types.GetSnowflakeType(fieldMetadata.Type) {
case types.FixedType:
switch f.Type.ID() {
case arrow.DECIMAL:
if withHigherPrecision {
converted = false
} else if fieldMetadata.Scale == 0 {
t = &arrow.Int64Type{}
} else {
t = &arrow.Float64Type{}
}
default:
if withHigherPrecision {
converted = false
} else if fieldMetadata.Scale != 0 {
t = &arrow.Float64Type{}
} else {
converted = false
}
}
case types.TimeType:
t = &arrow.Time64Type{Unit: arrow.Nanosecond}
case types.TimestampNtzType, types.TimestampTzType:
switch timestampOption {
case ia.UseOriginalTimestamp:
converted = false
case ia.UseMicrosecondTimestamp:
t = &arrow.TimestampType{Unit: arrow.Microsecond}
case ia.UseMillisecondTimestamp:
t = &arrow.TimestampType{Unit: arrow.Millisecond}
case ia.UseSecondTimestamp:
t = &arrow.TimestampType{Unit: arrow.Second}
default:
t = &arrow.TimestampType{Unit: arrow.Nanosecond}
}
case types.TimestampLtzType:
switch timestampOption {
case ia.UseOriginalTimestamp:
converted = false
case ia.UseMicrosecondTimestamp:
t = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: loc.String()}
case ia.UseMillisecondTimestamp:
t = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: loc.String()}
case ia.UseSecondTimestamp:
t = &arrow.TimestampType{Unit: arrow.Second, TimeZone: loc.String()}
default:
t = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()}
}
case types.ObjectType:
converted = false
if f.Type.ID() == arrow.STRUCT {
var internalFields []arrow.Field
for idx, internalField := range f.Type.(*arrow.StructType).Fields() {
internalConverted, convertedDataType := recordToSchemaSingleField(fieldMetadata.Fields[idx], internalField, withHigherPrecision, timestampOption, loc)
converted = converted || internalConverted
if internalConverted {
newInternalField := arrow.Field{
Name: internalField.Name,
Type: convertedDataType,
Metadata: internalField.Metadata,
Nullable: internalField.Nullable,
}
internalFields = append(internalFields, newInternalField)
} else {
internalFields = append(internalFields, internalField)
}
}
t = arrow.StructOf(internalFields...)
}
case types.ArrayType:
if _, ok := f.Type.(*arrow.ListType); ok {
converted, dataType := recordToSchemaSingleField(fieldMetadata.Fields[0], f.Type.(*arrow.ListType).ElemField(), withHigherPrecision, timestampOption, loc)
if converted {
t = arrow.ListOf(dataType)
}
} else {
t = f.Type
}
case types.MapType:
convertedKey, keyDataType := recordToSchemaSingleField(fieldMetadata.Fields[0], f.Type.(*arrow.MapType).KeyField(), withHigherPrecision, timestampOption, loc)
convertedValue, valueDataType := recordToSchemaSingleField(fieldMetadata.Fields[1], f.Type.(*arrow.MapType).ItemField(), withHigherPrecision, timestampOption, loc)
converted = convertedKey || convertedValue
if converted {
t = arrow.MapOf(keyDataType, valueDataType)
}
default:
converted = false
}
return converted, t
}
================================================
FILE: assert_test.go
================================================
package gosnowflake
import (
"bytes"
"errors"
"fmt"
"math"
"reflect"
"regexp"
"slices"
"strings"
"testing"
"time"
)
func assertNilE(t *testing.T, actual any, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateNil(actual, descriptions...))
}
func assertNilF(t *testing.T, actual any, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateNil(actual, descriptions...))
}
func assertNotNilE(t *testing.T, actual any, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateNotNil(actual, descriptions...))
}
func assertNotNilF(t *testing.T, actual any, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateNotNil(actual, descriptions...))
}
func assertErrIsF(t *testing.T, actual, expected error, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateErrIs(actual, expected, descriptions...))
}
func assertErrIsE(t *testing.T, actual, expected error, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateErrIs(actual, expected, descriptions...))
}
func assertErrorsAsF(t *testing.T, err error, target any, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateErrorsAs(err, target, descriptions...))
}
func assertEqualE(t *testing.T, actual any, expected any, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateEqual(actual, expected, descriptions...))
}
func assertEqualF(t *testing.T, actual any, expected any, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateEqual(actual, expected, descriptions...))
}
func assertEqualIgnoringWhitespaceE(t *testing.T, actual string, expected string, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateEqualIgnoringWhitespace(actual, expected, descriptions...))
}
func assertEqualEpsilonE(t *testing.T, actual, expected, epsilon float64, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateEqualEpsilon(actual, expected, epsilon, descriptions...))
}
func assertDeepEqualE(t *testing.T, actual any, expected any, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateDeepEqual(actual, expected, descriptions...))
}
func assertNotEqualF(t *testing.T, actual any, expected any, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateNotEqual(actual, expected, descriptions...))
}
func assertNotEqualE(t *testing.T, actual any, expected any, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateNotEqual(actual, expected, descriptions...))
}
func assertBytesEqualE(t *testing.T, actual []byte, expected []byte, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateBytesEqual(actual, expected, descriptions...))
}
func assertTrueF(t *testing.T, actual bool, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateEqual(actual, true, descriptions...))
}
func assertTrueE(t *testing.T, actual bool, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateEqual(actual, true, descriptions...))
}
func assertFalseF(t *testing.T, actual bool, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateEqual(actual, false, descriptions...))
}
func assertFalseE(t *testing.T, actual bool, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateEqual(actual, false, descriptions...))
}
func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...))
}
func assertStringContainsF(t *testing.T, actual string, expectedToContain string, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...))
}
func assertEmptyStringE(t *testing.T, actual string, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateEmptyString(actual, descriptions...))
}
func assertHasPrefixF(t *testing.T, actual string, expectedPrefix string, descriptions ...string) {
t.Helper()
fatalOnNonEmpty(t, validateHasPrefix(actual, expectedPrefix, descriptions...))
}
func assertHasPrefixE(t *testing.T, actual string, expectedPrefix string, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateHasPrefix(actual, expectedPrefix, descriptions...))
}
func assertBetweenE(t *testing.T, value float64, min float64, max float64, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateValueBetween(value, min, max, descriptions...))
}
func assertBetweenInclusiveE(t *testing.T, value float64, min float64, max float64, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateValueBetweenInclusive(value, min, max, descriptions...))
}
func assertEmptyE[T any](t *testing.T, actual []T, descriptions ...string) {
t.Helper()
errorOnNonEmpty(t, validateEmpty(actual, descriptions...))
}
func fatalOnNonEmpty(t *testing.T, errMsg string) {
if errMsg != "" {
t.Helper()
t.Fatal(formatErrorMessage(errMsg))
}
}
func errorOnNonEmpty(t *testing.T, errMsg string) {
if errMsg != "" {
t.Helper()
t.Error(formatErrorMessage(errMsg))
}
}
func formatErrorMessage(errMsg string) string {
return fmt.Sprintf("[%s] %s", time.Now().Format(time.RFC3339Nano), maskSecrets(errMsg))
}
func validateNil(actual any, descriptions ...string) string {
if isNil(actual) {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" to be nil but was not. %s", maskSecrets(fmt.Sprintf("%v", actual)), desc)
}
func validateNotNil(actual any, descriptions ...string) string {
if !isNil(actual) {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected to be not nil but was not. %s", desc)
}
func validateErrIs(actual, expected error, descriptions ...string) string {
if errors.Is(actual, expected) {
return ""
}
desc := joinDescriptions(descriptions...)
actualStr := "nil"
expectedStr := "nil"
if actual != nil {
actualStr = maskSecrets(actual.Error())
}
if expected != nil {
expectedStr = maskSecrets(expected.Error())
}
return fmt.Sprintf("expected %v to be %v. %s", actualStr, expectedStr, desc)
}
func validateErrorsAs(err error, target any, descriptions ...string) string {
if errors.As(err, target) {
return ""
}
desc := joinDescriptions(descriptions...)
errStr := "nil"
if err != nil {
errStr = maskSecrets(err.Error())
}
targetType := reflect.TypeOf(target)
return fmt.Sprintf("expected error %v to be assignable to %v but was not. %s", errStr, targetType, desc)
}
func validateEqual(actual any, expected any, descriptions ...string) string {
if expected == actual {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s",
maskSecrets(fmt.Sprintf("%v", actual)),
maskSecrets(fmt.Sprintf("%v", expected)),
desc)
}
func removeWhitespaces(s string) string {
pattern, err := regexp.Compile(`\s+`)
if err != nil {
panic(err)
}
return pattern.ReplaceAllString(s, "")
}
func validateEqualIgnoringWhitespace(actual string, expected string, descriptions ...string) string {
if removeWhitespaces(expected) == removeWhitespaces(actual) {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s",
maskSecrets(actual),
maskSecrets(expected),
desc)
}
func validateEqualEpsilon(actual, expected, epsilon float64, descriptions ...string) string {
if math.Abs(actual-expected) < epsilon {
return ""
}
return fmt.Sprintf("expected \"%f\" to be equal to \"%f\" within epsilon \"%f\" but was not. %s", actual, expected, epsilon, joinDescriptions(descriptions...))
}
func validateDeepEqual(actual any, expected any, descriptions ...string) string {
if reflect.DeepEqual(actual, expected) {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s",
maskSecrets(fmt.Sprintf("%v", actual)),
maskSecrets(fmt.Sprintf("%v", expected)),
desc)
}
func validateNotEqual(actual any, expected any, descriptions ...string) string {
if expected != actual {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" not to be equal to \"%s\" but they were the same. %s",
maskSecrets(fmt.Sprintf("%v", actual)),
maskSecrets(fmt.Sprintf("%v", expected)),
desc)
}
func validateBytesEqual(actual []byte, expected []byte, descriptions ...string) string {
if bytes.Equal(actual, expected) {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s",
maskSecrets(string(actual)),
maskSecrets(string(expected)),
desc)
}
func validateStringContains(actual string, expectedToContain string, descriptions ...string) string {
if strings.Contains(actual, expectedToContain) {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" to contain \"%s\" but did not. %s",
maskSecrets(actual),
maskSecrets(expectedToContain),
desc)
}
func validateEmptyString(actual string, descriptions ...string) string {
if actual == "" {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" to be empty, but was not. %s", maskSecrets(actual), desc)
}
func validateHasPrefix(actual string, expectedPrefix string, descriptions ...string) string {
if strings.HasPrefix(actual, expectedPrefix) {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" to start with \"%s\" but did not. %s",
maskSecrets(actual),
maskSecrets(expectedPrefix),
desc)
}
func validateValueBetween(value float64, min float64, max float64, descriptions ...string) string {
if value > min && value < max {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" should be between \"%s\" and \"%s\" but did not. %s",
fmt.Sprintf("%f", value),
fmt.Sprintf("%f", min),
fmt.Sprintf("%f", max),
desc)
}
func validateValueBetweenInclusive(value float64, min float64, max float64, descriptions ...string) string {
if value >= min && value <= max {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%s\" should be between \"%s\" and \"%s\" inclusively but did not. %s",
fmt.Sprintf("%f", value),
fmt.Sprintf("%f", min),
fmt.Sprintf("%f", max),
desc)
}
func validateEmpty[T any](value []T, descriptions ...string) string {
if len(value) == 0 {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%v\" to be empty. %s", maskSecrets(fmt.Sprintf("%v", value)), desc)
}
func joinDescriptions(descriptions ...string) string {
return strings.Join(descriptions, " ")
}
func isNil(value any) bool {
if value == nil {
return true
}
val := reflect.ValueOf(value)
return slices.Contains([]reflect.Kind{reflect.Pointer, reflect.Slice, reflect.Map, reflect.Interface, reflect.Func}, val.Kind()) && val.IsNil()
}
================================================
FILE: async.go
================================================
package gosnowflake
import (
"context"
"fmt"
"net/url"
"strconv"
"time"
)
func (sr *snowflakeRestful) processAsync(
ctx context.Context,
respd *execResponse,
headers map[string]string,
timeout time.Duration,
cfg *Config) (*execResponse, error) {
// placeholder object to return to user while retrieving results
rows := new(snowflakeRows)
res := new(snowflakeResult)
switch resType := getResultType(ctx); resType {
case execResultType:
res.queryID = respd.Data.QueryID
res.status = QueryStatusInProgress
res.errChannel = make(chan error)
respd.Data.AsyncResult = res
case queryResultType:
rows.queryID = respd.Data.QueryID
rows.status = QueryStatusInProgress
rows.errChannel = make(chan error)
rows.ctx = ctx
respd.Data.AsyncRows = rows
default:
return respd, nil
}
// spawn goroutine to retrieve asynchronous results
go GoroutineWrapper(
ctx,
func() {
err := sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg)
if err != nil {
logger.WithContext(ctx).Errorf("error while calling getAsync. %v", err)
}
},
)
return respd, nil
}
func (sr *snowflakeRestful) getAsync(
ctx context.Context,
headers map[string]string,
URL *url.URL,
timeout time.Duration,
res *snowflakeResult,
rows *snowflakeRows,
cfg *Config) error {
resType := getResultType(ctx)
var errChannel chan error
sfError := &SnowflakeError{
Number: ErrAsync,
}
if resType == execResultType {
errChannel = res.errChannel
sfError.QueryID = res.queryID
} else {
errChannel = rows.errChannel
sfError.QueryID = rows.queryID
}
defer close(errChannel)
token, _, _ := sr.TokenAccessor.GetTokens()
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
respd, err := getQueryResultWithRetriesForAsyncMode(ctx, sr, URL, headers, timeout)
if err != nil {
logger.WithContext(ctx).Errorf("error: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
return err
}
sc := &snowflakeConn{rest: sr, cfg: cfg, currentTimeProvider: defaultTimeProvider}
if respd.Success {
if resType == execResultType {
res.insertID = -1
if isDml(respd.Data.StatementTypeID) {
res.affectedRows, err = updateRows(respd.Data)
if err != nil {
return err
}
} else if isMultiStmt(&respd.Data) {
r, err := sc.handleMultiExec(ctx, respd.Data)
if err != nil {
res.errChannel <- err
return err
}
res.affectedRows, err = r.RowsAffected()
if err != nil {
res.errChannel <- err
return err
}
}
res.queryID = respd.Data.QueryID
res.errChannel <- nil // mark exec status complete
} else {
rows.sc = sc
rows.queryID = respd.Data.QueryID
if isMultiStmt(&respd.Data) {
if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil {
rows.errChannel <- err
return err
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data))
}
if err = rows.ChunkDownloader.start(); err != nil {
rows.errChannel <- err
return err
}
rows.errChannel <- nil // mark query status complete
}
} else {
var code int
if respd.Code != "" {
code, err = strconv.Atoi(respd.Code)
if err != nil {
code = -1
}
} else {
code = -1
}
errChannel <- &SnowflakeError{
Number: code,
SQLState: respd.Data.SQLState,
Message: respd.Message,
QueryID: respd.Data.QueryID,
}
}
return nil
}
func getQueryResultWithRetriesForAsyncMode(
ctx context.Context,
sr *snowflakeRestful,
URL *url.URL,
headers map[string]string,
timeout time.Duration) (respd *execResponse, err error) {
retry := 0
retryPattern := []int32{1, 1, 2, 3, 4, 8, 10}
retryPatternIndex := 0
retryCountForSessionRenewal := 0
for {
logger.WithContext(ctx).Debugf("Retry count for get query result request in async mode: %v", retry)
respd, err = getExecResponse(ctx, sr, URL, headers, timeout)
if err != nil {
return respd, err
}
if respd.Code == sessionExpiredCode {
// Update the session token in the header and retry
token, _, _ := sr.TokenAccessor.GetTokens()
if token != "" && headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, token) {
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
logger.WithContext(ctx).Debug("Session token has been updated.")
retry++
continue
}
// Renew the session token
if err = sr.renewExpiredSessionToken(ctx, timeout, token); err != nil {
logger.WithContext(ctx).Errorf("failed to renew session token. err: %v", err)
return respd, err
}
retryCountForSessionRenewal++
// If this is the first response, go back to retry the query
// since it failed due to session expiration
logger.WithContext(ctx).Debugf("retry count for session renewal: %v", retryCountForSessionRenewal)
if retryCountForSessionRenewal < 2 {
retry++
continue
} else {
logger.WithContext(ctx).Errorf("failed to get query result with the renewed session token. err: %v", err)
return respd, err
}
} else if respd.Code != queryInProgressAsyncCode {
// If the query takes longer than 45 seconds to complete the results are not returned.
// If the query is still in progress after 45 seconds, retry the request to the /results endpoint.
// For all other scenarios continue processing results response
break
} else {
// Sleep before retrying get result request. Exponential backoff up to 5 seconds.
// Once 5 second backoff is reached it will keep retrying with this sleeptime.
sleepTime := time.Millisecond * time.Duration(500*retryPattern[retryPatternIndex])
logger.WithContext(ctx).Debugf("Query execution still in progress. Response code: %v, message: %v Sleep for %v ms", respd.Code, respd.Message, sleepTime)
time.Sleep(sleepTime)
retry++
if retryPatternIndex < len(retryPattern)-1 {
retryPatternIndex++
}
}
}
if len(respd.Data.RowType) > 0 {
logger.Infof("[Server Response Validation]: RowType: %s, QueryResultFormat: %s", respd.Data.RowType[0].Name, respd.Data.QueryResultFormat)
}
return respd, nil
}
================================================
FILE: async_test.go
================================================
package gosnowflake
import (
"context"
"database/sql"
"fmt"
"testing"
)
func TestAsyncMode(t *testing.T) {
ctx := WithAsyncMode(context.Background())
numrows := 100000
cnt := 0
var idx int
var v string
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows))
defer rows.Close()
// Next() will block and wait until results are available
for rows.Next() {
if err := rows.Scan(&idx, &v); err != nil {
t.Fatal(err)
}
cnt++
}
logger.Infof("NextResultSet: %v", rows.NextResultSet())
if cnt != numrows {
t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt)
}
dbt.mustExec("create or replace table test_async_exec (value boolean)")
res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)")
count, err := res.RowsAffected()
if err != nil {
t.Fatalf("res.RowsAffected() returned error: %v", err)
}
if count != 1 {
t.Fatalf("expected 1 affected row, got %d", count)
}
})
}
func TestAsyncModePing(t *testing.T) {
ctx := WithAsyncMode(context.Background())
runDBTest(t, func(dbt *DBTest) {
defer func() {
if r := recover(); r != nil {
t.Fatalf("panic during ping: %v", r)
}
}()
err := dbt.conn.PingContext(ctx)
if err != nil {
t.Fatal(err)
}
})
}
func TestAsyncModeMultiStatement(t *testing.T) {
withMultiStmtCtx := WithMultiStatement(context.Background(), 6)
ctx := WithAsyncMode(withMultiStmtCtx)
multiStmtQuery := "begin;\n" +
"delete from test_multi_statement_async;\n" +
"insert into test_multi_statement_async values (1, 'a'), (2, 'b');\n" +
"select 1;\n" +
"select 2;\n" +
"rollback;"
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_async")
dbt.mustExec(`create or replace table test_multi_statement_async(
c1 number, c2 string) as select 10, 'z'`)
defer dbt.mustExec("drop table if exists test_multi_statement_async")
res := dbt.mustExecContext(ctx, multiStmtQuery)
count, err := res.RowsAffected()
if err != nil {
t.Fatalf("res.RowsAffected() returned error: %v", err)
}
if count != 3 {
t.Fatalf("expected 3 affected rows, got %d", count)
}
})
}
func TestAsyncModeCancel(t *testing.T) {
withCancelCtx, cancel := context.WithCancel(context.Background())
ctx := WithAsyncMode(withCancelCtx)
numrows := 100000
runDBTest(t, func(dbt *DBTest) {
dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows))
cancel()
})
}
func TestAsyncQueryFail(t *testing.T) {
ctx := WithAsyncMode(context.Background())
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQueryContext(ctx, "selectt 1")
defer rows.Close()
if rows.Next() {
t.Fatal("should have no rows available")
} else {
if err := rows.Err(); err == nil {
t.Fatal("should return a syntax error")
}
}
})
}
// TestMultipleAsyncQueries validates that shorter async queries return before
// longer ones. The TIMELIMIT values (30 and 10) must have sufficient separation
// to avoid flaky ordering. Do not reduce these values significantly.
func TestMultipleAsyncQueries(t *testing.T) {
ctx := WithAsyncMode(context.Background())
s1 := "foo"
s2 := "bar"
ch1 := make(chan string)
ch2 := make(chan string)
db := openDB(t)
runDBTest(t, func(dbt *DBTest) {
rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30))
if err != nil {
t.Fatalf("can't read rows1: %v", err)
}
defer rows1.Close()
rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10))
if err != nil {
t.Fatalf("can't read rows2: %v", err)
}
defer rows2.Close()
go retrieveRows(rows1, ch1)
go retrieveRows(rows2, ch2)
select {
case res := <-ch1:
t.Fatalf("value %v should not have been called earlier.", res)
case res := <-ch2:
if res != s2 {
t.Fatalf("query failed. expected: %v, got: %v", s2, res)
}
}
})
}
func retrieveRows(rows *sql.Rows, ch chan string) {
var s string
for rows.Next() {
if err := rows.Scan(&s); err != nil {
ch <- err.Error()
close(ch)
return
}
}
ch <- s
close(ch)
}
// TestLongRunningAsyncQuery validates the retry logic for async queries that
// exceed Snowflake's 45-second threshold. After 45 seconds, the /results
// endpoint returns "query in progress" (code 333334) and the driver must retry.
// The 50-second wait MUST exceed 45 seconds to exercise this code path.
func TestLongRunningAsyncQuery(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
ctx := WithMultiStatement(context.Background(), 0)
query := "CALL SYSTEM$WAIT(50, 'SECONDS');use snowflake_sample_data"
rows := dbt.mustQueryContext(WithAsyncMode(ctx), query)
defer rows.Close()
var v string
i := 0
for {
for rows.Next() {
err := rows.Scan(&v)
if err != nil {
t.Fatalf("failed to get result. err: %v", err)
}
if v == "" {
t.Fatal("should have returned a result")
}
results := []string{"waited 50 seconds", "Statement executed successfully."}
if v != results[i] {
t.Fatalf("unexpected result returned. expected: %v, but got: %v", results[i], v)
}
i++
}
if !rows.NextResultSet() {
break
}
}
})
}
func TestLongRunningAsyncQueryFetchResultByID(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
queryIDChan := make(chan string, 1)
ctx := WithAsyncMode(context.Background())
ctx = WithQueryIDChan(ctx, queryIDChan)
// Run a long running query asynchronously
go dbt.mustExecContext(ctx, "CALL SYSTEM$WAIT(50, 'SECONDS')")
// Get the query ID without waiting for the query to finish
queryID := <-queryIDChan
assertNotNilF(t, queryID, "expected a nonempty query ID")
ctx = WithFetchResultByID(ctx, queryID)
rows := dbt.mustQueryContext(ctx, "")
defer rows.Close()
var v string
assertTrueF(t, rows.Next())
err := rows.Scan(&v)
assertNilF(t, err, fmt.Sprintf("failed to get result. err: %v", err))
assertNotNilF(t, v, "should have returned a result")
expected := "waited 50 seconds"
if v != expected {
t.Fatalf("unexpected result returned. expected: %v, but got: %v", expected, v)
}
assertFalseF(t, rows.NextResultSet())
})
}
================================================
FILE: auth.go
================================================
package gosnowflake
import (
"context"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"runtime"
"slices"
"strconv"
"strings"
"time"
sferrors "github.com/snowflakedb/gosnowflake/v2/internal/errors"
"github.com/golang-jwt/jwt/v5"
"github.com/snowflakedb/gosnowflake/v2/internal/compilation"
sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config"
internalos "github.com/snowflakedb/gosnowflake/v2/internal/os"
)
const (
clientType = "Go"
)
const (
clientStoreTemporaryCredential = "CLIENT_STORE_TEMPORARY_CREDENTIAL"
clientRequestMfaToken = "CLIENT_REQUEST_MFA_TOKEN"
idTokenAuthenticator = "ID_TOKEN"
)
// AuthType indicates the type of authentication in Snowflake
type AuthType = sfconfig.AuthType
const (
// AuthTypeSnowflake is the general username password authentication
AuthTypeSnowflake = sfconfig.AuthTypeSnowflake
// AuthTypeOAuth is the OAuth authentication
AuthTypeOAuth = sfconfig.AuthTypeOAuth
// AuthTypeExternalBrowser is to use a browser to access an Fed and perform SSO authentication
AuthTypeExternalBrowser = sfconfig.AuthTypeExternalBrowser
// AuthTypeOkta is to use a native okta URL to perform SSO authentication on Okta
AuthTypeOkta = sfconfig.AuthTypeOkta
// AuthTypeJwt is to use Jwt to perform authentication
AuthTypeJwt = sfconfig.AuthTypeJwt
// AuthTypeTokenAccessor is to use the provided token accessor and bypass authentication
AuthTypeTokenAccessor = sfconfig.AuthTypeTokenAccessor
// AuthTypeUsernamePasswordMFA is to use username and password with mfa
AuthTypeUsernamePasswordMFA = sfconfig.AuthTypeUsernamePasswordMFA
// AuthTypePat is to use programmatic access token
AuthTypePat = sfconfig.AuthTypePat
// AuthTypeOAuthAuthorizationCode is to use browser-based OAuth2 flow
AuthTypeOAuthAuthorizationCode = sfconfig.AuthTypeOAuthAuthorizationCode
// AuthTypeOAuthClientCredentials is to use non-interactive OAuth2 flow
AuthTypeOAuthClientCredentials = sfconfig.AuthTypeOAuthClientCredentials
// AuthTypeWorkloadIdentityFederation is to use CSP identity for authentication
AuthTypeWorkloadIdentityFederation = sfconfig.AuthTypeWorkloadIdentityFederation
)
func isOauthNativeFlow(authType AuthType) bool {
return authType == AuthTypeOAuthAuthorizationCode || authType == AuthTypeOAuthClientCredentials
}
var refreshOAuthTokenErrorCodes = []string{
strconv.Itoa(ErrMissingAccessATokenButRefreshTokenPresent),
invalidOAuthAccessTokenCode,
expiredOAuthAccessTokenCode,
}
// userAgent shows up in User-Agent HTTP header
var userAgent = fmt.Sprintf("%v/%v (%v-%v) %v/%v",
clientType,
SnowflakeGoDriverVersion,
runtime.GOOS,
runtime.GOARCH,
runtime.Compiler,
runtime.Version())
type authRequestClientEnvironment struct {
Application string `json:"APPLICATION"`
ApplicationPath string `json:"APPLICATION_PATH"`
Os string `json:"OS"`
OsVersion string `json:"OS_VERSION"`
OsDetails map[string]string `json:"OS_DETAILS,omitempty"`
Isa string `json:"ISA,omitempty"`
OCSPMode string `json:"OCSP_MODE"`
GoVersion string `json:"GO_VERSION"`
OAuthType string `json:"OAUTH_TYPE,omitempty"`
CertRevocationCheckMode string `json:"CERT_REVOCATION_CHECK_MODE,omitempty"`
Platform []string `json:"PLATFORM,omitempty"`
CoreVersion string `json:"CORE_VERSION,omitempty"`
CoreLoadError string `json:"CORE_LOAD_ERROR,omitempty"`
CoreFileName string `json:"CORE_FILE_NAME,omitempty"`
CgoEnabled bool `json:"CGO_ENABLED,omitempty"`
LinkingMode string `json:"LINKING_MODE,omitempty"`
LibcFamily string `json:"LIBC_FAMILY,omitempty"`
LibcVersion string `json:"LIBC_VERSION,omitempty"`
}
type authRequestData struct {
ClientAppID string `json:"CLIENT_APP_ID"`
ClientAppVersion string `json:"CLIENT_APP_VERSION"`
SvnRevision string `json:"SVN_REVISION"`
AccountName string `json:"ACCOUNT_NAME"`
LoginName string `json:"LOGIN_NAME,omitempty"`
Password string `json:"PASSWORD,omitempty"`
RawSAMLResponse string `json:"RAW_SAML_RESPONSE,omitempty"`
ExtAuthnDuoMethod string `json:"EXT_AUTHN_DUO_METHOD,omitempty"`
Passcode string `json:"PASSCODE,omitempty"`
Authenticator string `json:"AUTHENTICATOR,omitempty"`
SessionParameters map[string]any `json:"SESSION_PARAMETERS,omitempty"`
ClientEnvironment authRequestClientEnvironment `json:"CLIENT_ENVIRONMENT"`
BrowserModeRedirectPort string `json:"BROWSER_MODE_REDIRECT_PORT,omitempty"`
ProofKey string `json:"PROOF_KEY,omitempty"`
Token string `json:"TOKEN,omitempty"`
Provider string `json:"PROVIDER,omitempty"`
}
type authRequest struct {
Data authRequestData `json:"data"`
}
type nameValueParameter struct {
Name string `json:"name"`
Value any `json:"value"`
}
type authResponseSessionInfo struct {
DatabaseName string `json:"databaseName"`
SchemaName string `json:"schemaName"`
WarehouseName string `json:"warehouseName"`
RoleName string `json:"roleName"`
}
type authResponseMain struct {
Token string `json:"token,omitempty"`
Validity time.Duration `json:"validityInSeconds,omitempty"`
MasterToken string `json:"masterToken,omitempty"`
MasterValidity time.Duration `json:"masterValidityInSeconds"`
MfaToken string `json:"mfaToken,omitempty"`
MfaTokenValidity time.Duration `json:"mfaTokenValidityInSeconds"`
IDToken string `json:"idToken,omitempty"`
IDTokenValidity time.Duration `json:"idTokenValidityInSeconds"`
DisplayUserName string `json:"displayUserName"`
ServerVersion string `json:"serverVersion"`
FirstLogin bool `json:"firstLogin"`
RemMeToken string `json:"remMeToken"`
RemMeValidity time.Duration `json:"remMeValidityInSeconds"`
HealthCheckInterval time.Duration `json:"healthCheckInterval"`
NewClientForUpgrade string `json:"newClientForUpgrade"`
SessionID int64 `json:"sessionId"`
Parameters []nameValueParameter `json:"parameters"`
SessionInfo authResponseSessionInfo `json:"sessionInfo"`
TokenURL string `json:"tokenUrl,omitempty"`
SSOURL string `json:"ssoUrl,omitempty"`
ProofKey string `json:"proofKey,omitempty"`
}
type authResponse struct {
Data authResponseMain `json:"data"`
Message string `json:"message"`
Code string `json:"code"`
Success bool `json:"success"`
}
func postAuth(
ctx context.Context,
sr *snowflakeRestful,
client *http.Client,
params *url.Values,
headers map[string]string,
bodyCreator bodyCreatorType,
timeout time.Duration) (
data *authResponse, err error) {
params.Set(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
params.Set(requestGUIDKey, NewUUID().String())
fullURL := sr.getFullURL(loginRequestPath, params)
logger.WithContext(ctx).Infof("full URL: %v", fullURL)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, sr.MaxRetryCount)
if err != nil {
return nil, err
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
logger.WithContext(ctx).Errorf("failed to close HTTP response body for %v. err: %v", fullURL, closeErr)
}
}()
if resp.StatusCode == http.StatusOK {
var respd authResponse
err = json.NewDecoder(resp.Body).Decode(&respd)
if err != nil {
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
return nil, err
}
return &respd, nil
}
switch resp.StatusCode {
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
// service availability or connectivity issue. Most likely server side issue.
return nil, &SnowflakeError{
Number: ErrCodeServiceUnavailable,
SQLState: SQLStateConnectionWasNotEstablished,
Message: sferrors.ErrMsgServiceUnavailable,
MessageArgs: []any{resp.StatusCode, fullURL},
}
case http.StatusUnauthorized, http.StatusForbidden:
// failed to connect to db. account name may be wrong
return nil, &SnowflakeError{
Number: ErrCodeFailedToConnect,
SQLState: SQLStateConnectionRejected,
Message: sferrors.ErrMsgFailedToConnect,
MessageArgs: []any{resp.StatusCode, fullURL},
}
}
b, err := io.ReadAll(resp.Body)
if err != nil {
logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
return nil, err
}
logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
logger.WithContext(ctx).Infof("Header: %v", resp.Header)
return nil, &SnowflakeError{
Number: ErrFailedToAuth,
SQLState: SQLStateConnectionRejected,
Message: sferrors.ErrMsgFailedToAuth,
MessageArgs: []any{resp.StatusCode, fullURL},
}
}
// Generates a map of headers needed to authenticate
// with Snowflake.
func getHeaders() map[string]string {
headers := make(map[string]string)
headers[httpHeaderContentType] = headerContentTypeApplicationJSON
headers[httpHeaderAccept] = headerAcceptTypeApplicationSnowflake
headers[httpClientAppID] = clientType
headers[httpClientAppVersion] = SnowflakeGoDriverVersion
headers[httpHeaderUserAgent] = userAgent
return headers
}
// Used to authenticate the user with Snowflake.
func authenticate(
ctx context.Context,
sc *snowflakeConn,
samlResponse []byte,
proofKey []byte,
) (resp *authResponseMain, err error) {
if sc.cfg.Authenticator == AuthTypeTokenAccessor {
logger.WithContext(ctx).Info("Bypass authentication using existing token from token accessor")
sessionInfo := authResponseSessionInfo{
DatabaseName: sc.cfg.Database,
SchemaName: sc.cfg.Schema,
WarehouseName: sc.cfg.Warehouse,
RoleName: sc.cfg.Role,
}
token, masterToken, sessionID := sc.cfg.TokenAccessor.GetTokens()
return &authResponseMain{
Token: token,
MasterToken: masterToken,
SessionID: sessionID,
SessionInfo: sessionInfo,
}, nil
}
headers := getHeaders()
// Get the current application path
applicationPath, err := os.Executable()
if err != nil {
logger.WithContext(ctx).Warnf("Failed to get executable path: %v", err)
applicationPath = "unknown"
}
oauthType := ""
switch sc.cfg.Authenticator {
case AuthTypeOAuthAuthorizationCode:
oauthType = "OAUTH_AUTHORIZATION_CODE"
case AuthTypeOAuthClientCredentials:
oauthType = "OAUTH_CLIENT_CREDENTIALS"
}
clientEnvironment := newAuthRequestClientEnvironment()
clientEnvironment.Application = sc.cfg.Application
clientEnvironment.ApplicationPath = applicationPath
clientEnvironment.OAuthType = oauthType
clientEnvironment.CertRevocationCheckMode = sc.cfg.CertRevocationCheckMode.String()
clientEnvironment.Platform = getDetectedPlatforms()
sessionParameters := make(map[string]any)
for k, v := range sc.syncParams.All() {
// upper casing to normalize keys
sessionParameters[strings.ToUpper(k)] = v
}
sessionParameters[sessionClientValidateDefaultParameters] = sc.cfg.ValidateDefaultParameters != ConfigBoolFalse
if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue {
sessionParameters[clientRequestMfaToken] = true
}
if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
sessionParameters[clientStoreTemporaryCredential] = true
}
bodyCreator := func() ([]byte, error) {
return createRequestBody(sc, sessionParameters, clientEnvironment, proofKey, samlResponse)
}
params := &url.Values{}
if sc.cfg.Database != "" {
params.Add("databaseName", sc.cfg.Database)
}
if sc.cfg.Schema != "" {
params.Add("schemaName", sc.cfg.Schema)
}
if sc.cfg.Warehouse != "" {
params.Add("warehouse", sc.cfg.Warehouse)
}
if sc.cfg.Role != "" {
params.Add("roleName", sc.cfg.Role)
}
logger.WithContext(ctx).Infof("Information for Auth: Host: %v, User: %v, Authenticator: %v, Params: %v, Protocol: %v, Port: %v, LoginTimeout: %v",
sc.rest.Host, sc.cfg.User, sc.cfg.Authenticator.String(), params, sc.rest.Protocol, sc.rest.Port, sc.rest.LoginTimeout)
respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, sc.rest.getClientFor(sc.cfg.Authenticator), params, headers, bodyCreator, sc.rest.LoginTimeout)
if err != nil {
return nil, err
}
if !respd.Success {
logger.WithContext(ctx).Error("Authentication FAILED")
sc.rest.TokenAccessor.SetTokens("", "", -1)
if sessionParameters[clientRequestMfaToken] == true {
credentialsStorage.deleteCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
}
if sessionParameters[clientStoreTemporaryCredential] == true && sc.cfg.Authenticator == AuthTypeExternalBrowser {
credentialsStorage.deleteCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
}
if sessionParameters[clientStoreTemporaryCredential] == true && isOauthNativeFlow(sc.cfg.Authenticator) {
credentialsStorage.deleteCredential(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
}
code, err := strconv.Atoi(respd.Code)
if err != nil {
return nil, err
}
return nil, exceptionTelemetry(&SnowflakeError{
Number: code,
SQLState: SQLStateConnectionRejected,
Message: respd.Message,
}, sc)
}
logger.WithContext(ctx).Info("Authentication SUCCESS")
sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
if sessionParameters[clientRequestMfaToken] == true {
token := respd.Data.MfaToken
credentialsStorage.setCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token)
}
if sessionParameters[clientStoreTemporaryCredential] == true {
token := respd.Data.IDToken
credentialsStorage.setCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token)
}
return &respd.Data, nil
}
func newAuthRequestClientEnvironment() authRequestClientEnvironment {
var coreVersion string
var coreLoadError string
// Try to get minicore version, but don't block if it's not loaded yet
if !compilation.MinicoreEnabled {
logger.Trace("minicore disabled at compile time")
coreLoadError = "Minicore is disabled at compile time (built with -tags minicore_disabled)"
} else if strings.EqualFold(os.Getenv(disableMinicoreEnv), "true") {
logger.Trace("minicore loading disabled")
coreLoadError = "Minicore is disabled with SF_DISABLE_MINICORE env variable"
} else if mc := getMiniCore(); mc != nil {
var err error
coreVersion, err = mc.FullVersion()
if err != nil {
logger.Debugf("Minicore loading failed. %v", err)
var mcErr *miniCoreError
if errors.As(err, &mcErr) {
coreLoadError = fmt.Sprintf("Failed to load binary: %v", mcErr.errorType)
} else {
coreLoadError = "Failed to load binary: unknown"
}
}
} else {
// Minicore not loaded yet - this is expected during startup
coreVersion = ""
coreLoadError = "Minicore is still loading"
logger.Debugf("Minicore not yet loaded for client environment telemetry")
}
libcInfo := internalos.GetLibcInfo()
linkingMode, err := compilation.CheckDynamicLinking()
if err != nil {
logger.Debugf("cannot determine if app is dynamically linked: %v", err)
}
return authRequestClientEnvironment{
Os: runtime.GOOS,
OsVersion: osVersion,
OsDetails: internalos.GetOsDetails(),
Isa: runtime.GOARCH,
GoVersion: runtime.Version(),
CoreVersion: coreVersion,
CoreFileName: getMiniCoreFileName(),
CoreLoadError: coreLoadError,
CgoEnabled: compilation.CgoEnabled,
LinkingMode: linkingMode.String(),
LibcFamily: libcInfo.Family,
LibcVersion: libcInfo.Version,
}
}
func createRequestBody(sc *snowflakeConn, sessionParameters map[string]any,
clientEnvironment authRequestClientEnvironment, proofKey []byte, samlResponse []byte,
) ([]byte, error) {
requestMain := authRequestData{
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
AccountName: sc.cfg.Account,
SessionParameters: sessionParameters,
ClientEnvironment: clientEnvironment,
}
switch sc.cfg.Authenticator {
case AuthTypeExternalBrowser:
if sc.idToken != "" {
requestMain.Authenticator = idTokenAuthenticator
requestMain.Token = sc.idToken
requestMain.LoginName = sc.cfg.User
} else {
requestMain.ProofKey = string(proofKey)
requestMain.Token = string(samlResponse)
requestMain.LoginName = sc.cfg.User
requestMain.Authenticator = AuthTypeExternalBrowser.String()
}
case AuthTypeOAuth:
requestMain.LoginName = sc.cfg.User
requestMain.Authenticator = AuthTypeOAuth.String()
var err error
if requestMain.Token, err = sfconfig.GetToken(sc.cfg); err != nil {
return nil, fmt.Errorf("failed to get OAuth token: %w", err)
}
case AuthTypeOkta:
samlResponse, err := authenticateBySAML(
sc.ctx,
sc.rest,
sc.cfg.OktaURL,
sc.cfg.Application,
sc.cfg.Account,
sc.cfg.User,
sc.cfg.Password,
sc.cfg.DisableSamlURLCheck)
if err != nil {
return nil, err
}
requestMain.RawSAMLResponse = string(samlResponse)
case AuthTypeJwt:
requestMain.Authenticator = AuthTypeJwt.String()
jwtTokenString, err := prepareJWTToken(sc.cfg)
if err != nil {
return nil, err
}
requestMain.Token = jwtTokenString
case AuthTypePat:
logger.WithContext(sc.ctx).Info("Programmatic access token")
requestMain.Authenticator = AuthTypePat.String()
requestMain.LoginName = sc.cfg.User
var err error
if requestMain.Token, err = sfconfig.GetToken(sc.cfg); err != nil {
return nil, fmt.Errorf("failed to get PAT token: %w", err)
}
case AuthTypeSnowflake:
logger.WithContext(sc.ctx).Debug("Username and password")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
switch {
case sc.cfg.PasscodeInPassword:
requestMain.ExtAuthnDuoMethod = "passcode"
case sc.cfg.Passcode != "":
requestMain.Passcode = sc.cfg.Passcode
requestMain.ExtAuthnDuoMethod = "passcode"
}
case AuthTypeUsernamePasswordMFA:
logger.WithContext(sc.ctx).Debug("Username and password MFA")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
switch {
case sc.mfaToken != "":
requestMain.Token = sc.mfaToken
case sc.cfg.PasscodeInPassword:
requestMain.ExtAuthnDuoMethod = "passcode"
case sc.cfg.Passcode != "":
requestMain.Passcode = sc.cfg.Passcode
requestMain.ExtAuthnDuoMethod = "passcode"
}
case AuthTypeOAuthAuthorizationCode:
logger.WithContext(sc.ctx).Debug("OAuth authorization code")
token, err := authenticateByAuthorizationCode(sc)
if err != nil {
return nil, err
}
requestMain.LoginName = sc.cfg.User
requestMain.Token = token
case AuthTypeOAuthClientCredentials:
logger.WithContext(sc.ctx).Debug("OAuth client credentials")
oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc)
if err != nil {
return nil, err
}
token, err := oauthClient.authenticateByOAuthClientCredentials()
if err != nil {
return nil, err
}
requestMain.LoginName = sc.cfg.User
requestMain.Token = token
case AuthTypeWorkloadIdentityFederation:
logger.WithContext(sc.ctx).Debug("Workload Identity Federation")
wifAttestationProvider := createWifAttestationProvider(sc.ctx, sc.cfg, sc.telemetry)
wifAttestation, err := wifAttestationProvider.getAttestation(sc.cfg.WorkloadIdentityProvider)
if err != nil {
return nil, err
}
if wifAttestation == nil {
return nil, errors.New("workload identity federation attestation is not available, please check your configuration")
}
requestMain.Authenticator = AuthTypeWorkloadIdentityFederation.String()
requestMain.Token = wifAttestation.Credential
requestMain.Provider = wifAttestation.ProviderType
}
logger.WithContext(sc.ctx).Debugf("Request body is created for the authentication. Authenticator: %s, User: %s, Account: %s", sc.cfg.Authenticator.String(), sc.cfg.User, sc.cfg.Account)
authRequest := authRequest{
Data: requestMain,
}
jsonBody, err := json.Marshal(authRequest)
if err != nil {
logger.WithContext(sc.ctx).Errorf("Failed to marshal JSON. err: %v", err)
return nil, err
}
return jsonBody, nil
}
type oauthLockKey struct {
tokenRequestURL string
user string
flowType string
}
func newOAuthAuthorizationCodeLockKey(tokenRequestURL, user string) *oauthLockKey {
return &oauthLockKey{
tokenRequestURL: tokenRequestURL,
user: user,
flowType: "authorization_code",
}
}
func newRefreshTokenLockKey(tokenRequestURL, user string) *oauthLockKey {
return &oauthLockKey{
tokenRequestURL: tokenRequestURL,
user: user,
flowType: "refresh_token",
}
}
func (o *oauthLockKey) lockID() string {
return o.tokenRequestURL + "|" + o.user + "|" + o.flowType
}
func authenticateByAuthorizationCode(sc *snowflakeConn) (string, error) {
oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc)
if err != nil {
return "", err
}
if !isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) {
return oauthClient.authenticateByOAuthAuthorizationCode()
}
lockKey := newOAuthAuthorizationCodeLockKey(oauthClient.tokenURL(), sc.cfg.User)
valueAwaiter := valueAwaitHolder.get(lockKey)
defer valueAwaiter.resumeOne()
token, err := awaitValue(valueAwaiter, func() (string, error) {
return credentialsStorage.getCredential(newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)), nil
}, func(s string, err error) bool {
return s != ""
}, func() string {
return ""
})
if err != nil || token != "" {
return token, err
}
token, err = oauthClient.authenticateByOAuthAuthorizationCode()
if err != nil {
return "", err
}
valueAwaiter.done()
return token, err
}
// Generate a JWT token in string given the configuration
func prepareJWTToken(config *Config) (string, error) {
if config.PrivateKey == nil {
return "", errors.New("trying to use keypair authentication, but PrivateKey was not provided in the driver config")
}
logger.Debug("preparing JWT for keypair authentication")
pubBytes, err := x509.MarshalPKIXPublicKey(config.PrivateKey.Public())
if err != nil {
return "", err
}
hash := sha256.Sum256(pubBytes)
accountName := sfconfig.ExtractAccountName(config.Account)
userName := strings.ToUpper(config.User)
issueAtTime := time.Now().UTC()
jwtClaims := jwt.MapClaims{
"iss": fmt.Sprintf("%s.%s.%s", accountName, userName, "SHA256:"+base64.StdEncoding.EncodeToString(hash[:])),
"sub": fmt.Sprintf("%s.%s", accountName, userName),
"iat": issueAtTime.Unix(),
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
"exp": issueAtTime.Add(config.JWTExpireTimeout).Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwtClaims)
tokenString, err := token.SignedString(config.PrivateKey)
if err != nil {
return "", err
}
logger.Debugf("successfully generated JWT with following claims: %v", jwtClaims)
return tokenString, err
}
type tokenLockKey struct {
snowflakeHost string
user string
tokenType string
}
func newMfaTokenLockKey(snowflakeHost, user string) *tokenLockKey {
return &tokenLockKey{
snowflakeHost: snowflakeHost,
user: user,
tokenType: "MFA",
}
}
func newIDTokenLockKey(snowflakeHost, user string) *tokenLockKey {
return &tokenLockKey{
snowflakeHost: snowflakeHost,
user: user,
tokenType: "ID",
}
}
func (m *tokenLockKey) lockID() string {
return m.snowflakeHost + "|" + m.user + "|" + m.tokenType
}
func authenticateWithConfig(sc *snowflakeConn) error {
var authData *authResponseMain
var samlResponse []byte
var proofKey []byte
var err error
mfaTokenLockKey := newMfaTokenLockKey(sc.cfg.Host, sc.cfg.User)
idTokenLockKey := newIDTokenLockKey(sc.cfg.Host, sc.cfg.User)
if sc.cfg.Authenticator == AuthTypeExternalBrowser || sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode || sc.cfg.Authenticator == AuthTypeOAuthClientCredentials {
if (runtime.GOOS == "windows" || runtime.GOOS == "darwin") && sc.cfg.ClientStoreTemporaryCredential == sfconfig.BoolNotSet {
sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
}
if sc.cfg.Authenticator == AuthTypeExternalBrowser {
if isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) {
valueAwaiter := valueAwaitHolder.get(idTokenLockKey)
defer valueAwaiter.resumeOne()
sc.idToken, _ = awaitValue(valueAwaiter, func() (string, error) {
credential := credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
return credential, nil
}, func(s string, err error) bool {
return s != ""
}, func() string {
return ""
})
} else if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
sc.idToken = credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
}
}
// Disable console login by default
if sc.cfg.DisableConsoleLogin == sfconfig.BoolNotSet {
sc.cfg.DisableConsoleLogin = ConfigBoolTrue
}
}
if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA {
if (runtime.GOOS == "windows" || runtime.GOOS == "darwin") && sc.cfg.ClientRequestMfaToken == sfconfig.BoolNotSet {
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
}
if isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientRequestMfaToken) {
valueAwaiter := valueAwaitHolder.get(mfaTokenLockKey)
defer valueAwaiter.resumeOne()
sc.mfaToken, _ = awaitValue(valueAwaiter, func() (string, error) {
credential := credentialsStorage.getCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
return credential, nil
}, func(s string, err error) bool {
return s != ""
}, func() string {
return ""
})
} else if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue {
sc.mfaToken = credentialsStorage.getCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
}
}
logger.WithContext(sc.ctx).Infof("Authenticating via %v", sc.cfg.Authenticator.String())
switch sc.cfg.Authenticator {
case AuthTypeExternalBrowser:
if sc.idToken == "" {
samlResponse, proofKey, err = authenticateByExternalBrowser(
sc.ctx,
sc.rest,
sc.cfg.Authenticator.String(),
sc.cfg.Application,
sc.cfg.Account,
sc.cfg.User,
sc.cfg.ExternalBrowserTimeout,
sc.cfg.DisableConsoleLogin)
if err != nil {
sc.cleanup()
return err
}
}
}
authData, err = authenticate(
sc.ctx,
sc,
samlResponse,
proofKey)
if err != nil {
var se *SnowflakeError
if errors.As(err, &se) && slices.Contains(refreshOAuthTokenErrorCodes, strconv.Itoa(se.Number)) {
credentialsStorage.deleteCredential(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode {
doRefreshTokenWithLock(sc)
}
// if refreshing succeeds for authorization code, we will take a token from cache
// if it fails, we will just run the full flow
authData, err = authenticate(sc.ctx, sc, nil, nil)
}
if err != nil {
sc.cleanup()
return err
}
}
if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA && isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientRequestMfaToken) {
valueAwaiter := valueAwaitHolder.get(mfaTokenLockKey)
valueAwaiter.done()
}
if sc.cfg.Authenticator == AuthTypeExternalBrowser && isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) {
valueAwaiter := valueAwaitHolder.get(idTokenLockKey)
valueAwaiter.done()
}
sc.populateSessionParameters(authData.Parameters)
sc.configureTelemetry()
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
return nil
}
func doRefreshTokenWithLock(sc *snowflakeConn) {
if oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc); err != nil {
logger.Warnf("failed to create oauth client. %v", err)
} else {
lockKey := newRefreshTokenLockKey(oauthClient.tokenURL(), sc.cfg.User)
if _, err = getValueWithLock(chooseLockerForAuth(sc.cfg), lockKey, func() (string, error) {
if err = oauthClient.refreshToken(); err != nil {
logger.Warnf("cannot refresh token. %v", err)
credentialsStorage.deleteCredential(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
return "", err
}
return "", nil
}); err != nil {
logger.Warnf("failed to refresh token with lock. %v", err)
}
}
}
func chooseLockerForAuth(cfg *Config) locker {
if cfg.SingleAuthenticationPrompt == ConfigBoolFalse {
return noopLocker
}
if cfg.User == "" {
return noopLocker
}
return exclusiveLocker
}
func isEligibleForParallelLogin(cfg *Config, cacheEnabled ConfigBool) bool {
return cfg.SingleAuthenticationPrompt != ConfigBoolFalse && cfg.User != "" && cacheEnabled == ConfigBoolTrue
}
================================================
FILE: auth_generic_test_methods_test.go
================================================
package gosnowflake
import (
"fmt"
"os"
"testing"
)
func getAuthTestConfigFromEnv() (*Config, error) {
return GetConfigFromEnv([]*ConfigParam{
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
{Name: "User", EnvName: "SNOWFLAKE_AUTH_TEST_OKTA_USER", FailOnMissing: true},
{Name: "Password", EnvName: "SNOWFLAKE_AUTH_TEST_OKTA_PASS", FailOnMissing: true},
{Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false},
{Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false},
{Name: "Protocol", EnvName: "SNOWFLAKE_AUTH_TEST_PROTOCOL", FailOnMissing: false},
{Name: "Role", EnvName: "SNOWFLAKE_TEST_ROLE", FailOnMissing: false},
{Name: "Warehouse", EnvName: "SNOWFLAKE_TEST_WAREHOUSE", FailOnMissing: false},
})
}
func getAuthTestsConfig(t *testing.T, authMethod AuthType) (*Config, error) {
cfg, err := getAuthTestConfigFromEnv()
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err))
cfg.Authenticator = authMethod
return cfg, nil
}
func isTestRunningInDockerContainer() bool {
return os.Getenv("AUTHENTICATION_TESTS_ENV") == "docker"
}
================================================
FILE: auth_oauth.go
================================================
package gosnowflake
import (
"bufio"
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"html"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)
const (
oauthSuccessHTML = `
OAuth for Snowflake
OAuth authentication completed successfully.
`
localApplicationClientCredentials = "LOCAL_APPLICATION"
)
var defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider {
return &browserBasedAuthorizationCodeProvider{}
}
type oauthClient struct {
ctx context.Context
cfg *Config
client *http.Client
port int
redirectURITemplate string
authorizationCodeProviderFactory func() authorizationCodeProvider
}
func newOauthClient(ctx context.Context, cfg *Config, sc *snowflakeConn) (*oauthClient, error) {
port := 0
if cfg.OauthRedirectURI != "" {
logger.Debugf("Using oauthRedirectUri from config: %v", cfg.OauthRedirectURI)
uri, err := url.Parse(cfg.OauthRedirectURI)
if err != nil {
return nil, err
}
portStr := uri.Port()
if portStr != "" {
if port, err = strconv.Atoi(portStr); err != nil {
return nil, err
}
}
}
redirectURITemplate := ""
if cfg.OauthRedirectURI == "" {
redirectURITemplate = "http://127.0.0.1:%v"
}
logger.Debugf("Redirect URI template: %v, port: %v", redirectURITemplate, port)
transport, err := newTransportFactory(cfg, sc.telemetry).createTransport(transportConfigFor(transportTypeOAuth))
if err != nil {
return nil, err
}
client := &http.Client{
Transport: transport,
}
return &oauthClient{
ctx: context.WithValue(ctx, oauth2.HTTPClient, client),
cfg: cfg,
client: client,
port: port,
redirectURITemplate: redirectURITemplate,
authorizationCodeProviderFactory: defaultAuthorizationCodeProviderFactory,
}, nil
}
type oauthBrowserResult struct {
accessToken string
refreshToken string
err error
}
func (oauthClient *oauthClient) authenticateByOAuthAuthorizationCode() (string, error) {
accessTokenSpec := oauthClient.accessTokenSpec()
if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
if accessToken := credentialsStorage.getCredential(accessTokenSpec); accessToken != "" {
logger.Debugf("Access token retrieved from cache")
return accessToken, nil
}
if refreshToken := credentialsStorage.getCredential(oauthClient.refreshTokenSpec()); refreshToken != "" {
return "", &SnowflakeError{Number: ErrMissingAccessATokenButRefreshTokenPresent}
}
}
logger.Debugf("Access token not present in cache, running full auth code flow")
resultChan := make(chan oauthBrowserResult, 1)
tcpListener, callbackPort, err := oauthClient.setupListener()
if err != nil {
return "", err
}
defer func() {
logger.Debug("Closing tcp listener")
if err := tcpListener.Close(); err != nil {
logger.Warnf("error while closing TCP listener. %v", err)
}
}()
go GoroutineWrapper(oauthClient.ctx, func() {
resultChan <- oauthClient.doAuthenticateByOAuthAuthorizationCode(tcpListener, callbackPort)
})
select {
case <-time.After(oauthClient.cfg.ExternalBrowserTimeout):
return "", errors.New("authentication via browser timed out")
case result := <-resultChan:
if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
logger.Debug("saving oauth access token in cache")
credentialsStorage.setCredential(oauthClient.accessTokenSpec(), result.accessToken)
credentialsStorage.setCredential(oauthClient.refreshTokenSpec(), result.refreshToken)
}
return result.accessToken, result.err
}
}
func (oauthClient *oauthClient) doAuthenticateByOAuthAuthorizationCode(tcpListener *net.TCPListener, callbackPort int) oauthBrowserResult {
authCodeProvider := oauthClient.authorizationCodeProviderFactory()
successChan := make(chan []byte)
errChan := make(chan error)
responseBodyChan := make(chan string, 2)
closeListenerChan := make(chan bool, 2)
defer func() {
closeListenerChan <- true
close(successChan)
close(errChan)
close(responseBodyChan)
close(closeListenerChan)
}()
logger.Debugf("opening socket on port %v", callbackPort)
defer func(tcpListener *net.TCPListener) {
<-closeListenerChan
}(tcpListener)
go handleOAuthSocket(tcpListener, successChan, errChan, responseBodyChan, closeListenerChan)
oauth2cfg := oauthClient.buildAuthorizationCodeConfig(callbackPort)
codeVerifier := authCodeProvider.createCodeVerifier()
state := authCodeProvider.createState()
authorizationURL := oauth2cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(codeVerifier))
if err := authCodeProvider.run(authorizationURL); err != nil {
responseBodyChan <- err.Error()
closeListenerChan <- true
return oauthBrowserResult{"", "", err}
}
err := <-errChan
if err != nil {
responseBodyChan <- err.Error()
return oauthBrowserResult{"", "", err}
}
codeReqBytes := <-successChan
codeReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(codeReqBytes)))
if err != nil {
responseBodyChan <- err.Error()
return oauthBrowserResult{"", "", err}
}
logger.Debugf("Received authorization code from %v", oauthClient.authorizationURL())
tokenResponse, err := oauthClient.exchangeAccessToken(codeReq, state, oauth2cfg, codeVerifier, responseBodyChan)
if err != nil {
return oauthBrowserResult{"", "", err}
}
logger.Debugf("Received token from %v", oauthClient.tokenURL())
return oauthBrowserResult{tokenResponse.AccessToken, tokenResponse.RefreshToken, err}
}
func (oauthClient *oauthClient) setupListener() (*net.TCPListener, int, error) {
tcpListener, err := createLocalTCPListener(oauthClient.port)
if err != nil {
return nil, 0, err
}
callbackPort := tcpListener.Addr().(*net.TCPAddr).Port
logger.Debugf("oauthClient.port: %v, callbackPort: %v", oauthClient.port, callbackPort)
return tcpListener, callbackPort, nil
}
func (oauthClient *oauthClient) exchangeAccessToken(codeReq *http.Request, state string, oauth2cfg *oauth2.Config, codeVerifier string, responseBodyChan chan string) (*oauth2.Token, error) {
queryParams := codeReq.URL.Query()
errorMsg := queryParams.Get("error")
if errorMsg != "" {
errorDesc := queryParams.Get("error_description")
errMsg := fmt.Sprintf("error while getting authentication from oauth: %v. Details: %v", errorMsg, errorDesc)
responseBodyChan <- html.EscapeString(errMsg)
return nil, errors.New(errMsg)
}
receivedState := queryParams.Get("state")
if state != receivedState {
errMsg := "invalid oauth state received"
responseBodyChan <- errMsg
return nil, errors.New(errMsg)
}
code := queryParams.Get("code")
opts := []oauth2.AuthCodeOption{oauth2.VerifierOption(codeVerifier)}
if oauthClient.cfg.EnableSingleUseRefreshTokens {
opts = append(opts, oauth2.SetAuthURLParam("enable_single_use_refresh_tokens", "true"))
}
token, err := oauth2cfg.Exchange(oauthClient.ctx, code, opts...)
if err != nil {
responseBodyChan <- err.Error()
return nil, err
}
responseBodyChan <- oauthSuccessHTML
return token, nil
}
func (oauthClient *oauthClient) buildAuthorizationCodeConfig(callbackPort int) *oauth2.Config {
clientID, clientSecret := oauthClient.cfg.OauthClientID, oauthClient.cfg.OauthClientSecret
if oauthClient.eligibleForDefaultClientCredentials() {
clientID, clientSecret = localApplicationClientCredentials, localApplicationClientCredentials
}
oauthClient.logIfHTTPInUse(oauthClient.authorizationURL())
oauthClient.logIfHTTPInUse(oauthClient.tokenURL())
return &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: oauthClient.buildRedirectURI(callbackPort),
Scopes: oauthClient.buildScopes(),
Endpoint: oauth2.Endpoint{
AuthURL: oauthClient.authorizationURL(),
TokenURL: oauthClient.tokenURL(),
AuthStyle: oauth2.AuthStyleInHeader,
},
}
}
func (oauthClient *oauthClient) eligibleForDefaultClientCredentials() bool {
return oauthClient.cfg.OauthClientID == "" && oauthClient.cfg.OauthClientSecret == "" && oauthClient.isSnowflakeAsIDP()
}
func (oauthClient *oauthClient) isSnowflakeAsIDP() bool {
return (oauthClient.cfg.OauthAuthorizationURL == "" || strings.Contains(oauthClient.cfg.OauthAuthorizationURL, oauthClient.cfg.Host)) &&
(oauthClient.cfg.OauthTokenRequestURL == "" || strings.Contains(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.Host))
}
func (oauthClient *oauthClient) authorizationURL() string {
return cmp.Or(oauthClient.cfg.OauthAuthorizationURL, oauthClient.defaultAuthorizationURL())
}
func (oauthClient *oauthClient) defaultAuthorizationURL() string {
return fmt.Sprintf("%v://%v:%v/oauth/authorize", oauthClient.cfg.Protocol, oauthClient.cfg.Host, oauthClient.cfg.Port)
}
func (oauthClient *oauthClient) tokenURL() string {
return cmp.Or(oauthClient.cfg.OauthTokenRequestURL, oauthClient.defaultTokenURL())
}
func (oauthClient *oauthClient) defaultTokenURL() string {
return fmt.Sprintf("%v://%v:%v/oauth/token-request", oauthClient.cfg.Protocol, oauthClient.cfg.Host, oauthClient.cfg.Port)
}
func (oauthClient *oauthClient) buildRedirectURI(port int) string {
if oauthClient.cfg.OauthRedirectURI != "" {
return oauthClient.cfg.OauthRedirectURI
}
return fmt.Sprintf(oauthClient.redirectURITemplate, port)
}
func (oauthClient *oauthClient) buildScopes() []string {
if oauthClient.cfg.OauthScope == "" {
return []string{"session:role:" + oauthClient.cfg.Role}
}
scopes := strings.Split(oauthClient.cfg.OauthScope, " ")
for i, scope := range scopes {
scopes[i] = strings.TrimSpace(scope)
}
return scopes
}
func handleOAuthSocket(tcpListener *net.TCPListener, successChan chan []byte, errChan chan error, responseBodyChan chan string, closeListenerChan chan bool) {
conn, err := tcpListener.AcceptTCP()
if err != nil {
logger.Warnf("error creating socket. %v", err)
return
}
defer func() {
if err := conn.Close(); err != nil {
logger.Warnf("error while closing connection (%v -> %v). %v", conn.LocalAddr(), conn.RemoteAddr(), err)
}
}()
var buf [bufSize]byte
codeResp := bytes.NewBuffer(nil)
for {
readBytes, err := conn.Read(buf[:])
if err == io.EOF {
break
}
if err != nil {
errChan <- err
return
}
codeResp.Write(buf[0:readBytes])
if readBytes < bufSize {
break
}
}
errChan <- nil
successChan <- codeResp.Bytes()
responseBody := <-responseBodyChan
respToBrowser, err := buildResponse(responseBody)
if err != nil {
logger.Warnf("cannot create response to browser. %v", err)
}
_, err = conn.Write(respToBrowser.Bytes())
if err != nil {
logger.Warnf("cannot write response to browser. %v", err)
}
closeListenerChan <- true
}
type authorizationCodeProvider interface {
run(authorizationURL string) error
createState() string
createCodeVerifier() string
}
type browserBasedAuthorizationCodeProvider struct {
}
func (provider *browserBasedAuthorizationCodeProvider) run(authorizationURL string) error {
return openBrowser(authorizationURL)
}
func (provider *browserBasedAuthorizationCodeProvider) createState() string {
return NewUUID().String()
}
func (provider *browserBasedAuthorizationCodeProvider) createCodeVerifier() string {
return oauth2.GenerateVerifier()
}
func (oauthClient *oauthClient) authenticateByOAuthClientCredentials() (string, error) {
accessTokenSpec := oauthClient.accessTokenSpec()
if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
if accessToken := credentialsStorage.getCredential(accessTokenSpec); accessToken != "" {
return accessToken, nil
}
}
oauth2Cfg, err := oauthClient.buildClientCredentialsConfig()
if err != nil {
return "", err
}
token, err := oauth2Cfg.Token(oauthClient.ctx)
if err != nil {
return "", err
}
if oauthClient.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
credentialsStorage.setCredential(accessTokenSpec, token.AccessToken)
}
return token.AccessToken, nil
}
func (oauthClient *oauthClient) buildClientCredentialsConfig() (*clientcredentials.Config, error) {
if oauthClient.cfg.OauthTokenRequestURL == "" {
return nil, errors.New("client credentials flow requires tokenRequestURL")
}
return &clientcredentials.Config{
ClientID: oauthClient.cfg.OauthClientID,
ClientSecret: oauthClient.cfg.OauthClientSecret,
TokenURL: oauthClient.cfg.OauthTokenRequestURL,
Scopes: oauthClient.buildScopes(),
}, nil
}
func (oauthClient *oauthClient) refreshToken() error {
if oauthClient.cfg.ClientStoreTemporaryCredential != ConfigBoolTrue {
logger.Debug("credentials storage is disabled, cannot use refresh tokens")
return nil
}
refreshTokenSpec := newOAuthRefreshTokenSpec(oauthClient.cfg.OauthTokenRequestURL, oauthClient.cfg.User)
refreshToken := credentialsStorage.getCredential(refreshTokenSpec)
if refreshToken == "" {
logger.Debug("no refresh token in cache, full flow must be run")
return nil
}
body := url.Values{}
body.Add("grant_type", "refresh_token")
body.Add("refresh_token", refreshToken)
body.Add("scope", strings.Join(oauthClient.buildScopes(), " "))
req, err := http.NewRequest("POST", oauthClient.tokenURL(), strings.NewReader(body.Encode()))
if err != nil {
return err
}
req.SetBasicAuth(oauthClient.cfg.OauthClientID, oauthClient.cfg.OauthClientSecret)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := oauthClient.client.Do(req)
if err != nil {
return err
}
defer func() {
if err := resp.Body.Close(); err != nil {
logger.Warnf("error while closing response body for %v. %v", req.URL, err)
}
}()
if resp.StatusCode != 200 {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
credentialsStorage.deleteCredential(refreshTokenSpec)
return errors.New(string(respBody))
}
var tokenResponse tokenExchangeResponseBody
if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
return err
}
accessTokenSpec := oauthClient.accessTokenSpec()
credentialsStorage.setCredential(accessTokenSpec, tokenResponse.AccessToken)
if tokenResponse.RefreshToken != "" {
credentialsStorage.setCredential(refreshTokenSpec, tokenResponse.RefreshToken)
}
return nil
}
type tokenExchangeResponseBody struct {
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token"`
}
func (oauthClient *oauthClient) accessTokenSpec() *secureTokenSpec {
return newOAuthAccessTokenSpec(oauthClient.tokenURL(), oauthClient.cfg.User)
}
func (oauthClient *oauthClient) refreshTokenSpec() *secureTokenSpec {
return newOAuthRefreshTokenSpec(oauthClient.tokenURL(), oauthClient.cfg.User)
}
func (oauthClient *oauthClient) logIfHTTPInUse(u string) {
parsed, err := url.Parse(u)
if err != nil {
logger.Warnf("Cannot parse URL: %v. %v", u, err)
return
}
if parsed.Scheme == "http" {
logger.Warnf("OAuth URL uses insecure HTTP protocol: %v", u)
}
}
================================================
FILE: auth_oauth_test.go
================================================
package gosnowflake
import (
"context"
"database/sql"
"errors"
sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config"
"io"
"net/http"
"sync"
"testing"
"time"
"golang.org/x/oauth2"
)
func TestUnitOAuthAuthorizationCode(t *testing.T) {
skipOnMac(t, "keychain requires password")
roundTripper := newCountingRoundTripper(createTestNoRevocationTransport())
httpClient := &http.Client{
Transport: roundTripper,
}
cfg := &Config{
User: "testUser",
Role: "ANALYST",
OauthClientID: "testClientId",
OauthClientSecret: "testClientSecret",
OauthAuthorizationURL: wiremock.baseURL() + "/oauth/authorize",
OauthTokenRequestURL: wiremock.baseURL() + "/oauth/token",
OauthRedirectURI: "http://localhost:1234/snowflake/oauth-redirect",
Transporter: roundTripper,
ClientStoreTemporaryCredential: ConfigBoolTrue,
ExternalBrowserTimeout: time.Duration(sfconfig.DefaultExternalBrowserTimeout),
}
client, err := newOauthClient(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient), cfg, &snowflakeConn{})
assertNilF(t, err)
accessTokenSpec := newOAuthAccessTokenSpec(wiremock.connectionConfig().OauthTokenRequestURL, wiremock.connectionConfig().User)
refreshTokenSpec := newOAuthRefreshTokenSpec(wiremock.connectionConfig().OauthTokenRequestURL, wiremock.connectionConfig().User)
t.Run("Success", func(t *testing.T) {
credentialsStorage.deleteCredential(accessTokenSpec)
credentialsStorage.deleteCredential(refreshTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"))
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
return authCodeProvider
}
token, err := client.authenticateByOAuthAuthorizationCode()
assertNilF(t, err)
assertEqualE(t, token, "access-token-123")
time.Sleep(100 * time.Millisecond)
authCodeProvider.assertResponseBodyContains("OAuth authentication completed successfully.")
})
t.Run("Store access token in cache", func(t *testing.T) {
skipOnMissingHome(t)
roundTripper.reset()
credentialsStorage.deleteCredential(accessTokenSpec)
credentialsStorage.deleteCredential(refreshTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"))
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
return authCodeProvider
}
_, err = client.authenticateByOAuthAuthorizationCode()
assertNilF(t, err)
assertEqualE(t, credentialsStorage.getCredential(accessTokenSpec), "access-token-123")
})
t.Run("Use cache for consecutive calls", func(t *testing.T) {
skipOnMissingHome(t)
roundTripper.reset()
credentialsStorage.setCredential(accessTokenSpec, "access-token-123")
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"))
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
for range 3 {
client, err := newOauthClient(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient), cfg, &snowflakeConn{})
assertNilF(t, err)
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
return authCodeProvider
}
_, err = client.authenticateByOAuthAuthorizationCode()
assertNilF(t, err)
}
assertEqualE(t, authCodeProvider.responseBody, "")
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 0)
})
t.Run("InvalidState", func(t *testing.T) {
credentialsStorage.deleteCredential(accessTokenSpec)
credentialsStorage.deleteCredential(refreshTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"))
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{
tamperWithState: true,
t: t,
}
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
return authCodeProvider
}
_, err = client.authenticateByOAuthAuthorizationCode()
assertEqualE(t, err.Error(), "invalid oauth state received")
time.Sleep(100 * time.Millisecond)
authCodeProvider.assertResponseBodyContains("invalid oauth state received")
})
t.Run("ErrorFromIdPWhileGettingCode", func(t *testing.T) {
credentialsStorage.deleteCredential(accessTokenSpec)
credentialsStorage.deleteCredential(refreshTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/error_from_idp.json"))
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
return authCodeProvider
}
_, err = client.authenticateByOAuthAuthorizationCode()
assertEqualE(t, err.Error(), "error while getting authentication from oauth: some error. Details: some error desc")
time.Sleep(100 * time.Millisecond)
authCodeProvider.assertResponseBodyContains("error while getting authentication from oauth: some error. Details: some error desc")
})
t.Run("ErrorFromProviderWhileGettingCode", func(t *testing.T) {
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{
triggerError: "test error",
}
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
return authCodeProvider
}
_, err = client.authenticateByOAuthAuthorizationCode()
assertEqualE(t, err.Error(), "test error")
})
t.Run("InvalidCode", func(t *testing.T) {
credentialsStorage.deleteCredential(accessTokenSpec)
credentialsStorage.deleteCredential(refreshTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/invalid_code.json"))
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
return authCodeProvider
}
_, err = client.authenticateByOAuthAuthorizationCode()
assertNotNilE(t, err)
assertEqualE(t, err.(*oauth2.RetrieveError).ErrorCode, "invalid_grant")
assertEqualE(t, err.(*oauth2.RetrieveError).ErrorDescription, "The authorization code is invalid or has expired.")
time.Sleep(100 * time.Millisecond)
authCodeProvider.assertResponseBodyContains("invalid_grant")
})
t.Run("timeout", func(t *testing.T) {
credentialsStorage.deleteCredential(accessTokenSpec)
credentialsStorage.deleteCredential(refreshTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"))
client.cfg.ExternalBrowserTimeout = 2 * time.Second
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{
sleepTime: 3 * time.Second,
triggerError: "timed out",
t: t,
}
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
return authCodeProvider
}
_, err = client.authenticateByOAuthAuthorizationCode()
assertNotNilE(t, err)
assertStringContainsE(t, err.Error(), "timed out")
time.Sleep(2 * time.Second) // awaiting timeout
})
}
func TestUnitOAuthClientCredentials(t *testing.T) {
skipOnMac(t, "keychain requires password")
cacheTokenSpec := newOAuthAccessTokenSpec(wiremock.connectionConfig().OauthTokenRequestURL, wiremock.connectionConfig().User)
crt := newCountingRoundTripper(createTestNoRevocationTransport())
httpClient := http.Client{
Transport: crt,
}
cfgFactory := func() *Config {
return &Config{
User: "testUser",
Role: "ANALYST",
OauthClientID: "testClientId",
OauthClientSecret: "testClientSecret",
OauthTokenRequestURL: wiremock.baseURL() + "/oauth/token",
Transporter: crt,
ClientStoreTemporaryCredential: ConfigBoolTrue,
}
}
client, err := newOauthClient(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient), cfgFactory(), &snowflakeConn{})
assertNilF(t, err)
t.Run("success", func(t *testing.T) {
credentialsStorage.deleteCredential(cacheTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"))
token, err := client.authenticateByOAuthClientCredentials()
assertNilF(t, err)
assertEqualE(t, token, "access-token-123")
})
t.Run("should store token in cache", func(t *testing.T) {
skipOnMissingHome(t)
crt.reset()
credentialsStorage.deleteCredential(cacheTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"))
token, err := client.authenticateByOAuthClientCredentials()
assertNilF(t, err)
assertEqualE(t, token, "access-token-123")
client, err := newOauthClient(context.Background(), cfgFactory(), &snowflakeConn{})
assertNilF(t, err)
token, err = client.authenticateByOAuthClientCredentials()
assertNilF(t, err)
assertEqualE(t, token, "access-token-123")
assertEqualE(t, crt.postReqCount[cfgFactory().OauthTokenRequestURL], 1)
})
t.Run("consecutive calls should take token from cache", func(t *testing.T) {
skipOnMissingHome(t)
crt.reset()
credentialsStorage.setCredential(cacheTokenSpec, "access-token-123")
for range 3 {
client, err := newOauthClient(context.Background(), cfgFactory(), &snowflakeConn{})
assertNilF(t, err)
token, err := client.authenticateByOAuthClientCredentials()
assertNilF(t, err)
assertEqualE(t, token, "access-token-123")
}
assertEqualE(t, crt.postReqCount[cfgFactory().OauthTokenRequestURL], 0)
})
t.Run("disabling cache", func(t *testing.T) {
skipOnMissingHome(t)
cfg := cfgFactory()
cfg.ClientStoreTemporaryCredential = ConfigBoolFalse
credentialsStorage.deleteCredential(cacheTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"))
client, err := newOauthClient(context.Background(), cfg, &snowflakeConn{})
assertNilF(t, err)
token, err := client.authenticateByOAuthClientCredentials()
assertNilF(t, err)
assertEqualE(t, token, "access-token-123")
client, err = newOauthClient(context.Background(), cfg, &snowflakeConn{})
assertNilF(t, err)
token, err = client.authenticateByOAuthClientCredentials()
assertNilF(t, err)
assertEqualE(t, token, "access-token-123")
assertEqualE(t, crt.postReqCount[cfg.OauthTokenRequestURL], 2)
})
t.Run("invalid_client", func(t *testing.T) {
credentialsStorage.deleteCredential(cacheTokenSpec)
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/client_credentials/invalid_client.json"))
_, err = client.authenticateByOAuthClientCredentials()
assertNotNilF(t, err)
oauth2Err := err.(*oauth2.RetrieveError)
assertEqualE(t, oauth2Err.ErrorCode, "invalid_client")
assertEqualE(t, oauth2Err.ErrorDescription, "The client secret supplied for a confidential client is invalid.")
})
}
func TestAuthorizationCodeFlow(t *testing.T) {
if runningOnGithubAction() && runningOnLinux() {
t.Skip("Github blocks writing to file system")
}
skipOnMac(t, "keychain requires password")
currentDefaultAuthorizationCodeProviderFactory := defaultAuthorizationCodeProviderFactory
defer func() {
defaultAuthorizationCodeProviderFactory = currentDefaultAuthorizationCodeProviderFactory
}()
defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider {
return &nonInteractiveAuthorizationCodeProvider{
t: t,
mu: sync.Mutex{},
}
}
roundTripper := newCountingRoundTripper(createTestNoRevocationTransport())
t.Run("successful flow", func(t *testing.T) {
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
cfg := wiremock.connectionConfig()
cfg.Role = "ANALYST"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.deleteCredential(oauthAccessTokenSpec)
credentialsStorage.deleteCredential(oauthRefreshTokenSpec)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
})
t.Run("successful flow with multiple threads", func(t *testing.T) {
for _, singleAuthenticationPrompt := range []ConfigBool{ConfigBoolFalse, ConfigBoolTrue, configBoolNotSet} {
t.Run("singleAuthenticationPrompt="+singleAuthenticationPrompt.String(), func(t *testing.T) {
currentDefaultAuthorizationCodeProviderFactory := defaultAuthorizationCodeProviderFactory
defer func() {
defaultAuthorizationCodeProviderFactory = currentDefaultAuthorizationCodeProviderFactory
}()
defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider {
return &nonInteractiveAuthorizationCodeProvider{
t: t,
mu: sync.Mutex{},
sleepTime: 500 * time.Millisecond,
}
}
roundTripper.reset()
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"),
newWiremockMapping("close_session.json"))
cfg := wiremock.connectionConfig()
cfg.Role = "ANALYST"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.Transporter = roundTripper
cfg.SingleAuthenticationPrompt = singleAuthenticationPrompt
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.deleteCredential(oauthAccessTokenSpec)
credentialsStorage.deleteCredential(oauthRefreshTokenSpec)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
initPoolWithSize(t, db, 20)
println(roundTripper.postReqCount[cfg.OauthTokenRequestURL])
if singleAuthenticationPrompt == ConfigBoolFalse {
assertTrueE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL] > 1)
} else {
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1)
}
})
}
})
t.Run("successful flow with single-use refresh token enabled", func(t *testing.T) {
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/authorization_code/successful_flow_with_single_use_refresh_token.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
cfg := wiremock.connectionConfig()
cfg.Role = "ANALYST"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
cfg.EnableSingleUseRefreshTokens = true
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.deleteCredential(oauthAccessTokenSpec)
credentialsStorage.deleteCredential(oauthRefreshTokenSpec)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
})
t.Run("should use cached access token", func(t *testing.T) {
roundTripper.reset()
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
cfg := wiremock.connectionConfig()
cfg.Role = "ANALYST"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.deleteCredential(oauthAccessTokenSpec)
credentialsStorage.deleteCredential(oauthRefreshTokenSpec)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
conn1, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn1.Close()
conn2, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn2.Close()
runSmokeQueryWithConn(t, conn1)
runSmokeQueryWithConn(t, conn2)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1)
})
t.Run("should update cache with new token when the old one expired if refresh token is missing", func(t *testing.T) {
roundTripper.reset()
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"),
newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
cfg := wiremock.connectionConfig()
cfg.Role = "ANALYST"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token")
credentialsStorage.deleteCredential(oauthRefreshTokenSpec)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1)
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123")
})
t.Run("if access token is missing and refresh token is present, should run refresh token flow", func(t *testing.T) {
roundTripper.reset()
cfg := wiremock.connectionConfig()
cfg.OauthScope = "session:role:ANALYST offline_access"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.deleteCredential(oauthAccessTokenSpec)
credentialsStorage.setCredential(oauthRefreshTokenSpec, "refresh-token-123")
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"),
newWiremockMapping("auth/oauth2/refresh_token/successful_flow.json"),
newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) // only refresh token
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123")
assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123a")
})
t.Run("if access token is expired and refresh token is present, should run refresh token flow", func(t *testing.T) {
roundTripper.reset()
cfg := wiremock.connectionConfig()
cfg.OauthScope = "session:role:ANALYST offline_access"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token")
credentialsStorage.setCredential(oauthRefreshTokenSpec, "refresh-token-123")
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"),
newWiremockMapping("auth/oauth2/refresh_token/successful_flow.json"),
newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) // only refresh token
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123")
assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123a")
})
t.Run("if new refresh token is not returned, should keep old one", func(t *testing.T) {
roundTripper.reset()
cfg := wiremock.connectionConfig()
cfg.OauthScope = "session:role:ANALYST offline_access"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token")
credentialsStorage.setCredential(oauthRefreshTokenSpec, "refresh-token-123")
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"),
newWiremockMapping("auth/oauth2/refresh_token/successful_flow_without_new_refresh_token.json"),
newWiremockMapping("auth/oauth2/authorization_code/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) // only refresh token
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123")
assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123")
})
t.Run("if refreshing token failed, run normal flow", func(t *testing.T) {
roundTripper.reset()
cfg := wiremock.connectionConfig()
cfg.OauthScope = "session:role:ANALYST offline_access"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token")
credentialsStorage.setCredential(oauthRefreshTokenSpec, "expired-refresh-token")
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"),
newWiremockMapping("auth/oauth2/refresh_token/invalid_refresh_token.json"),
newWiremockMapping("auth/oauth2/authorization_code/successful_flow_with_offline_access.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 2) // only refresh token fails, then authorization code
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123")
assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123")
})
t.Run("if secure storage is disabled, run normal flow", func(t *testing.T) {
roundTripper.reset()
cfg := wiremock.connectionConfig()
cfg.OauthScope = "session:role:ANALYST offline_access"
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
cfg.OauthRedirectURI = "http://localhost:1234/snowflake/oauth-redirect"
cfg.Transporter = roundTripper
cfg.ClientStoreTemporaryCredential = ConfigBoolFalse
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
credentialsStorage.setCredential(oauthAccessTokenSpec, "old-access-token")
credentialsStorage.setCredential(oauthRefreshTokenSpec, "old-refresh-token")
wiremock.registerMappings(t, newWiremockMapping("auth/oauth2/authorization_code/successful_flow_with_offline_access.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1) // only access token token
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "old-access-token")
assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "old-refresh-token")
})
}
func TestClientCredentialsFlow(t *testing.T) {
if runningOnGithubAction() && runningOnLinux() {
t.Skip("Github blocks writing to file system")
}
skipOnMac(t, "keychain requires password")
currentDefaultAuthorizationCodeProviderFactory := defaultAuthorizationCodeProviderFactory
defer func() {
defaultAuthorizationCodeProviderFactory = currentDefaultAuthorizationCodeProviderFactory
}()
defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider {
return &nonInteractiveAuthorizationCodeProvider{
t: t,
mu: sync.Mutex{},
}
}
roundTripper := newCountingRoundTripper(createTestNoRevocationTransport())
cfg := wiremock.connectionConfig()
cfg.Role = "ANALYST"
cfg.Authenticator = AuthTypeOAuthClientCredentials
cfg.Transporter = roundTripper
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
t.Run("successful flow", func(t *testing.T) {
credentialsStorage.deleteCredential(oauthAccessTokenSpec)
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
})
t.Run("should use cached access token", func(t *testing.T) {
roundTripper.reset()
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
credentialsStorage.deleteCredential(oauthAccessTokenSpec)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
conn1, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn1.Close()
conn2, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn2.Close()
runSmokeQueryWithConn(t, conn1)
runSmokeQueryWithConn(t, conn2)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1)
})
t.Run("should update cache with new token when the old one expired", func(t *testing.T) {
roundTripper.reset()
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"),
newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token")
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1)
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123")
})
t.Run("should not use refresh token, but ask for fresh access token", func(t *testing.T) {
roundTripper.reset()
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"),
newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
credentialsStorage.setCredential(oauthAccessTokenSpec, "expired-token")
credentialsStorage.setCredential(oauthRefreshTokenSpec, "refresh-token-123")
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1)
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123")
assertEqualE(t, credentialsStorage.getCredential(oauthRefreshTokenSpec), "refresh-token-123")
})
t.Run("should not use access token if token cache is disabled", func(t *testing.T) {
roundTripper.reset()
wiremock.registerMappings(t,
newWiremockMapping("auth/oauth2/login_request_with_expired_access_token.json"),
newWiremockMapping("auth/oauth2/client_credentials/successful_flow.json"),
newWiremockMapping("auth/oauth2/login_request.json"),
newWiremockMapping("select1.json"))
credentialsStorage.setCredential(oauthAccessTokenSpec, "access-token-123")
cfg.ClientStoreTemporaryCredential = ConfigBoolFalse
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
runSmokeQuery(t, db)
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1)
assertEqualE(t, credentialsStorage.getCredential(oauthAccessTokenSpec), "access-token-123")
})
}
func TestEligibleForDefaultClientCredentials(t *testing.T) {
tests := []struct {
name string
oauthClient *oauthClient
expected bool
}{
{
name: "Client credentials not supplied and Snowflake as IdP",
oauthClient: &oauthClient{
cfg: &Config{
Host: "example.snowflakecomputing.com",
OauthClientID: "",
OauthClientSecret: "",
OauthAuthorizationURL: "https://example.snowflakecomputing.com/oauth/authorize",
OauthTokenRequestURL: "https://example.snowflakecomputing.com/oauth/token",
},
},
expected: true,
},
{
name: "Client credentials not supplied and empty URLs (defaults to Snowflake)",
oauthClient: &oauthClient{
cfg: &Config{
Host: "example.snowflakecomputing.com",
OauthClientID: "",
OauthClientSecret: "",
OauthAuthorizationURL: "",
OauthTokenRequestURL: "",
},
},
expected: true,
},
{
name: "Client credentials supplied",
oauthClient: &oauthClient{
cfg: &Config{
Host: "example.snowflakecomputing.com",
OauthClientID: "testClientID",
OauthClientSecret: "testClientSecret",
OauthAuthorizationURL: "https://example.snowflakecomputing.com/oauth/authorize",
OauthTokenRequestURL: "https://example.snowflakecomputing.com/oauth/token",
},
},
expected: false,
},
{
name: "Only client ID supplied",
oauthClient: &oauthClient{
cfg: &Config{
Host: "example.snowflakecomputing.com",
OauthClientID: "testClientID",
OauthClientSecret: "",
OauthAuthorizationURL: "https://example.snowflakecomputing.com/oauth/authorize",
OauthTokenRequestURL: "https://example.snowflakecomputing.com/oauth/token",
},
},
expected: false,
},
{
name: "Non-Snowflake IdP",
oauthClient: &oauthClient{
cfg: &Config{
Host: "example.snowflakecomputing.com",
OauthClientID: "",
OauthClientSecret: "",
OauthAuthorizationURL: "https://example.com/oauth/authorize",
OauthTokenRequestURL: "https://example.com/oauth/token",
},
},
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := test.oauthClient.eligibleForDefaultClientCredentials()
if result != test.expected {
t.Errorf("expected %v, got %v", test.expected, result)
}
})
}
}
type nonInteractiveAuthorizationCodeProvider struct {
t *testing.T
tamperWithState bool
triggerError string
responseBody string
mu sync.Mutex
sleepTime time.Duration
}
func (provider *nonInteractiveAuthorizationCodeProvider) run(authorizationURL string) error {
if provider.sleepTime != 0 {
time.Sleep(provider.sleepTime)
if provider.triggerError != "" {
return errors.New(provider.triggerError)
}
}
if provider.triggerError != "" {
return errors.New(provider.triggerError)
}
go func() {
resp, err := http.Get(authorizationURL)
assertNilF(provider.t, err)
assertEqualE(provider.t, resp.StatusCode, http.StatusOK)
respBody, err := io.ReadAll(resp.Body)
assertNilF(provider.t, err)
provider.mu.Lock()
defer provider.mu.Unlock()
provider.responseBody = string(respBody)
}()
return nil
}
func (provider *nonInteractiveAuthorizationCodeProvider) createState() string {
if provider.tamperWithState {
return "invalidState"
}
return "testState"
}
func (provider *nonInteractiveAuthorizationCodeProvider) createCodeVerifier() string {
return "testCodeVerifier"
}
func (provider *nonInteractiveAuthorizationCodeProvider) assertResponseBodyContains(str string) {
provider.mu.Lock()
defer provider.mu.Unlock()
assertStringContainsE(provider.t, provider.responseBody, str)
}
================================================
FILE: auth_test.go
================================================
package gosnowflake
import (
"cmp"
"context"
"crypto/rand"
"crypto/rsa"
"database/sql"
"encoding/json"
"errors"
"fmt"
sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
func TestUnitPostAuth(t *testing.T) {
sr := &snowflakeRestful{
TokenAccessor: getSimpleTokenAccessor(),
FuncAuthPost: postAuthTestAfterRenew,
}
var err error
bodyCreator := func() ([]byte, error) {
return []byte{0x12, 0x34}, nil
}
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err != nil {
t.Fatalf("err: %v", err)
}
sr.FuncAuthPost = postAuthTestError
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppBadGatewayError
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppForbiddenError
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppUnexpectedError
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
}
func postAuthFailServiceIssue(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return nil, &SnowflakeError{
Number: ErrCodeServiceUnavailable,
}
}
func postAuthFailWrongAccount(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return nil, &SnowflakeError{
Number: ErrCodeFailedToConnect,
}
}
func postAuthFailUnknown(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return nil, &SnowflakeError{
Number: ErrFailedToAuth,
}
}
func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Code: "98765",
Message: "wrong!",
}, nil
}
func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Code: "abcdef",
Message: "wrong!",
}, nil
}
func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func postAuthCheckSAMLResponse(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, err := bodyCreator()
if err != nil {
return nil, err
}
if err = json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.RawSAMLResponse == "" {
return nil, errors.New("SAML response is empty")
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
// Checks that the request body generated when authenticating with OAuth
// contains all the necessary values.
func postAuthCheckOAuth(
_ context.Context,
_ *snowflakeRestful,
_ *http.Client,
_ *url.Values, _ map[string]string,
bodyCreator bodyCreatorType,
_ time.Duration,
) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.Authenticator != AuthTypeOAuth.String() {
return nil, errors.New("Authenticator is not OAUTH")
}
if ar.Data.Token == "" {
return nil, errors.New("Token is empty")
}
if ar.Data.LoginName == "" {
return nil, errors.New("Login name is empty")
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.Passcode != "987654321" || ar.Data.ExtAuthnDuoMethod != "passcode" {
return nil, fmt.Errorf("passcode didn't match. expected: 987654321, got: %v, duo: %v", ar.Data.Passcode, ar.Data.ExtAuthnDuoMethod)
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.Passcode != "" || ar.Data.ExtAuthnDuoMethod != "passcode" {
return nil, fmt.Errorf("passcode must be empty, got: %v, duo: %v", ar.Data.Passcode, ar.Data.ExtAuthnDuoMethod)
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"] != true {
return nil, fmt.Errorf("expected client_request_mfa_token to be true but was %v", ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"])
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
MfaToken: "mockedMfaToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.Token != "mockedMfaToken" {
return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token)
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
MfaToken: "mockedMfaToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.Token != "mockedMfaToken" {
return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token)
}
return &authResponse{
Success: false,
Data: authResponseMain{},
Message: "auth failed",
Code: "260008",
}, nil
}
func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true {
return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"])
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
IDToken: "mockedIDToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.Token != "mockedIDToken" {
return nil, fmt.Errorf("unexpected mfatoken: %v", ar.Data.Token)
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
IDToken: "mockedIDToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true {
return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"])
}
return &authResponse{
Success: false,
Data: authResponseMain{},
Message: "auth failed",
Code: "260008",
}, nil
}
type restfulTestWrapper struct {
t *testing.T
}
func (rtw restfulTestWrapper) postAuthOktaWithNewToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
cfg := &Config{
Authenticator: AuthTypeOkta,
}
// Retry 3 times and success
client := &fakeHTTPClient{
cnt: 3,
success: true,
statusCode: 429,
t: rtw.t,
}
urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_guid=testguid")
if err != nil {
return &authResponse{}, err
}
body := func() ([]byte, error) {
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
return jsonBody, err
}
_, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, cfg).doPost().setBodyCreator(body).execute()
if err != nil {
return &authResponse{}, err
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
MfaToken: "mockedMfaToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
func getDefaultSnowflakeConn() *snowflakeConn {
sc := &snowflakeConn{
rest: &snowflakeRestful{
TokenAccessor: getSimpleTokenAccessor(),
},
cfg: &Config{
Account: "a",
User: "u",
Password: "p",
Database: "d",
Schema: "s",
Warehouse: "w",
Role: "r",
Region: "",
PasscodeInPassword: false,
Passcode: "",
Application: "testapp",
},
telemetry: &snowflakeTelemetry{enabled: false},
}
return sc
}
func TestUnitAuthenticateWithTokenAccessor(t *testing.T) {
expectedSessionID := int64(123)
expectedMasterToken := "master_token"
expectedToken := "auth_token"
ta := getSimpleTokenAccessor()
ta.SetTokens(expectedToken, expectedMasterToken, expectedSessionID)
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeTokenAccessor
sc.cfg.TokenAccessor = ta
sr := &snowflakeRestful{
FuncPostAuth: postAuthFailServiceIssue,
TokenAccessor: ta,
}
sc.rest = sr
// FuncPostAuth is set to fail, but AuthTypeTokenAccessor should not even make a call to FuncPostAuth
resp, err := authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("should not have failed, err %v", err)
}
if resp.SessionID != expectedSessionID {
t.Fatalf("Expected session id %v but got %v", expectedSessionID, resp.SessionID)
}
if resp.Token != expectedToken {
t.Fatalf("Expected token %v but got %v", expectedToken, resp.Token)
}
if resp.MasterToken != expectedMasterToken {
t.Fatalf("Expected master token %v but got %v", expectedMasterToken, resp.MasterToken)
}
if resp.SessionInfo.DatabaseName != sc.cfg.Database {
t.Fatalf("Expected database %v but got %v", sc.cfg.Database, resp.SessionInfo.DatabaseName)
}
if resp.SessionInfo.WarehouseName != sc.cfg.Warehouse {
t.Fatalf("Expected warehouse %v but got %v", sc.cfg.Warehouse, resp.SessionInfo.WarehouseName)
}
if resp.SessionInfo.RoleName != sc.cfg.Role {
t.Fatalf("Expected role %v but got %v", sc.cfg.Role, resp.SessionInfo.RoleName)
}
if resp.SessionInfo.SchemaName != sc.cfg.Schema {
t.Fatalf("Expected schema %v but got %v", sc.cfg.Schema, resp.SessionInfo.SchemaName)
}
}
func TestUnitAuthenticate(t *testing.T) {
var err error
var driverErr *SnowflakeError
var ok bool
ta := getSimpleTokenAccessor()
sc := getDefaultSnowflakeConn()
sr := &snowflakeRestful{
FuncPostAuth: postAuthFailServiceIssue,
TokenAccessor: ta,
}
sc.rest = sr
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
driverErr, ok = err.(*SnowflakeError)
if !ok || driverErr.Number != ErrCodeServiceUnavailable {
t.Fatalf("Snowflake error is expected. err: %v", driverErr)
}
sr.FuncPostAuth = postAuthFailWrongAccount
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
driverErr, ok = err.(*SnowflakeError)
if !ok || driverErr.Number != ErrCodeFailedToConnect {
t.Fatalf("Snowflake error is expected. err: %v", driverErr)
}
sr.FuncPostAuth = postAuthFailUnknown
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
driverErr, ok = err.(*SnowflakeError)
if !ok || driverErr.Number != ErrFailedToAuth {
t.Fatalf("Snowflake error is expected. err: %v", driverErr)
}
ta.SetTokens("bad-token", "bad-master-token", 1)
sr.FuncPostAuth = postAuthSuccessWithErrorCode
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
newToken, newMasterToken, newSessionID := ta.GetTokens()
if newToken != "" || newMasterToken != "" || newSessionID != -1 {
t.Fatalf("failed auth should have reset tokens: %v %v %v", newToken, newMasterToken, newSessionID)
}
driverErr, ok = err.(*SnowflakeError)
if !ok || driverErr.Number != 98765 {
t.Fatalf("Snowflake error is expected. err: %v", driverErr)
}
ta.SetTokens("bad-token", "bad-master-token", 1)
sr.FuncPostAuth = postAuthSuccessWithInvalidErrorCode
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
oldToken, oldMasterToken, oldSessionID := ta.GetTokens()
if oldToken != "" || oldMasterToken != "" || oldSessionID != -1 {
t.Fatalf("failed auth should have reset tokens: %v %v %v", oldToken, oldMasterToken, oldSessionID)
}
sr.FuncPostAuth = postAuthSuccess
var resp *authResponseMain
resp, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to auth. err: %v", err)
}
if resp.SessionInfo.DatabaseName != "dbn" {
t.Fatalf("failed to get response from auth")
}
newToken, newMasterToken, newSessionID = ta.GetTokens()
if newToken == oldToken {
t.Fatalf("new token was not set: %v", newToken)
}
if newMasterToken == oldMasterToken {
t.Fatalf("new master token was not set: %v", newMasterToken)
}
if newSessionID == oldSessionID {
t.Fatalf("new session id was not set: %v", newSessionID)
}
}
func TestUnitAuthenticateSaml(t *testing.T) {
var err error
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthSAMLAuthSuccess,
FuncPostAuthOKTA: postAuthOKTASuccess,
FuncGetSSO: getSSOSuccess,
FuncPostAuth: postAuthCheckSAMLResponse,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeOkta
sc.cfg.OktaURL = &url.URL{
Scheme: "https",
Host: "abc.com",
}
sc.rest = sr
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
assertNilF(t, err, "failed to run.")
}
// Unit test for OAuth.
func TestUnitAuthenticateOAuth(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckOAuth,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Token = "oauthToken"
sc.cfg.Authenticator = AuthTypeOAuth
sc.rest = sr
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
}
func TestUnitAuthenticatePasscode(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckPasscode,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Passcode = "987654321"
sc.rest = sr
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
sr.FuncPostAuth = postAuthCheckPasscodeInPassword
sc.rest = sr
sc.cfg.PasscodeInPassword = true
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
}
// Test JWT function in the local environment against the validation function in go
func TestUnitAuthenticateJWT(t *testing.T) {
var err error
// Generate a fresh private key for this unit test only
localTestKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate test private key: %s", err.Error())
}
// Create custom JWT verification function that uses the local key
postAuthCheckLocalJWTToken := func(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
if ar.Data.Authenticator != AuthTypeJwt.String() {
return nil, errors.New("Authenticator is not JWT")
}
tokenString := ar.Data.Token
// Validate token using the local test key's public key
_, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
return localTestKey.Public(), nil // Use local key for verification
})
if err != nil {
return nil, err
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckLocalJWTToken, // Use local verification function
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeJwt
sc.cfg.JWTExpireTimeout = time.Duration(sfconfig.DefaultJWTTimeout)
sc.cfg.PrivateKey = localTestKey
sc.rest = sr
// A valid JWT token should pass
if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err != nil {
t.Fatalf("failed to run. err: %v", err)
}
// An invalid JWT token should not pass
invalidPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Error(err)
}
sc.cfg.PrivateKey = invalidPrivateKey
if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err == nil {
t.Fatalf("invalid token passed")
}
}
func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckUsernamePasswordMfa,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
sc.rest = sr
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaToken
sc.mfaToken = "mockedMfaToken"
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaFailed
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed")
}
}
func TestUnitAuthenticateWithConfigMFA(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckUsernamePasswordMfa,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
sc.rest = sr
sc.ctx = context.Background()
err = authenticateWithConfig(sc)
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
}
// This test creates two groups of scenarios:
// a) singleAuthenticationPrompt=true - in this case, we start authenticating threads at once,
// but due to locking mechanism only one should reach wiremock without MFA token.
// b) singleAuthenticationPrompt=false - in this case, there is no locking, so all threads should rush,
// but on Wiremock only first will be served with correct response (simulating a user confirming MFA only once).
// The remaining threads should return error.
func TestMfaParallelLogin(t *testing.T) {
skipOnMissingHome(t)
skipOnMac(t, "interactive keyring access not available on macOS runners")
cfg := wiremock.connectionConfig()
tokenSpec := newMfaTokenSpec(cfg.Host, cfg.User)
for _, singleAuthenticationPrompt := range []ConfigBool{ConfigBoolTrue, ConfigBoolFalse} {
t.Run("starts without mfa token, singleAuthenticationPrompt="+singleAuthenticationPrompt.String(), func(t *testing.T) {
wiremock.registerMappings(t, newWiremockMapping("auth/mfa/parallel_login_successful_flow.json"),
newWiremockMapping("select1.json"),
newWiremockMapping("close_session.json"))
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypeUsernamePasswordMFA
cfg.SingleAuthenticationPrompt = singleAuthenticationPrompt
cfg.ClientRequestMfaToken = ConfigBoolTrue
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
defer db.Close()
credentialsStorage.deleteCredential(tokenSpec)
errs := initPoolWithSizeAndReturnErrors(db, 20)
if singleAuthenticationPrompt == ConfigBoolTrue {
assertEqualE(t, len(errs), 0)
} else {
// most of for the one that actually retrieves MFA token should fail
assertEqualE(t, len(errs), 19)
}
})
t.Run("starts without mfa token, first attempt fails, singleAuthenticationPrompt="+singleAuthenticationPrompt.String(), func(t *testing.T) {
wiremock.registerMappings(t, newWiremockMapping("auth/mfa/parallel_login_first_fails_then_successful_flow.json"),
newWiremockMapping("select1.json"),
newWiremockMapping("close_session.json"))
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypeUsernamePasswordMFA
cfg.SingleAuthenticationPrompt = singleAuthenticationPrompt
cfg.ClientRequestMfaToken = ConfigBoolTrue
credentialsStorage.deleteCredential(tokenSpec)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
defer db.Close()
errs := initPoolWithSizeAndReturnErrors(db, 20)
if singleAuthenticationPrompt == ConfigBoolTrue {
assertEqualF(t, len(errs), 1)
assertStringContainsE(t, errs[0].Error(), "MFA with TOTP is required")
} else {
assertEqualE(t, len(errs), 19)
}
})
}
}
func TestUnitAuthenticateWithConfigOkta(t *testing.T) {
var err error
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthSAMLAuthSuccess,
FuncPostAuthOKTA: postAuthOKTASuccess,
FuncGetSSO: getSSOSuccess,
FuncPostAuth: postAuthCheckSAMLResponse,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeOkta
sc.cfg.OktaURL = &url.URL{
Scheme: "https",
Host: "abc.com",
}
sc.rest = sr
sc.ctx = context.Background()
err = authenticateWithConfig(sc)
assertNilE(t, err, "expected to have no error.")
sr.FuncPostAuthSAML = postAuthSAMLError
err = authenticateWithConfig(sc)
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
assertEqualE(t, err.Error(), "failed to get SAML response")
}
func TestUnitAuthenticateWithExternalBrowserParallel(t *testing.T) {
skipOnMissingHome(t)
skipOnMac(t, "interactive keyring access not available on macOS runners")
t.Run("no ID token cached", func(t *testing.T) {
origSamlResponseProvider := defaultSamlResponseProvider
defer func() { defaultSamlResponseProvider = origSamlResponseProvider }()
defaultSamlResponseProvider = func() samlResponseProvider {
return &nonInteractiveSamlResponseProvider{t: t}
}
wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/successful_flow.json"),
newWiremockMapping("select1.json"),
newWiremockMapping("close_session.json"))
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypeExternalBrowser
cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
connector := NewConnector(SnowflakeDriver{}, *cfg)
credentialsStorage.deleteCredential(newIDTokenSpec(cfg.Host, cfg.User))
db := sql.OpenDB(connector)
defer db.Close()
runSmokeQuery(t, db)
assertEqualE(t, credentialsStorage.getCredential(newIDTokenSpec(cfg.Host, cfg.User)), "test-id-token")
})
t.Run("ID token cached", func(t *testing.T) {
wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/successful_flow.json"),
newWiremockMapping("select1.json"),
newWiremockMapping("close_session.json"))
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypeExternalBrowser
cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
connector := NewConnector(SnowflakeDriver{}, *cfg)
credentialsStorage.setCredential(newIDTokenSpec(cfg.Host, cfg.User), "test-id-token")
db := sql.OpenDB(connector)
defer db.Close()
runSmokeQuery(t, db)
})
t.Run("first connection retrieves ID token, second request uses cached ID token", func(t *testing.T) {
origSamlResponseProvider := defaultSamlResponseProvider
defer func() { defaultSamlResponseProvider = origSamlResponseProvider }()
defaultSamlResponseProvider = func() samlResponseProvider {
return &nonInteractiveSamlResponseProvider{t: t}
}
wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/parallel_login_successful_flow.json"),
newWiremockMapping("select1.json"),
newWiremockMapping("close_session.json"))
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypeExternalBrowser
cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
connector := NewConnector(SnowflakeDriver{}, *cfg)
credentialsStorage.deleteCredential(newIDTokenSpec(cfg.Host, cfg.User))
db := sql.OpenDB(connector)
defer db.Close()
conn1, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn1.Close()
runSmokeQueryWithConn(t, conn1)
conn2, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn2.Close()
runSmokeQueryWithConn(t, conn2)
})
t.Run("first connection retrieves ID token, remaining ones wait and reuse", func(t *testing.T) {
origSamlResponseProvider := defaultSamlResponseProvider
defer func() { defaultSamlResponseProvider = origSamlResponseProvider }()
defaultSamlResponseProvider = func() samlResponseProvider {
return &nonInteractiveSamlResponseProvider{t: t}
}
wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/parallel_login_successful_flow.json"),
newWiremockMapping("select1.json"),
newWiremockMapping("close_session.json"))
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypeExternalBrowser
cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
connector := NewConnector(SnowflakeDriver{}, *cfg)
credentialsStorage.deleteCredential(newIDTokenSpec(cfg.Host, cfg.User))
db := sql.OpenDB(connector)
defer db.Close()
errs := initPoolWithSizeAndReturnErrors(db, 20)
assertEqualE(t, len(errs), 0)
})
t.Run("first connection fails, second retrieves ID token, remaining ones wait and reuse", func(t *testing.T) {
origSamlResponseProvider := defaultSamlResponseProvider
defer func() { defaultSamlResponseProvider = origSamlResponseProvider }()
defaultSamlResponseProvider = func() samlResponseProvider {
return &nonInteractiveSamlResponseProvider{t: t}
}
wiremock.registerMappings(t, newWiremockMapping("auth/external_browser/parallel_login_first_fails_then_successful_flow.json"),
newWiremockMapping("select1.json"),
newWiremockMapping("close_session.json"))
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypeExternalBrowser
cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
connector := NewConnector(SnowflakeDriver{}, *cfg)
credentialsStorage.deleteCredential(newIDTokenSpec(cfg.Host, cfg.User))
db := sql.OpenDB(connector)
defer db.Close()
errs := initPoolWithSizeAndReturnErrors(db, 20)
assertEqualE(t, len(errs), 1)
})
}
func TestUnitAuthenticateWithConfigExternalBrowserWithFailedSAMLResponse(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuthSAML: postAuthSAMLError,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeExternalBrowser
sc.cfg.ExternalBrowserTimeout = time.Duration(sfconfig.DefaultExternalBrowserTimeout)
sc.rest = sr
sc.ctx = context.Background()
err = authenticateWithConfig(sc)
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
assertEqualE(t, err.Error(), "failed to get SAML response")
}
func TestUnitAuthenticateExternalBrowser(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckExternalBrowser,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeExternalBrowser
sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
sc.rest = sr
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
sr.FuncPostAuth = postAuthCheckExternalBrowserToken
sc.idToken = "mockedIDToken"
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
sr.FuncPostAuth = postAuthCheckExternalBrowserFailed
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed")
}
}
// To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled
// Set any other snowflake_test variables needed for database, schema, role for this user
func TestUsernamePasswordMfaCaching(t *testing.T) {
t.Skip("manual test for MFA token caching")
config, err := ParseDSN(dsn)
if err != nil {
t.Fatal("Failed to parse dsn")
}
// connect with MFA authentication
user := os.Getenv("SNOWFLAKE_TEST_MFA_USER")
password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD")
config.User = user
config.Password = password
config.Authenticator = AuthTypeUsernamePasswordMFA
if runtime.GOOS == "linux" {
config.ClientRequestMfaToken = ConfigBoolTrue
}
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
for range 3 {
// should only be prompted to authenticate first time around.
_, err := db.Query("select current_user()")
if err != nil {
t.Fatal(err)
}
}
}
func TestUsernamePasswordMfaCachingWithPasscode(t *testing.T) {
t.Skip("manual test for MFA token caching")
config, err := ParseDSN(dsn)
if err != nil {
t.Fatal("Failed to parse dsn")
}
// connect with MFA authentication
user := os.Getenv("SNOWFLAKE_TEST_MFA_USER")
password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD")
config.User = user
config.Password = password
config.Passcode = "" // fill with your passcode from DUO app
config.Authenticator = AuthTypeUsernamePasswordMFA
if runtime.GOOS == "linux" {
config.ClientRequestMfaToken = ConfigBoolTrue
}
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
for range 3 {
// should only be prompted to authenticate first time around.
_, err := db.Query("select current_user()")
if err != nil {
t.Fatal(err)
}
}
}
func TestUsernamePasswordMfaCachingWithPasscodeInPassword(t *testing.T) {
t.Skip("manual test for MFA token caching")
config, err := ParseDSN(dsn)
if err != nil {
t.Fatal("Failed to parse dsn")
}
// connect with MFA authentication
user := os.Getenv("SNOWFLAKE_TEST_MFA_USER")
password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD")
config.User = user
config.Password = password + "" // fill with your passcode from DUO app
config.PasscodeInPassword = true
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
for range 3 {
// should only be prompted to authenticate first time around.
_, err := db.Query("select current_user()")
if err != nil {
t.Fatal(err)
}
}
}
// To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled
// Set any other snowflake_test variables needed for database, schema, role for this user
func TestDisableUsernamePasswordMfaCaching(t *testing.T) {
t.Skip("manual test for disabling MFA token caching")
config, err := ParseDSN(dsn)
if err != nil {
t.Fatal("Failed to parse dsn")
}
// connect with MFA authentication
user := os.Getenv("SNOWFLAKE_TEST_MFA_USER")
password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD")
config.User = user
config.Password = password
config.Authenticator = AuthTypeUsernamePasswordMFA
// disable MFA token caching
config.ClientRequestMfaToken = ConfigBoolFalse
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
for range 3 {
// should be prompted to authenticate 3 times.
_, err := db.Query("select current_user()")
if err != nil {
t.Fatal(err)
}
}
}
// To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user
// Set any other snowflake_test variables needed for database, schema, role for this user
func TestExternalBrowserCaching(t *testing.T) {
t.Skip("manual test for external browser token caching")
config, err := ParseDSN(dsn)
if err != nil {
t.Fatal("Failed to parse dsn")
}
// connect with external browser authentication
user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER")
config.User = user
config.Authenticator = AuthTypeExternalBrowser
if runtime.GOOS == "linux" {
config.ClientStoreTemporaryCredential = ConfigBoolTrue
}
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
for range 3 {
// should only be prompted to authenticate first time around.
_, err := db.Query("select current_user()")
if err != nil {
t.Fatal(err)
}
}
}
// To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user
// Set any other snowflake_test variables needed for database, schema, role for this user
func TestDisableExternalBrowserCaching(t *testing.T) {
t.Skip("manual test for disabling external browser token caching")
config, err := ParseDSN(dsn)
if err != nil {
t.Fatal("Failed to parse dsn")
}
// connect with external browser authentication
user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER")
config.User = user
config.Authenticator = AuthTypeExternalBrowser
// disable external browser token caching
config.ClientStoreTemporaryCredential = ConfigBoolFalse
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
for range 3 {
// should be prompted to authenticate 3 times.
_, err := db.Query("select current_user()")
if err != nil {
t.Fatal(err)
}
}
}
func TestOktaRetryWithNewToken(t *testing.T) {
expectedMasterToken := "m"
expectedToken := "t"
expectedMfaToken := "mockedMfaToken"
expectedDatabaseName := "dbn"
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthSAMLAuthSuccess,
FuncPostAuthOKTA: postAuthOKTASuccess,
FuncGetSSO: getSSOSuccess,
FuncPostAuth: restfulTestWrapper{t: t}.postAuthOktaWithNewToken,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeOkta
sc.cfg.OktaURL = &url.URL{
Scheme: "https",
Host: "abc.com",
}
sc.rest = sr
sc.ctx = context.Background()
authResponse, err := authenticate(context.Background(), sc, []byte{0x12, 0x34}, []byte{0x56, 0x78})
assertNilF(t, err, "should not have failed to run authenticate()")
assertEqualF(t, authResponse.MasterToken, expectedMasterToken)
assertEqualF(t, authResponse.Token, expectedToken)
assertEqualF(t, authResponse.MfaToken, expectedMfaToken)
assertEqualF(t, authResponse.SessionInfo.DatabaseName, expectedDatabaseName)
}
func TestContextPropagatedToAuthWhenUsingOpen(t *testing.T) {
db, err := sql.Open("snowflake", dsn)
assertNilF(t, err)
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
_, err = db.QueryContext(ctx, "SELECT 1")
assertNotNilF(t, err)
assertStringContainsE(t, err.Error(), "context deadline exceeded")
cancel()
}
func TestContextPropagatedToAuthWhenUsingOpenDB(t *testing.T) {
cfg, err := ParseDSN(dsn)
assertNilF(t, err)
connector := NewConnector(&SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
_, err = db.QueryContext(ctx, "SELECT 1")
assertNotNilF(t, err)
assertStringContainsE(t, err.Error(), "context deadline exceeded")
cancel()
}
func TestPatSuccessfulFlow(t *testing.T) {
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.Token = "some PAT"
wiremock.registerMappings(t,
wiremockMapping{filePath: "auth/pat/successful_flow.json"},
wiremockMapping{filePath: "select1.json"},
)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
rows, err := db.Query("SELECT 1")
assertNilF(t, err)
var v int
assertTrueE(t, rows.Next())
assertNilF(t, rows.Scan(&v))
assertEqualE(t, v, 1)
}
func TestPatTokenRotation(t *testing.T) {
dir := t.TempDir()
tokenFilePath := filepath.Join(dir, "tokenFile")
assertNilF(t, os.WriteFile(tokenFilePath, []byte("some PAT"), 0644))
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.TokenFilePath = tokenFilePath
wiremock.registerMappings(t,
wiremockMapping{filePath: "auth/pat/reading_fresh_token.json"},
)
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
_, err := db.Conn(context.Background())
assertNilF(t, err)
assertNilF(t, os.WriteFile(tokenFilePath, []byte("some PAT 2"), 0644))
_, err = db.Conn(context.Background())
assertNilF(t, err)
}
func TestPatInvalidToken(t *testing.T) {
wiremock.registerMappings(t,
wiremockMapping{filePath: "auth/pat/invalid_token.json"},
)
cfg := wiremock.connectionConfig()
cfg.Authenticator = AuthTypePat
cfg.Token = "some PAT"
connector := NewConnector(SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
_, err := db.Query("SELECT 1")
assertNotNilF(t, err)
var se *SnowflakeError
assertErrorsAsF(t, err, &se)
assertEqualE(t, se.Number, 394400)
assertEqualE(t, se.Message, "Programmatic access token is invalid.")
}
func TestWithOauthAuthorizationCodeFlowManual(t *testing.T) {
t.Skip("manual test")
for _, provider := range []string{"OKTA", "SNOWFLAKE"} {
t.Run(provider, func(t *testing.T) {
cfg, err := GetConfigFromEnv([]*ConfigParam{
{Name: "OAuthClientId", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_CLIENT_ID", FailOnMissing: true},
{Name: "OAuthClientSecret", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_CLIENT_SECRET", FailOnMissing: true},
{Name: "OAuthAuthorizationURL", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_AUTHORIZATION_URL", FailOnMissing: false},
{Name: "OAuthTokenRequestURL", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_TOKEN_REQUEST_URL", FailOnMissing: false},
{Name: "OAuthRedirectURI", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_REDIRECT_URI", FailOnMissing: false},
{Name: "OAuthScope", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_SCOPE", FailOnMissing: false},
{Name: "User", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_USER", FailOnMissing: true},
{Name: "Role", EnvName: "SNOWFLAKE_TEST_OAUTH_" + provider + "_ROLE", FailOnMissing: true},
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
})
assertNilF(t, err)
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
tokenRequestURL := cmp.Or(cfg.OauthTokenRequestURL, fmt.Sprintf("https://%v.snowflakecomputing.com:443/oauth/token-request", cfg.Account))
credentialsStorage.deleteCredential(newOAuthAccessTokenSpec(tokenRequestURL, cfg.User))
credentialsStorage.deleteCredential(newOAuthRefreshTokenSpec(tokenRequestURL, cfg.User))
connector := NewConnector(&SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
defer db.Close()
conn1, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn1.Close()
runSmokeQueryWithConn(t, conn1)
conn2, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn2.Close()
runSmokeQueryWithConn(t, conn2)
credentialsStorage.setCredential(newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User), "expired-token")
conn3, err := db.Conn(context.Background())
assertNilF(t, err)
defer conn3.Close()
runSmokeQueryWithConn(t, conn3)
})
}
}
func TestWithOAuthClientCredentialsFlowManual(t *testing.T) {
t.Skip("manual test")
cfg, err := GetConfigFromEnv([]*ConfigParam{
{Name: "OAuthClientId", EnvName: "SNOWFLAKE_TEST_OAUTH_OKTA_CLIENT_ID", FailOnMissing: true},
{Name: "OAuthClientSecret", EnvName: "SNOWFLAKE_TEST_OAUTH_OKTA_CLIENT_SECRET", FailOnMissing: true},
{Name: "OAuthTokenRequestURL", EnvName: "SNOWFLAKE_TEST_OAUTH_OKTA_TOKEN_REQUEST_URL", FailOnMissing: true},
{Name: "Role", EnvName: "SNOWFLAKE_TEST_OAUTH_OKTA_ROLE", FailOnMissing: true},
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
})
assertNilF(t, err)
cfg.Authenticator = AuthTypeOAuthClientCredentials
connector := NewConnector(&SnowflakeDriver{}, *cfg)
db := sql.OpenDB(connector)
defer db.Close()
runSmokeQuery(t, db)
}
================================================
FILE: auth_wif.go
================================================
package gosnowflake
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/golang-jwt/jwt/v5"
sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config"
)
const (
awsWif wifProviderType = "AWS"
gcpWif wifProviderType = "GCP"
azureWif wifProviderType = "AZURE"
oidcWif wifProviderType = "OIDC"
gcpMetadataFlavorHeaderName = "Metadata-Flavor"
gcpMetadataFlavor = "Google"
defaultMetadataServiceBase = "http://169.254.169.254"
defaultGcpIamCredentialsBase = "https://iamcredentials.googleapis.com"
snowflakeAudience = "snowflakecomputing.com"
)
type wifProviderType string
type wifAttestation struct {
ProviderType string `json:"providerType"`
Credential string `json:"credential"`
Metadata map[string]string `json:"metadata"`
}
type wifAttestationCreator interface {
createAttestation() (*wifAttestation, error)
}
type wifAttestationProvider struct {
context context.Context
cfg *Config
awsCreator wifAttestationCreator
gcpCreator wifAttestationCreator
azureCreator wifAttestationCreator
oidcCreator wifAttestationCreator
}
func createWifAttestationProvider(ctx context.Context, cfg *Config, telemetry *snowflakeTelemetry) *wifAttestationProvider {
return &wifAttestationProvider{
context: ctx,
cfg: cfg,
awsCreator: &awsIdentityAttestationCreator{
cfg: cfg,
attestationServiceFactory: createDefaultAwsAttestationMetadataProvider,
ctx: ctx,
},
gcpCreator: &gcpIdentityAttestationCreator{
cfg: cfg,
telemetry: telemetry,
metadataServiceBaseURL: defaultMetadataServiceBase,
iamCredentialsURL: defaultGcpIamCredentialsBase,
},
azureCreator: &azureIdentityAttestationCreator{
azureAttestationMetadataProvider: &defaultAzureAttestationMetadataProvider{},
cfg: cfg,
telemetry: telemetry,
workloadIdentityEntraResource: determineEntraResource(cfg),
azureMetadataServiceBaseURL: defaultMetadataServiceBase,
},
oidcCreator: &oidcIdentityAttestationCreator{token: func() (string, error) { return sfconfig.GetToken(cfg) }},
}
}
func (p *wifAttestationProvider) getAttestation(identityProvider string) (*wifAttestation, error) {
switch strings.ToUpper(identityProvider) {
case string(awsWif):
return p.awsCreator.createAttestation()
case string(gcpWif):
return p.gcpCreator.createAttestation()
case string(azureWif):
return p.azureCreator.createAttestation()
case string(oidcWif):
return p.oidcCreator.createAttestation()
default:
return nil, fmt.Errorf("unknown WorkloadIdentityProvider specified: %s. Valid values are: %s, %s, %s, %s", identityProvider, awsWif, gcpWif, azureWif, oidcWif)
}
}
type awsAttestastationMetadataProviderFactory func(ctx context.Context, cfg *Config) awsAttestationMetadataProvider
type awsIdentityAttestationCreator struct {
cfg *Config
attestationServiceFactory awsAttestastationMetadataProviderFactory
ctx context.Context
}
type gcpIdentityAttestationCreator struct {
cfg *Config
telemetry *snowflakeTelemetry
metadataServiceBaseURL string
iamCredentialsURL string
}
type oidcIdentityAttestationCreator struct {
token func() (string, error)
}
type awsAttestationMetadataProvider interface {
awsCredentials() (aws.Credentials, error)
awsCredentialsViaRoleChaining() (aws.Credentials, error)
awsRegion() string
}
type defaultAwsAttestationMetadataProvider struct {
ctx context.Context
cfg *Config
awsCfg aws.Config
}
func createDefaultAwsAttestationMetadataProvider(ctx context.Context, cfg *Config) awsAttestationMetadataProvider {
awsCfg, err := config.LoadDefaultConfig(ctx, config.WithEC2IMDSRegion())
if err != nil {
logger.Debugf("Unable to load AWS config: %v", err)
return nil
}
return &defaultAwsAttestationMetadataProvider{
awsCfg: awsCfg,
cfg: cfg,
ctx: ctx,
}
}
func (s *defaultAwsAttestationMetadataProvider) awsCredentials() (aws.Credentials, error) {
return s.awsCfg.Credentials.Retrieve(s.ctx)
}
func (s *defaultAwsAttestationMetadataProvider) awsCredentialsViaRoleChaining() (aws.Credentials, error) {
creds, err := s.awsCredentials()
if err != nil {
return aws.Credentials{}, err
}
for _, roleArn := range s.cfg.WorkloadIdentityImpersonationPath {
if creds, err = s.assumeRole(creds, roleArn); err != nil {
return aws.Credentials{}, err
}
}
return creds, nil
}
func (s *defaultAwsAttestationMetadataProvider) assumeRole(creds aws.Credentials, roleArn string) (aws.Credentials, error) {
logger.Debugf("assuming role %v", roleArn)
awsCfg := s.awsCfg
awsCfg.Credentials = credentials.StaticCredentialsProvider{Value: creds}
awsCfg.Region = s.awsRegion()
stsClient := sts.NewFromConfig(awsCfg)
role, err := stsClient.AssumeRole(s.ctx, &sts.AssumeRoleInput{
RoleArn: aws.String(roleArn),
RoleSessionName: aws.String("identity-federation-session"),
})
if err != nil {
logger.Debugf("failed to assume role %v: %v", roleArn, err)
return aws.Credentials{}, err
}
return aws.Credentials{
AccessKeyID: *role.Credentials.AccessKeyId,
SecretAccessKey: *role.Credentials.SecretAccessKey,
SessionToken: *role.Credentials.SessionToken,
Expires: *role.Credentials.Expiration,
}, nil
}
func (s *defaultAwsAttestationMetadataProvider) awsRegion() string {
return s.awsCfg.Region
}
func (c *awsIdentityAttestationCreator) createAttestation() (*wifAttestation, error) {
logger.Debug("Creating AWS identity attestation...")
attestationService := c.attestationServiceFactory(c.ctx, c.cfg)
if attestationService == nil {
return nil, errors.New("AWS attestation service could not be created")
}
var creds aws.Credentials
var err error
if len(c.cfg.WorkloadIdentityImpersonationPath) == 0 {
if creds, err = attestationService.awsCredentials(); err != nil {
logger.Debugf("error while getting for aws credentials. %v", err)
return nil, err
}
} else {
if creds, err = attestationService.awsCredentialsViaRoleChaining(); err != nil {
logger.Debugf("error while getting for aws credentials via role chaining. %v", err)
return nil, err
}
}
if creds.AccessKeyID == "" || creds.SecretAccessKey == "" {
return nil, fmt.Errorf("no AWS credentials were found")
}
region := attestationService.awsRegion()
if region == "" {
return nil, fmt.Errorf("no AWS region was found")
}
stsHostname := stsHostname(region)
req, err := c.createStsRequest(stsHostname)
if err != nil {
return nil, err
}
err = c.signRequestWithSigV4(c.ctx, req, creds, region)
if err != nil {
return nil, err
}
credential, err := c.createBase64EncodedRequestCredential(req)
if err != nil {
return nil, err
}
return &wifAttestation{
ProviderType: string(awsWif),
Credential: credential,
Metadata: map[string]string{},
}, nil
}
func stsHostname(region string) string {
var domain string
if strings.HasPrefix(region, "cn-") {
domain = "amazonaws.com.cn"
} else {
domain = "amazonaws.com"
}
return fmt.Sprintf("sts.%s.%s", region, domain)
}
func (c *awsIdentityAttestationCreator) createStsRequest(hostname string) (*http.Request, error) {
url := fmt.Sprintf("https://%s?Action=GetCallerIdentity&Version=2011-06-15", hostname)
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Host", hostname)
req.Header.Set("X-Snowflake-Audience", "snowflakecomputing.com")
return req, nil
}
func (c *awsIdentityAttestationCreator) signRequestWithSigV4(ctx context.Context, req *http.Request, creds aws.Credentials, region string) error {
signer := v4.NewSigner()
// as per docs of SignHTTP, the payload hash must be present even if the payload is empty
payloadHash := hex.EncodeToString(sha256.New().Sum(nil))
return signer.SignHTTP(ctx, creds, req, payloadHash, "sts", region, time.Now())
}
func (c *awsIdentityAttestationCreator) createBase64EncodedRequestCredential(req *http.Request) (string, error) {
headers := make(map[string]string)
for key, values := range req.Header {
headers[key] = values[0]
}
assertion := map[string]any{
"url": req.URL.String(),
"method": req.Method,
"headers": headers,
}
assertionJSON, err := json.Marshal(assertion)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(assertionJSON), nil
}
func (c *gcpIdentityAttestationCreator) createAttestation() (*wifAttestation, error) {
logger.Debugf("Creating GCP identity attestation...")
if len(c.cfg.WorkloadIdentityImpersonationPath) == 0 {
return c.createGcpIdentityTokenFromMetadataService()
}
return c.createGcpIdentityViaImpersonation()
}
func (c *gcpIdentityAttestationCreator) createGcpIdentityTokenFromMetadataService() (*wifAttestation, error) {
req, err := c.createTokenRequest()
if err != nil {
return nil, fmt.Errorf("failed to create GCP token request: %w", err)
}
token := fetchTokenFromMetadataService(req, c.cfg, c.telemetry)
if token == "" {
return nil, fmt.Errorf("no GCP token was found")
}
sub, _, err := extractSubIssWithoutVerifyingSignature(token)
if err != nil {
return nil, fmt.Errorf("could not extract claims from token: %v", err)
}
return &wifAttestation{
ProviderType: string(gcpWif),
Credential: token,
Metadata: map[string]string{"sub": sub},
}, nil
}
func (c *gcpIdentityAttestationCreator) createTokenRequest() (*http.Request, error) {
uri := fmt.Sprintf("%s/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s",
c.metadataServiceBaseURL, snowflakeAudience)
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
}
req.Header.Set(gcpMetadataFlavorHeaderName, gcpMetadataFlavor)
return req, nil
}
func (c *gcpIdentityAttestationCreator) createGcpIdentityViaImpersonation() (*wifAttestation, error) {
// initialize transport
transport, err := newTransportFactory(c.cfg, c.telemetry).createTransport(transportConfigFor(transportTypeWIF))
if err != nil {
logger.Debugf("Failed to create HTTP transport: %v", err)
return nil, err
}
client := &http.Client{Transport: transport}
// fetch access token for impersonation
accessToken, err := c.fetchServiceToken(client)
if err != nil {
return nil, err
}
// map paths to full service account paths
var fullServiceAccountPaths []string
for _, path := range c.cfg.WorkloadIdentityImpersonationPath {
fullServiceAccountPaths = append(fullServiceAccountPaths, fmt.Sprintf("projects/-/serviceAccounts/%s", path))
}
targetServiceAccount := fullServiceAccountPaths[len(fullServiceAccountPaths)-1]
delegates := fullServiceAccountPaths[:len(fullServiceAccountPaths)-1]
// fetch impersonated token
impersonationToken, err := c.fetchImpersonatedToken(targetServiceAccount, delegates, accessToken, client)
if err != nil {
return nil, err
}
// create attestation
sub, _, err := extractSubIssWithoutVerifyingSignature(impersonationToken)
if err != nil {
return nil, fmt.Errorf("could not extract claims from token: %v", err)
}
return &wifAttestation{
ProviderType: string(gcpWif),
Credential: impersonationToken,
Metadata: map[string]string{"sub": sub},
}, nil
}
func (c *gcpIdentityAttestationCreator) fetchServiceToken(client *http.Client) (string, error) {
// initialize and do request
req, err := http.NewRequest("GET", c.metadataServiceBaseURL+"/computeMetadata/v1/instance/service-accounts/default/token", nil)
if err != nil {
logger.Debugf("cannot create token request for impersonation. %v", err)
return "", err
}
req.Header.Set(gcpMetadataFlavorHeaderName, gcpMetadataFlavor)
resp, err := client.Do(req)
if err != nil {
logger.Debugf("cannot fetch token for impersonation. %v", err)
return "", err
}
defer func(body io.ReadCloser) {
if err = body.Close(); err != nil {
logger.Debugf("cannot close token response body for impersonation. %v", err)
}
}(resp.Body)
// if it is not 200, do not parse the response
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("token response status is %v, not parsing", resp.StatusCode)
}
// parse response and extract access token
accessTokenResponse := struct {
AccessToken string `json:"access_token"`
}{}
if err = json.NewDecoder(resp.Body).Decode(&accessTokenResponse); err != nil {
logger.Debugf("cannot decode token for impersonation. %v", err)
return "", err
}
accessToken := accessTokenResponse.AccessToken
return accessToken, nil
}
func (c *gcpIdentityAttestationCreator) fetchImpersonatedToken(targetServiceAccount string, delegates []string, accessToken string, client *http.Client) (string, error) {
// prepare the request
url := fmt.Sprintf("%v/v1/%v:generateIdToken", c.iamCredentialsURL, targetServiceAccount)
body := struct {
Delegates []string `json:"delegates,omitempty"`
Audience string `json:"audience"`
}{
Delegates: delegates,
Audience: snowflakeAudience,
}
payload := new(bytes.Buffer)
if err := json.NewEncoder(payload).Encode(body); err != nil {
logger.Debugf("cannot encode impersonation request body. %v", err)
return "", err
}
req, err := http.NewRequest("POST", url, payload)
if err != nil {
logger.Debugf("cannot create token request for impersonation. %v", err)
return "", err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
// send the request
resp, err := client.Do(req)
if err != nil {
logger.Debugf("cannot call impersonation service. %v", err)
return "", err
}
defer func(body io.ReadCloser) {
if err = body.Close(); err != nil {
logger.Debugf("cannot close token response body for impersonation. %v", err)
}
}(resp.Body)
// handle the response
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("response status is %v, not parsing", resp.StatusCode)
}
tokenResponse := struct {
Token string `json:"token"`
}{}
if err = json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
logger.Debugf("cannot decode token response. %v", err)
return "", err
}
return tokenResponse.Token, nil
}
func fetchTokenFromMetadataService(req *http.Request, cfg *Config, telemetry *snowflakeTelemetry) string {
transport, err := newTransportFactory(cfg, telemetry).createTransport(transportConfigFor(transportTypeWIF))
if err != nil {
logger.Debugf("Failed to create HTTP transport: %v", err)
return ""
}
client := &http.Client{Transport: transport}
resp, err := client.Do(req)
if err != nil {
logger.Debugf("Metadata server request was not successful: %v", err)
return ""
}
defer func() {
if err = resp.Body.Close(); err != nil {
logger.Debugf("Failed to close response body: %v", err)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Debugf("Failed to read response body: %v", err)
return ""
}
return string(body)
}
func extractSubIssWithoutVerifyingSignature(token string) (subject string, issuer string, err error) {
claims, err := extractClaimsMap(token)
if err != nil {
return "", "", err
}
issuerClaim, ok := claims["iss"]
if !ok {
return "", "", errors.New("missing issuer claim in JWT token")
}
subjectClaim, ok := claims["sub"]
if !ok {
return "", "", errors.New("missing sub claim in JWT token")
}
subject, ok = subjectClaim.(string)
if !ok {
return "", "", errors.New("sub claim is not a string in JWT token")
}
issuer, ok = issuerClaim.(string)
if !ok {
return "", "", errors.New("iss claim is not a string in JWT token")
}
return
}
// extractClaimsMap parses a JWT token and returns its claims as a map.
// It does not verify the token signature.
func extractClaimsMap(token string) (map[string]any, error) {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
claims := jwt.MapClaims{}
_, _, err := parser.ParseUnverified(token, claims)
if err != nil {
return nil, fmt.Errorf("unable to extract JWT claims from token: %w", err)
}
return claims, nil
}
func (c *oidcIdentityAttestationCreator) createAttestation() (*wifAttestation, error) {
logger.Debugf("Creating OIDC identity attestation...")
token, err := c.token()
if err != nil {
return nil, fmt.Errorf("failed to get OIDC token: %w", err)
}
if token == "" {
return nil, fmt.Errorf("no OIDC token was specified")
}
sub, iss, err := extractSubIssWithoutVerifyingSignature(token)
if err != nil {
return nil, err
}
if sub == "" || iss == "" {
return nil, errors.New("missing sub or iss claim in JWT token")
}
return &wifAttestation{
ProviderType: string(oidcWif),
Credential: token,
Metadata: map[string]string{"sub": sub},
}, nil
}
// azureAttestationMetadataProvider defines the interface for Azure attestation services
type azureAttestationMetadataProvider interface {
identityEndpoint() string
identityHeader() string
clientID() string
}
type defaultAzureAttestationMetadataProvider struct{}
func (p *defaultAzureAttestationMetadataProvider) identityEndpoint() string {
return os.Getenv("IDENTITY_ENDPOINT")
}
func (p *defaultAzureAttestationMetadataProvider) identityHeader() string {
return os.Getenv("IDENTITY_HEADER")
}
func (p *defaultAzureAttestationMetadataProvider) clientID() string {
return os.Getenv("MANAGED_IDENTITY_CLIENT_ID")
}
type azureIdentityAttestationCreator struct {
azureAttestationMetadataProvider azureAttestationMetadataProvider
cfg *Config
telemetry *snowflakeTelemetry
workloadIdentityEntraResource string
azureMetadataServiceBaseURL string
}
// createAttestation creates an attestation using Azure identity
func (a *azureIdentityAttestationCreator) createAttestation() (*wifAttestation, error) {
logger.Debug("Creating Azure identity attestation...")
identityEndpoint := a.azureAttestationMetadataProvider.identityEndpoint()
var request *http.Request
var err error
if identityEndpoint == "" {
request, err = a.azureVMIdentityRequest()
if err != nil {
return nil, fmt.Errorf("failed to create Azure VM identity request: %v", err)
}
} else {
identityHeader := a.azureAttestationMetadataProvider.identityHeader()
if identityHeader == "" {
return nil, fmt.Errorf("managed identity is not enabled on this Azure function")
}
request, err = a.azureFunctionsIdentityRequest(
identityEndpoint,
identityHeader,
a.azureAttestationMetadataProvider.clientID(),
)
if err != nil {
return nil, fmt.Errorf("failed to create Azure Functions identity request: %v", err)
}
}
tokenJSON := fetchTokenFromMetadataService(request, a.cfg, a.telemetry)
if tokenJSON == "" {
return nil, fmt.Errorf("could not fetch Azure token")
}
token, err := extractTokenFromJSON(tokenJSON)
if err != nil {
return nil, fmt.Errorf("failed to extract token from JSON: %v", err)
}
if token == "" {
return nil, fmt.Errorf("no access token found in Azure response")
}
sub, iss, err := extractSubIssWithoutVerifyingSignature(token)
if err != nil {
return nil, fmt.Errorf("failed to extract sub and iss claims from token: %v", err)
}
if sub == "" || iss == "" {
return nil, fmt.Errorf("missing sub or iss claim in JWT token")
}
return &wifAttestation{
ProviderType: string(azureWif),
Credential: token,
Metadata: map[string]string{"sub": sub, "iss": iss},
}, nil
}
func determineEntraResource(config *Config) string {
if config != nil && config.WorkloadIdentityEntraResource != "" {
return config.WorkloadIdentityEntraResource
}
// default resource if none specified
return "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad"
}
func extractTokenFromJSON(tokenJSON string) (string, error) {
var response struct {
AccessToken string `json:"access_token"`
}
err := json.Unmarshal([]byte(tokenJSON), &response)
if err != nil {
return "", err
}
return response.AccessToken, nil
}
func (a *azureIdentityAttestationCreator) azureFunctionsIdentityRequest(identityEndpoint, identityHeader, managedIdentityClientID string) (*http.Request, error) {
queryParams := fmt.Sprintf("api-version=2019-08-01&resource=%s", a.workloadIdentityEntraResource)
if managedIdentityClientID != "" {
queryParams += fmt.Sprintf("&client_id=%s", managedIdentityClientID)
}
url := fmt.Sprintf("%s?%s", identityEndpoint, queryParams)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Add("X-IDENTITY-HEADER", identityHeader)
return req, nil
}
func (a *azureIdentityAttestationCreator) azureVMIdentityRequest() (*http.Request, error) {
urlWithoutQuery := a.azureMetadataServiceBaseURL + "/metadata/identity/oauth2/token?"
queryParams := fmt.Sprintf("api-version=2018-02-01&resource=%s", a.workloadIdentityEntraResource)
url := urlWithoutQuery + queryParams
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Add("Metadata", "true")
return req, nil
}
================================================
FILE: auth_wif_test.go
================================================
package gosnowflake
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
)
type mockWifAttestationCreator struct {
providerType wifProviderType
returnError error
}
func (m *mockWifAttestationCreator) createAttestation() (*wifAttestation, error) {
if m.returnError != nil {
return nil, m.returnError
}
return &wifAttestation{
ProviderType: string(m.providerType),
}, nil
}
func TestGetAttestation(t *testing.T) {
awsError := errors.New("aws attestation error")
gcpError := errors.New("gcp attestation error")
azureError := errors.New("azure attestation error")
oidcError := errors.New("oidc attestation error")
provider := &wifAttestationProvider{
context: context.Background(),
awsCreator: &mockWifAttestationCreator{providerType: awsWif},
gcpCreator: &mockWifAttestationCreator{providerType: gcpWif},
azureCreator: &mockWifAttestationCreator{providerType: azureWif},
oidcCreator: &mockWifAttestationCreator{providerType: oidcWif},
}
providerWithErrors := &wifAttestationProvider{
context: context.Background(),
awsCreator: &mockWifAttestationCreator{providerType: awsWif, returnError: awsError},
gcpCreator: &mockWifAttestationCreator{providerType: gcpWif, returnError: gcpError},
azureCreator: &mockWifAttestationCreator{providerType: azureWif, returnError: azureError},
oidcCreator: &mockWifAttestationCreator{providerType: oidcWif, returnError: oidcError},
}
tests := []struct {
name string
provider *wifAttestationProvider
identityProvider string
expectedResult *wifAttestation
expectedError error
}{
{
name: "AWS success",
provider: provider,
identityProvider: "AWS",
expectedResult: &wifAttestation{ProviderType: string(awsWif)},
expectedError: nil,
},
{
name: "AWS error",
provider: providerWithErrors,
identityProvider: "AWS",
expectedResult: nil,
expectedError: awsError,
},
{
name: "GCP success",
provider: provider,
identityProvider: "GCP",
expectedResult: &wifAttestation{ProviderType: string(gcpWif)},
expectedError: nil,
},
{
name: "GCP error",
provider: providerWithErrors,
identityProvider: "GCP",
expectedResult: nil,
expectedError: gcpError,
},
{
name: "AZURE success",
provider: provider,
identityProvider: "AZURE",
expectedResult: &wifAttestation{ProviderType: string(azureWif)},
expectedError: nil,
},
{
name: "AZURE error",
provider: providerWithErrors,
identityProvider: "AZURE",
expectedResult: nil,
expectedError: azureError,
},
{
name: "OIDC success",
provider: provider,
identityProvider: "OIDC",
expectedResult: &wifAttestation{ProviderType: string(oidcWif)},
expectedError: nil,
},
{
name: "OIDC error",
provider: providerWithErrors,
identityProvider: "OIDC",
expectedResult: nil,
expectedError: oidcError,
},
{
name: "Unknown provider",
provider: provider,
identityProvider: "UNKNOWN",
expectedResult: nil,
expectedError: errors.New("unknown WorkloadIdentityProvider specified: UNKNOWN. Valid values are: AWS, GCP, AZURE, OIDC"),
},
{
name: "Empty provider",
provider: provider,
identityProvider: "",
expectedResult: nil,
expectedError: errors.New("unknown WorkloadIdentityProvider specified: . Valid values are: AWS, GCP, AZURE, OIDC"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
attestation, err := test.provider.getAttestation(test.identityProvider)
if test.expectedError != nil {
assertNilE(t, attestation)
assertNotNilF(t, err)
assertEqualE(t, test.expectedError.Error(), err.Error())
} else if test.expectedResult != nil {
assertNilE(t, err)
assertNotNilF(t, attestation)
assertEqualE(t, test.expectedResult.ProviderType, attestation.ProviderType)
} else {
t.Fatal("test case must specify either expectedError or expectedResult")
}
})
}
}
func TestAwsIdentityAttestationCreator(t *testing.T) {
tests := []struct {
name string
config Config
attestationSvc awsAttestationMetadataProvider
expectedError error
expectedProvider string
expectedStsHost string
}{
{
name: "No attestation service",
attestationSvc: nil,
expectedError: fmt.Errorf("AWS attestation service could not be created"),
},
{
name: "No AWS credentials",
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: aws.Credentials{},
region: "us-west-2",
},
expectedError: fmt.Errorf("no AWS credentials were found"),
},
{
name: "No AWS region",
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: mockCreds,
region: "",
},
expectedError: fmt.Errorf("no AWS region was found"),
},
{
name: "Successful attestation",
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: mockCreds,
region: "us-west-2",
},
expectedProvider: "AWS",
expectedStsHost: "sts.us-west-2.amazonaws.com",
},
{
name: "Successful attestation for CN region",
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: mockCreds,
region: "cn-northwest-1",
},
expectedProvider: "AWS",
expectedStsHost: "sts.cn-northwest-1.amazonaws.com.cn",
},
{
name: "Successful attestation with single role chaining",
config: Config{
WorkloadIdentityImpersonationPath: []string{"arn:aws:iam::123456789012:role/test-role"},
},
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: mockCreds,
chainingCreds: aws.Credentials{
AccessKeyID: "chainedAccessKey",
SecretAccessKey: "chainedSecretKey",
SessionToken: "chainedSessionToken",
},
region: "us-east-1",
useRoleChaining: true,
},
expectedProvider: "AWS",
expectedStsHost: "sts.us-east-1.amazonaws.com",
},
{
name: "Successful attestation with multiple role chaining",
config: Config{
WorkloadIdentityImpersonationPath: []string{
"arn:aws:iam::123456789012:role/role1",
"arn:aws:iam::123456789012:role/role2",
"arn:aws:iam::123456789012:role/role3",
},
},
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: mockCreds,
chainingCreds: aws.Credentials{
AccessKeyID: "finalRoleAccessKey",
SecretAccessKey: "finalRoleSecretKey",
SessionToken: "finalRoleSessionToken",
},
region: "us-west-2",
useRoleChaining: true,
},
expectedProvider: "AWS",
expectedStsHost: "sts.us-west-2.amazonaws.com",
},
{
name: "Role chaining with no credentials",
config: Config{
WorkloadIdentityImpersonationPath: []string{"arn:aws:iam::123456789012:role/test-role"},
},
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: aws.Credentials{},
region: "us-west-2",
useRoleChaining: true,
},
expectedError: fmt.Errorf("no AWS credentials were found"),
},
{
name: "Role chaining with no region",
config: Config{
WorkloadIdentityImpersonationPath: []string{"arn:aws:iam::123456789012:role/test-role"},
},
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: aws.Credentials{
AccessKeyID: "chainedAccessKey",
SecretAccessKey: "chainedSecretKey",
SessionToken: "chainedSessionToken",
},
region: "",
useRoleChaining: true,
},
expectedError: fmt.Errorf("no AWS region was found"),
},
{
name: "Role chaining failure",
config: Config{
WorkloadIdentityImpersonationPath: []string{"arn:aws:iam::123456789012:role/test-role"},
},
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: mockCreds,
region: "us-west-2",
chainingError: fmt.Errorf("failed to assume role: AccessDenied"),
useRoleChaining: true,
},
expectedError: fmt.Errorf("failed to assume role: AccessDenied"),
},
{
name: "Cross-account role chaining",
config: Config{
WorkloadIdentityImpersonationPath: []string{
"arn:aws:iam::111111111111:role/cross-account-role",
"arn:aws:iam::222222222222:role/target-role",
},
},
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: mockCreds,
chainingCreds: aws.Credentials{
AccessKeyID: "crossAccountAccessKey",
SecretAccessKey: "crossAccountSecretKey",
SessionToken: "crossAccountSessionToken",
},
region: "us-east-1",
useRoleChaining: true,
},
expectedProvider: "AWS",
expectedStsHost: "sts.us-east-1.amazonaws.com",
},
{
name: "Role chaining in CN region",
config: Config{
WorkloadIdentityImpersonationPath: []string{"arn:aws-cn:iam::123456789012:role/cn-role"},
},
attestationSvc: &mockAwsAttestationMetadataProvider{
creds: mockCreds,
chainingCreds: aws.Credentials{
AccessKeyID: "cnRoleAccessKey",
SecretAccessKey: "cnRoleSecretKey",
SessionToken: "cnRoleSessionToken",
},
region: "cn-north-1",
useRoleChaining: true,
},
expectedProvider: "AWS",
expectedStsHost: "sts.cn-north-1.amazonaws.com.cn",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
creator := &awsIdentityAttestationCreator{
attestationServiceFactory: func(ctx context.Context, cfg *Config) awsAttestationMetadataProvider {
return test.attestationSvc
},
ctx: context.Background(),
cfg: &test.config,
}
attestation, err := creator.createAttestation()
if test.expectedError != nil {
assertNilF(t, attestation)
assertNotNilE(t, err)
assertEqualE(t, test.expectedError.Error(), err.Error())
} else {
assertNilE(t, err)
assertNotNilE(t, attestation)
assertNotNilE(t, attestation.Credential)
assertEqualE(t, test.expectedProvider, attestation.ProviderType)
decoded, err := base64.StdEncoding.DecodeString(attestation.Credential)
if err != nil {
t.Fatalf("Failed to decode credential: %v", err)
}
var credentialMap map[string]any
if err := json.Unmarshal(decoded, &credentialMap); err != nil {
t.Fatalf("Failed to unmarshal credential JSON: %v", err)
}
assertEqualE(t, fmt.Sprintf("https://%s?Action=GetCallerIdentity&Version=2011-06-15", test.expectedStsHost), credentialMap["url"])
}
})
}
}
type mockAwsAttestationMetadataProvider struct {
creds aws.Credentials
region string
chainingCreds aws.Credentials
chainingError error
useRoleChaining bool
}
var mockCreds = aws.Credentials{
AccessKeyID: "mockAccessKey",
SecretAccessKey: "mockSecretKey",
SessionToken: "mockSessionToken",
}
func (m *mockAwsAttestationMetadataProvider) awsCredentials() (aws.Credentials, error) {
return m.creds, nil
}
func (m *mockAwsAttestationMetadataProvider) awsCredentialsViaRoleChaining() (aws.Credentials, error) {
if m.chainingError != nil {
return aws.Credentials{}, m.chainingError
}
if m.chainingCreds.AccessKeyID != "" {
return m.chainingCreds, nil
}
return m.creds, nil
}
func (m *mockAwsAttestationMetadataProvider) awsRegion() string {
return m.region
}
func TestGcpIdentityAttestationCreator(t *testing.T) {
tests := []struct {
name string
wiremockMappingPath string
config Config
expectedError error
expectedSub string
}{
{
name: "Successful flow",
wiremockMappingPath: "auth/wif/gcp/successful_flow.json",
expectedError: nil,
expectedSub: "some-subject",
},
{
name: "Successful impersonation flow",
wiremockMappingPath: "auth/wif/gcp/successful_impersionation_flow.json",
config: Config{
WorkloadIdentityImpersonationPath: []string{
"delegate1",
"delegate2",
"targetServiceAccount",
},
},
expectedError: nil,
expectedSub: "some-impersonated-subject",
},
{
name: "No GCP credential - http error",
wiremockMappingPath: "auth/wif/gcp/http_error.json",
expectedError: fmt.Errorf("no GCP token was found"),
expectedSub: "",
},
{
name: "missing issuer claim",
wiremockMappingPath: "auth/wif/gcp/missing_issuer_claim.json",
expectedError: fmt.Errorf("could not extract claims from token: missing issuer claim in JWT token"),
expectedSub: "",
},
{
name: "missing sub claim",
wiremockMappingPath: "auth/wif/gcp/missing_sub_claim.json",
expectedError: fmt.Errorf("could not extract claims from token: missing sub claim in JWT token"),
expectedSub: "",
},
{
name: "unparsable token",
wiremockMappingPath: "auth/wif/gcp/unparsable_token.json",
expectedError: fmt.Errorf("could not extract claims from token: unable to extract JWT claims from token: token is malformed: token contains an invalid number of segments"),
expectedSub: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
creator := &gcpIdentityAttestationCreator{
cfg: &test.config,
metadataServiceBaseURL: wiremock.baseURL(),
iamCredentialsURL: wiremock.baseURL(),
}
wiremock.registerMappings(t, wiremockMapping{filePath: test.wiremockMappingPath})
attestation, err := creator.createAttestation()
if test.expectedError != nil {
assertNilF(t, attestation)
assertNotNilF(t, err)
assertEqualE(t, test.expectedError.Error(), err.Error())
} else {
assertNilF(t, err)
assertNotNilF(t, attestation)
assertEqualE(t, string(gcpWif), attestation.ProviderType)
assertEqualE(t, test.expectedSub, attestation.Metadata["sub"])
}
})
}
}
func TestOidcIdentityAttestationCreator(t *testing.T) {
const (
/*
* {
* "sub": "some-subject",
* "iat": 1743761213,
* "exp": 1743764813,
* "aud": "www.example.com"
* }
*/
missingIssuerClaimToken = "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImU2M2I5NzA1OTRiY2NmZTAxMDlkOTg4OWM2MDk3OWEwIn0.eyJzdWIiOiJzb21lLXN1YmplY3QiLCJpYXQiOjE3NDM3NjEyMTMsImV4cCI6MTc0Mzc2NDgxMywiYXVkIjoid3d3LmV4YW1wbGUuY29tIn0.H6sN6kjA82EuijFcv-yCJTqau5qvVTCsk0ZQ4gvFQMkB7c71XPs4lkwTa7ZlNNlx9e6TpN1CVGnpCIRDDAZaDw" // pragma: allowlist secret
/*
* {
* "iss": "https://accounts.google.com",
* "iat": 1743761213,
* "exp": 1743764813,
* "aud": "www.example.com"
* }
*/
missingSubClaimToken = "eyJ0eXAiOiJhdCtqd3QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImU2M2I5NzA1OTRiY2NmZTAxMDlkOTg4OWM2MDk3OWEwIn0.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJpYXQiOjE3NDM3NjEyMTMsImV4cCI6MTc0Mzc2NDgxMywiYXVkIjoid3d3LmV4YW1wbGUuY29tIn0.w0njdpfWFETVK8Ktq9GdvuKRQJjvhOplcSyvQ_zHHwBUSMapqO1bjEWBx5VhGkdECZIGS1VY7db_IOqT45yOMA" // pragma: allowlist secret
/*
* {
* "iss": "https://oidc.eks.us-east-2.amazonaws.com/id/3B869BC5D12CEB5515358621D8085D58",
* "iat": 1743692017,
* "exp": 1775228014,
* "aud": "www.example.com",
* "sub": "system:serviceaccount:poc-namespace:oidc-sa"
* }
*/
validToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMi5hbWF6b25hd3MuY29tL2lkLzNCODY5QkM1RDEyQ0VCNTUxNTM1ODYyMUQ4MDg1RDU4IiwiaWF0IjoxNzQ0Mjg3ODc4LCJleHAiOjE3NzU4MjM4NzgsImF1ZCI6Ind3dy5leGFtcGxlLmNvbSIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpwb2MtbmFtZXNwYWNlOm9pZGMtc2EifQ.a8H6KRIF1XmM8lkqL6kR8ccInr7wAzQrbKd3ZHFgiEg" // pragma: allowlist secret
unparsableToken = "unparsable_token"
emptyToken = ""
)
type testCase struct {
name string
token string
expectedError error
expectedSub string
}
tests := []testCase{
{
name: "no token input",
token: emptyToken,
expectedError: fmt.Errorf("no OIDC token was specified"),
},
{
name: "valid token returns proper attestation",
token: validToken,
expectedError: nil,
expectedSub: "system:serviceaccount:poc-namespace:oidc-sa",
},
{
name: "missing issuer returns error",
token: missingIssuerClaimToken,
expectedError: errors.New("missing issuer claim in JWT token"),
},
{
name: "missing sub returns error",
token: missingSubClaimToken,
expectedError: errors.New("missing sub claim in JWT token"),
},
{
name: "unparsable token returns error",
token: unparsableToken,
expectedError: errors.New("unable to extract JWT claims from token: token is malformed: token contains an invalid number of segments"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
creator := &oidcIdentityAttestationCreator{token: func() (string, error) {
return test.token, nil
}}
attestation, err := creator.createAttestation()
if test.expectedError != nil {
assertNotNilE(t, err)
assertNilF(t, attestation)
assertEqualE(t, test.expectedError.Error(), err.Error())
} else {
assertNilE(t, err)
assertNotNilE(t, attestation)
assertEqualE(t, string(oidcWif), attestation.ProviderType)
assertEqualE(t, test.expectedSub, attestation.Metadata["sub"])
}
})
}
}
func TestAzureIdentityAttestationCreator(t *testing.T) {
tests := []struct {
name string
wiremockMappingPath string
metadataProvider *mockAzureAttestationMetadataProvider
cfg *Config
expectedIss string
expectedError error
}{
/*
* {
* "iss": "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/",
* "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD"
* }
*/
{
name: "Successful flow",
wiremockMappingPath: "auth/wif/azure/successful_flow_basic.json",
metadataProvider: azureVMMetadataProvider(),
expectedIss: "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/",
expectedError: nil,
},
/*
* {
* "iss": "https://login.microsoftonline.com/fa15d692-e9c7-4460-a743-29f29522229/",
* "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD"
* }
*/
{
name: "Successful flow v2 issuer",
wiremockMappingPath: "auth/wif/azure/successful_flow_v2_issuer.json",
metadataProvider: azureVMMetadataProvider(),
expectedIss: "https://login.microsoftonline.com/fa15d692-e9c7-4460-a743-29f29522229/",
expectedError: nil,
},
/*
* {
* "iss": "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/",
* "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD"
* }
*/
{
name: "Successful flow azure functions",
wiremockMappingPath: "auth/wif/azure/successful_flow_azure_functions.json",
metadataProvider: azureFunctionsMetadataProvider(),
expectedIss: "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/",
expectedError: nil,
},
/*
* {
* "iss": "https://login.microsoftonline.com/fa15d692-e9c7-4460-a743-29f29522229/",
* "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD"
* }
*/
{
name: "Successful flow azure functions v2 issuer",
wiremockMappingPath: "auth/wif/azure/successful_flow_azure_functions_v2_issuer.json",
metadataProvider: azureFunctionsMetadataProvider(),
expectedIss: "https://login.microsoftonline.com/fa15d692-e9c7-4460-a743-29f29522229/",
expectedError: nil,
},
/*
* {
* "iss": "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/",
* "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD"
* }
*/
{
name: "Successful flow azure functions no client ID",
wiremockMappingPath: "auth/wif/azure/successful_flow_azure_functions_no_client_id.json",
metadataProvider: &mockAzureAttestationMetadataProvider{
identityEndpointValue: wiremock.baseURL() + "/metadata/identity/endpoint/from/env",
identityHeaderValue: "some-identity-header-from-env",
clientIDValue: "",
},
expectedIss: "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/",
expectedError: nil,
},
/*
* {
* "iss": "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/",
* "sub": "77213E30-E8CB-4595-B1B6-5F050E8308FD"
* }
*/
{
name: "Successful flow azure functions custom Entra resource",
wiremockMappingPath: "auth/wif/azure/successful_flow_azure_functions_custom_entra_resource.json",
metadataProvider: azureFunctionsMetadataProvider(),
cfg: &Config{WorkloadIdentityEntraResource: "api://1111111-2222-3333-44444-55555555"},
expectedIss: "https://sts.windows.net/fa15d692-e9c7-4460-a743-29f29522229/",
expectedError: nil,
},
{
name: "Non-json response",
wiremockMappingPath: "auth/wif/azure/non_json_response.json",
metadataProvider: azureVMMetadataProvider(),
expectedError: fmt.Errorf("failed to extract token from JSON: invalid character 'o' in literal null (expecting 'u')"),
},
{
name: "Identity endpoint but no identity header",
metadataProvider: &mockAzureAttestationMetadataProvider{
identityEndpointValue: wiremock.baseURL() + "/metadata/identity/endpoint/from/env",
identityHeaderValue: "",
clientIDValue: "managed-client-id-from-env",
},
expectedError: fmt.Errorf("managed identity is not enabled on this Azure function"),
},
{
name: "Unparsable token",
wiremockMappingPath: "auth/wif/azure/unparsable_token.json",
metadataProvider: azureVMMetadataProvider(),
expectedError: fmt.Errorf("failed to extract sub and iss claims from token: unable to extract JWT claims from token: token is malformed: token contains an invalid number of segments"),
},
{
name: "HTTP error",
metadataProvider: azureVMMetadataProvider(),
wiremockMappingPath: "auth/wif/azure/http_error.json",
expectedError: fmt.Errorf("could not fetch Azure token"),
},
{
name: "Missing sub or iss claim",
wiremockMappingPath: "auth/wif/azure/missing_issuer_claim.json",
metadataProvider: azureVMMetadataProvider(),
expectedError: fmt.Errorf("failed to extract sub and iss claims from token: missing issuer claim in JWT token"),
},
{
name: "Missing sub claim",
wiremockMappingPath: "auth/wif/azure/missing_sub_claim.json",
metadataProvider: azureVMMetadataProvider(),
expectedError: fmt.Errorf("failed to extract sub and iss claims from token: missing sub claim in JWT token"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.wiremockMappingPath != "" {
wiremock.registerMappings(t, wiremockMapping{filePath: test.wiremockMappingPath})
}
creator := &azureIdentityAttestationCreator{
cfg: test.cfg,
azureMetadataServiceBaseURL: wiremock.baseURL(),
azureAttestationMetadataProvider: test.metadataProvider,
workloadIdentityEntraResource: determineEntraResource(test.cfg),
}
attestation, err := creator.createAttestation()
if test.expectedError != nil {
assertNilF(t, attestation)
assertNotNilE(t, err)
assertEqualE(t, test.expectedError.Error(), err.Error())
} else {
assertNilF(t, err)
assertNotNilF(t, attestation)
assertEqualE(t, string(azureWif), attestation.ProviderType)
assertEqualE(t, test.expectedIss, attestation.Metadata["iss"])
assertEqualE(t, "77213E30-E8CB-4595-B1B6-5F050E8308FD", attestation.Metadata["sub"])
}
})
}
}
type mockAzureAttestationMetadataProvider struct {
identityEndpointValue string
identityHeaderValue string
clientIDValue string
}
func (m *mockAzureAttestationMetadataProvider) identityEndpoint() string {
return m.identityEndpointValue
}
func (m *mockAzureAttestationMetadataProvider) identityHeader() string {
return m.identityHeaderValue
}
func (m *mockAzureAttestationMetadataProvider) clientID() string {
return m.clientIDValue
}
func azureFunctionsMetadataProvider() *mockAzureAttestationMetadataProvider {
return &mockAzureAttestationMetadataProvider{
identityEndpointValue: wiremock.baseURL() + "/metadata/identity/endpoint/from/env",
identityHeaderValue: "some-identity-header-from-env",
clientIDValue: "managed-client-id-from-env",
}
}
func azureVMMetadataProvider() *mockAzureAttestationMetadataProvider {
return &mockAzureAttestationMetadataProvider{
identityEndpointValue: "",
identityHeaderValue: "",
clientIDValue: "",
}
}
// Running this test locally:
// * Push branch to repository
// * Set PARAMETERS_SECRET
// * Run ci/test_wif.sh
func TestWorkloadIdentityAuthOnCloudVM(t *testing.T) {
account := os.Getenv("SNOWFLAKE_TEST_WIF_ACCOUNT")
host := os.Getenv("SNOWFLAKE_TEST_WIF_HOST")
provider := os.Getenv("SNOWFLAKE_TEST_WIF_PROVIDER")
println("provider = " + provider)
if account == "" || host == "" || provider == "" {
t.Skip("Test can run only on cloud VM with env variables set")
}
testCases := []struct {
name string
skip func() (bool, string)
setupCfg func(*testing.T, *Config)
expectedUsername string
}{
{
name: "provider=" + provider,
setupCfg: func(_ *testing.T, config *Config) {
if provider != "GCP+OIDC" {
config.WorkloadIdentityProvider = provider
} else {
config.WorkloadIdentityProvider = "OIDC"
config.Token = func() string {
cmd := exec.Command("wget", "-O", "-", "--header=Metadata-Flavor: Google", "http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience=snowflakecomputing.com")
output, err := cmd.Output()
if err != nil {
t.Fatalf("error executing GCP metadata request: %v", err)
}
token := strings.TrimSpace(string(output))
if token == "" {
t.Fatal("failed to retrieve GCP access token: empty response")
}
return token
}()
}
},
expectedUsername: os.Getenv("SNOWFLAKE_TEST_WIF_USERNAME"),
},
{
name: "provider=" + provider + ",impersonation",
skip: func() (bool, string) {
if provider != "AWS" && provider != "GCP" {
return true, "Impersonation is supported only on AWS and GCP"
}
return false, ""
},
setupCfg: func(t *testing.T, config *Config) {
config.WorkloadIdentityProvider = provider
impersonationPath := os.Getenv("SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH")
assertNotEqualF(t, impersonationPath, "", "SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH is not set")
config.WorkloadIdentityImpersonationPath = strings.Split(impersonationPath, ",")
assertNotEqualF(t, os.Getenv("SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION"), "", "SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION is not set")
},
expectedUsername: os.Getenv("SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.skip != nil {
if skip, msg := tc.skip(); skip {
t.Skip(msg)
}
}
config := &Config{
Account: account,
Host: host,
Authenticator: AuthTypeWorkloadIdentityFederation,
}
tc.setupCfg(t, config)
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
defer db.Close()
currentUser := runSelectCurrentUser(t, db)
assertEqualE(t, currentUser, tc.expectedUsername)
})
}
}
================================================
FILE: auth_with_external_browser_test.go
================================================
package gosnowflake
import (
"context"
"database/sql"
"fmt"
"log"
"os/exec"
"sync"
"testing"
"time"
)
func TestExternalBrowserSuccessful(t *testing.T) {
cfg := setupExternalBrowserTest(t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.Success, cfg.User, cfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}()
wg.Wait()
}
func TestExternalBrowserFailed(t *testing.T) {
cfg := setupExternalBrowserTest(t)
cfg.ExternalBrowserTimeout = time.Duration(10) * time.Second
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.Fail, "FakeAccount", "NotARealPassword")
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNotNilF(t, err)
assertEqualE(t, err.Error(), "authentication timed out")
}()
wg.Wait()
}
func TestExternalBrowserTimeout(t *testing.T) {
cfg := setupExternalBrowserTest(t)
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.Timeout, cfg.User, cfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNotNilF(t, err)
assertEqualE(t, err.Error(), "authentication timed out")
}()
wg.Wait()
}
func TestExternalBrowserMismatchUser(t *testing.T) {
cfg := setupExternalBrowserTest(t)
correctUsername := cfg.User
cfg.User = "fakeAccount"
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.Success, correctUsername, cfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390191, fmt.Sprintf("Expected 390191, but got %v", snowflakeErr.Number))
}()
wg.Wait()
}
func TestClientStoreCredentials(t *testing.T) {
cfg := setupExternalBrowserTest(t)
cfg.ClientStoreTemporaryCredential = 1
cfg.ExternalBrowserTimeout = time.Duration(10) * time.Second
t.Run("Obtains the ID token from the server and saves it on the local storage", func(t *testing.T) {
cleanupBrowserProcesses(t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.Success, cfg.User, cfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed: err %v", err))
}()
wg.Wait()
})
t.Run("Verify validation of ID token if option enabled", func(t *testing.T) {
cleanupBrowserProcesses(t)
cfg.ClientStoreTemporaryCredential = 1
db := getDbHandlerFromConfig(t, cfg)
conn, err := db.Conn(context.Background())
assertNilE(t, err, fmt.Sprintf("Failed to connect to Snowflake. err: %v", err))
defer conn.Close()
rows, err := conn.QueryContext(context.Background(), "SELECT 1")
assertNilE(t, err, fmt.Sprintf("Failed to run a query. err: %v", err))
rows.Close()
})
t.Run("Verify validation of idToken if option disabled", func(t *testing.T) {
cleanupBrowserProcesses(t)
cfg.ClientStoreTemporaryCredential = 0
db := getDbHandlerFromConfig(t, cfg)
_, err := db.Conn(context.Background())
assertNotNilF(t, err)
assertEqualE(t, err.Error(), "authentication timed out", fmt.Sprintf("Expected timeout, but got %v", err))
})
}
type ExternalBrowserProcessResult struct {
Success string
Fail string
Timeout string
OauthOktaSuccess string
OauthSnowflakeSuccess string
}
var externalBrowserType = ExternalBrowserProcessResult{
Success: "success",
Fail: "fail",
Timeout: "timeout",
OauthOktaSuccess: "externalOauthOktaSuccess",
OauthSnowflakeSuccess: "internalOauthSnowflakeSuccess",
}
func cleanupBrowserProcesses(t *testing.T) {
if isTestRunningInDockerContainer() {
const cleanBrowserProcessesPath = "/externalbrowser/cleanBrowserProcesses.js"
_, err := exec.Command("node", cleanBrowserProcessesPath).CombinedOutput()
assertNilE(t, err, fmt.Sprintf("failed to execute command: %v", err))
}
}
func provideExternalBrowserCredentials(t *testing.T, ExternalBrowserProcess string, user string, password string) {
if isTestRunningInDockerContainer() {
const provideBrowserCredentialsPath = "/externalbrowser/provideBrowserCredentials.js"
output, err := exec.Command("node", provideBrowserCredentialsPath, ExternalBrowserProcess, user, password).CombinedOutput()
log.Printf("Output: %s\n", output)
assertNilE(t, err, fmt.Sprintf("failed to execute command: %v", err))
}
}
func verifyConnectionToSnowflakeAuthTests(t *testing.T, cfg *Config) (err error) {
dsn, err := DSN(cfg)
assertNilE(t, err, "failed to create DSN from Config")
db, err := sql.Open("snowflake", dsn)
assertNilE(t, err, "failed to open Snowflake DB connection")
defer db.Close()
rows, err := db.Query("SELECT 1")
if err != nil {
log.Printf("failed to run a query. 'SELECT 1', err: %v", err)
return err
}
defer rows.Close()
assertTrueE(t, rows.Next(), "failed to get result", "There were no results for query: ")
return err
}
func setupExternalBrowserTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping External Browser tests")
cleanupBrowserProcesses(t)
cfg, err := getAuthTestsConfig(t, AuthTypeExternalBrowser)
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err))
return cfg
}
================================================
FILE: auth_with_keypair_test.go
================================================
package gosnowflake
import (
"crypto/rsa"
"fmt"
"golang.org/x/crypto/ssh"
"os"
"testing"
)
func TestKeypairSuccessful(t *testing.T) {
cfg := setupKeyPairTest(t)
cfg.PrivateKey = loadRsaPrivateKeyForKeyPair(t, "SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH")
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err))
}
func TestKeypairInvalidKey(t *testing.T) {
cfg := setupKeyPairTest(t)
cfg.PrivateKey = loadRsaPrivateKeyForKeyPair(t, "SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH")
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390144, fmt.Sprintf("Expected 390144, but got %v", snowflakeErr.Number))
}
func setupKeyPairTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping KeyPair tests")
cfg, err := getAuthTestsConfig(t, AuthTypeJwt)
assertEqualE(t, err, nil, fmt.Sprintf("failed to get config: %v", err))
return cfg
}
func loadRsaPrivateKeyForKeyPair(t *testing.T, envName string) *rsa.PrivateKey {
filePath, err := GetFromEnv(envName, true)
assertNilF(t, err, fmt.Sprintf("failed to get env: %v", err))
bytes, err := os.ReadFile(filePath)
assertNilF(t, err, fmt.Sprintf("failed to read file: %v", err))
key, err := ssh.ParseRawPrivateKey(bytes)
assertNilF(t, err, fmt.Sprintf("failed to parse private key: %v", err))
return key.(*rsa.PrivateKey)
}
================================================
FILE: auth_with_mfa_test.go
================================================
package gosnowflake
import (
"errors"
"fmt"
"log"
"os/exec"
"strings"
"testing"
)
func TestMfaSuccessful(t *testing.T) {
cfg := setupMfaTest(t)
// Enable MFA token caching
cfg.ClientRequestMfaToken = ConfigBoolTrue
//Provide your own TOTP code/codes here, to test manually
//totpKeys := []string{"222222", "333333", "444444"}
totpKeys := getTOPTcodes(t)
verifyConnectionToSnowflakeUsingTotpCodes(t, cfg, totpKeys)
log.Printf("Testing MFA token caching with second connection...")
// Clear the passcode to force use of cached MFA token
cfg.Passcode = ""
// Attempt to connect using cached MFA token
cacheErr := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilF(t, cacheErr, "Failed to connect with cached MFA token")
}
func setupMfaTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping MFA tests")
cfg, err := getAuthTestsConfig(t, AuthTypeUsernamePasswordMFA)
assertNilF(t, err, "failed to get config")
cfg.User, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_MFA_USER", true)
assertNilF(t, err, "failed to get MFA user from environment")
cfg.Password, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_MFA_PASSWORD", true)
assertNilF(t, err, "failed to get MFA password from environment")
return cfg
}
func getTOPTcodes(t *testing.T) []string {
if isTestRunningInDockerContainer() {
const provideTotpPath = "/externalbrowser/totpGenerator.js"
output, err := exec.Command("node", provideTotpPath).CombinedOutput()
assertNilF(t, err, fmt.Sprintf("failed to execute command: %v", err))
totpCodes := strings.Fields(string(output))
return totpCodes
}
return []string{}
}
func verifyConnectionToSnowflakeUsingTotpCodes(t *testing.T, cfg *Config, totpKeys []string) {
if len(totpKeys) == 0 {
t.Fatalf("no TOTP codes provided")
}
var lastError error
for i, totpKey := range totpKeys {
cfg.Passcode = totpKey
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
if err == nil {
return
}
lastError = err
errorMsg := err.Error()
log.Printf("TOTP code %d failed: %v", i+1, errorMsg)
var snowflakeErr *SnowflakeError
if errors.As(err, &snowflakeErr) && (snowflakeErr.Number == 394633 || snowflakeErr.Number == 394507) {
log.Printf("MFA error detected (%d), trying next code...", snowflakeErr.Number)
continue
} else {
log.Printf("Non-MFA error detected: %v", errorMsg)
break
}
}
assertNilF(t, lastError, "failed to connect with any TOTP code")
}
================================================
FILE: auth_with_oauth_okta_authorization_code_test.go
================================================
package gosnowflake
import (
"fmt"
"sync"
"testing"
"time"
)
func TestOauthOktaAuthorizationCodeSuccessful(t *testing.T) {
cfg := setupOauthOktaAuthorizationCodeTest(t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthOktaSuccess, cfg.User, cfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}()
wg.Wait()
}
func TestOauthOktaAuthorizationCodeMismatchedUsername(t *testing.T) {
cfg := setupOauthOktaAuthorizationCodeTest(t)
user := cfg.User
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthOktaSuccess, user, cfg.Password)
}()
go func() {
defer wg.Done()
cfg.User = "fakeUser@snowflake.com"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number))
}()
wg.Wait()
}
func TestOauthOktaAuthorizationCodeOktaTimeout(t *testing.T) {
cfg := setupOauthOktaAuthorizationCodeTest(t)
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNotNilF(t, err, "should failed due to timeout")
assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err))
}
func TestOauthOktaAuthorizationCodeUsingTokenCache(t *testing.T) {
cfg := setupOauthOktaAuthorizationCodeTest(t)
cfg.ClientStoreTemporaryCredential = 1
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthOktaSuccess, cfg.User, cfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}()
wg.Wait()
cleanupBrowserProcesses(t)
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}
func setupOauthOktaAuthorizationCodeTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping Okta Authorization Code tests")
cfg, err := getAuthTestsConfig(t, AuthTypeOAuthAuthorizationCode)
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err))
cleanupBrowserProcesses(t)
cfg.OauthClientID, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthClientSecret, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_SECRET", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthRedirectURI, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_REDIRECT_URI", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthAuthorizationURL, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_AUTH_URL", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthTokenRequestURL, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_TOKEN", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.Role, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_ROLE", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
return cfg
}
================================================
FILE: auth_with_oauth_okta_client_credentials_test.go
================================================
package gosnowflake
import (
"fmt"
"strings"
"testing"
)
func TestOauthOktaClientCredentialsSuccessful(t *testing.T) {
cfg := setupOauthOktaClientCredentialsTest(t)
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err))
}
func TestOauthOktaClientCredentialsMismatchedUsername(t *testing.T) {
cfg := setupOauthOktaClientCredentialsTest(t)
cfg.User = "invalidUser"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number))
}
func TestOauthOktaClientCredentialsUnauthorized(t *testing.T) {
cfg := setupOauthOktaClientCredentialsTest(t)
cfg.OauthClientID = "invalidClientID"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNotNilF(t, err, "Expected an error but got nil")
assertTrueF(t, strings.Contains(err.Error(), "invalid_client"), fmt.Sprintf("Expected error to contain 'invalid_client', but got: %v", err.Error()))
}
func setupOauthOktaClientCredentialsTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping Okta Client Credentials tests")
cfg, err := getAuthTestsConfig(t, AuthTypeOAuthClientCredentials)
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err))
cfg.OauthClientID, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthClientSecret, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_SECRET", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthTokenRequestURL, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_TOKEN", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.User, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.Role, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_ROLE", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
return cfg
}
================================================
FILE: auth_with_oauth_snowflake_authorization_code_test.go
================================================
package gosnowflake
import (
"fmt"
"sync"
"testing"
"time"
)
func TestOauthSnowflakeAuthorizationCodeSuccessful(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeTest(t)
browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials()
assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}()
wg.Wait()
}
func TestOauthSnowflakeAuthorizationCodeMismatchedUsername(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeTest(t)
browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials()
assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password)
}()
go func() {
defer wg.Done()
cfg.User = "fakeUser@snowflake.com"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number))
}()
wg.Wait()
}
func TestOauthSnowflakeAuthorizationCodeTimeout(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeTest(t)
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNotNilF(t, err, "should failed due to timeout")
assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err))
}
func TestOauthSnowflakeAuthorizationCodeUsingTokenCache(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeTest(t)
browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials()
assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err))
cfg.ClientStoreTemporaryCredential = 1
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}()
wg.Wait()
cleanupBrowserProcesses(t)
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second
err = verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}
func TestOauthSnowflakeAuthorizationCodeWithoutTokenCache(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeTest(t)
browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials()
assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err))
cfg.ClientStoreTemporaryCredential = 2
var wg sync.WaitGroup
cfg.DisableQueryContextCache = true
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}()
wg.Wait()
cleanupBrowserProcesses(t)
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second
err = verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNotNilF(t, err, "Expected an error but got nil")
assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err))
}
func setupOauthSnowflakeAuthorizationCodeTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping Snowflake Authorization Code tests")
cfg, err := getAuthTestsConfig(t, AuthTypeOAuthAuthorizationCode)
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err))
cleanupBrowserProcesses(t)
cfg.OauthClientID, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_CLIENT_ID", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthClientSecret, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_CLIENT_SECRET", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthRedirectURI, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_REDIRECT_URI", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.User, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.Role, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_ROLE", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.ClientStoreTemporaryCredential = 2
return cfg
}
func getOauthSnowflakeAuthorizationCodeTestCredentials() (*Config, error) {
return GetConfigFromEnv([]*ConfigParam{
{Name: "User", EnvName: "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", FailOnMissing: true},
{Name: "Password", EnvName: "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_USER_PASSWORD", FailOnMissing: true},
})
}
================================================
FILE: auth_with_oauth_snowflake_authorization_code_wildcards_test.go
================================================
package gosnowflake
import (
"fmt"
"sync"
"testing"
"time"
)
func TestOauthSnowflakeAuthorizationCodeWildcardsSuccessful(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeWildcardsTest(t)
browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials()
assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}()
wg.Wait()
}
func TestOauthSnowflakeAuthorizationCodeWildcardsMismatchedUsername(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeWildcardsTest(t)
browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials()
assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password)
}()
go func() {
defer wg.Done()
cfg.User = "fakeUser@snowflake.com"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number))
}()
wg.Wait()
}
func TestOauthSnowflakeAuthorizationWildcardsCodeTimeout(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeWildcardsTest(t)
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNotNilF(t, err, "should failed due to timeout")
assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err))
}
func TestOauthSnowflakeAuthorizationCodeWildcardsWithoutTokenCache(t *testing.T) {
cfg := setupOauthSnowflakeAuthorizationCodeWildcardsTest(t)
browserCfg, err := getOauthSnowflakeAuthorizationCodeTestCredentials()
assertNilF(t, err, fmt.Sprintf("failed to get browser config: %v", err))
cfg.ClientStoreTemporaryCredential = 2
var wg sync.WaitGroup
cfg.DisableQueryContextCache = true
wg.Add(2)
go func() {
defer wg.Done()
provideExternalBrowserCredentials(t, externalBrowserType.OauthSnowflakeSuccess, browserCfg.User, browserCfg.Password)
}()
go func() {
defer wg.Done()
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("Connection failed due to %v", err))
}()
wg.Wait()
cleanupBrowserProcesses(t)
cfg.ExternalBrowserTimeout = time.Duration(1) * time.Second
err = verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNotNilF(t, err, "Expected an error but got nil")
assertEqualE(t, err.Error(), "authentication via browser timed out", fmt.Sprintf("Expecteed timeout, but got %v", err))
}
func setupOauthSnowflakeAuthorizationCodeWildcardsTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping Snowflake Authorization Code tests")
cfg, err := getAuthTestsConfig(t, AuthTypeOAuthAuthorizationCode)
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err))
cleanupBrowserProcesses(t)
cfg.OauthClientID, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_WILDCARDS_CLIENT_ID", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.OauthClientSecret, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_WILDCARDS_CLIENT_SECRET", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.User, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.Role, err = GetFromEnv("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE", true)
assertNilF(t, err, fmt.Sprintf("failed to setup config: %v", err))
cfg.ClientStoreTemporaryCredential = 2
return cfg
}
================================================
FILE: auth_with_oauth_test.go
================================================
package gosnowflake
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"testing"
)
func TestOauthSuccessful(t *testing.T) {
cfg := setupOauthTest(t)
token, err := getOauthTestToken(t, cfg)
assertNilE(t, err, fmt.Sprintf("failed to get token. err: %v", err))
cfg.Token = token
err = verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err))
}
func TestOauthInvalidToken(t *testing.T) {
cfg := setupOauthTest(t)
cfg.Token = "invalid_token"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390303, fmt.Sprintf("Expected 390303, but got %v", snowflakeErr.Number))
}
func TestOauthMismatchedUser(t *testing.T) {
cfg := setupOauthTest(t)
token, err := getOauthTestToken(t, cfg)
assertNilE(t, err, fmt.Sprintf("failed to get token. err: %v", err))
cfg.Token = token
cfg.User = "fakeaccount"
err = verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390309, fmt.Sprintf("Expected 390309, but got %v", snowflakeErr.Number))
}
func setupOauthTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping OAuth tests")
cfg, err := getAuthTestsConfig(t, AuthTypeOAuth)
assertNilF(t, err, fmt.Sprintf("failed to connect. err: %v", err))
return cfg
}
func getOauthTestToken(t *testing.T, cfg *Config) (string, error) {
client := &http.Client{}
authURL, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_URL", true)
assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_URL is not set")
oauthClientID, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_ID", true)
assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_ID is not set")
oauthClientSecret, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_SECRET", true)
assertNilF(t, err, "SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_SECRET is not set")
inputData := formData(cfg)
req, err := http.NewRequest("POST", authURL, strings.NewReader(inputData.Encode()))
assertNilF(t, err, fmt.Sprintf("Request failed %v", err))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8")
req.SetBasicAuth(oauthClientID, oauthClientSecret)
resp, err := client.Do(req)
assertNilF(t, err, fmt.Sprintf("Response failed %v", err))
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to get access token, status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
var response OAuthTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return "", fmt.Errorf("failed to decode response: %v", err)
}
return response.Token, err
}
func formData(cfg *Config) url.Values {
data := url.Values{}
data.Set("username", cfg.User)
data.Set("password", cfg.Password)
data.Set("grant_type", "password")
data.Set("scope", fmt.Sprintf("session:role:%s", strings.ToLower(cfg.Role)))
return data
}
type OAuthTokenResponse struct {
Type string `json:"token_type"`
Expiration int `json:"expires_in"`
Token string `json:"access_token"`
Scope string `json:"scope"`
}
================================================
FILE: auth_with_okta_test.go
================================================
package gosnowflake
import (
"fmt"
"net/url"
"testing"
)
func TestOktaSuccessful(t *testing.T) {
cfg := setupOktaTest(t)
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err))
}
func TestOktaWrongCredentials(t *testing.T) {
cfg := setupOktaTest(t)
cfg.Password = "fakePassword"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 261006, fmt.Sprintf("Expected 261006, but got %v", snowflakeErr.Number))
}
func TestOktaWrongAuthenticator(t *testing.T) {
cfg := setupOktaTest(t)
invalidAddress, err := url.Parse("https://fake-account-0000.okta.com")
assertNilF(t, err, fmt.Sprintf("failed to parse: %v", err))
cfg.OktaURL = invalidAddress
err = verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 390139, fmt.Sprintf("Expected 390139, but got %v", snowflakeErr.Number))
}
func setupOktaTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping Okta tests")
urlEnv, err := GetFromEnv("SNOWFLAKE_AUTH_TEST_OKTA_AUTH", true)
assertNilF(t, err, fmt.Sprintf("failed to get env: %v", err))
cfg, err := getAuthTestsConfig(t, AuthTypeOkta)
assertNilF(t, err, fmt.Sprintf("failed to get config: %v", err))
cfg.OktaURL, err = url.Parse(urlEnv)
assertNilF(t, err, fmt.Sprintf("failed to parse: %v", err))
return cfg
}
================================================
FILE: auth_with_pat_test.go
================================================
package gosnowflake
import (
"database/sql"
"fmt"
"log"
"strings"
"testing"
"time"
)
type PatToken struct {
Name string
Value string
}
func TestEndToEndPatSuccessful(t *testing.T) {
cfg := setupEndToEndPatTest(t)
patToken := createEndToEndPatToken(t)
defer removeEndToEndPatToken(t, patToken.Name)
cfg.Token = patToken.Value
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
assertNilE(t, err, fmt.Sprintf("failed to connect. err: %v", err))
}
func TestEndToEndPatMismatchedUser(t *testing.T) {
cfg := setupEndToEndPatTest(t)
patToken := createEndToEndPatToken(t)
defer removeEndToEndPatToken(t, patToken.Name)
cfg.Token = patToken.Value
cfg.User = "invalidUser"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 394400, fmt.Sprintf("Expected 394400, but got %v", snowflakeErr.Number))
}
func TestEndToEndPatInvalidToken(t *testing.T) {
cfg := setupEndToEndPatTest(t)
cfg.Token = "invalidToken"
err := verifyConnectionToSnowflakeAuthTests(t, cfg)
var snowflakeErr *SnowflakeError
assertErrorsAsF(t, err, &snowflakeErr)
assertEqualE(t, snowflakeErr.Number, 394400, fmt.Sprintf("Expected 394400, but got %v", snowflakeErr.Number))
}
func setupEndToEndPatTest(t *testing.T) *Config {
skipAuthTests(t, "Skipping PAT tests")
cfg, err := getAuthTestsConfig(t, AuthTypePat)
assertNilF(t, err, fmt.Sprintf("failed to parse: %v", err))
return cfg
}
func getEndToEndPatSetupCommandVariables() (*Config, error) {
return GetConfigFromEnv([]*ConfigParam{
{Name: "User", EnvName: "SNOWFLAKE_AUTH_TEST_SNOWFLAKE_USER", FailOnMissing: true},
{Name: "Role", EnvName: "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE", FailOnMissing: true},
})
}
func createEndToEndPatToken(t *testing.T) *PatToken {
cfg := setupOktaTest(t)
patTokenName := fmt.Sprintf("PAT_GOLANG_%s", strings.ReplaceAll(time.Now().Format("20060102150405.000"), ".", ""))
patCommandVariables, err := getEndToEndPatSetupCommandVariables()
assertNilE(t, err, "failed to get PAT command variables")
query := fmt.Sprintf(
"alter user %s add programmatic access token %s ROLE_RESTRICTION = '%s' DAYS_TO_EXPIRY=1;",
patCommandVariables.User,
patTokenName,
patCommandVariables.Role,
)
patToken, err := connectUsingOktaConnectionAndExecuteCustomCommand(t, cfg, query, true)
assertNilE(t, err, "failed to create PAT command variables")
return patToken
}
func removeEndToEndPatToken(t *testing.T, patTokenName string) {
cfg := setupOktaTest(t)
cfg.Role = "analyst"
patCommandVariables, err := getEndToEndPatSetupCommandVariables()
assertNilE(t, err, "failed to get PAT command variables")
query := fmt.Sprintf(
"alter user %s remove programmatic access token %s;",
patCommandVariables.User,
patTokenName,
)
_, err = connectUsingOktaConnectionAndExecuteCustomCommand(t, cfg, query, false)
assertNilE(t, err, "failed to remove PAT command variables")
}
func connectUsingOktaConnectionAndExecuteCustomCommand(t *testing.T, cfg *Config, query string, returnToken bool) (*PatToken, error) {
dsn, err := DSN(cfg)
assertNilE(t, err, "failed to create DSN from Config")
db, err := sql.Open("snowflake", dsn)
assertNilE(t, err, "failed to open Snowflake DB connection")
defer db.Close()
rows, err := db.Query(query)
if err != nil {
log.Printf("failed to run a query: %v, err: %v", query, err)
return nil, err
}
var patTokenName, patTokenValue string
if returnToken && rows.Next() {
if err := rows.Scan(&patTokenName, &patTokenValue); err != nil {
t.Fatalf("failed to scan token: %v", err)
}
return &PatToken{Name: patTokenName, Value: patTokenValue}, nil
}
if returnToken {
t.Fatalf("no results found for query: %s", query)
}
return nil, err
}
================================================
FILE: authexternalbrowser.go
================================================
package gosnowflake
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
errors2 "github.com/snowflakedb/gosnowflake/v2/internal/errors"
"io"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/pkg/browser"
)
const (
samlSuccessHTML = `
SAML Response for Snowflake
Your identity was confirmed and propagated to Snowflake %v.
You can close this window now and go back where you started from.
`
bufSize = 8192
)
// Builds a response to show to the user after successfully
// getting a response from Snowflake.
func buildResponse(body string) (bytes.Buffer, error) {
t := &http.Response{
Status: "200 OK",
StatusCode: 200,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(bytes.NewBufferString(body)),
ContentLength: int64(len(body)),
Request: nil,
Header: make(http.Header),
}
var b bytes.Buffer
err := t.Write(&b)
return b, err
}
// This opens a socket that listens on all available unicast
// and any anycast IP addresses locally. By specifying "0", we are
// able to bind to a free port.
func createLocalTCPListener(port int) (*net.TCPListener, error) {
logger.Debugf("creating local TCP listener on port %v", port)
allAddressesListener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%v", port))
if err != nil {
logger.Warnf("error while setting up 0.0.0.0 listener: %v", err)
return nil, err
}
logger.Debug("Closing 0.0.0.0 tcp listener")
if err := allAddressesListener.Close(); err != nil {
logger.Errorf("error while closing TCP listener. %v", err)
return nil, err
}
l, err := net.Listen("tcp", fmt.Sprintf("localhost:%v", port))
if err != nil {
logger.Warnf("error while setting up listener: %v", err)
return nil, err
}
tcpListener, ok := l.(*net.TCPListener)
if !ok {
return nil, fmt.Errorf("failed to assert type as *net.TCPListener")
}
return tcpListener, nil
}
// Opens a browser window (or new tab) with the configured login Url.
// This can / will fail if running inside a shell with no display, ie
// ssh'ing into a box attempting to authenticate via external browser.
func openBrowser(browserURL string) error {
parsedURL, err := url.ParseRequestURI(browserURL)
if err != nil {
logger.Errorf("error parsing url %v, err: %v", browserURL, err)
return err
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return fmt.Errorf("invalid browser URL: %v", browserURL)
}
err = browser.OpenURL(browserURL)
if err != nil {
logger.Errorf("failed to open a browser. err: %v", err)
return err
}
return nil
}
// Gets the IDP Url and Proof Key from Snowflake.
// Note: FuncPostAuthSaml will return a fully qualified error if
// there is something wrong getting data from Snowflake.
func getIdpURLProofKey(
ctx context.Context,
sr *snowflakeRestful,
authenticator string,
application string,
account string,
user string,
callbackPort int) (string, string, error) {
headers := make(map[string]string)
headers[httpHeaderContentType] = headerContentTypeApplicationJSON
headers[httpHeaderAccept] = headerContentTypeApplicationJSON
headers[httpHeaderUserAgent] = userAgent
clientEnvironment := newAuthRequestClientEnvironment()
clientEnvironment.Application = application
requestMain := authRequestData{
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
AccountName: account,
LoginName: user,
ClientEnvironment: clientEnvironment,
Authenticator: authenticator,
BrowserModeRedirectPort: strconv.Itoa(callbackPort),
}
authRequest := authRequest{
Data: requestMain,
}
jsonBody, err := json.Marshal(authRequest)
if err != nil {
logger.WithContext(ctx).Errorf("failed to serialize json. err: %v", err)
return "", "", err
}
respd, err := sr.FuncPostAuthSAML(ctx, sr, headers, jsonBody, sr.LoginTimeout)
if err != nil {
return "", "", err
}
if !respd.Success {
logger.WithContext(ctx).Error("Authentication FAILED")
sr.TokenAccessor.SetTokens("", "", -1)
code, err := strconv.Atoi(respd.Code)
if err != nil {
return "", "", err
}
return "", "", &SnowflakeError{
Number: code,
SQLState: SQLStateConnectionRejected,
Message: respd.Message,
}
}
return respd.Data.SSOURL, respd.Data.ProofKey, nil
}
// Gets the login URL for multiple SAML
func getLoginURL(sr *snowflakeRestful, user string, callbackPort int) (string, string, error) {
proofKey := generateProofKey()
params := &url.Values{}
params.Add("login_name", user)
params.Add("browser_mode_redirect_port", strconv.Itoa(callbackPort))
params.Add("proof_key", proofKey)
url := sr.getFullURL(consoleLoginRequestPath, params)
return url.String(), proofKey, nil
}
func generateProofKey() string {
randomness := getSecureRandom(32)
return base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(randomness)
}
// The response returned from Snowflake looks like so:
// GET /?token=encodedSamlToken
// Host: localhost:54001
// Connection: keep-alive
// Upgrade-Insecure-Requests: 1
// User-Agent: userAgentStr
// Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8
// Referer: https://myaccount.snowflakecomputing.com/fed/login
// Accept-Encoding: gzip, deflate, br
// Accept-Language: en-US,en;q=0.9
// This extracts the token portion of the response.
func getTokenFromResponse(response string) (string, error) {
start := "GET /?token="
arr := strings.Split(response, "\r\n")
if !strings.HasPrefix(arr[0], start) {
logger.Errorf("response is malformed. ")
return "", &SnowflakeError{
Number: ErrFailedToParseResponse,
SQLState: SQLStateConnectionRejected,
Message: errors2.ErrMsgFailedToParseResponse,
MessageArgs: []any{response},
}
}
token := strings.TrimPrefix(arr[0], start)
token = strings.Split(token, " ")[0]
return token, nil
}
type authenticateByExternalBrowserResult struct {
escapedSamlResponse []byte
proofKey []byte
err error
}
func authenticateByExternalBrowser(ctx context.Context, sr *snowflakeRestful, authenticator string, application string,
account string, user string, externalBrowserTimeout time.Duration, disableConsoleLogin ConfigBool) ([]byte, []byte, error) {
resultChan := make(chan authenticateByExternalBrowserResult, 1)
go GoroutineWrapper(
ctx,
func() {
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, disableConsoleLogin)
},
)
select {
case <-time.After(externalBrowserTimeout):
return nil, nil, errors.New("authentication timed out")
case result := <-resultChan:
return result.escapedSamlResponse, result.proofKey, result.err
}
}
// Authentication by an external browser takes place via the following:
// - the golang snowflake driver communicates to Snowflake that the user wishes to
// authenticate via external browser
// - snowflake sends back the IDP Url configured at the Snowflake side for the
// provided account, or use the multiple SAML way via console login
// - the default browser is opened to that URL
// - user authenticates at the IDP, and is redirected to Snowflake
// - Snowflake directs the user back to the driver
// - authenticate is complete!
func doAuthenticateByExternalBrowser(ctx context.Context, sr *snowflakeRestful, authenticator string, application string, account string, user string, disableConsoleLogin ConfigBool) authenticateByExternalBrowserResult {
l, err := createLocalTCPListener(0)
if err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}
defer func() {
if err = l.Close(); err != nil {
logger.Errorf("error while closing TCP listener for external browser (%v). %v", l.Addr().String(), err)
}
}()
callbackPort := l.Addr().(*net.TCPAddr).Port
var loginURL string
var proofKey string
if disableConsoleLogin == ConfigBoolTrue {
// Gets the IDP URL and Proof Key from Snowflake
loginURL, proofKey, err = getIdpURLProofKey(ctx, sr, authenticator, application, account, user, callbackPort)
} else {
// Multiple SAML way to do authentication via console login
loginURL, proofKey, err = getLoginURL(sr, user, callbackPort)
}
if err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}
if err = defaultSamlResponseProvider().run(loginURL); err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}
encodedSamlResponseChan := make(chan string)
errChan := make(chan error)
var encodedSamlResponse string
var errFromGoroutine error
conn, err := l.Accept()
if err != nil {
logger.WithContext(ctx).Errorf("unable to accept connection. err: %v", err)
log.Fatal(err)
}
go func(c net.Conn) {
var buf bytes.Buffer
total := 0
encodedSamlResponse := ""
var errAccept error
for {
b := make([]byte, bufSize)
n, err := c.Read(b)
if err != nil {
if err != io.EOF {
logger.WithContext(ctx).Infof("error reading from socket. err: %v", err)
errAccept = &SnowflakeError{
Number: ErrFailedToGetExternalBrowserResponse,
SQLState: SQLStateConnectionRejected,
Message: errors2.ErrMsgFailedToGetExternalBrowserResponse,
MessageArgs: []any{err},
}
}
break
}
total += n
buf.Write(b)
if n < bufSize {
// We successfully read all data
s := string(buf.Bytes()[:total])
encodedSamlResponse, errAccept = getTokenFromResponse(s)
break
}
buf.Grow(bufSize)
}
if encodedSamlResponse != "" {
body := fmt.Sprintf(samlSuccessHTML, application)
httpResponse, err := buildResponse(body)
if err != nil && errAccept == nil {
errAccept = err
}
if _, err = c.Write(httpResponse.Bytes()); err != nil && errAccept == nil {
errAccept = err
}
}
if err := c.Close(); err != nil {
logger.Warnf("error while closing browser connection. %v", err)
}
encodedSamlResponseChan <- encodedSamlResponse
errChan <- errAccept
}(conn)
encodedSamlResponse = <-encodedSamlResponseChan
errFromGoroutine = <-errChan
if errFromGoroutine != nil {
return authenticateByExternalBrowserResult{nil, nil, errFromGoroutine}
}
escapedSamlResponse, err := url.QueryUnescape(encodedSamlResponse)
if err != nil {
logger.WithContext(ctx).Errorf("unable to unescape saml response. err: %v", err)
return authenticateByExternalBrowserResult{nil, nil, err}
}
return authenticateByExternalBrowserResult{[]byte(escapedSamlResponse), []byte(proofKey), nil}
}
type samlResponseProvider interface {
run(url string) error
}
type externalBrowserSamlResponseProvider struct {
}
func (e externalBrowserSamlResponseProvider) run(url string) error {
return openBrowser(url)
}
var defaultSamlResponseProvider = func() samlResponseProvider {
return &externalBrowserSamlResponseProvider{}
}
================================================
FILE: authexternalbrowser_test.go
================================================
package gosnowflake
import (
"context"
"errors"
"fmt"
sfconfig "github.com/snowflakedb/gosnowflake/v2/internal/config"
"net/http"
"net/url"
"strings"
"testing"
"time"
)
func TestGetTokenFromResponseFail(t *testing.T) {
response := "GET /?fakeToken=fakeEncodedSamlToken HTTP/1.1\r\n" +
"Host: localhost:54001\r\n" +
"Connection: keep-alive\r\n" +
"Upgrade-Insecure-Requests: 1\r\n" +
"User-Agent: userAgentStr\r\n" +
"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\r\n" +
"Referer: https://myaccount.snowflakecomputing.com/fed/login\r\n" +
"Accept-Encoding: gzip, deflate, br\r\n" +
"Accept-Language: en-US,en;q=0.9\r\n\r\n"
_, err := getTokenFromResponse(response)
if err == nil {
t.Errorf("Should have failed parsing the malformed response.")
}
}
func TestGetTokenFromResponse(t *testing.T) {
response := "GET /?token=GETtokenFromResponse HTTP/1.1\r\n" +
"Host: localhost:54001\r\n" +
"Connection: keep-alive\r\n" +
"Upgrade-Insecure-Requests: 1\r\n" +
"User-Agent: userAgentStr\r\n" +
"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\r\n" +
"Referer: https://myaccount.snowflakecomputing.com/fed/login\r\n" +
"Accept-Encoding: gzip, deflate, br\r\n" +
"Accept-Language: en-US,en;q=0.9\r\n\r\n"
expected := "GETtokenFromResponse"
token, err := getTokenFromResponse(response)
if err != nil {
t.Errorf("Failed to get the token. Err: %#v", err)
}
if token != expected {
t.Errorf("Expected: %s, found: %s", expected, token)
}
}
func TestBuildResponse(t *testing.T) {
resp, err := buildResponse(fmt.Sprintf(samlSuccessHTML, "Go"))
assertNilF(t, err)
bytes := resp.Bytes()
respStr := string(bytes[:])
if !strings.Contains(respStr, "Your identity was confirmed and propagated to Snowflake Go.\nYou can close this window now and go back where you started from.") {
t.Fatalf("failed to build response")
}
}
func postAuthExternalBrowserError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{}, errors.New("failed to get SAML response")
}
func postAuthExternalBrowserErrorDelayed(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
time.Sleep(2 * time.Second)
return &authResponse{}, errors.New("failed to get SAML response")
}
func postAuthExternalBrowserFail(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Message: "external browser auth failed",
}, nil
}
func postAuthExternalBrowserFailWithCode(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Message: "failed to connect to db",
Code: "260008",
}, nil
}
func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
authenticator := "externalbrowser"
application := "testapp"
account := "testaccount"
user := "u"
timeout := sfconfig.DefaultExternalBrowserTimeout
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFail
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
driverErr, ok := err.(*SnowflakeError)
if !ok {
t.Fatalf("should be snowflake error. err: %v", err)
}
if driverErr.Number != ErrCodeFailedToConnect {
t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeFailedToConnect, driverErr.Number)
}
}
func TestAuthenticationTimeout(t *testing.T) {
authenticator := "externalbrowser"
application := "testapp"
account := "testaccount"
user := "u"
timeout := 1 * time.Second
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthExternalBrowserErrorDelayed,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, timeout, ConfigBoolTrue)
assertEqualE(t, err.Error(), "authentication timed out", err.Error())
}
func Test_createLocalTCPListener(t *testing.T) {
listener, err := createLocalTCPListener(0)
if err != nil {
t.Fatalf("createLocalTCPListener() failed: %v", err)
}
if listener == nil {
t.Fatal("createLocalTCPListener() returned nil listener")
}
// Close the listener after the test.
defer listener.Close()
}
func TestUnitGetLoginURL(t *testing.T) {
expectedScheme := "https"
expectedHost := "abc.com:443"
user := "u"
callbackPort := 123
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
TokenAccessor: getSimpleTokenAccessor(),
}
loginURL, proofKey, err := getLoginURL(sr, user, callbackPort)
assertNilF(t, err, "failed to get login URL")
assertNotNilF(t, len(proofKey), "proofKey should be non-empty string")
urlPtr, err := url.Parse(loginURL)
assertNilF(t, err, "failed to parse the login URL")
assertEqualF(t, urlPtr.Scheme, expectedScheme)
assertEqualF(t, urlPtr.Host, expectedHost)
assertEqualF(t, urlPtr.Path, consoleLoginRequestPath)
assertStringContainsF(t, urlPtr.RawQuery, "login_name")
assertStringContainsF(t, urlPtr.RawQuery, "browser_mode_redirect_port")
assertStringContainsF(t, urlPtr.RawQuery, "proof_key")
}
type nonInteractiveSamlResponseProvider struct {
t *testing.T
}
func (provider *nonInteractiveSamlResponseProvider) run(url string) error {
go func() {
resp, err := http.Get(url)
assertNilF(provider.t, err)
assertEqualE(provider.t, resp.StatusCode, http.StatusOK)
}()
return nil
}
================================================
FILE: authokta.go
================================================
package gosnowflake
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/snowflakedb/gosnowflake/v2/internal/errors"
"html"
"io"
"net/http"
"net/url"
"strconv"
"time"
)
type authOKTARequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type authOKTAResponse struct {
CookieToken string `json:"cookieToken"`
SessionToken string `json:"sessionToken"`
}
/*
authenticateBySAML authenticates a user by SAML
SAML Authentication
1. query GS to obtain IDP token and SSO url
2. IMPORTANT Client side validation:
validate both token url and sso url contains same prefix
(protocol + host + port) as the given authenticator url.
Explanation:
This provides a way for the user to 'authenticate' the IDP it is
sending his/her credentials to. Without such a check, the user could
be coerced to provide credentials to an IDP impersonator.
3. query IDP token url to authenticate and retrieve access token
4. given access token, query IDP URL snowflake app to get SAML response
5. IMPORTANT Client side validation:
validate the post back url come back with the SAML response
contains the same prefix as the Snowflake's server url, which is the
intended destination url to Snowflake.
Explanation:
This emulates the behavior of IDP initiated login flow in the user
browser where the IDP instructs the browser to POST the SAML
assertion to the specific SP endpoint. This is critical in
preventing a SAML assertion issued to one SP from being sent to
another SP.
*/
func authenticateBySAML(
ctx context.Context,
sr *snowflakeRestful,
oktaURL *url.URL,
application string,
account string,
user string,
password string,
disableSamlURLCheck ConfigBool,
) (samlResponse []byte, err error) {
logger.WithContext(ctx).Info("step 1: query GS to obtain IDP token and SSO url")
headers := make(map[string]string)
headers[httpHeaderContentType] = headerContentTypeApplicationJSON
headers[httpHeaderAccept] = headerContentTypeApplicationJSON
headers[httpHeaderUserAgent] = userAgent
clientEnvironment := newAuthRequestClientEnvironment()
clientEnvironment.Application = application
requestMain := authRequestData{
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
AccountName: account,
ClientEnvironment: clientEnvironment,
Authenticator: oktaURL.String(),
}
authRequest := authRequest{
Data: requestMain,
}
params := &url.Values{}
jsonBody, err := json.Marshal(authRequest)
if err != nil {
return nil, err
}
logger.WithContext(ctx).Infof("PARAMS for Auth: %v, %v", params, sr)
respd, err := sr.FuncPostAuthSAML(ctx, sr, headers, jsonBody, sr.LoginTimeout)
if err != nil {
return nil, err
}
if !respd.Success {
logger.WithContext(ctx).Error("Authentication FAILED")
sr.TokenAccessor.SetTokens("", "", -1)
code, err := strconv.Atoi(respd.Code)
if err != nil {
return nil, err
}
return nil, &SnowflakeError{
Number: code,
SQLState: SQLStateConnectionRejected,
Message: respd.Message,
}
}
logger.WithContext(ctx).Info("step 2: validate Token and SSO URL has the same prefix as oktaURL")
var tokenURL *url.URL
var ssoURL *url.URL
if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil {
return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL)
}
if ssoURL, err = url.Parse(respd.Data.SSOURL); err != nil {
return nil, fmt.Errorf("failed to parse SSO URL. %v", respd.Data.SSOURL)
}
if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) {
return nil, &SnowflakeError{
Number: ErrCodeIdpConnectionError,
SQLState: SQLStateConnectionRejected,
Message: errors.ErrMsgIdpConnectionError,
MessageArgs: []any{oktaURL, respd.Data.TokenURL, respd.Data.SSOURL},
}
}
logger.WithContext(ctx).Info("step 3: query IDP token url to authenticate and retrieve access token")
jsonBody, err = json.Marshal(authOKTARequest{
Username: user,
Password: password,
})
if err != nil {
return nil, err
}
respa, err := sr.FuncPostAuthOKTA(ctx, sr, headers, jsonBody, respd.Data.TokenURL, sr.LoginTimeout)
if err != nil {
return nil, err
}
logger.WithContext(ctx).Info("step 4: query IDP URL snowflake app to get SAML response")
params = &url.Values{}
params.Add("RelayState", "/some/deep/link")
var oneTimeToken string
if respa.SessionToken != "" {
oneTimeToken = respa.SessionToken
} else {
oneTimeToken = respa.CookieToken
}
params.Add("onetimetoken", oneTimeToken)
headers = make(map[string]string)
headers[httpHeaderAccept] = "*/*"
bd, err := sr.FuncGetSSO(ctx, sr, params, headers, respd.Data.SSOURL, sr.LoginTimeout)
if err != nil {
return nil, err
}
if disableSamlURLCheck == ConfigBoolFalse {
logger.WithContext(ctx).Info("step 5: validate post_back_url matches Snowflake URL")
tgtURL, err := postBackURL(bd)
if err != nil {
return nil, err
}
fullURL := sr.getURL()
logger.WithContext(ctx).Infof("tgtURL: %v, origURL: %v", tgtURL, fullURL)
if !isPrefixEqual(tgtURL, fullURL) {
return nil, &SnowflakeError{
Number: ErrCodeSSOURLNotMatch,
SQLState: SQLStateConnectionRejected,
Message: errors.ErrMsgSSOURLNotMatch,
MessageArgs: []any{tgtURL, fullURL},
}
}
}
return bd, nil
}
func postBackURL(htmlData []byte) (url *url.URL, err error) {
idx0 := bytes.Index(htmlData, []byte("