Repository: qustavo/sqlhooks Branch: master Commit: 7875602513fa Files: 28 Total size: 51.5 KB Directory structure: gitextract_1t1_l1wt/ ├── .github/ │ └── workflows/ │ ├── lint.yml │ └── test.yml ├── .gitignore ├── .golangci.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── benchmark_test.go ├── compose.go ├── compose_test.go ├── doc.go ├── go.mod ├── go.sum ├── hooks/ │ ├── loghooks/ │ │ ├── example_test.go │ │ ├── examples/ │ │ │ └── main.go │ │ └── loghooks.go │ └── othooks/ │ ├── examples/ │ │ └── main.go │ ├── othooks.go │ └── othooks_test.go ├── sqlhooks.go ├── sqlhooks_1_10.go ├── sqlhooks_1_10_interface_test.go ├── sqlhooks_interface_test.go ├── sqlhooks_mysql_test.go ├── sqlhooks_postgres_test.go ├── sqlhooks_pre_1_10.go ├── sqlhooks_sqlite3_test.go └── sqlhooks_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/lint.yml ================================================ name: lint on: pull_request: jobs: golangci: name: lint runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: golangci-lint uses: golangci/golangci-lint-action@v2 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version version: latest ================================================ FILE: .github/workflows/test.yml ================================================ name: "test" on: ["push","pull_request"] jobs: test: name: "Run unit tests" strategy: matrix: os: [ubuntu-latest] go-version: ["1.15.x", "1.16.x", "1.17.x"] runs-on: ${{ matrix.os }} services: mysql: image: mysql env: MYSQL_USER: test MYSQL_PASSWORD: test MYSQL_DATABASE: sqlhooks MYSQL_ALLOW_EMPTY_PASSWORD: true ports: - 3306:3306 options: >- --health-cmd="mysqladmin -v ping" --health-interval=10s --health-timeout=5s --health-retries=5 postgres: image: postgres env: POSTGRES_PASSWORD: test POSTGRES_DB: sqlhooks ports: - 5432:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - name: Install Go uses: actions/setup-go@v2 with: go-version: ${{ matrix.go-version }} - name: Checkout code uses: actions/checkout@v2 with: fetch-depth: 1 - uses: actions/cache@v2 with: path: | ~/go/pkg/mod ~/.cache/go-build key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Test env: SQLHOOKS_MYSQL_DSN: "test:test@/sqlhooks?interpolateParams=true" SQLHOOKS_POSTGRES_DSN: "postgres://postgres:test@localhost/sqlhooks?sslmode=disable" run: go test -race -covermode atomic -coverprofile=covprofile ./... - name: Install goveralls run: go get github.com/mattn/goveralls@v0.0.11 - name: Send coverage env: COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: goveralls -coverprofile=covprofile -service=github ================================================ FILE: .gitignore ================================================ # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe *.test *.prof ================================================ FILE: .golangci.yml ================================================ linters-settings: staticcheck: checks: ["all", "-SA1019"] issues: exclude-rules: - path: example_test.go linters: - errcheck ================================================ FILE: CHANGELOG.md ================================================ # Change Log ## [Unreleased](https://github.com/qustavo/sqlhooks/tree/HEAD) [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v1.0.0...HEAD) **Closed issues:** - Add Benchmarks [\#9](https://github.com/qustavo/sqlhooks/issues/9) ## [v1.0.0](https://github.com/qustavo/sqlhooks/tree/v1.0.0) (2017-05-08) [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v0.4...v1.0.0) **Merged pull requests:** - Godoc [\#7](https://github.com/qustavo/sqlhooks/pull/7) ([qustavo](https://github.com/qustavo)) - Make covermode=count [\#6](https://github.com/qustavo/sqlhooks/pull/6) ([qustavo](https://github.com/qustavo)) - V1 [\#5](https://github.com/qustavo/sqlhooks/pull/5) ([qustavo](https://github.com/qustavo)) - Expose a WrapDriver function [\#4](https://github.com/qustavo/sqlhooks/issues/4) - Implement new 1.8 interfaces [\#3](https://github.com/qustavo/sqlhooks/issues/3) ## [v0.4](https://github.com/qustavo/sqlhooks/tree/v0.4) (2017-03-23) [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v0.3...v0.4) ## [v0.3](https://github.com/qustavo/sqlhooks/tree/v0.3) (2016-06-02) [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v0.2...v0.3) **Closed issues:** - Change Notifications [\#2](https://github.com/qustavo/sqlhooks/issues/2) ## [v0.2](https://github.com/qustavo/sqlhooks/tree/v0.2) (2016-05-01) [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v0.1...v0.2) ## [v0.1](https://github.com/qustavo/sqlhooks/tree/v0.1) (2016-04-25) **Merged pull requests:** - Sqlite3 [\#1](https://github.com/qustavo/sqlhooks/pull/1) ([qustavo](https://github.com/qustavo)) \* *This Change Log was automatically generated by [github_changelog_generator](https://github.com/skywinder/Github-Changelog-Generator)* ================================================ FILE: LICENSE ================================================ The MIT License (MIT) Copyright (c) 2016 Gustavo Chaín Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # sqlhooks ![Build Status](https://github.com/qustavo/sqlhooks/actions/workflows/test.yml/badge.svg) [![Go Report Card](https://goreportcard.com/badge/github.com/qustavo/sqlhooks)](https://goreportcard.com/report/github.com/qustavo/sqlhooks) [![Coverage Status](https://coveralls.io/repos/github/qustavo/sqlhooks/badge.svg?branch=master)](https://coveralls.io/github/qustavo/sqlhooks?branch=master) Attach hooks to any database/sql driver. The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code. # Install ```bash go get github.com/qustavo/sqlhooks/v2 ``` Requires Go >= 1.14.x ## Breaking changes `V2` isn't backward compatible with previous versions, if you want to fetch old versions, you can use go modules or get them from [gopkg.in](http://gopkg.in/) ```bash go get github.com/qustavo/sqlhooks go get gopkg.in/qustavo/sqlhooks.v1 ``` # Usage [![GoDoc](https://godoc.org/github.com/qustavo/dotsql?status.svg)](https://godoc.org/github.com/qustavo/sqlhooks) ```go // This example shows how to instrument sql queries in order to display the time that they consume package main import ( "context" "database/sql" "fmt" "time" "github.com/qustavo/sqlhooks/v2" "github.com/mattn/go-sqlite3" ) // Hooks satisfies the sqlhook.Hooks interface type Hooks struct {} // Before hook will print the query with it's args and return the context with the timestamp func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { fmt.Printf("> %s %q", query, args) return context.WithValue(ctx, "begin", time.Now()), nil } // After hook will get the timestamp registered on the Before hook and print the elapsed time func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { begin := ctx.Value("begin").(time.Time) fmt.Printf(". took: %s\n", time.Since(begin)) return ctx, nil } func main() { // First, register the wrapper sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{})) // Connect to the registered wrapped driver db, _ := sql.Open("sqlite3WithHooks", ":memory:") // Do you're stuff db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))") db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar") db.Query("SELECT id, text FROM t") } /* Output should look like: > CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs > INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs > SELECT id, text FROM t []. took: 4.653µs */ ``` # Benchmarks ``` go test -bench=. -benchmem goos: linux goarch: amd64 pkg: github.com/qustavo/sqlhooks/v2 cpu: Intel(R) Xeon(R) W-10885M CPU @ 2.40GHz BenchmarkSQLite3/Without_Hooks-16 191196 6163 ns/op 456 B/op 14 allocs/op BenchmarkSQLite3/With_Hooks-16 189997 6329 ns/op 456 B/op 14 allocs/op BenchmarkMySQL/Without_Hooks-16 13278 83462 ns/op 309 B/op 7 allocs/op BenchmarkMySQL/With_Hooks-16 13460 87331 ns/op 309 B/op 7 allocs/op BenchmarkPostgres/Without_Hooks-16 13016 91421 ns/op 401 B/op 10 allocs/op BenchmarkPostgres/With_Hooks-16 12339 94033 ns/op 401 B/op 10 allocs/op PASS ok github.com/qustavo/sqlhooks/v2 10.294s ``` ================================================ FILE: benchmark_test.go ================================================ package sqlhooks import ( "database/sql" "os" "testing" "github.com/go-sql-driver/mysql" "github.com/lib/pq" "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/require" ) func init() { hooks := &testHooks{} hooks.reset() sql.Register("sqlite3-benchmark", Wrap(&sqlite3.SQLiteDriver{}, hooks)) sql.Register("mysql-benchmark", Wrap(&mysql.MySQLDriver{}, hooks)) sql.Register("postgres-benchmark", Wrap(&pq.Driver{}, hooks)) } func benchmark(b *testing.B, driver, dsn string) { db, err := sql.Open(driver, dsn) require.NoError(b, err) defer db.Close() var query = "SELECT 'hello'" b.ResetTimer() for i := 0; i < b.N; i++ { rows, err := db.Query(query) require.NoError(b, err) require.NoError(b, rows.Close()) } } func BenchmarkSQLite3(b *testing.B) { b.Run("Without Hooks", func(b *testing.B) { benchmark(b, "sqlite3", ":memory:") }) b.Run("With Hooks", func(b *testing.B) { benchmark(b, "sqlite3-benchmark", ":memory:") }) } func BenchmarkMySQL(b *testing.B) { dsn := os.Getenv("SQLHOOKS_MYSQL_DSN") if dsn == "" { b.Skipf("SQLHOOKS_MYSQL_DSN not set") } b.Run("Without Hooks", func(b *testing.B) { benchmark(b, "mysql", dsn) }) b.Run("With Hooks", func(b *testing.B) { benchmark(b, "mysql-benchmark", dsn) }) } func BenchmarkPostgres(b *testing.B) { dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN") if dsn == "" { b.Skipf("SQLHOOKS_POSTGRES_DSN not set") } b.Run("Without Hooks", func(b *testing.B) { benchmark(b, "postgres", dsn) }) b.Run("With Hooks", func(b *testing.B) { benchmark(b, "postgres-benchmark", dsn) }) } ================================================ FILE: compose.go ================================================ package sqlhooks import ( "context" "fmt" ) // Compose allows for composing multiple Hooks into one. // It runs every callback on every hook in argument order, // even if previous hooks return an error. // If multiple hooks return errors, the error return value will be // MultipleErrors, which allows for introspecting the errors if necessary. func Compose(hooks ...Hooks) Hooks { return composed(hooks) } type composed []Hooks func (c composed) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { var errors []error for _, hook := range c { c, err := hook.Before(ctx, query, args...) if err != nil { errors = append(errors, err) } if c != nil { ctx = c } } return ctx, wrapErrors(nil, errors) } func (c composed) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { var errors []error for _, hook := range c { var err error c, err := hook.After(ctx, query, args...) if err != nil { errors = append(errors, err) } if c != nil { ctx = c } } return ctx, wrapErrors(nil, errors) } func (c composed) OnError(ctx context.Context, cause error, query string, args ...interface{}) error { var errors []error for _, hook := range c { if onErrorer, ok := hook.(OnErrorer); ok { if err := onErrorer.OnError(ctx, cause, query, args...); err != nil && err != cause { errors = append(errors, err) } } } return wrapErrors(cause, errors) } func wrapErrors(def error, errors []error) error { switch len(errors) { case 0: return def case 1: return errors[0] default: return MultipleErrors(errors) } } // MultipleErrors is an error that contains multiple errors. type MultipleErrors []error func (m MultipleErrors) Error() string { return fmt.Sprint("multiple errors:", []error(m)) } ================================================ FILE: compose_test.go ================================================ package sqlhooks import ( "context" "errors" "reflect" "testing" ) var ( oops = errors.New("oops") oopsHook = &testHooks{ before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return ctx, oops }, after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return ctx, oops }, onError: func(ctx context.Context, err error, query string, args ...interface{}) error { return oops }, } okHook = &testHooks{ before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return ctx, nil }, after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return ctx, nil }, onError: func(ctx context.Context, err error, query string, args ...interface{}) error { return nil }, } ) func TestCompose(t *testing.T) { for _, it := range []struct { name string hooks Hooks want error }{ {"happy case", Compose(okHook, okHook), nil}, {"no hooks", Compose(), nil}, {"multiple errors", Compose(oopsHook, okHook, oopsHook), MultipleErrors([]error{oops, oops})}, {"single error", Compose(okHook, oopsHook, okHook), oops}, } { t.Run(it.name, func(t *testing.T) { t.Run("Before", func(t *testing.T) { _, got := it.hooks.Before(context.Background(), "query") if !reflect.DeepEqual(it.want, got) { t.Errorf("unexpected error. want: %q, got: %q", it.want, got) } }) t.Run("After", func(t *testing.T) { _, got := it.hooks.After(context.Background(), "query") if !reflect.DeepEqual(it.want, got) { t.Errorf("unexpected error. want: %q, got: %q", it.want, got) } }) t.Run("OnError", func(t *testing.T) { cause := errors.New("crikey") want := it.want if want == nil { want = cause } got := it.hooks.(OnErrorer).OnError(context.Background(), cause, "query") if !reflect.DeepEqual(want, got) { t.Errorf("unexpected error. want: %q, got: %q", want, got) } }) }) } } func TestWrapErrors(t *testing.T) { var ( err1 = errors.New("oops") err2 = errors.New("oops2") ) for _, it := range []struct { name string def error errors []error want error }{ {"no errors", err1, nil, err1}, {"single error", nil, []error{err1}, err1}, {"multiple errors", nil, []error{err1, err2}, MultipleErrors([]error{err1, err2})}, } { t.Run(it.name, func(t *testing.T) { if want, got := it.want, wrapErrors(it.def, it.errors); !reflect.DeepEqual(want, got) { t.Errorf("unexpected wrapping. want: %q, got %q", want, got) } }) } } ================================================ FILE: doc.go ================================================ // package sqlhooks allows you to attach hooks to any database/sql driver. // The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code. // This example shows how to instrument sql queries in order to display the time that they consume // package main // // import ( // "context" // "database/sql" // "fmt" // "time" // // "github.com/qustavo/sqlhooks/v2" // "github.com/mattn/go-sqlite3" // ) // // // Hooks satisfies the sqlhook.Hooks interface // type Hooks struct {} // // // Before hook will print the query with it's args and return the context with the timestamp // func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { // fmt.Printf("> %s %q", query, args) // return context.WithValue(ctx, "begin", time.Now()), nil // } // // // After hook will get the timestamp registered on the Before hook and print the elapsed time // func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { // begin := ctx.Value("begin").(time.Time) // fmt.Printf(". took: %s\n", time.Since(begin)) // return ctx, nil // } // // func main() { // // First, register the wrapper // sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{})) // // // Connect to the registered wrapped driver // db, _ := sql.Open("sqlite3WithHooks", ":memory:") // // // Do you're stuff // db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))") // db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar") // db.Query("SELECT id, text FROM t") // } // // /* // Output should look like: // > CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs // > INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs // > SELECT id, text FROM t []. took: 4.653µs // */ package sqlhooks ================================================ FILE: go.mod ================================================ module github.com/qustavo/sqlhooks/v2 go 1.13 require ( github.com/go-sql-driver/mysql v1.4.1 github.com/lib/pq v1.2.0 github.com/mattn/go-sqlite3 v1.10.0 github.com/opentracing/opentracing-go v1.1.0 github.com/stretchr/testify v1.4.0 golang.org/x/tools v0.1.7 // indirect ) ================================================ FILE: go.sum ================================================ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e h1:WUoyKPm6nCo1BnNUvPGnFG3T5DUVem42yDJZZ4CNxMA= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ= golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= ================================================ FILE: hooks/loghooks/example_test.go ================================================ package loghooks import ( "database/sql" "github.com/qustavo/sqlhooks/v2" sqlite3 "github.com/mattn/go-sqlite3" ) func Example() { driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New()) sql.Register("sqlite3-logger", driver) db, _ := sql.Open("sqlite3-logger", ":memory:") // This query will output logs db.Query("SELECT 1+1") } ================================================ FILE: hooks/loghooks/examples/main.go ================================================ package main import ( "database/sql" "log" "github.com/qustavo/sqlhooks/v2" "github.com/qustavo/sqlhooks/v2/hooks/loghooks" "github.com/mattn/go-sqlite3" ) func main() { sql.Register("sqlite3log", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, loghooks.New())) db, err := sql.Open("sqlite3log", ":memory:") if err != nil { log.Fatal(err) } if _, err := db.Exec("CREATE TABLE users(ID int, name text)"); err != nil { log.Fatal(err) } if _, err := db.Exec(`INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil { log.Fatal(err) } if _, err := db.Query(`SELECT id, name FROM users`); err != nil { log.Fatal(err) } } ================================================ FILE: hooks/loghooks/loghooks.go ================================================ package loghooks import ( "context" "log" "os" "time" ) var started int type logger interface { Printf(string, ...interface{}) } type Hook struct { log logger } func New() *Hook { return &Hook{ log: log.New(os.Stderr, "", log.LstdFlags), } } func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return context.WithValue(ctx, &started, time.Now()), nil } func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { h.log.Printf("Query: `%s`, Args: `%q`. took: %s", query, args, time.Since(ctx.Value(&started).(time.Time))) return ctx, nil } func (h *Hook) OnError(ctx context.Context, err error, query string, args ...interface{}) error { h.log.Printf("Error: %v, Query: `%s`, Args: `%q`, Took: %s", err, query, args, time.Since(ctx.Value(&started).(time.Time))) return err } ================================================ FILE: hooks/othooks/examples/main.go ================================================ package main import ( "context" "database/sql" "log" "github.com/qustavo/sqlhooks/v2" "github.com/qustavo/sqlhooks/v2/hooks/othooks" "github.com/mattn/go-sqlite3" "github.com/opentracing/opentracing-go" ) func main() { tracer := opentracing.GlobalTracer() hooks := othooks.New(tracer) sql.Register("sqlite3ot", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, hooks)) db, err := sql.Open("sqlite3ot", ":memory:") if err != nil { log.Fatal(err) } span := tracer.StartSpan("sql") defer span.Finish() ctx := opentracing.ContextWithSpan(context.Background(), span) if _, err := db.ExecContext(ctx, "CREATE TABLE users(ID int, name text)"); err != nil { log.Fatal(err) } if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil { log.Fatal(err) } if _, err := db.QueryContext(ctx, `SELECT id, name FROM users`); err != nil { log.Fatal(err) } } ================================================ FILE: hooks/othooks/othooks.go ================================================ package othooks import ( "context" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/log" ) type Hook struct { tracer opentracing.Tracer } func New(tracer opentracing.Tracer) *Hook { return &Hook{tracer: tracer} } func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { parent := opentracing.SpanFromContext(ctx) if parent == nil { return ctx, nil } span := h.tracer.StartSpan("sql", opentracing.ChildOf(parent.Context())) span.LogFields( log.String("query", query), log.Object("args", args), ) return opentracing.ContextWithSpan(ctx, span), nil } func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { span := opentracing.SpanFromContext(ctx) if span != nil { defer span.Finish() } return ctx, nil } func (h *Hook) OnError(ctx context.Context, err error, query string, args ...interface{}) error { span := opentracing.SpanFromContext(ctx) if span != nil { defer span.Finish() span.SetTag("error", true) span.LogFields( log.Error(err), ) } return err } ================================================ FILE: hooks/othooks/othooks_test.go ================================================ package othooks import ( "context" "database/sql" "testing" "github.com/qustavo/sqlhooks/v2" sqlite3 "github.com/mattn/go-sqlite3" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( tracer *mocktracer.MockTracer ) func init() { tracer = mocktracer.New() driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New(tracer)) sql.Register("ot", driver) } func TestSpansAreRecorded(t *testing.T) { db, err := sql.Open("ot", ":memory:") require.NoError(t, err) defer db.Close() tracer.Reset() parent := tracer.StartSpan("parent") ctx := opentracing.ContextWithSpan(context.Background(), parent) { rows, err := db.QueryContext(ctx, "SELECT 1+?", "1") require.NoError(t, err) rows.Close() } { rows, err := db.QueryContext(ctx, "SELECT 1+?", "1") require.NoError(t, err) rows.Close() } parent.Finish() spans := tracer.FinishedSpans() require.Len(t, spans, 3) span := spans[1] assert.Equal(t, "sql", span.OperationName) logFields := span.Logs()[0].Fields assert.Equal(t, "query", logFields[0].Key) assert.Equal(t, "SELECT 1+?", logFields[0].ValueString) assert.Equal(t, "args", logFields[1].Key) assert.Equal(t, "[1]", logFields[1].ValueString) assert.NotEmpty(t, span.FinishTime) } func TestNoSpansAreRecorded(t *testing.T) { db, err := sql.Open("ot", ":memory:") require.NoError(t, err) defer db.Close() tracer.Reset() rows, err := db.QueryContext(context.Background(), "SELECT 1") require.NoError(t, err) rows.Close() assert.Empty(t, tracer.FinishedSpans()) } ================================================ FILE: sqlhooks.go ================================================ package sqlhooks import ( "context" "database/sql/driver" "errors" ) // Hook is the hook callback signature type Hook func(ctx context.Context, query string, args ...interface{}) (context.Context, error) // ErrorHook is the error handling callback signature type ErrorHook func(ctx context.Context, err error, query string, args ...interface{}) error // Hooks instances may be passed to Wrap() to define an instrumented driver type Hooks interface { Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) } // OnErrorer instances will be called if any error happens type OnErrorer interface { OnError(ctx context.Context, err error, query string, args ...interface{}) error } func handlerErr(ctx context.Context, hooks Hooks, err error, query string, args ...interface{}) error { h, ok := hooks.(OnErrorer) if !ok { return err } if err := h.OnError(ctx, err, query, args...); err != nil { return err } return err } // Driver implements a database/sql/driver.Driver type Driver struct { driver.Driver hooks Hooks } // Open opens a connection func (drv *Driver) Open(name string) (driver.Conn, error) { conn, err := drv.Driver.Open(name) if err != nil { return conn, err } // Drivers that don't implement driver.ConnBeginTx are not supported. if _, ok := conn.(driver.ConnBeginTx); !ok { return nil, errors.New("driver must implement driver.ConnBeginTx") } wrapped := &Conn{conn, drv.hooks} if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) { return &ExecerQueryerContextWithSessionResetter{wrapped, &ExecerContext{wrapped}, &QueryerContext{wrapped}, &SessionResetter{wrapped}}, nil } else if isExecer(conn) && isQueryer(conn) { return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, &QueryerContext{wrapped}}, nil } else if isExecer(conn) { // If conn implements an Execer interface, return a driver.Conn which // also implements Execer return &ExecerContext{wrapped}, nil } else if isQueryer(conn) { // If conn implements an Queryer interface, return a driver.Conn which // also implements Queryer return &QueryerContext{wrapped}, nil } return wrapped, nil } // Conn implements a database/sql.driver.Conn type Conn struct { Conn driver.Conn hooks Hooks } func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { var ( stmt driver.Stmt err error ) if c, ok := conn.Conn.(driver.ConnPrepareContext); ok { stmt, err = c.PrepareContext(ctx, query) } else { stmt, err = conn.Prepare(query) } if err != nil { return stmt, err } return &Stmt{stmt, conn.hooks, query}, nil } func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) } func (conn *Conn) Close() error { return conn.Conn.Close() } func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() } func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts) } // ExecerContext implements a database/sql.driver.ExecerContext type ExecerContext struct { *Conn } func isExecer(conn driver.Conn) bool { switch conn.(type) { case driver.ExecerContext: return true case driver.Execer: return true default: return false } } func (conn *ExecerContext) execContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { switch c := conn.Conn.Conn.(type) { case driver.ExecerContext: return c.ExecContext(ctx, query, args) case driver.Execer: dargs, err := namedValueToValue(args) if err != nil { return nil, err } return c.Exec(query, dargs) default: // This should not happen return nil, errors.New("ExecerContext created for a non Execer driver.Conn") } } func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { var err error list := namedToInterface(args) // Exec `Before` Hooks if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { return nil, err } results, err := conn.execContext(ctx, query, args) if err != nil { return results, handlerErr(ctx, conn.hooks, err, query, list...) } if _, err := conn.hooks.After(ctx, query, list...); err != nil { return nil, err } return results, err } func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Result, error) { // We have to implement Exec since it is required in the current version of // Go for it to run ExecContext. From Go 10 it will be optional. However, // this code should never run since database/sql always prefers to run // ExecContext. return nil, errors.New("Exec was called when ExecContext was implemented") } // QueryerContext implements a database/sql.driver.QueryerContext type QueryerContext struct { *Conn } func isQueryer(conn driver.Conn) bool { switch conn.(type) { case driver.QueryerContext: return true case driver.Queryer: return true default: return false } } func (conn *QueryerContext) queryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { switch c := conn.Conn.Conn.(type) { case driver.QueryerContext: return c.QueryContext(ctx, query, args) case driver.Queryer: dargs, err := namedValueToValue(args) if err != nil { return nil, err } return c.Query(query, dargs) default: // This should not happen return nil, errors.New("QueryerContext created for a non Queryer driver.Conn") } } func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { var err error list := namedToInterface(args) // Query `Before` Hooks if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { return nil, err } results, err := conn.queryContext(ctx, query, args) if err != nil { return results, handlerErr(ctx, conn.hooks, err, query, list...) } if _, err := conn.hooks.After(ctx, query, list...); err != nil { return nil, err } return results, err } // ExecerQueryerContext implements database/sql.driver.ExecerContext and // database/sql.driver.QueryerContext type ExecerQueryerContext struct { *Conn *ExecerContext *QueryerContext } // ExecerQueryerContext implements database/sql.driver.ExecerContext and // database/sql.driver.QueryerContext type ExecerQueryerContextWithSessionResetter struct { *Conn *ExecerContext *QueryerContext *SessionResetter } type SessionResetter struct { *Conn } // Stmt implements a database/sql/driver.Stmt type Stmt struct { Stmt driver.Stmt hooks Hooks query string } func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { if s, ok := stmt.Stmt.(driver.StmtExecContext); ok { return s.ExecContext(ctx, args) } values := make([]driver.Value, len(args)) for _, arg := range args { values[arg.Ordinal-1] = arg.Value } return stmt.Exec(values) } func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { var err error list := namedToInterface(args) // Exec `Before` Hooks if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { return nil, err } results, err := stmt.execContext(ctx, args) if err != nil { return results, handlerErr(ctx, stmt.hooks, err, stmt.query, list...) } if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil { return nil, err } return results, err } func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { if s, ok := stmt.Stmt.(driver.StmtQueryContext); ok { return s.QueryContext(ctx, args) } values := make([]driver.Value, len(args)) for _, arg := range args { values[arg.Ordinal-1] = arg.Value } return stmt.Query(values) } func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { var err error list := namedToInterface(args) // Exec Before Hooks if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { return nil, err } rows, err := stmt.queryContext(ctx, args) if err != nil { return rows, handlerErr(ctx, stmt.hooks, err, stmt.query, list...) } if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil { return nil, err } return rows, err } func (stmt *Stmt) Close() error { return stmt.Stmt.Close() } func (stmt *Stmt) NumInput() int { return stmt.Stmt.NumInput() } func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { return stmt.Stmt.Exec(args) } func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.Stmt.Query(args) } // Wrap is used to create a new instrumented driver, it takes a vendor specific driver, and a Hooks instance to produce a new driver instance. // It's usually used inside a sql.Register() statement func Wrap(driver driver.Driver, hooks Hooks) driver.Driver { return &Driver{driver, hooks} } func namedToInterface(args []driver.NamedValue) []interface{} { list := make([]interface{}, len(args)) for i, a := range args { list[i] = a.Value } return list } // namedValueToValue copied from database/sql func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { dargs := make([]driver.Value, len(named)) for n, param := range named { if len(param.Name) > 0 { return nil, errors.New("sql: driver does not support the use of Named Parameters") } dargs[n] = param.Value } return dargs, nil } /* type hooks struct { } func (h *hooks) Before(ctx context.Context, query string, args ...interface{}) error { log.Printf("before> ctx = %+v, q=%s, args = %+v\n", ctx, query, args) return nil } func (h *hooks) After(ctx context.Context, query string, args ...interface{}) error { log.Printf("after> ctx = %+v, q=%s, args = %+v\n", ctx, query, args) return nil } func main() { sql.Register("sqlite3-proxy", Wrap(&sqlite3.SQLiteDriver{}, &hooks{})) db, err := sql.Open("sqlite3-proxy", ":memory:") if err != nil { log.Fatalln(err) } if _, ok := driver.Stmt(&Stmt{}).(driver.StmtExecContext); !ok { panic("NOPE") } if _, err := db.Exec("CREATE table users(id int)"); err != nil { log.Printf("|err| = %+v\n", err) } if _, err := db.QueryContext(context.Background(), "SELECT * FROM users WHERE id = ?", 1); err != nil { log.Printf("err = %+v\n", err) } } */ ================================================ FILE: sqlhooks_1_10.go ================================================ // +build go1.10 package sqlhooks import ( "context" "database/sql/driver" ) func isSessionResetter(conn driver.Conn) bool { _, ok := conn.(driver.SessionResetter) return ok } func (s *SessionResetter) ResetSession(ctx context.Context) error { c := s.Conn.Conn.(driver.SessionResetter) return c.ResetSession(ctx) } ================================================ FILE: sqlhooks_1_10_interface_test.go ================================================ // +build go1.10 package sqlhooks import "database/sql/driver" func init() { interfaceTestCases = append(interfaceTestCases, struct { name string expectedInterfaces []interface{} }{ "ExecerQueryerContextSessionResetter", []interface{}{ (*driver.ExecerContext)(nil), (*driver.QueryerContext)(nil), (*driver.SessionResetter)(nil)}}) } ================================================ FILE: sqlhooks_interface_test.go ================================================ package sqlhooks import ( "context" "database/sql/driver" "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var interfaceTestCases = []struct { name string expectedInterfaces []interface{} }{ {"Basic", []interface{}{(*driver.Conn)(nil)}}, {"Execer", []interface{}{(*driver.Execer)(nil)}}, {"ExecerContext", []interface{}{(*driver.ExecerContext)(nil)}}, {"Queryer", []interface{}{(*driver.QueryerContext)(nil)}}, {"QueryerContext", []interface{}{(*driver.QueryerContext)(nil)}}, {"ExecerQueryerContext", []interface{}{ (*driver.ExecerContext)(nil), (*driver.QueryerContext)(nil)}}, } type fakeDriver struct{} func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { switch dsn { case "Basic": return &struct{ *FakeConnBasic }{}, nil case "Execer": return &struct { *FakeConnBasic *FakeConnExecer }{}, nil case "ExecerContext": return &struct { *FakeConnBasic *FakeConnExecerContext }{}, nil case "Queryer": return &struct { *FakeConnBasic *FakeConnQueryer }{}, nil case "QueryerContext": return &struct { *FakeConnBasic *FakeConnQueryerContext }{}, nil case "ExecerQueryerContext": return &struct { *FakeConnBasic *FakeConnExecerContext *FakeConnQueryerContext }{}, nil case "ExecerQueryerContextSessionResetter": return &struct { *FakeConnBasic *FakeConnExecer *FakeConnQueryer *FakeConnSessionResetter }{}, nil case "NonConnBeginTx": return &FakeConnUnsupported{}, nil } return nil, errors.New("Fake driver not implemented") } // Conn implements a database/sql.driver.Conn type FakeConnBasic struct{} func (*FakeConnBasic) Prepare(query string) (driver.Stmt, error) { return nil, errors.New("Not implemented") } func (*FakeConnBasic) Close() error { return errors.New("Not implemented") } func (*FakeConnBasic) Begin() (driver.Tx, error) { return nil, errors.New("Not implemented") } func (*FakeConnBasic) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) { return nil, errors.New("Not implemented") } type FakeConnExecer struct{} func (*FakeConnExecer) Exec(query string, args []driver.Value) (driver.Result, error) { return nil, errors.New("Not implemented") } type FakeConnExecerContext struct{} func (*FakeConnExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { return nil, errors.New("Not implemented") } type FakeConnQueryer struct{} func (*FakeConnQueryer) Query(query string, args []driver.Value) (driver.Rows, error) { return nil, errors.New("Not implemented") } type FakeConnQueryerContext struct{} func (*FakeConnQueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { return nil, errors.New("Not implemented") } type FakeConnSessionResetter struct{} func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error { return errors.New("Not implemented") } // FakeConnUnsupported implements a database/sql.driver.Conn but doesn't implement // driver.ConnBeginTx. type FakeConnUnsupported struct{} func (*FakeConnUnsupported) Prepare(query string) (driver.Stmt, error) { return nil, errors.New("Not implemented") } func (*FakeConnUnsupported) Close() error { return errors.New("Not implemented") } func (*FakeConnUnsupported) Begin() (driver.Tx, error) { return nil, errors.New("Not implemented") } func TestInterfaces(t *testing.T) { drv := Wrap(&fakeDriver{}, &testHooks{}) for _, c := range interfaceTestCases { conn, err := drv.Open(c.name) require.NoErrorf(t, err, "Driver name %s", c.name) for _, i := range c.expectedInterfaces { assert.Implements(t, i, conn) } } } func TestUnsupportedDrivers(t *testing.T) { drv := Wrap(&fakeDriver{}, &testHooks{}) _, err := drv.Open("NonConnBeginTx") require.EqualError(t, err, "driver must implement driver.ConnBeginTx") } ================================================ FILE: sqlhooks_mysql_test.go ================================================ package sqlhooks import ( "database/sql" "os" "testing" "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func setUpMySQL(t *testing.T, dsn string) { db, err := sql.Open("mysql", dsn) require.NoError(t, err) require.NoError(t, db.Ping()) defer db.Close() _, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)") require.NoError(t, err) } func TestMySQL(t *testing.T) { dsn := os.Getenv("SQLHOOKS_MYSQL_DSN") if dsn == "" { t.Skipf("SQLHOOKS_MYSQL_DSN not set") } setUpMySQL(t, dsn) s := newSuite(t, &mysql.MySQLDriver{}, dsn) s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1) s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus") s.TestHooksErrors(t, "SELECT 1+1") s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") t.Run("DBWorks", func(t *testing.T) { s.hooks.reset() if _, err := s.db.Exec("DELETE FROM users"); err != nil { t.Fatal(err) } stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)") require.NoError(t, err) for i := range [5]struct{}{} { _, err := stmt.Exec(i, "gus") require.NoError(t, err) } var count int require.NoError(t, s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), ) assert.Equal(t, 5, count) }) } ================================================ FILE: sqlhooks_postgres_test.go ================================================ package sqlhooks import ( "database/sql" "os" "testing" "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func setUpPostgres(t *testing.T, dsn string) { db, err := sql.Open("postgres", dsn) require.NoError(t, err) require.NoError(t, db.Ping()) defer db.Close() _, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)") require.NoError(t, err) } func TestPostgres(t *testing.T) { dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN") if dsn == "" { t.Skipf("SQLHOOKS_POSTGRES_DSN not set") } setUpPostgres(t, dsn) s := newSuite(t, &pq.Driver{}, dsn) s.TestHooksExecution(t, "SELECT * FROM users WHERE id = $1", 1) s.TestHooksArguments(t, "SELECT * FROM users WHERE id = $1 AND name = $2", int64(1), "Gus") s.TestHooksErrors(t, "SELECT 1+1") s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") t.Run("DBWorks", func(t *testing.T) { s.hooks.reset() if _, err := s.db.Exec("DELETE FROM users"); err != nil { t.Fatal(err) } stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES($1, $2)") require.NoError(t, err) for i := range [5]struct{}{} { _, err := stmt.Exec(i, "gus") require.NoError(t, err) } var count int require.NoError(t, s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), ) assert.Equal(t, 5, count) }) } ================================================ FILE: sqlhooks_pre_1_10.go ================================================ // +build !go1.10 package sqlhooks import ( "context" "database/sql/driver" "errors" ) func isSessionResetter(conn driver.Conn) bool { return false } func (s *SessionResetter) ResetSession(ctx context.Context) error { return errors.New("SessionResetter not implemented") } ================================================ FILE: sqlhooks_sqlite3_test.go ================================================ package sqlhooks import ( "database/sql" "os" "testing" "time" "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func setUp(t *testing.T) func() { dbName := "sqlite3test.db" db, err := sql.Open("sqlite3", dbName) require.NoError(t, err) defer db.Close() _, err = db.Exec("CREATE table users(id int, name text)") require.NoError(t, err) return func() { os.Remove(dbName) } } func TestSQLite3(t *testing.T) { defer setUp(t)() s := newSuite(t, &sqlite3.SQLiteDriver{}, "sqlite3test.db") s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1) s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus") s.TestHooksErrors(t, "SELECT 1+1") s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") t.Run("DBWorks", func(t *testing.T) { s.hooks.reset() if _, err := s.db.Exec("DELETE FROM users"); err != nil { t.Fatal(err) } stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)") require.NoError(t, err) for range [5]struct{}{} { _, err := stmt.Exec(time.Now().UnixNano(), "gus") require.NoError(t, err) } var count int require.NoError(t, s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), ) assert.Equal(t, 5, count) }) } ================================================ FILE: sqlhooks_test.go ================================================ package sqlhooks import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type testHooks struct { before Hook after Hook onError ErrorHook } func newTestHooks() *testHooks { th := &testHooks{} th.reset() return th } func (h *testHooks) reset() { noop := func(ctx context.Context, _ string, _ ...interface{}) (context.Context, error) { return ctx, nil } noopErr := func(_ context.Context, err error, _ string, _ ...interface{}) error { return err } h.before, h.after, h.onError = noop, noop, noopErr } func (h *testHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return h.before(ctx, query, args...) } func (h *testHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return h.after(ctx, query, args...) } func (h *testHooks) OnError(ctx context.Context, err error, query string, args ...interface{}) error { return h.onError(ctx, err, query, args...) } type suite struct { db *sql.DB hooks *testHooks } func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite { hooks := newTestHooks() driverName := fmt.Sprintf("sqlhooks-%s", time.Now().String()) sql.Register(driverName, Wrap(driver, hooks)) db, err := sql.Open(driverName, dsn) require.NoError(t, err) require.NoError(t, db.Ping()) return &suite{db, hooks} } func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) { var before, after bool s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { before = true return ctx, nil } s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { after = true return ctx, nil } t.Run("Query", func(t *testing.T) { before, after = false, false _, err := s.db.Query(query, args...) require.NoError(t, err) assert.True(t, before, "Before Hook did not run for query: "+query) assert.True(t, after, "After Hook did not run for query: "+query) }) t.Run("QueryContext", func(t *testing.T) { before, after = false, false _, err := s.db.QueryContext(context.Background(), query, args...) require.NoError(t, err) assert.True(t, before, "Before Hook did not run for query: "+query) assert.True(t, after, "After Hook did not run for query: "+query) }) t.Run("Exec", func(t *testing.T) { before, after = false, false _, err := s.db.Exec(query, args...) require.NoError(t, err) assert.True(t, before, "Before Hook did not run for query: "+query) assert.True(t, after, "After Hook did not run for query: "+query) }) t.Run("ExecContext", func(t *testing.T) { before, after = false, false _, err := s.db.ExecContext(context.Background(), query, args...) require.NoError(t, err) assert.True(t, before, "Before Hook did not run for query: "+query) assert.True(t, after, "After Hook did not run for query: "+query) }) t.Run("Statements", func(t *testing.T) { before, after = false, false stmt, err := s.db.Prepare(query) require.NoError(t, err) // Hooks just run when the stmt is executed (Query or Exec) assert.False(t, before, "Before Hook run before execution: "+query) assert.False(t, after, "After Hook run before execution: "+query) _, err = stmt.Query(args...) require.NoError(t, err) assert.True(t, before, "Before Hook did not run for query: "+query) assert.True(t, after, "After Hook did not run for query: "+query) }) } func (s *suite) testHooksArguments(t *testing.T, query string, args ...interface{}) { hook := func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { assert.Equal(t, query, q) assert.Equal(t, args, a) assert.Equal(t, "val", ctx.Value("key").(string)) return ctx, nil } s.hooks.before = hook s.hooks.after = hook ctx := context.WithValue(context.Background(), "key", "val") //nolint:staticcheck { _, err := s.db.QueryContext(ctx, query, args...) require.NoError(t, err) } { _, err := s.db.ExecContext(ctx, query, args...) require.NoError(t, err) } } func (s *suite) TestHooksArguments(t *testing.T, query string, args ...interface{}) { t.Run("TestHooksArguments", func(t *testing.T) { s.testHooksArguments(t, query, args...) }) } func (s *suite) testHooksErrors(t *testing.T, query string) { boom := errors.New("boom") s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return ctx, boom } s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { assert.False(t, true, "this should not run") return ctx, nil } _, err := s.db.Query(query) assert.Equal(t, boom, err) } func (s *suite) TestHooksErrors(t *testing.T, query string) { t.Run("TestHooksErrors", func(t *testing.T) { s.testHooksErrors(t, query) }) } func (s *suite) testErrHookHook(t *testing.T, query string, args ...interface{}) { s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { return ctx, nil } s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { assert.False(t, true, "after hook should not run") return ctx, nil } s.hooks.onError = func(ctx context.Context, err error, query string, args ...interface{}) error { assert.True(t, true, "onError hook should run") return err } _, err := s.db.Query(query) require.Error(t, err) } func (s *suite) TestErrHookHook(t *testing.T, query string, args ...interface{}) { t.Run("TestErrHookHook", func(t *testing.T) { s.testErrHookHook(t, query, args...) }) } func TestNamedValueToValue(t *testing.T) { named := []driver.NamedValue{ {Ordinal: 1, Value: "foo"}, {Ordinal: 2, Value: 42}, } want := []driver.Value{"foo", 42} dargs, err := namedValueToValue(named) require.NoError(t, err) assert.Equal(t, want, dargs) }