Repository: travelaudience/go-sx Branch: master Commit: 27002439c7dc Files: 19 Total size: 87.1 KB Directory structure: gitextract_m_sz7qj6/ ├── .circleci/ │ └── config.yml ├── .github/ │ ├── CODE_OF_CONDUCT.md │ ├── ISSUE_TEMPLATE.md │ └── PULL_REQUEST_TEMPLATE.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── doc.go ├── example_test.go ├── go.mod ├── go.sum ├── helpers.go ├── helpers_test.go ├── matching.go ├── matching_test.go ├── placeholder.go ├── placeholder_test.go ├── tx.go └── tx_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .circleci/config.yml ================================================ version: 2.1 orbs: ta-go: travelaudience/go@0.9 executors: golang-executor: docker: - image: cimg/go:1.23 workflows: build_and_test: jobs: - ta-go/checks: name: check exec: golang-executor run-static-analysis: true - ta-go/test_and_coverage: name: test exec: golang-executor ================================================ FILE: .github/CODE_OF_CONDUCT.md ================================================ # travel audience Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ### Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource@travelaudience.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [https://contributor-covenant.org/version/1/4][version] [homepage]: https://contributor-covenant.org [version]: https://contributor-covenant.org/version/1/4/ ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ ### Expected Behaviour ### Actual Behaviour ### Steps to Reproduce ### Optional additional info: #### Platform and Version #### Sample Code that illustrates the problem #### Logs taken while reproducing problem ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ **What this PR does / why we need it**: **Special notes for your reviewer**: **If applicable**: - [ ] this PR contains documentation - [ ] this PR contains unit tests - [ ] this PR has been tested for backwards compatibility ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing Thanks for choosing to contribute! The following are a set of guidelines to follow when contributing to this project. ## Have A Question? Start by filing an issue. The existing committers on this project work to reach consensus around project direction and issue solutions within issue threads (when appropriate). ## How to Contribute Code 1. Fork the repo, develop and test your code changes. 1. Submit a pull request. Lastly, please follow the [pull request template](.github/PULL_REQUEST_TEMPLATE.md) when submitting a pull request! #### Documentation PRs Documentation PRs will follow the same lifecycle as other PRs. They should also be labeled with the `docs` label. For documentation, special attention will be paid to spelling, grammar, and clarity (whereas those things don't matter *as* much for comments in code). ## Code Of Conduct This project adheres to the travel audience [code of conduct](.github/CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to [opensource@travelaudience.com](mailto:opensource@travelaudience.com). ================================================ FILE: LICENSE.txt ================================================ MIT License © Copyright 2019 travel audience. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Some simple SQL extensions for Go [![Go Reference](https://pkg.go.dev/badge/github.com/travelaudience/go-sx.svg)](https://pkg.go.dev/github.com/travelaudience/go-sx) [![CircleCI](https://circleci.com/gh/travelaudience/go-sx.svg?style=svg)](https://circleci.com/gh/travelaudience/go-sx) **go-sx** provides some extensions to the standard library `database/sql` package. It is designed for those who wish to use the full power of SQL without a heavy abstraction layer. This library is actively maintained. Contributions are welcome. ## Quickstart This is what application code typically looks like. The example here is for fetching multiple database rows into a slice of structs. Note that the order of exported struct fields needs to match the columns in the SQL result set. ```go type exam struct { ID int64 Name string Score float64 } db, err := sql.Open(...) if err != nil { ... } var scores []exam err = sx.Do(db, func(tx *sx.Tx) { tx.MustQuery("SELECT id, name, score FROM exam_scores ORDER BY name").Each(func(r *sx.Rows) { var e exam r.MustScans(&e) scores = append(scores, e) }) }) if err != nil { ... } ``` For more in-depth details, please keep reading! ## Goals The primary goal of **go-sx** is to eliminate boilerplate code. Specifically, **go-sx** attempts to address the following pain points: 1. Transactions are clumsy. It would be nice to have a simple function to run a callback in a transaction. 2. Error handling is clumsy. It would be nice to have errors within a transaction automatically exit the transaction and trigger a rollback. (This is nearly always what we want to do.) 3. Scanning multiple columns is clumsy. It would be nice to have a simple way to scan into multiple struct fields at once. 4. Constructing queries is clumsy, especially when there are a lot of columns. 5. Iterating over result sets is clumsy. ## Non-goals These are considered to be out of scope: 1. Be an ORM. 2. Write your queries for you. 3. Suggest that we need 1:1 relationship between struct types and tables. 4. Maintain database schemas. 5. Abstract away differences between SQL dialects. 6. Automatic type-manipulation. 7. Magic. ## Pain point #1: Transactions are clumsy. **go-sx** provides a function `Do` to run a transaction in a callback, automatically committing on success or rolling back on failure. Here is some simple code to run two queries in a transaction. The second query returns two values, which are read into variables `x` and `y`. ```go tx, err := db.Begin() if err != nil { return err } if _, err := tx.Exec(query0); err != nil { tx.Rollback() return err } if err := tx.QueryRow(query1).Scan(&x, &y); err != nil { tx.Rollback() return err } if err := tx.Commit(); err != nil { return err } ``` Using the `Do` function, we put the business logic into a callback function and have **go-sx** take care of the transaction logic. The `sx.Tx` object provided to the callback is the `sql.Tx` transaction object, extended with a few methods. If we call `tx.Fail()`, then the transaction is immediately aborted and rolled back. ```go err := sx.Do(db, func (tx *sx.Tx) { if _, err := tx.Exec(query0); err != nil { tx.Fail(err) } if err := tx.QueryRow(query1).Scan(&x, &y); err != nil { tx.Fail(err) } }) ``` Under the hood, `tx.Fail()` generates a panic which is recovered by `Do`. ## Pain point #2: Error handling is clumsy. **go-sx** provides a collection of `Must***` methods which may be used inside of the callback to `Do`. Any error encountered while in a `Must***` method causes the transaction to be aborted and rolled back. Here is the code above, rewritten to use `Do`'s error handling. It's simple and readable. ```go err := sx.Do(db, func (tx *sx.Tx) { tx.MustExec(query0) tx.MustQueryRow(query1).MustScan(&x, &y) }) ``` ## Pain point #3: Scanning multiple columns is clumsy. **go-sx** provides an `Addrs` function, which takes a struct and returns a slice of pointers to the elements. So instead of: ```go row.Scan(&a.Width, &a.Height, &a.Depth) ``` We can write: ```go row.Scan(sx.Addrs(&a)...) ``` Or better yet, let **go-sx** handle the errors: ```go row.MustScan(sx.Addrs(&a)...) ``` This is such a common pattern that we provide a shortcut to do this all in one step: ```go row.MustScans(&a) ``` ## Pain point #4: Constructing queries is clumsy. We would like **go-sx** to be able to construct some common queries for us. To this end, we define a simple way to match struct fields with database columns, and then provide some helper functions that use this matching to construct queries. By default, all exported struct fields match database columns whose name is the the field name snake_cased. The default can be overridden by explicitly tagging fields, much like what is done with the standard json encoder. Note that we don't care about the name of the table at this point. Here is a struct that can be used to scan columns `violin`, `viola`, `cello` and `contrabass`. ```go type orchestra struct { Violin string Viola string Cello string Bass string `sx:"contrabass"` } ``` We can use the helper function `SelectQuery` to construct a simple query. Then we can add the WHERE clause that we need and scan the result set into our struct. ```go var spo orchestra wantID := 123 query := sx.SelectQuery("symphony", &spo) + " WHERE id=?" // SELECT violin,viola,cello,contrabass FROM symphony WHERE id=? tx.MustQueryRow(query, wantID).MustScans(&spo) ``` Note that a struct need not follow the database schema exactly. It's entirely possible to have various structs mapped to different columns of the same table, or even one struct that maps to a query on joined tables. On the other hand, it's essential that the columns in the query match the fields of the struct, and **go-sx** guarantees this, as we'll see below. In some cases it's useful to have a struct that is used for both selects and inserts, with some of the fields being used just for selects. This can be accomplished with the "readonly" tag. ```go type orchestra1 struct { Violin string `sx:",readonly"` Viola string Cello string Bass string `sx:"contrabass"` } ``` It's also useful in some cases to have a struct field that is ignored by **go-sx**. This can be accomplished with the "-" tag. ```go type orchestra2 struct { Violin string `sx:",readonly"` Viola string `sx:"-"` Cello string Bass string `sx:"contrabass"` } ``` We can construct insert queries in a similar manner. Violin is read-only and Viola is ignored, so we only need to provide values for Cello and Bass. (If you need Postgres-style `$n` placeholders, see `sx.SetNumberedPlaceholders()`.) ```go spo := orchestra2{Cello: "Strad", Bass: "Cecilio"} query := sx.InsertQuery("symphony", &spo) // INSERT INTO symphony (cello,contrabass) VALUES (?,?) tx.MustExec(query, sx.Values(&spo)...) ``` We can contruct update queries this way too, and there is also an option to skip fields whose values are the zero values. (The update structs support pointer fields, making this skip option rather useful.) ```go spoChanges := orchestra2{Bass: "Strad"} wantID := 123 query, values := sx.UpdateQuery("symphony", &spoChanges) + " WHERE id=?" // UPDATE symphony SET contrabass=? WHERE id=? tx.MustExec(query, append(values, wantID)...) ``` It is entirely possible to construct all of these queries by hand, and you're all welcome to do so. Using the query generators, however, ensures that the fields match correctly, something that is particularly useful with a large number of columns. ## Pain point #5: Iterating over result sets is clumsy. **go-sx** provides an iterator called `Each` which runs a callback function on each row of a result set. Using the iterator simplifies this code: ```go var orchestras []orchestra query := "SELECT violin,viola,cello,contrabass FROM symphony ORDER BY viola" // Or we could use sx.SelectQuery() rows := tx.MustQuery(query) defer rows.Close() for rows.Next() { var o orchestra rows.MustScans(&o) orchestras = append(orchestras, o) } if err := rows.Err(); err != nil { tx.Fail(err) } ``` To this: ```go var orchestras []orchestra query := "SELECT violin,viola,cello,contrabass FROM symphony ORDER BY viola" tx.MustQuery(query).Each(func (r *sx.Rows) { var o orchestra r.MustScans(&o) orchestras = append(orchestras, o) }) ``` ## Contributing Contributions are welcome! Read the [Contributing Guide](CONTRIBUTING.md) for more information. ## Licensing This project is licensed under the MIT License - see the [LICENSE](LICENSE.txt) file for details ================================================ FILE: doc.go ================================================ // Package sx provides some simple extensions to the database/sql package to reduce the amount of boilerplate code. // // Transactions and error handling // // Package sx provides a function called Do, which runs a callback function inside a transaction. The callback // function is provided with a Tx object, which is an sql.Tx object that has been extended with some Must*** methods. // When a Must*** method encounters an error, it panics, and the panic is caught by Do and returned to the caller // as an error value. // // Do automatically commits or rolls back the transaction based on whether or not the callback function completed // successfuly. // // Query helpers and struct matching // // Package sx provides functions to generate frequently-used queries, based on a simple matching between struct // fields and database columns. // // By default, every field in a struct corresponds to the database column whose name is the snake-cased version of // the field name, i.e. the field HelloWorld corresponds to the "hello_world" column. Acronyms are treated as words, // so HelloRPCWorld becomes "hello_rpc_world". // // The column name can also be specified explicitly by tagging the field with the desired name, and fields can be // excluded altogether by tagging with "-". // // Fields that should be used for scanning but exluded for inserts and updates are additionally tagged "readonly". // // Examples: // // // Field is called "field" in the database. // Field int // // // Field is called "hage" in the database. // Field int `sx:"hage"` // // // Field is called "hage" in the database and should be skipped for inserts and updates. // Field int `sx:"hage,readonly"` // // // Field is called "field" in the database and should be skipped for inserts and updates. // Field int `sx:",readonly"` // // // Field should be ignored by sx. // Field int `sx:"-"` package sx ================================================ FILE: example_test.go ================================================ package sx_test import ( "database/sql" "fmt" _ "github.com/mattn/go-sqlite3" sx "github.com/travelaudience/go-sx" ) func Example() { db, err := sql.Open("sqlite3", ":memory:") if err != nil { fmt.Println(err) return } _, err = db.Exec("CREATE TABLE numbers (foo integer, bar string)") if err != nil { fmt.Println(err) return } // This is the default, but other examples set numbered placeholders to true, so // we need to make sure here that it's false. In practice, this would only be set // during initialization, and then only when $n-style placeholders are needed. sx.SetNumberedPlaceholders(false) type abc struct { Foo int32 Bar string } var data = []abc{ {Foo: 1, Bar: "one"}, {Foo: 2, Bar: "two"}, {Foo: 3, Bar: "three"}, } // Use Do to run a transaction. if err = sx.Do(db, func(tx *sx.Tx) { // Use MustPrepare with Do to insert rows into the table. query := sx.InsertQuery("numbers", &abc{}) tx.MustPrepare(query).Do(func(s *sx.Stmt) { for _, x := range data { s.MustExec(sx.Values(&x)...) } }) }); err != nil { // Any database-level error will be caught and printed here. fmt.Println(err) return } var dataRead []abc if err = sx.Do(db, func(tx *sx.Tx) { // Use MustQuery with Each to read the rows back in alphabetical order. query := sx.SelectQuery("numbers", &abc{}) + " ORDER BY bar" tx.MustQuery(query).Each(func(r *sx.Rows) { var x abc r.MustScans(&x) dataRead = append(dataRead, x) }) }); err != nil { fmt.Println(err) return } fmt.Println(dataRead) // Output: // [{1 one} {3 three} {2 two}] } func ExampleSelectQuery() { type abc struct { Field1 int64 FieldTwo string Field3 bool `sx:"gigo"` } query := sx.SelectQuery("sometable", &abc{}) fmt.Println(query) // Output: // SELECT field1,field_two,gigo FROM sometable } func ExampleSelectAliasQuery() { type abc struct { Foo, Bar string } query := sx.SelectAliasQuery("sometable", "s", &abc{}) fmt.Println(query) // Output: // SELECT s.foo,s.bar FROM sometable s } func ExampleWhere() { conditions := []string{ "ordered", "NOT sent", } query := "SELECT * FROM sometable" + sx.Where(conditions...) fmt.Println(query) // Output: // SELECT * FROM sometable WHERE (ordered) AND (NOT sent) } func ExampleLimitOffset() { query := "SELECT * FROM sometable" + sx.LimitOffset(100, 0) fmt.Println(query) // Output: // SELECT * FROM sometable LIMIT 100 } func ExampleInsertQuery() { sx.SetNumberedPlaceholders(true) type abc struct { Foo, Bar string Baz int64 `sx:",readonly"` } query := sx.InsertQuery("sometable", &abc{}) fmt.Println(query) // Output: // INSERT INTO sometable (foo,bar) VALUES ($1,$2) } func ExampleUpdateQuery() { sx.SetNumberedPlaceholders(true) type updateABC struct { Foo string // cannot update to "" Bar *string // can update to "" Baz int64 // cannot update to 0 Qux *int64 // can update to 0 } s1, i1 := "hello", int64(0) x := updateABC{Bar: &s1, Baz: 42, Qux: &i1} query, values := sx.UpdateQuery("sometable", &x) query += " WHERE id=$1" fmt.Println(query) fmt.Println(values) query, values = sx.UpdateQuery("sometable", &updateABC{}) fmt.Println(query == "", len(values)) // Output: // UPDATE sometable SET bar=$2,baz=$3,qux=$4 WHERE id=$1 // [hello 42 0] // true 0 } func ExampleUpdateAllQuery() { sx.SetNumberedPlaceholders(true) type abc struct { Foo, Bar string Baz int64 `sx:",readonly"` } query := sx.UpdateAllQuery("sometable", &abc{}) + " WHERE id=$1" fmt.Println(query) // Output: // UPDATE sometable SET foo=$2,bar=$3 WHERE id=$1 } func ExampleUpdateFieldsQuery() { sx.SetNumberedPlaceholders(true) type abc struct { Foo, Bar string Baz int64 } x := abc{Foo: "hello", Bar: "Goodbye", Baz: 42} query, values := sx.UpdateFieldsQuery("sometable", &x, "Bar", "Baz") query += " WHERE id=$1" fmt.Println(query) fmt.Println(values) // Output: // UPDATE sometable SET bar=$2,baz=$3 WHERE id=$1 // [Goodbye 42] } ================================================ FILE: go.mod ================================================ module github.com/travelaudience/go-sx go 1.23 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/mattn/go-sqlite3 v1.14.24 ) ================================================ FILE: go.sum ================================================ github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= ================================================ FILE: helpers.go ================================================ package sx import ( "errors" "reflect" "strconv" "strings" ) // SelectQuery returns a query string of the form // // SELECT FROM // // where is the list of columns defined by the struct pointed at by datatype, and
is the table name // given. func SelectQuery(table string, datatype interface{}) string { bob := strings.Builder{} bob.WriteString("SELECT") var sep byte = ' ' for _, c := range matchingOf(datatype).columns { bob.WriteByte(sep) bob.WriteString(c.name) sep = ',' } bob.WriteString(" FROM ") bob.WriteString(table) return bob.String() } // SelectAliasQuery returns a query string like that of SelectQuery except that a table alias is included, e.g. // // SELECT ., ., ..., . FROM
func SelectAliasQuery(table, alias string, datatype interface{}) string { bob := strings.Builder{} bob.WriteString("SELECT") var sep byte = ' ' for _, c := range matchingOf(datatype).columns { bob.WriteByte(sep) bob.WriteString(alias) bob.WriteByte('.') bob.WriteString(c.name) sep = ',' } bob.WriteString(" FROM ") bob.WriteString(table) bob.WriteByte(' ') bob.WriteString(alias) return bob.String() } // Where returns a string of the form // // WHERE () AND () ... // // with a leading space. // // If no conditions are given, then Where returns the empty string. func Where(conditions ...string) string { if len(conditions) == 0 { return "" } return " WHERE (" + strings.Join(conditions, ") AND (") + ")" } // LimitOffset returns a string of the form // // LIMIT OFFSET // // with a leading space. // // If either limit or offset are zero, then that part of the string is omitted. If both limit and offset are zero, // then LimitOffset returns the empty string. func LimitOffset(limit, offset int64) string { x := "" if limit != 0 { x = " LIMIT " + strconv.FormatInt(limit, 10) } if offset != 0 { x += " OFFSET " + strconv.FormatInt(offset, 10) } return x } // InsertQuery returns a query string of the form // // INSERT INTO
() VALUES (?,?,...) // INSERT INTO
() VALUES ($1,$2,...) (numbered placeholders) // // where
is the table name given, and is the list of the columns defined by the struct pointed at by // datatype. Struct fields tagged "readonly" are skipped. // // Panics if all fields are tagged "readonly". func InsertQuery(table string, datatype interface{}) string { columns := matchingOf(datatype).columns bob := strings.Builder{} bob.WriteString("INSERT INTO ") bob.WriteString(table) bob.WriteByte(' ') var sep byte = '(' var n int for _, c := range columns { if !c.readonly { bob.WriteByte(sep) bob.WriteString(c.name) sep = ',' n++ } } if n == 0 { panic("sx: struct " + matchingOf(datatype).reflectType.Name() + " has no writeable fields") } bob.WriteString(") VALUES ") sep = '(' for p := Placeholder(0); p < Placeholder(n); { bob.WriteByte(sep) bob.WriteString(p.Next()) sep = ',' } bob.WriteByte(')') return bob.String() } // UpdateQuery returns a query string and a list of values from the struct pointed at by data. This is the prefferred // way to do updates, as it allows pointer fields in the struct and automatically skips zero values. // // The query string of the form // // UPDATE
SET =?,=?,... // UPDATE
SET =$2,=$3,... (numbered placeholders) // // where
is the table name given, and each is a column name defined by the struct pointed at by data. // // Note: // - placeholder could be passed as an optional parameter ( to control numbered placeholders ) // - if not, numbering starts at $2 to allow $1 to be used in the WHERE clause ( ie. WHERE id = $1 ). // // The list of values contains values from the struct to match the placeholders. For pointer fields, the values // pointed at are used. // // UpdateQuery takes all the writeable fields (not tagged "readonly") from the struct, looks up their values, and if // it finds a zero value, the field is skipped. This allows the caller to set only those values that need updating. // If it is necessary to update a field to a zero value, then a pointer field should be used. A pointer to a zero // value will force an update, and a nil pointer will be skipped. // // The struct used for UpdateQuery will normally be a different struct from that used for select or insert on the // same table. This is okay. // // If there are no applicable fields, Update returns ("", nil). func UpdateQuery(table string, data interface{}, ph ...*Placeholder) (string, []interface{}) { m := matchingOf(data) instance := reflect.ValueOf(data).Elem() columns := make([]string, 0) values := make([]interface{}, 0) // check for optional placeholder provided in the parameters var p *Placeholder if len(ph) > 0 { p = ph[0] } else { // if not, use default one that starts from 2 var defaultPh Placeholder = 1 p = &defaultPh } for _, c := range m.columns { if !c.readonly { if val := instance.Field(c.index); !val.IsZero() { columns = append(columns, c.name+"="+p.Next()) if val.Kind() == reflect.Ptr { val = val.Elem() } values = append(values, val.Interface()) } } } if len(columns) == 0 { return "", nil } return "UPDATE " + table + " SET " + strings.Join(columns, ","), values } // UpdateAllQuery returns a query string of the form // // UPDATE
SET =?,=?,... // UPDATE
SET =$2,=$3,... (numbered placeholders) // // where
is the table name given, and each is a column name defined by the struct pointed at by // data. All writeable fields (those not tagged "readonly") are included. Fields are in the order of the struct. // // With numbered placeholders, numbering starts at $2. This allows $1 to be used in the WHERE clause. // // Use with the Values function to write to all writeable feilds. func UpdateAllQuery(table string, data interface{}) string { m := matchingOf(data) columns := make([]string, 0) var p Placeholder = 1 // start from 2 for _, c := range m.columns { if !c.readonly { columns = append(columns, c.name+"="+p.Next()) } } return "UPDATE " + table + " SET " + strings.Join(columns, ",") } // UpdateFieldsQuery returns a query string and a list of values for the specified fields of the struct pointed at by data. // // The query string is of the form // // UPDATE
SET =?,=?,... // UPDATE
SET =$2,=$3,... (numbered placeholders) // // where
is the table name given, and each is a column name defined by the struct pointed at by data. // // The list of values contains values from the struct to match the placeholders. The order matches the the order of // fields provided by the caller. // // With numbered placeholders, numbering starts at $2. This allows $1 to be used in the WHERE clause. // // UpdateFieldsQuery panics if no field names are provided or if any of the requested fields do not exist. If it is // necessary to validate field names, use ColumnOf. func UpdateFieldsQuery(table string, data interface{}, fields ...string) (string, []interface{}) { m := matchingOf(data) instance := reflect.ValueOf(data).Elem() columns := make([]string, 0) values := make([]interface{}, 0) var p Placeholder = 1 // start from 2 if len(fields) == 0 { panic("UpdateFieldsQuery requires at least one field") } for _, field := range fields { if c, ok := m.columnMap[field]; ok { columns = append(columns, c.name+"="+p.Next()) values = append(values, instance.Field(c.index).Interface()) } else { panic("struct " + m.reflectType.Name() + " has no usable field " + field) } } return "UPDATE " + table + " SET " + strings.Join(columns, ","), values } // Addrs returns a slice of pointers to the fields of the struct pointed at by dest. Use for scanning rows from a // SELECT query. // // Panics if dest does not point at a struct. func Addrs(dest interface{}) []interface{} { m := matchingOf(dest) val := reflect.ValueOf(dest).Elem() addrs := make([]interface{}, 0, len(m.columns)) for _, c := range m.columns { addrs = append(addrs, val.Field(c.index).Addr().Interface()) } return addrs } // Values returns a slice of values from the struct pointed at by data, excluding those from fields tagged "readonly". // Use for providing values to an INSERT query. // // Panics if data does not point at a struct. func Values(data interface{}) []interface{} { m := matchingOf(data) val := reflect.ValueOf(data).Elem() values := make([]interface{}, 0, len(m.columns)) for _, c := range m.columns { if !c.readonly { values = append(values, val.Field(c.index).Interface()) } } return values } // ValueOf returns the value of the specified field of the struct pointed at by data. Panics if data does not // point at a struct, or if the requested field doesn't exist. func ValueOf(data interface{}, field string) interface{} { // This step verifies data and field and might panic. c := matchingOf(data).columnOf(field) // If there is a panic, then the reflection here will not be attempted. return reflect.ValueOf(data).Elem().Field(c.index).Interface() } // Columns returns the names of the database columns that correspond to the fields in the struct pointed at by // datatype. The order of returned fields matches the order of the struct. func Columns(datatype interface{}) []string { return matchingOf(datatype).columnList() } // ColumnsWriteable returns the names of the database columns that correspond to the fields in the struct pointed at // by datatype, excluding those tagged "readonly". The order of returned fields matches the order of the struct. func ColumnsWriteable(datatype interface{}) []string { return matchingOf(datatype).columnWriteableList() } // ColumnOf returns the name of the database column that corresponds to the specified field of the struct pointed // at by datatype. // // ColumnOf returns an error if the provided field name is missing from the struct. func ColumnOf(datatype interface{}, field string) (string, error) { m := matchingOf(datatype) if c, ok := m.columnMap[field]; ok { return c.name, nil } return "", errors.New("struct " + m.reflectType.Name() + " has no usable field " + field) } ================================================ FILE: helpers_test.go ================================================ package sx_test import ( "reflect" "testing" sx "github.com/travelaudience/go-sx" ) // Test structs type menagerie0 struct { Platypus string Rhinoceros float64 } type menagerie1 struct { Chimpanzee int64 `sx:"human"` Flamingo string `sx:",readonly"` Warthog string `sx:"-"` } func TestSelectInsertUpdateAll(t *testing.T) { var testCases = []struct { name string table string datatype interface{} numberedPlaceholders bool wantSelect string wantInsert string wantUpdate string }{ { name: "menagerie0", table: "zoo", datatype: &menagerie0{}, numberedPlaceholders: false, wantSelect: "SELECT platypus,rhinoceros FROM zoo", wantInsert: "INSERT INTO zoo (platypus,rhinoceros) VALUES (?,?)", wantUpdate: "UPDATE zoo SET platypus=?,rhinoceros=?", }, { name: "menagerie0 numbered", table: "zoo", datatype: &menagerie0{}, numberedPlaceholders: true, wantSelect: "SELECT platypus,rhinoceros FROM zoo", wantInsert: "INSERT INTO zoo (platypus,rhinoceros) VALUES ($1,$2)", wantUpdate: "UPDATE zoo SET platypus=$2,rhinoceros=$3", }, { name: "menagerie1", table: "jungle", datatype: &menagerie1{}, numberedPlaceholders: false, wantSelect: "SELECT human,flamingo FROM jungle", wantInsert: "INSERT INTO jungle (human) VALUES (?)", wantUpdate: "UPDATE jungle SET human=?", }, { name: "menagerie1 numbered", table: "jungle", datatype: &menagerie1{}, numberedPlaceholders: true, wantSelect: "SELECT human,flamingo FROM jungle", wantInsert: "INSERT INTO jungle (human) VALUES ($1)", wantUpdate: "UPDATE jungle SET human=$2", }, } for _, c := range testCases { sx.SetNumberedPlaceholders(c.numberedPlaceholders) if a, b := c.wantSelect, sx.SelectQuery(c.table, c.datatype); a != b { t.Errorf("case %s select: expected \"%s\", got \"%s\"", c.name, a, b) } if a, b := c.wantInsert, sx.InsertQuery(c.table, c.datatype); a != b { t.Errorf("case %s insert: expected \"%s\", got \"%s\"", c.name, a, b) } if a, b := c.wantUpdate, sx.UpdateAllQuery(c.table, c.datatype); a != b { t.Errorf("case %s update all: expected \"%s\", got \"%s\"", c.name, a, b) } } } func TestSelectAlias(t *testing.T) { var testCases = []struct { name string table string alias string datatype interface{} wantSelect string }{ { name: "menagerie0", table: "zoo", alias: "home", datatype: &menagerie0{}, wantSelect: "SELECT home.platypus,home.rhinoceros FROM zoo home", }, { name: "menagerie1", table: "jungle", alias: "a", datatype: &menagerie1{}, wantSelect: "SELECT a.human,a.flamingo FROM jungle a", }, } for _, c := range testCases { if a, b := c.wantSelect, sx.SelectAliasQuery(c.table, c.alias, c.datatype); a != b { t.Errorf("case %s select alias: expected \"%s\", got \"%s\"", c.name, a, b) } } } func TestWhere(t *testing.T) { var testCases = []struct { name string conditions []string want string }{ { name: "empty", }, { name: "one condition", conditions: []string{"a=5"}, want: " WHERE (a=5)", }, { name: "two conditions", conditions: []string{"a=5", "b=6"}, want: " WHERE (a=5) AND (b=6)", }, { name: "three conditions", conditions: []string{"a=5", "b=6", "c=7"}, want: " WHERE (a=5) AND (b=6) AND (c=7)", }, } for _, c := range testCases { if a, b := c.want, sx.Where(c.conditions...); a != b { t.Errorf("case %s: expected \"%s\", got \"%s\"", c.name, a, b) } } } func TestLimitOffset(t *testing.T) { var testCases = []struct { name string limit int64 offset int64 want string }{ { name: "empty", }, { name: "limit only", limit: 100, want: " LIMIT 100", }, { name: "offset only", offset: 200, want: " OFFSET 200", }, { name: "limit and offset", limit: 123, offset: 456, want: " LIMIT 123 OFFSET 456", }, { name: "negative numbers", // no practical use for this, but the function should still process it limit: -3, offset: -999999, want: " LIMIT -3 OFFSET -999999", }, } for _, c := range testCases { if a, b := c.want, sx.LimitOffset(c.limit, c.offset); a != b { t.Errorf("case %s: expected \"%s\", got \"%s\"", c.name, a, b) } } } func TestInsertPanic(t *testing.T) { // InsertQuery should panic if all of a struct's fields are tagged readonly. type ohNo struct { _ int `sx:",readonly"` _ int `sx:",readonly"` } const wantPanic = "sx: struct ohNo has no usable fields" func() { defer func() { r := recover() if r == nil { t.Errorf("expected a panic") return } if s, ok := r.(string); ok { if s != wantPanic { t.Errorf("expected \"%s\", got \"%s\"", wantPanic, s) } return } panic(r) }() sx.InsertQuery("zoo", &ohNo{}) }() } func TestUpdate(t *testing.T) { type menagerie2 struct { Cougar string Cheetah *string `sx:"cat"` Grizzly int32 Wilderbeast []int32 Ant bool `sx:"-"` Bee bool `sx:",readonly"` } var ( x = []int32{1, 2, 3} y = int32(5) s = "abcde" ) var testCases = []struct { name string table string data interface{} numberedPlaceholders bool wantQuery string wantValues []interface{} }{ { name: "empty", table: "africa", data: &menagerie2{}, numberedPlaceholders: false, wantQuery: "", wantValues: nil, }, { name: "ignore ant and bee", table: "asia", data: &menagerie2{Ant: true, Bee: true}, numberedPlaceholders: true, wantQuery: "", wantValues: nil, }, { name: "cougar and wilderbeast", table: "australia", data: &menagerie2{Cougar: "abc", Wilderbeast: x}, numberedPlaceholders: false, wantQuery: "UPDATE australia SET cougar=?,wilderbeast=?", wantValues: []interface{}{"abc", x}, }, { name: "cougar and wilderbeast numbered", table: "australia", data: &menagerie2{Cougar: "abc", Wilderbeast: x}, numberedPlaceholders: true, wantQuery: "UPDATE australia SET cougar=$2,wilderbeast=$3", wantValues: []interface{}{"abc", x}, }, { name: "cheetah and grizzly", table: "siberia", data: &menagerie2{Cheetah: &s, Grizzly: y}, numberedPlaceholders: false, wantQuery: "UPDATE siberia SET cat=?,grizzly=?", wantValues: []interface{}{s, y}, }, { name: "cheetah and grizzly numbered", table: "siberia", data: &menagerie2{Cheetah: &s, Grizzly: y}, numberedPlaceholders: true, wantQuery: "UPDATE siberia SET cat=$2,grizzly=$3", wantValues: []interface{}{s, y}, }, { name: "everyone except ant", table: "berlin", data: &menagerie2{Cougar: "roar", Cheetah: &s, Grizzly: y, Wilderbeast: x, Bee: true}, numberedPlaceholders: false, wantQuery: "UPDATE berlin SET cougar=?,cat=?,grizzly=?,wilderbeast=?", wantValues: []interface{}{"roar", s, y, x}, }, { name: "everyone except bee", table: "berlin", data: &menagerie2{Cougar: "roar", Cheetah: &s, Grizzly: y, Wilderbeast: x, Ant: true}, numberedPlaceholders: true, wantQuery: "UPDATE berlin SET cougar=$2,cat=$3,grizzly=$4,wilderbeast=$5", wantValues: []interface{}{"roar", s, y, x}, }, } for _, c := range testCases { sx.SetNumberedPlaceholders(c.numberedPlaceholders) query, values := sx.UpdateQuery(c.table, c.data) if a, b := c.wantQuery, query; a != b { t.Errorf("case %s query: expected \"%s\", got \"%s\"", c.name, a, b) } if a, b := c.wantValues, values; !reflect.DeepEqual(a, b) { t.Errorf("case %s values: expected %v, got %v", c.name, a, b) } } } func TestUpdateFields(t *testing.T) { var testCases = []struct { name string table string data interface{} fields []string numberedPlaceholders bool wantQuery string wantValues []interface{} wantPanic string }{ { name: "no fields", table: "irrelevant", data: &menagerie0{}, wantPanic: "UpdateFieldsQuery requires at least one field", }, { name: "invalid field", table: "irrelevant", data: &menagerie0{}, fields: []string{"Research"}, wantPanic: "struct menagerie0 has no usable field Research", }, { name: "ignored field", table: "irrelevant", data: &menagerie1{}, fields: []string{"Warthog"}, wantPanic: "struct menagerie1 has no usable field Warthog", }, { name: "rhinoceros", table: "swamp", data: &menagerie0{Platypus: "abc", Rhinoceros: -5.0}, fields: []string{"Rhinoceros"}, numberedPlaceholders: false, wantQuery: "UPDATE swamp SET rhinoceros=?", wantValues: []interface{}{float64(-5.0)}, }, { name: "rhinoceros numbered", table: "swamp", data: &menagerie0{Platypus: "abc", Rhinoceros: -10.0}, fields: []string{"Rhinoceros"}, numberedPlaceholders: true, wantQuery: "UPDATE swamp SET rhinoceros=$2", wantValues: []interface{}{float64(-10.0)}, }, { name: "rhinoceros and platypus", table: "swamp", data: &menagerie0{Platypus: "abc", Rhinoceros: -5.0}, fields: []string{"Rhinoceros", "Platypus"}, numberedPlaceholders: false, wantQuery: "UPDATE swamp SET rhinoceros=?,platypus=?", wantValues: []interface{}{float64(-5.0), "abc"}, }, { name: "rhinoceros and platypus numbered", table: "swamp", data: &menagerie0{Platypus: "abc", Rhinoceros: 10.0}, fields: []string{"Rhinoceros", "Platypus"}, numberedPlaceholders: true, wantQuery: "UPDATE swamp SET rhinoceros=$2,platypus=$3", wantValues: []interface{}{float64(10.0), "abc"}, }, { name: "can update readonly field", // TODO: should we allow this? table: "forest", data: &menagerie1{Chimpanzee: 123, Flamingo: "foo"}, fields: []string{"Chimpanzee", "Flamingo"}, numberedPlaceholders: false, wantQuery: "UPDATE forest SET human=?,flamingo=?", wantValues: []interface{}{int64(123), "foo"}, }, { name: "can update readonly field numbered", table: "forest", data: &menagerie1{Chimpanzee: 123, Flamingo: "foo"}, fields: []string{"Chimpanzee", "Flamingo"}, numberedPlaceholders: true, wantQuery: "UPDATE forest SET human=$2,flamingo=$3", wantValues: []interface{}{int64(123), "foo"}, }, } type _ struct { Chimpanzee int64 `sx:"human"` Flamingo string `sx:",readonly"` Warthog string `sx:"-"` } for _, c := range testCases { var ( query string values []interface{} gotPanic string ) sx.SetNumberedPlaceholders(c.numberedPlaceholders) func() { gotPanic = "" defer func() { r := recover() if r == nil { return } if s, ok := r.(string); ok { gotPanic = s return } panic(r) }() query, values = sx.UpdateFieldsQuery(c.table, c.data, c.fields...) }() if gotPanic != c.wantPanic { if c.wantPanic == "" { t.Errorf("case %s: unexpected panic %q", c.name, gotPanic) } else if gotPanic == "" { t.Errorf("case %s: expected panic %q but got none", c.name, c.wantPanic) } else { t.Errorf("case %s: expected panic %q, got %q", c.name, c.wantPanic, gotPanic) } continue } if a, b := c.wantQuery, query; a != b { t.Errorf("case %s query: expected \"%s\", got \"%s\"", c.name, a, b) } if a, b := c.wantValues, values; !reflect.DeepEqual(a, b) { t.Errorf("case %s values: expected %v, got %v", c.name, a, b) } } } func TestAddrsValues(t *testing.T) { var ( data0 = menagerie0{Platypus: "yes", Rhinoceros: 1.0} data1 = menagerie1{Chimpanzee: 64, Flamingo: "maybe", Warthog: "no"} ) var testCases = []struct { name string data interface{} wantAddrs []interface{} wantValues []interface{} }{ { name: "menagerie0", data: &data0, wantAddrs: []interface{}{&data0.Platypus, &data0.Rhinoceros}, wantValues: []interface{}{"yes", float64(1.0)}, }, { name: "menagerie1", data: &data1, wantAddrs: []interface{}{&data1.Chimpanzee, &data1.Flamingo}, wantValues: []interface{}{int64(64)}, }, } // What's returned from Addrs is a slice of pointers, and we need to test that these are the exact pointers // that we want. We can get a false true from DeepEqual, since DeepEqual considers two distinct pointers whose // pointed-at values are the same to be equal. shallowEqual := func(a, b []interface{}) bool { if len(a) != len(b) { return false } for i := range a { if a[i] != b[i] { return false } } return true } for _, c := range testCases { if a, b := c.wantAddrs, sx.Addrs(c.data); !shallowEqual(a, b) { t.Errorf("case %s addrs: expected %v, got %v", c.name, a, b) } if a, b := c.wantValues, sx.Values(c.data); !reflect.DeepEqual(a, b) { t.Errorf("case %s values: expected %v, got %v", c.name, a, b) } } } func TestValueOf(t *testing.T) { var testCases = []struct { name string data interface{} field string wantValue interface{} }{ { name: "platypus", data: &menagerie0{Platypus: "pus", Rhinoceros: 2.0}, field: "Platypus", wantValue: "pus", }, { name: "rhinoceros", data: &menagerie0{Platypus: "pus", Rhinoceros: 2.0}, field: "Rhinoceros", wantValue: float64(2.0), }, { name: "chimpanzee", data: &menagerie1{Chimpanzee: 641, Flamingo: "goth", Warthog: "hog"}, field: "Chimpanzee", wantValue: int64(641), }, { name: "flamingo", data: &menagerie1{Chimpanzee: 641, Flamingo: "goth", Warthog: "hog"}, field: "Flamingo", wantValue: "goth", }, } for _, c := range testCases { if a, b := c.wantValue, sx.ValueOf(c.data, c.field); a != b { t.Errorf("case %s: expected %v, got %v", c.name, a, b) } } } func TestColumnsColumnsWriteable(t *testing.T) { var testCases = []struct { name string datatype interface{} wantColumns []string wantColumnsWriteable []string }{ { name: "menagerie0", datatype: &menagerie0{}, wantColumns: []string{"platypus", "rhinoceros"}, wantColumnsWriteable: []string{"platypus", "rhinoceros"}, }, { name: "menagerie1", datatype: &menagerie1{}, wantColumns: []string{"human", "flamingo"}, wantColumnsWriteable: []string{"human"}, }, } for _, c := range testCases { if a, b := c.wantColumns, sx.Columns(c.datatype); !reflect.DeepEqual(a, b) { t.Errorf("case %s columns: expected %v, got %v", c.name, a, b) } if a, b := c.wantColumnsWriteable, sx.ColumnsWriteable(c.datatype); !reflect.DeepEqual(a, b) { t.Errorf("case %s columns writeable: expected %v, got %v", c.name, a, b) } } } func TestColumnOf(t *testing.T) { var testCases = []struct { name string datatype interface{} field string wantColumn string wantErr string }{ { name: "platypus", datatype: &menagerie0{}, field: "Platypus", wantColumn: "platypus", }, { name: "rhinoceros", datatype: &menagerie0{}, field: "Rhinoceros", wantColumn: "rhinoceros", }, { name: "unknown column", datatype: &menagerie0{}, field: "Hippopotamus", wantErr: "struct menagerie0 has no usable field Hippopotamus", }, { name: "chimpanzee", datatype: &menagerie1{}, field: "Chimpanzee", wantColumn: "human", }, { name: "flamingo", datatype: &menagerie1{}, field: "Flamingo", wantColumn: "flamingo", }, { name: "ignored field", datatype: &menagerie1{}, field: "Warthog", wantErr: "struct menagerie1 has no usable field Warthog", }, } for _, c := range testCases { column, err := sx.ColumnOf(c.datatype, c.field) if err != nil { if c.wantErr == "" { t.Errorf("case %s: unexpected error %q", c.name, err.Error()) } else if a, b := c.wantErr, err.Error(); a != b { t.Errorf("case %s: expected error %q, got %q", c.name, a, b) } continue } if a, b := c.wantColumn, column; a != b { t.Errorf("case %s: expected %v, got %v", c.name, a, b) } } } ================================================ FILE: matching.go ================================================ package sx import ( "reflect" "regexp" "strings" "sync" ) // A matching is between struct fields and database columns. type matching struct { reflectType reflect.Type columns []*column // an ordered list of columns columnMap map[string]*column // columns keyed by field name } type column struct { index int // index of this field in the struct name string // name of the corresponding db column readonly bool // flag to skip this column on insert/update operations (e.g. for primary key or automatic timestamp) } // ColumnList returns the names of the database columns in the order of the struct. func (m *matching) columnList() []string { list := make([]string, 0, len(m.columns)) for _, c := range m.columns { list = append(list, c.name) } return list } // ColumnWriteableList returns the names of the database columns in the order of the struct, excluding read-only columns. func (m *matching) columnWriteableList() []string { list := make([]string, 0, len(m.columns)) for _, c := range m.columns { if !c.readonly { list = append(list, c.name) } } return list } // ColumnOf returns the column which matches the named field. Panics if the field doesn't exist. func (m *matching) columnOf(field string) *column { if c, ok := m.columnMap[field]; ok { return c } panic("sx: struct " + m.reflectType.Name() + " has no usable field " + field) } // MatchingOf returns a matching for the given struct type, generating it if necessary. MatchingOf looks only at the // structure of datatype and ignore its values. // // Panics if datatype does not point at a struct, or if the struct has no usable fields. func matchingOf(datatype interface{}) *matching { matchingCacheMu.Lock() defer matchingCacheMu.Unlock() v := reflect.ValueOf(datatype) if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { panic("sx: expected a pointer to a struct") } // First look for a cached matching. reflectType := v.Elem().Type() if m, ok := matchingCache[reflectType]; ok { return m } // Nothing cached, generate a new matching and cache it. n := reflectType.NumField() cols := make([]*column, 0) colmap := make(map[string]*column) for i := 0; i < n; i++ { field := reflectType.Field(i) tags := strings.Split(field.Tag.Get("sx"), ",") colname := tags[0] if colname == "-" || field.PkgPath != "" { continue // skip excluded and unexported fields. } if colname == "" { colname = snakeCase(field.Name) // default column name based on field name } col := &column{ index: i, name: colname, } // See if there's a readonly tag. A readonly tag would have to be in at least the second position, since // the first position is always interpreted as a column name. for _, tag := range tags[1:] { if tag == "readonly" { col.readonly = true break } } cols = append(cols, col) colmap[field.Name] = col } if len(cols) == 0 { panic("sx: struct " + reflectType.Name() + " has no usable fields") } m := &matching{ reflectType: reflectType, columns: cols, columnMap: colmap, } matchingCache[reflectType] = m return m } // Cache to keep track of struct types that have been seen and therefore analyzed. var matchingCache = make(map[reflect.Type]*matching) var matchingCacheMu sync.Mutex // Snake-casing logic. var ( matchWord = regexp.MustCompile(`(.)([A-Z][a-z]+)`) matchAcronym = regexp.MustCompile(`([a-z0-9])([A-Z])`) ) func snakeCase(in string) string { const r = `${1}_${2}` return strings.ToLower(matchAcronym.ReplaceAllString(matchWord.ReplaceAllString(in, r), r)) } ================================================ FILE: matching_test.go ================================================ package sx_test import ( "testing" sx "github.com/travelaudience/go-sx" ) // The tests in this file test that the correct panics are generated. The tests in helpers_test.go test for // the correct results. func TestMatching(t *testing.T) { type test1 struct { A int Lollipop bool ChocolateID float64 FOOBarBAZ string } type test2 struct { A int `sx:"-"` } type test3 struct { } type test4 struct { A int `sx:"-"` B int `sx:"foo"` C int `sx:"bar"` _ int `sx:"baz"` } t.Run("panics on bad input", func(t *testing.T) { var testCases = []struct { name string data interface{} wantPanic string }{ { name: "pass a struct, not a pointer", data: test1{}, wantPanic: "sx: expected a pointer to a struct", }, { name: "pass nil", data: nil, wantPanic: "sx: expected a pointer to a struct", }, { name: "pass something else", data: "hello", wantPanic: "sx: expected a pointer to a struct", }, { name: "no usable fields", data: &test2{}, wantPanic: "sx: struct test2 has no usable fields", }, { name: "no exported fields", data: &test3{}, wantPanic: "sx: struct test3 has no usable fields", }, } for _, c := range testCases { func() { defer func() { r := recover() if r == nil { t.Errorf("case %s: expected a panic", c.name) return } if s, ok := r.(string); ok { if s != c.wantPanic { t.Errorf("case %s: expected \"%s\", got \"%s\"", c.name, c.wantPanic, s) } return } panic(r) }() // this calls matchingOf(c.data) straight away sx.Values(c.data) }() } }) t.Run("ColumnOf panics on unknown field", func(t *testing.T) { var testCases = []struct { name string data interface{} field string wantPanic string }{ { name: "unknown field", data: &test1{}, field: "Zzzzz", wantPanic: "sx: struct test1 has no usable field Zzzzz", }, { name: "ignored field", data: &test4{}, field: "A", wantPanic: "sx: struct test4 has no usable field A", }, { name: "unexported field", data: &test4{}, field: "d", wantPanic: "sx: struct test4 has no usable field d", }, } for _, c := range testCases { func() { defer func() { r := recover() if r == nil { t.Errorf("case %s: expected a panic", c.name) return } if s, ok := r.(string); ok { if s != c.wantPanic { t.Errorf("case %s: expected %q, got %q", c.name, c.wantPanic, s) } return } panic(r) }() // this calls matchingOf(c.data).ColumnOf(c.field) sx.ValueOf(c.data, c.field) }() } }) } ================================================ FILE: placeholder.go ================================================ package sx import "strconv" var numberedPlaceholders bool // SetNumberedPlaceholders sets the style of placeholders to be used for generated queries. If yes is true, then // postgres-style "$n" placeholders will be used for all future queries. If yes is false, then mysql-style "?" // placeholders will be used. This setting may be changed at any time. Default is false. func SetNumberedPlaceholders(yes bool) { numberedPlaceholders = yes } // A Placeholder is a generator for the currently selected placeholder type. See SetNumberedPlaceholders(). type Placeholder int // String displays the current placeholder value in its chosen format (either "?" or "$n"). func (p Placeholder) String() string { if numberedPlaceholders { return "$" + strconv.Itoa(int(p)) } return "?" } // Next increments the placeholder value and returns the string value of the next placeholder in sequence. // // When using numbered placeholders, a zero-valued placeholder will return "$1" on its first call to Next(). // When using ?-style placeholders, Next always returns "?". func (p *Placeholder) Next() string { *p++ return p.String() } ================================================ FILE: placeholder_test.go ================================================ package sx_test import ( "testing" sx "github.com/travelaudience/go-sx" ) func TestPlaceholders(t *testing.T) { sx.SetNumberedPlaceholders(false) t.Run("? placeholders", func(t *testing.T) { want := []string{"?", "?", "?"} var p sx.Placeholder for i, x := range want { y := p.Next() if x != y { t.Errorf("case a-%d: expected %s, got %s", i, x, y) } } }) sx.SetNumberedPlaceholders(true) t.Run("numbered placeholders", func(t *testing.T) { want := []string{"$1", "$2", "$3"} var p sx.Placeholder for i, x := range want { y := p.Next() if x != y { t.Errorf("case b-%d: expected %s, got %s", i, x, y) } } }) } ================================================ FILE: tx.go ================================================ package sx import ( "context" "database/sql" ) // Tx extends sql.Tx with some Must*** methods that panic instead of returning an error code. Tx objects are used // inside of transactions managed by Do. Panics are caught by Do and returned as errors. type Tx struct { *sql.Tx } // An sxError is used to wrap errors that we want to send back to the caller of Do. type sxError struct { err error } // MustExec executes a query without returning any rows. The args are for any placeholder parameters in the query. // In case of error, the transaction is aborted and Do returns the error code. func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { return tx.MustExecContext(context.Background(), query, args...) } // MustExecContext executes a query without returning any rows. The args are for any placeholder parameters in the // query. In case of error, the transaction is aborted and Do returns the error code. func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { res, err := tx.ExecContext(ctx, query, args...) if err != nil { panic(sxError{err}) } return res } // MustQuery executes a query that returns rows. The args are for any placeholder parameters in the query. // In case of error, the transaction is aborted and Do returns the error code. func (tx *Tx) MustQuery(query string, args ...interface{}) *Rows { return tx.MustQueryContext(context.Background(), query, args...) } // MustQueryContext executes a query that returns rows. The args are for any placeholder parameters in the query. // In case of error, the transaction is aborted and Do returns the error code. func (tx *Tx) MustQueryContext(ctx context.Context, query string, args ...interface{}) *Rows { rows, err := tx.QueryContext(ctx, query, args...) if err != nil { panic(sxError{err}) } return &Rows{rows} } // MustQueryRow executes a query that is expected to return at most one row. MustQueryRow always returns a non-nil // value. Errors are deferred until one of the Row's scan methods is called. func (tx *Tx) MustQueryRow(query string, args ...interface{}) *Row { return &Row{tx.QueryRowContext(context.Background(), query, args...)} } // MustQueryRowContext executes a query that is expected to return at most one row. MustQueryRow always returns a // non-nil value. Errors are deferred until one of the Row's scan methods is called. func (tx *Tx) MustQueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { return &Row{tx.QueryRowContext(ctx, query, args...)} } // MustPrepare creates a prepared statement for later queries or executions. Multiple queries or executions may be // run concurrently from the returned statement. In case of error, the transaction is aborted and Do returns the // error code. // // The caller must call the statement's Close method when the statement is no longer needed. func (tx *Tx) MustPrepare(query string) *Stmt { return tx.MustPrepareContext(context.Background(), query) } // MustPrepareContext creates a prepared statement for later queries or executions. Multiple queries or executions // may be run concurrently from the returned statement. In case of error, the transaction is aborted and Do returns // the error code. // // The caller must call the statement's Close method when the statement is no longer needed. func (tx *Tx) MustPrepareContext(ctx context.Context, query string) *Stmt { stmt, err := tx.PrepareContext(ctx, query) if err != nil { panic(sxError{err}) } return &Stmt{stmt} } // Fail aborts and rolls back the transaction, returning the given error code to the caller of Do. Fail always // rolls back the transaction, even if err is nil. func (tx *Tx) Fail(err error) { panic(sxError{err}) } // Stmt extends sql.Stmt with some Must*** methods that panic instead of returning an error code. Stmt objects are // used inside of transactions managed by Do. Panics are caught by Do and returned as errors. type Stmt struct { *sql.Stmt } // MustExec executes a prepared statement with the given arguments and returns an sql.Result summarizing the effect // of the statement. In case of error, the transaction is aborted and Do returns the error code. func (stmt *Stmt) MustExec(args ...interface{}) sql.Result { return stmt.MustExecContext(context.Background(), args...) } // MustExecContext executes a prepared statement with the given arguments and returns an sql.Result summarizing the // effect of the statement. In case of error, the transaction is aborted and Do returns the error code. func (stmt *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { res, err := stmt.ExecContext(ctx, args...) if err != nil { panic(sxError{err}) } return res } // MustQuery executes a prepared query statement with the given arguments and returns the query results as a *Rows. // In case of error, the transaction is aborted and Do returns the error code. func (stmt *Stmt) MustQuery(args ...interface{}) *Rows { return stmt.MustQueryContext(context.Background(), args...) } // MustQueryContext executes a prepared query statement with the given arguments and returns the query results as // a *Rows. In case of error, the transaction is aborted and Do returns the error code. func (stmt *Stmt) MustQueryContext(ctx context.Context, args ...interface{}) *Rows { rows, err := stmt.QueryContext(ctx, args...) if err != nil { panic(sxError{err}) } return &Rows{rows} } // MustQueryRow executes a prepared query that is expected to return at most one row. MustQueryRow always returns // a non-nil value. Errors are deferred until one of the Row's scan methods is called. func (stmt *Stmt) MustQueryRow(args ...interface{}) *Row { return &Row{stmt.QueryRowContext(context.Background(), args...)} } // MustQueryRowContext executes a prepared query that is expected to return at most one row. MustQueryRowContext // always returns a non-nil value. Errors are deferred until one of the Row's scan methods is called. func (stmt *Stmt) MustQueryRowContext(ctx context.Context, args ...interface{}) *Row { return &Row{stmt.QueryRowContext(ctx, args...)} } // Do runs a callback function f, providing f with the prepared statement, and then closing the prepared statement // after f returns. func (stmt *Stmt) Do(f func(*Stmt)) { defer stmt.Close() f(stmt) } // Row is the result of calling MustQueryRow to select a single row. Row extends sql.Row with some useful // scan methods. type Row struct { *sql.Row } // MustScan copies the columns in the current row into the values pointed at by dest. In case of error, the // transaction is aborted and Do returns the error code. func (row *Row) MustScan(dest ...interface{}) { err := row.Scan(dest...) if err != nil { panic(sxError{err}) } } // MustScans copies the columns in the current row into the struct pointed at by dest. In case of error, the // transaction is aborted and Do returns the error code. func (row *Row) MustScans(dest interface{}) { row.MustScan(Addrs(dest)...) } // Rows is the result of calling MustQuery to select a set of rows. Rows extends sql.Rows with some useful // scan methods. type Rows struct { *sql.Rows } // MustScan calls Scan to read in a row of the result set. In case of error, the transaction is aborted and Do // returns the error code. func (rows *Rows) MustScan(dest ...interface{}) { err := rows.Scan(dest...) if err != nil { panic(sxError{err}) } } // MustScans copies the columns in the current row into the struct pointed at by dest. In case of error, the // transaction is aborted and Do returns the error code. func (rows *Rows) MustScans(dest interface{}) { rows.MustScan(Addrs(dest)...) } // Each iterates over all of the rows in a result set and runs a callback function on each row. func (rows *Rows) Each(f func(*Rows)) { defer rows.Close() for rows.Next() { f(rows) } err := rows.Err() if err != nil { panic(sxError{err}) } } // Do runs the function f in a transaction. Within f, if Fail() is invoked or if any Must*** method encounters // an error, then the transaction is rolled back and Do returns the error. If f runs to completion, then the // transaction is committed, and Do returns nil. // // Internally, the Must*** methods panic on error, and Fail() always panics. The panic aborts execution of f. // f should not attempt to recover from the panic. Instead, Do will catch the panic and return it as an error. // // A TxOptions may be provided to specify isolation level and/or read-only status. If no TxOptions is provided, // then the default oprtions are used. Extra TxOptions are ignored. func Do(db *sql.DB, f func(*Tx), opts ...sql.TxOptions) error { return DoContext(context.Background(), db, f, opts...) } // DoContext runs the function f in a transaction. Within f, if Fail() is invoked or if any Must*** method encounters // an error, then the transaction is rolled back and Do returns the error. If f runs to completion, then the // transaction is committed, and DoContext returns nil. // // Internally, the Must*** methods panic on error, and Fail() always panics. The panic aborts execution of f. // f should not attempt to recover from the panic. Instead, Do will catch the panic and return it as an error. // // A TxOptions may be provided to specify isolation level and/or read-only status. If no TxOptions is provided, // then the default oprtions are used. Extra TxOptions are ignored. func DoContext(ctx context.Context, db *sql.DB, f func(*Tx), opts ...sql.TxOptions) (err error) { var opt *sql.TxOptions if len(opts) > 0 { opt = &opts[0] } var tx *sql.Tx tx, err = db.BeginTx(ctx, opt) if err != nil { return } defer func() { if r := recover(); r != nil { if ourerr, ok := r.(sxError); ok { // Our panic. Unwrap it and return it as an error code. tx.Rollback() err = ourerr.err } else { // Not our panic, so propagate it. panic(r) } } }() // This runs the queries. f(&Tx{tx}) err = tx.Commit() return } ================================================ FILE: tx_test.go ================================================ package sx_test import ( "context" "database/sql" "errors" "math/rand" "os" "strings" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" sx "github.com/travelaudience/go-sx" ) func TestMain(m *testing.M) { os.Exit(m.Run()) } // helper functions func newMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { t.Helper() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatalf("error creating mock database: %v", err) } return db, mock } func endMock(t *testing.T, mock sqlmock.Sqlmock) { t.Helper() err := mock.ExpectationsWereMet() if err != nil { t.Errorf("mocked expectations were not met: %v", err) } } func TestMustExec(t *testing.T) { t.Run("MustExec with result", func(t *testing.T) { db, mock := newMock(t) a, b := rand.Int63(), rand.Int63() const query = "SELECT alpha" mock.ExpectBegin() mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(a, b)) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { res := tx.MustExec(query) a0, _ := res.LastInsertId() b0, _ := res.RowsAffected() if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustExec with error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT bravo" err0 := errors.New("bravo error") mock.ExpectBegin() mock.ExpectExec(query).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustExec(query) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustExec with 1 argument and result", func(t *testing.T) { db, mock := newMock(t) x, a, b := rand.Int63(), rand.Int63(), rand.Int63() const query = "SELECT charlie" mock.ExpectBegin() mock.ExpectExec(query).WithArgs(x).WillReturnResult(sqlmock.NewResult(a, b)) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { res := tx.MustExec(query, x) a0, _ := res.LastInsertId() b0, _ := res.RowsAffected() if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustExec with 2 arguments and error", func(t *testing.T) { db, mock := newMock(t) x, y := rand.Int63(), rand.Int63() const query = "SELECT delta" err0 := errors.New("delta error") mock.ExpectBegin() mock.ExpectExec(query).WithArgs(x, y).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustExec(query, x, y) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustExecContext with result", func(t *testing.T) { db, mock := newMock(t) a, b := rand.Int63(), rand.Int63() const query = "SELECT alpha_context" mock.ExpectBegin() mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(a, b)) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { res := tx.MustExecContext(context.Background(), query) a0, _ := res.LastInsertId() b0, _ := res.RowsAffected() if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustExec with isolation level and error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT bravissimo" err0 := errors.New("bravissimo error") mock.ExpectBegin() mock.ExpectExec(query).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustExec(query) }, sql.TxOptions{Isolation: sql.LevelSerializable}) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) } func TestMustQueryRow(t *testing.T) { t.Run("MustQueryRow with result", func(t *testing.T) { db, mock := newMock(t) a, b := rand.Int63(), rand.Int63() const query = "SELECT echo" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a, b) mock.ExpectBegin() mock.ExpectQuery(query).WillReturnRows(rows) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { var a0, b0 int64 tx.MustQueryRow(query).MustScan(&a0, &b0) if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustQueryRow with no rows", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT foxtrot" rows := sqlmock.NewRows([]string{"a", "b"}) mock.ExpectBegin() mock.ExpectQuery(query).WillReturnRows(rows) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { var a0, b0 int64 tx.MustQueryRow(query).MustScan(&a0, &b0) }) if err != sql.ErrNoRows { t.Errorf("expected error %v, got %v", sql.ErrNoRows, err) } endMock(t, mock) }) t.Run("MustQueryRow with error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT golf" err0 := errors.New("golf error") mock.ExpectBegin() mock.ExpectQuery(query).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { var a0, b0 int64 tx.MustQueryRow(query).MustScan(&a0, &b0) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustQueryRow with 1 argument and error", func(t *testing.T) { db, mock := newMock(t) x := rand.Int63() const query = "SELECT hotel" err0 := errors.New("hotel error") mock.ExpectBegin() mock.ExpectQuery(query).WithArgs(x).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { var a0, b0 int64 tx.MustQueryRow(query, x).MustScan(&a0, &b0) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustQueryRow with 3 arguments and struct result", func(t *testing.T) { db, mock := newMock(t) a, b, x, y, z := rand.Int63(), rand.Int63(), rand.Int63(), rand.Int63(), rand.Int63() const query = "SELECT indigo" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a, b) mock.ExpectBegin() mock.ExpectQuery(query).WithArgs(x, y, z).WillReturnRows(rows) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { var res struct{ A, B int64 } tx.MustQueryRow(query, x, y, z).MustScans(&res) if res.A != a || res.B != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, res.A, res.B) } }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustQueryRowContext with result", func(t *testing.T) { db, mock := newMock(t) a, b := rand.Int63(), rand.Int63() const query = "SELECT echo_context" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a, b) mock.ExpectBegin() mock.ExpectQuery(query).WillReturnRows(rows) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { var a0, b0 int64 tx.MustQueryRowContext(context.TODO(), query).MustScan(&a0, &b0) if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) } func TestMustQuery(t *testing.T) { t.Run("MustQuery with error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT juliett" err0 := errors.New("juliett error") mock.ExpectBegin() mock.ExpectQuery(query).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustQuery(query) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustQuery with 1 argument and error", func(t *testing.T) { db, mock := newMock(t) x := rand.Int63() const query = "SELECT kilo" err0 := errors.New("kilo error") mock.ExpectBegin() mock.ExpectQuery(query).WithArgs(x).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustQuery(query, x) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustQuery with 1 argument and 1 result row", func(t *testing.T) { db, mock := newMock(t) a, b, x := rand.Int63(), rand.Int63(), rand.Int63() const query = "SELECT lima" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a, b) mock.ExpectBegin() mock.ExpectQuery(query).WithArgs(x).WillReturnRows(rows) mock.ExpectCommit() var a0, b0 int64 n := 0 err := sx.Do(db, func(tx *sx.Tx) { tx.MustQuery(query, x).Each(func(r *sx.Rows) { r.MustScan(&a0, &b0) n++ }) }) if err != nil { t.Errorf("unexpected error: %v", err) } if n != 1 { t.Errorf("Expected 1 row, got %d", n) } else if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } endMock(t, mock) }) t.Run("MustQuery with 1 argument and 1 result row with error", func(t *testing.T) { db, mock := newMock(t) x := rand.Int63() const query = "SELECT mike" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow("scan", "error") mock.ExpectBegin() mock.ExpectQuery(query).WithArgs(x).WillReturnRows(rows) mock.ExpectRollback() var a0, b0 int64 n := 0 err := sx.Do(db, func(tx *sx.Tx) { tx.MustQuery(query, x).Each(func(r *sx.Rows) { r.MustScan(&a0, &b0) n++ }) }) if n != 0 { t.Errorf("Expected no rows, got %d", n) } else if err == nil || !strings.Contains(err.Error(), "Scan error") { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustQuery with 1 argument and 2 struct result rows", func(t *testing.T) { type ab struct{ A, B int64 } db, mock := newMock(t) dat, x := [2]ab{{A: rand.Int63(), B: rand.Int63()}, {A: rand.Int63(), B: rand.Int63()}}, rand.Int63() const query = "SELECT november" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(dat[0].A, dat[0].B).AddRow(dat[1].A, dat[1].B) mock.ExpectBegin() mock.ExpectQuery(query).WithArgs(x).WillReturnRows(rows) mock.ExpectCommit() var res [2]ab n := 0 err := sx.Do(db, func(tx *sx.Tx) { tx.MustQuery(query, x).Each(func(r *sx.Rows) { r.MustScans(&res[n]) n++ }) }) if err != nil { t.Errorf("unexpected error: %v", err) } if n != 2 { t.Errorf("Expected 2 rows, got %d", n) } else if res != dat { t.Errorf("Expected results (%d, %d), (%d, %d) got (%d, %d), (%d, %d)", dat[0].A, dat[0].B, dat[1].A, dat[1].B, res[0].A, res[0].B, res[1].A, res[1].B) } endMock(t, mock) }) t.Run("MustQuery with 2 arguments, 2 result rows and row error", func(t *testing.T) { db, mock := newMock(t) a, b, x, y := [2]int64{rand.Int63(), 0}, [2]int64{rand.Int63(), 0}, rand.Int63(), rand.Int63() const query = "SELECT oscar" err0 := errors.New("oscar error") rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a[0], b[0]).AddRow(a[1], b[1]).RowError(1, err0) mock.ExpectBegin() mock.ExpectQuery(query).WithArgs(x, y).WillReturnRows(rows) mock.ExpectRollback() var aa, bb [2]int64 n := 0 err := sx.Do(db, func(tx *sx.Tx) { tx.MustQuery(query, x, y).Each(func(r *sx.Rows) { r.MustScan(&aa[n], &bb[n]) n++ }) }) if n != 1 { t.Errorf("Expected 1 row before the row error, got %d", n) } else if err != err0 { t.Errorf("unexpected error: %v", err) } else if aa != a || bb != b { t.Errorf("Expected result (%d, %d) before the row error, got (%d, %d)", a[0], b[0], aa[0], bb[0]) } endMock(t, mock) }) t.Run("MustQuery with 2 arguments, 2 result rows and scan error", func(t *testing.T) { db, mock := newMock(t) a, b := [3]int64{rand.Int63(), rand.Int63(), 0}, [3]int64{rand.Int63(), rand.Int63(), 0} x, y := rand.Int63(), rand.Int63() const query = "SELECT papa" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a[0], b[0]).AddRow(a[1], b[1]).AddRow("scan", "error") mock.ExpectBegin() mock.ExpectQuery(query).WithArgs(x, y).WillReturnRows(rows) mock.ExpectRollback() var aa, bb [3]int64 n := 0 err := sx.Do(db, func(tx *sx.Tx) { tx.MustQuery(query, x, y).Each(func(r *sx.Rows) { r.MustScan(&aa[n], &bb[n]) n++ }) }) if n != 2 { t.Errorf("Expected 2 rows before the scan error, got %d", n) } else if err == nil || !strings.Contains(err.Error(), "Scan error") { t.Errorf("unexpected error: %v", err) } else if aa != a || bb != b { t.Errorf("Expected results (%d, %d), (%d, %d) before the scan error, got (%d, %d), (%d, %d)", a[0], b[0], a[1], b[1], aa[0], bb[0], aa[1], bb[1]) } endMock(t, mock) }) t.Run("MustQueryContext with error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT juliett_context" err0 := errors.New("juliett_context error") mock.ExpectBegin() mock.ExpectQuery(query).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustQueryContext(context.TODO(), query) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) } func TestMustPrepare(t *testing.T) { t.Run("MustPrepare with MustExec and result", func(t *testing.T) { db, mock := newMock(t) a, b, x := rand.Int63(), rand.Int63(), rand.Int63() const query = "SELECT quebec" mock.ExpectBegin() mock.ExpectPrepare(query).ExpectExec().WithArgs(x).WillReturnResult(sqlmock.NewResult(a, b)) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { tx.MustPrepare(query).Do(func(stmt *sx.Stmt) { res := stmt.MustExec(x) a0, _ := res.LastInsertId() b0, _ := res.RowsAffected() if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } }) }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustPrepare with MustExec and error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT romeo" err0 := errors.New("romeo error") mock.ExpectBegin() mock.ExpectPrepare(query).ExpectExec().WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustPrepare(query).Do(func(stmt *sx.Stmt) { stmt.MustExec() }) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustPrepare with MustQueryRow and result", func(t *testing.T) { db, mock := newMock(t) a, b, x := rand.Int63(), rand.Int63(), rand.Int63() const query = "SELECT sierra" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a, b) mock.ExpectBegin() mock.ExpectPrepare(query).ExpectQuery().WithArgs(x).WillReturnRows(rows) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { var a0, b0 int64 tx.MustPrepare(query).Do(func(stmt *sx.Stmt) { stmt.MustQueryRow(x).MustScan(&a0, &b0) if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } }) }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustPrepare with MustQueryRow and error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT tango" err0 := errors.New("tango error") mock.ExpectBegin() mock.ExpectPrepare(query).ExpectQuery().WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { var a0, b0 int64 tx.MustPrepare(query).Do(func(stmt *sx.Stmt) { stmt.MustQueryRow().MustScan(&a0, &b0) }) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustPrepare with MustQuery and result", func(t *testing.T) { db, mock := newMock(t) a, b, x := rand.Int63(), rand.Int63(), rand.Int63() const query = "SELECT uniform" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a, b) mock.ExpectBegin() mock.ExpectPrepare(query).ExpectQuery().WithArgs(x).WillReturnRows(rows) mock.ExpectCommit() var a0, b0 int64 n := 0 err := sx.Do(db, func(tx *sx.Tx) { tx.MustPrepare(query).Do(func(stmt *sx.Stmt) { stmt.MustQuery(x).Each(func(r *sx.Rows) { r.MustScan(&a0, &b0) n++ }) }) }) if err != nil { t.Errorf("unexpected error: %v", err) } if n != 1 { t.Errorf("Expected 1 row, got %d", n) } else if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } endMock(t, mock) }) t.Run("MustPrepare with MustQuery and error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT victor" err0 := errors.New("victor error") mock.ExpectBegin() mock.ExpectPrepare(query).ExpectQuery().WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustPrepare(query).Do(func(stmt *sx.Stmt) { stmt.MustQuery() }) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustPrepare with error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT whiskey" err0 := errors.New("whiskey error") mock.ExpectBegin() mock.ExpectPrepare(query).WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustPrepare(query) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("MustPrepareContext with MustQueryRowContext and result", func(t *testing.T) { db, mock := newMock(t) a, b, x := rand.Int63(), rand.Int63(), rand.Int63() const query = "SELECT sierra_context" rows := sqlmock.NewRows([]string{"a", "b"}).AddRow(a, b) mock.ExpectBegin() mock.ExpectPrepare(query).ExpectQuery().WithArgs(x).WillReturnRows(rows) mock.ExpectCommit() err := sx.Do(db, func(tx *sx.Tx) { var a0, b0 int64 tx.MustPrepareContext(context.TODO(), query).Do(func(stmt *sx.Stmt) { stmt.MustQueryRowContext(context.TODO(), x).MustScan(&a0, &b0) if a0 != a || b0 != b { t.Errorf("Expected result (%d, %d), got (%d, %d)", a, b, a0, b0) } }) }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) t.Run("MustPrepareContext with MustQueryContext and error", func(t *testing.T) { db, mock := newMock(t) const query = "SELECT victor_context" err0 := errors.New("victor_context error") mock.ExpectBegin() mock.ExpectPrepare(query).ExpectQuery().WillReturnError(err0) mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.MustPrepareContext(context.TODO(), query).Do(func(stmt *sx.Stmt) { stmt.MustQueryContext(context.TODO()) }) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) } func TestFail(t *testing.T) { t.Run("explicit fail", func(t *testing.T) { db, mock := newMock(t) err0 := errors.New("x-ray error") mock.ExpectBegin() mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { tx.Fail(err0) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("begin transaction fail", func(t *testing.T) { db, mock := newMock(t) err0 := errors.New("yankee error") mock.ExpectBegin().WillReturnError(err0) err := sx.Do(db, func(tx *sx.Tx) { tx.Fail(errors.New("should never happen")) }) if err != err0 { t.Errorf("expected error %v, got %v", err0, err) } endMock(t, mock) }) t.Run("panic inside transaction", func(t *testing.T) { // This test ensures that an arbitrary panic inside a transaction is not erroneously caught by us and instead // gets propagated back up as a panic. db, mock := newMock(t) err0 := errors.New("zulu error") mock.ExpectBegin() var err error func() { defer func() { if r := recover(); r != nil { if e, ok := r.(error); ok { err = e } } }() sx.Do(db, func(tx *sx.Tx) { panic(err0) }) }() if err != err0 { t.Errorf("expected panic %v, got %v", err0, err) } endMock(t, mock) }) t.Run("explicit nil fail", func(t *testing.T) { db, mock := newMock(t) mock.ExpectBegin() mock.ExpectRollback() err := sx.Do(db, func(tx *sx.Tx) { // This should roll back the transaction and return a nil error tx.Fail(nil) }) if err != nil { t.Errorf("unexpected error: %v", err) } endMock(t, mock) }) }