Repository: bokwoon95/sq Branch: main Commit: eae3b0c03361 Files: 44 Total size: 745.8 KB Directory structure: gitextract_lej4qbg4/ ├── .github/ │ ├── mddocs │ └── workflows/ │ ├── neocities.yml │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── START_HERE.md ├── builtins.go ├── builtins_test.go ├── colors.go ├── cte.go ├── cte_test.go ├── delete_query.go ├── delete_query_test.go ├── fetch_exec.go ├── fetch_exec_test.go ├── fields.go ├── fields_test.go ├── fmt.go ├── fmt_test.go ├── go.mod ├── go.sum ├── insert_query.go ├── insert_query_test.go ├── integration_test.go ├── internal/ │ ├── googleuuid/ │ │ └── googleuuid.go │ ├── pqarray/ │ │ └── pqarray.go │ └── testutil/ │ └── testutil.go ├── joins.go ├── joins_test.go ├── logger.go ├── logger_test.go ├── misc.go ├── misc_test.go ├── row_column.go ├── select_query.go ├── select_query_test.go ├── sq.go ├── sq.md ├── sq_test.go ├── update_query.go ├── update_query_test.go ├── window.go └── window_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/neocities.yml ================================================ name: Deploy docs to Neocities on: push: branches: [main] jobs: deploy_to_neocities: runs-on: ubuntu-latest steps: - name: Clone repo uses: actions/checkout@v3 - run: mkdir public && .github/mddocs sq.md public/sq.html - name: Deploy to neocities uses: bcomnes/deploy-to-neocities@v1 with: api_token: ${{ secrets.NEOCITIES_API_KEY }} dist_dir: public ================================================ FILE: .github/workflows/tests.yml ================================================ name: tests on: push: branches: [main] pull_request: branches: [main] jobs: run_sq_tests: runs-on: ubuntu-latest services: postgres: image: postgres env: POSTGRES_USER: 'user1' POSTGRES_PASSWORD: 'Hunter2!' POSTGRES_DB: 'sakila' options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 ports: - '5456:5432' mysql: image: mysql env: MYSQL_ROOT_PASSWORD: 'Hunter2!' MYSQL_USER: 'user1' MYSQL_PASSWORD: 'Hunter2!' MYSQL_DATABASE: 'sakila' options: >- --health-cmd "mysqladmin ping" --health-interval 10s --health-timeout 5s --health-retries 5 --health-start-period 30s ports: - '3330:3306' sqlserver: image: 'mcr.microsoft.com/azure-sql-edge' env: ACCEPT_EULA: 'Y' MSSQL_SA_PASSWORD: 'Hunter2!' options: >- --health-cmd "/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P Hunter2! -Q 'select 1' -b -o /dev/null" --health-interval 10s --health-timeout 5s --health-retries 5 --health-start-period 30s ports: - '1447:1433' steps: - name: Install go uses: actions/setup-go@v3 with: go-version: '>=1.18.0' - name: Clone repo uses: actions/checkout@v3 - run: go test . -tags=fts5 -failfast -shuffle on -coverprofile coverage -race -postgres 'postgres://user1:Hunter2!@localhost:5456/sakila?sslmode=disable' -mysql 'root:Hunter2!@tcp(localhost:3330)/sakila?multiStatements=true&parseTime=true' -sqlserver 'sqlserver://sa:Hunter2!@localhost:1447' - name: Convert coverage to coverage.lcov uses: jandelgado/gcov2lcov-action@v1.0.0 with: infile: coverage outfile: coverage.lcov - name: Upload coverage.lcov to Coveralls uses: coverallsapp/github-action@master with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: coverage.lcov ================================================ FILE: .gitignore ================================================ *.sqlite* .idea coverage.out coverage ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Chua Bok Woon Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ [![GoDoc](https://img.shields.io/badge/pkg.go.dev-sq-blue)](https://pkg.go.dev/github.com/bokwoon95/sq) ![tests](https://github.com/bokwoon95/sq/actions/workflows/tests.yml/badge.svg?branch=main) [![Go Report Card](https://goreportcard.com/badge/github.com/bokwoon95/sq)](https://goreportcard.com/report/github.com/bokwoon95/sq) [![Coverage Status](https://coveralls.io/repos/github/bokwoon95/sq/badge.svg?branch=main)](https://coveralls.io/github/bokwoon95/sq?branch=main) code example of a select query using sq, to give viewers a quick idea of what the library is about # sq (Structured Query) [one-page documentation](https://bokwoon.neocities.org/sq.html) sq is a type-safe data mapper and query builder for Go. Its concept is simple: you provide a callback function that maps a row to a struct, generics ensure that you get back a slice of structs at the end. Additionally, mentioning a column in the callback function automatically adds it to the SELECT clause so you don't even have to explicitly mention what columns you want to select: the [act of mapping a column is the same as selecting it](#select-example-raw-sql). This eliminates a source of errors where you have specify the columns twice (once in the query itself, once to the call to rows.Scan) and end up missing a column, getting the column order wrong or mistyping a column name. Notable features: - Works across SQLite, Postgres, MySQL and SQL Server. [[more info](https://bokwoon.neocities.org/sq.html#set-query-dialect)] - Each dialect has its own query builder, allowing you to use dialect-specific features. [[more info](https://bokwoon.neocities.org/sq.html#dialect-specific-features)] - Declarative schema migrations. [[more info](https://bokwoon.neocities.org/sq.html#declarative-schema)] - Supports arrays, enums, JSON and UUID. [[more info](https://bokwoon.neocities.org/sq.html#arrays-enums-json-uuid)] - Query logging. [[more info](https://bokwoon.neocities.org/sq.html#logging)] # Installation This package only supports Go 1.19 and above. ```shell $ go get github.com/bokwoon95/sq $ go install -tags=fts5 github.com/bokwoon95/sqddl@latest ``` # Features - IN - [In Slice](https://bokwoon.neocities.org/sq.html#in-slice) - `a IN (1, 2, 3)` - [In RowValues](https://bokwoon.neocities.org/sq.html#in-rowvalues) - `(a, b, c) IN ((1, 2, 3), (4, 5, 6), (7, 8, 9))` - [In Subquery](https://bokwoon.neocities.org/sq.html#in-subquery) - `(a, b) IN (SELECT a, b FROM tbl WHERE condition)` - CASE - [Predicate Case](https://bokwoon.neocities.org/sq.html#predicate-case) - `CASE WHEN a THEN b WHEN c THEN d ELSE e END` - [Simple case](https://bokwoon.neocities.org/sq.html#simple-case) - `CASE expr WHEN a THEN b WHEN c THEN d ELSE e END` - EXISTS - [Where Exists](https://bokwoon.neocities.org/sq.html#where-exists) - [Where Not Exists](https://bokwoon.neocities.org/sq.html#where-not-exists) - [Select Exists](https://bokwoon.neocities.org/sq.html#querybuilder-fetch-exists) - [Subqueries](https://bokwoon.neocities.org/sq.html#subqueries) - [WITH (Common Table Expressions)](https://bokwoon.neocities.org/sq.html#common-table-expressions) - [Aggregate functions](https://bokwoon.neocities.org/sq.html#aggregate-functions) - [Window functions](https://bokwoon.neocities.org/sq.html#window-functions) - [UNION, INTERSECT, EXCEPT](https://bokwoon.neocities.org/sq.html#union-intersect-except) - [INSERT from SELECT](https://bokwoon.neocities.org/sq.html#querybuilder-insert-from-select) - RETURNING - [SQLite RETURNING](https://bokwoon.neocities.org/sq.html#sqlite-returning) - [Postgres RETURNING](https://bokwoon.neocities.org/sq.html#postgres-returning) - LastInsertId - [SQLite LastInsertId](https://bokwoon.neocities.org/sq.html#sqlite-last-insert-id) - [MySQL LastInsertId](https://bokwoon.neocities.org/sq.html#mysql-last-insert-id) - Insert ignore duplicates - [SQLite Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#sqlite-insert-ignore-duplicates) - [Postgres Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#postgres-insert-ignore-duplicates) - [MySQL Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#mysql-insert-ignore-duplicates) - [SQL Server Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#sqlserver-insert-ignore-duplicates) - Upsert - [SQLite Upsert](https://bokwoon.neocities.org/sq.html#sqlite-upsert) - [Postgres Upsert](https://bokwoon.neocities.org/sq.html#postgres-upsert) - [MySQL Upsert](https://bokwoon.neocities.org/sq.html#mysql-upsert) - [SQL Server Upsert](https://bokwoon.neocities.org/sq.html#sqlserver-upsert) - Update with Join - [SQLite Update with Join](https://bokwoon.neocities.org/sq.html#sqlite-update-with-join) - [Postgres Update with Join](https://bokwoon.neocities.org/sq.html#postgres-update-with-join) - [MySQL Update with Join](https://bokwoon.neocities.org/sq.html#mysql-update-with-join) - [SQL Server Update with Join](https://bokwoon.neocities.org/sq.html#sqlserver-update-with-join) - Delete with Join - [SQLite Delete with Join](https://bokwoon.neocities.org/sq.html#sqlite-delete-with-join) - [Postgres Delete with Join](https://bokwoon.neocities.org/sq.html#postgres-delete-with-join) - [MySQL Delete with Join](https://bokwoon.neocities.org/sq.html#mysql-delete-with-join) - [SQL Server Delete with Join](https://bokwoon.neocities.org/sq.html#sqlserver-delete-with-join) - Bulk Update - [SQLite Bulk Update](https://bokwoon.neocities.org/sq.html#sqlite-bulk-update) - [Postgres Bulk Update](https://bokwoon.neocities.org/sq.html#postgres-bulk-update) - [MySQL Bulk Update](https://bokwoon.neocities.org/sq.html#mysql-bulk-update) - [SQL Server Bulk Update](https://bokwoon.neocities.org/sq.html#sqlserver-bulk-update) ## SELECT example (Raw SQL) ```go db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") actors, err := sq.FetchAll(db, sq. Queryf("SELECT {*} FROM actor AS a WHERE a.actor_id IN ({})", []int{1, 2, 3, 4, 5}, ). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.Int("a.actor_id"), FirstName: row.String("a.first_name"), LastName: row.String("a.last_name"), LastUpdate: row.Time("a.last_update"), } }, ) ``` ## SELECT example (Query Builder) To use the query builder, you must first [define your table structs](https://bokwoon.neocities.org/sq.html#table-structs). ```go type ACTOR struct { sq.TableStruct ACTOR_ID sq.NumberField FIRST_NAME sq.StringField LAST_NAME sq.StringField LAST_UPDATE sq.TimeField } db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") a := sq.New[ACTOR]("a") actors, err := sq.FetchAll(db, sq. From(a). Where(a.ACTOR_ID.In([]int{1, 2, 3, 4, 5})). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), LastUpdate: row.TimeField(a.LAST_UPDATE), } }, ) ``` ## INSERT example (Raw SQL) ```go db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") _, err := sq.Exec(db, sq. Queryf("INSERT INTO actor (actor_id, first_name, last_name) VALUES {}", sq.RowValues{ {18, "DAN", "TORN"}, {56, "DAN", "HARRIS"}, {166, "DAN", "STREEP"}, }). SetDialect(sq.DialectPostgres), ) ``` ## INSERT example (Query Builder) To use the query builder, you must first [define your table structs](https://bokwoon.neocities.org/sq.html#table-structs). ```go type ACTOR struct { sq.TableStruct ACTOR_ID sq.NumberField FIRST_NAME sq.StringField LAST_NAME sq.StringField LAST_UPDATE sq.TimeField } db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") a := sq.New[ACTOR]("a") _, err := sq.Exec(db, sq. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(18, "DAN", "TORN"). Values(56, "DAN", "HARRIS"). Values(166, "DAN", "STREEP"). SetDialect(sq.DialectPostgres), ) ``` For a more detailed overview, look at the [Quickstart](https://bokwoon.neocities.org/sq.html#quickstart). ## Project Status sq is done for my use case (hence it may seem inactive, but it's just complete). At this point I'm just waiting for people to ask questions or file feature requests under [discussions](https://github.com/bokwoon95/sq/discussions). ## Contributing See [START\_HERE.md](https://github.com/bokwoon95/sq/blob/main/START_HERE.md). ================================================ FILE: START_HERE.md ================================================ This document describes how the codebase is organized. It is meant for people who are contributing to the codebase (or are just casually browsing). Files are written in such a way that **each successive file in the list below only depends on files that come before it**. This self-enforced restriction makes deep architectural changes trivial because you can essentially blow away the entire codebase and rewrite it from scratch file-by-file, complete with working tests every step of the way. Please adhere to this file order when submitting pull requests. - [**sq.go**](https://github.com/bokwoon95/sq/blob/main/sq.go) - Core interfaces: SQLWriter, DB, Query, Table, PolicyTable, Window, Field, Predicate, Assignment, Any, Array, Binary, Boolean, Enum, JSON, Number, String, UUID, Time, Enumeration, DialectValuer, - Data types: Result, TableStruct, ViewStruct. - Misc utility functions. - [**fmt.go**](https://github.com/bokwoon95/sq/blob/main/fmt.go) - Two important string building functions that everything else is built on: [Writef](https://pkg.go.dev/github.com/bokwoon95/sq#Writef) and [WriteValue](https://pkg.go.dev/github.com/bokwoon95/sq#WriteValue). - Data types: Parameter, BinaryParameter, BooleanParameter, NumberParameter, StringParameter, TimeParameter. - Utility functions: QuoteIdentifier, EscapeQuote, Sprintf, Sprint. - [**builtins.go**](https://github.com/bokwoon95/sq/blob/main/builtins.go) - Builtin data types that are built on top of Writef and WriteValue: Expression (Expr), CustomQuery (Queryf), VariadicPredicate, assignment, RowValue, RowValues, Fields. - Builtin functions that are built on top of Writef and WriteValue: Eq, Ne, Lt, Le, Gt, Ge, Exists, NotExists, In. - [**fields.go**](https://github.com/bokwoon95/sq/blob/main/fields.go) - All of the field types: AnyField, ArrayField, BinaryField, BooleanField, EnumField, JSONField, NumberField, StringField, UUIDField, TimeField. - Data types: Identifier, Timestamp. - Functions: [New](https://pkg.go.dev/github.com/bokwoon95/sq#New), ArrayValue, EnumValue, JSONValue, UUIDValue. - [**cte.go**](https://github.com/bokwoon95/sq/blob/main/cte.go) - CTE represents an SQL common table expression (CTE). - UNION, INTERSECT, EXCEPT. - [**joins.go**](https://github.com/bokwoon95/sq/blob/main/joins.go) - The various SQL joins. - [**row_column.go**](https://github.com/bokwoon95/sq/blob/main/row_column.go) - Row and Column methods. - [**window.go**](https://github.com/bokwoon95/sq/blob/main/window.go) - SQL windows and window functions. - [**select_query.go**](https://github.com/bokwoon95/sq/blob/main/select_query.go) - SQL SELECT query builder. - [**insert_query.go**](https://github.com/bokwoon95/sq/blob/main/insert_query.go) - SQL INSERT query builder. - [**update_query.go**](https://github.com/bokwoon95/sq/blob/main/update_query.go) - SQL UPDATE query builder. - [**delete_query.go**](https://github.com/bokwoon95/sq/blob/main/delete_query.go) - SQL DELETE query builder. - [**logger.go**](https://github.com/bokwoon95/sq/blob/main/logger.go) - sq.Log and sq.VerboseLog. - [**fetch_exec.go**](https://github.com/bokwoon95/sq/blob/main/fetch_exec.go) - FetchCursor, FetchOne, FetchAll, Exec. - CompiledFetch, CompiledExec. - PreparedFetch, PreparedExec. - [**misc.go**](https://github.com/bokwoon95/sq/blob/main/misc.go) - Misc SQL constructs. - ValueExpression, LiteralValue, DialectExpression, CaseExpression, SimpleCaseExpression. - SelectValues (`SELECT ... UNION ALL SELECT ... UNION ALL SELECT ...`) - TableValues (`VALUES (...), (...), (...)`). - [**integration_test.go**](https://github.com/bokwoon95/sq/blob/main/integration_test.go) - Tests that interact with a live database i.e. SQLite, Postgres, MySQL and SQL Server. ## Testing Add tests if you add code. To run tests, use: ```shell $ go test . # -failfast -shuffle=on -coverprofile=coverage ``` There are tests that require a live database connection. They will only run if you provide the corresponding database URL in the test flags: ```shell $ go test . -postgres $POSTGRES_URL -mysql $MYSQL_URL -sqlserver $SQLSERVER_URL # -failfast -shuffle=on -coverprofile=coverage ``` You can consider using the [docker-compose.yml defined in the sqddl repo](https://github.com/bokwoon95/sqddl/blob/main/docker-compose.yml) to spin up Postgres, MySQL and SQL Server databases that are reachable at the following URLs: ```shell # docker-compose up -d POSTGRES_URL='postgres://user1:Hunter2!@localhost:5456/sakila?sslmode=disable' MYSQL_URL='root:Hunter2!@tcp(localhost:3330)/sakila?multiStatements=true&parseTime=true' MARIADB_URL='root:Hunter2!@tcp(localhost:3340)/sakila?multiStatements=true&parseTime=true' SQLSERVER_URL='sqlserver://sa:Hunter2!@localhost:1447' ``` ## Documentation Documentation is contained entirely within [sq.md](https://github.com/bokwoon95/sq/blob/main/sq.md) in the project root directory. You can view the output at [https://bokwoon.neocities.org/sq.html](https://bokwoon.neocities.org/sq.html). The documentation is regenerated everytime a new commit is pushed to the main branch, so to change the documentation just change sq.md and submit a pull request. You can preview the output of sq.md locally by installing [github.com/bokwoon95/mddocs](https://github.com/bokwoon95/mddocs) and running it with sq.md as the argument. ```shell $ go install github/bokwoon95/mddocs@latest $ mddocs Usage: mddocs project.md # serves project.md on a localhost connection mddocs project.md project.html # render project.md into project.html $ mddocs sq.md serving sq.md at localhost:6060 ``` To add a new section and register it in the table of contents, append a `#headerID` to the end of a header (replace `headerID` with the actual header ID). The header ID should only contain unicode letters, digits, hyphen `-` and underscore `_`. ```text ## This is a header. ## This is a header with a headerID. #header-id <-- added to table of contents ``` ================================================ FILE: builtins.go ================================================ package sq import ( "bytes" "context" "fmt" "strings" ) // Expression is an SQL expression that satisfies the Table, Field, Predicate, // Binary, Boolean, Number, String and Time interfaces. type Expression struct { format string values []any alias string } var _ interface { Table Field Predicate Any Assignment } = (*Expression)(nil) // Expr creates a new Expression using Writef syntax. func Expr(format string, values ...any) Expression { return Expression{format: format, values: values} } // WriteSQL implements the SQLWriter interface. func (expr Expression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { err := Writef(ctx, dialect, buf, args, params, expr.format, expr.values) if err != nil { return err } return nil } // As returns a new Expression with the given alias. func (expr Expression) As(alias string) Expression { expr.alias = alias return expr } // In returns an 'expr IN (value)' Predicate. func (expr Expression) In(value any) Predicate { return In(expr, value) } // In returns an 'expr NOT IN (value)' Predicate. func (expr Expression) NotIn(value any) Predicate { return NotIn(expr, value) } // Eq returns an 'expr = value' Predicate. func (expr Expression) Eq(value any) Predicate { return cmp("=", expr, value) } // Ne returns an 'expr <> value' Predicate. func (expr Expression) Ne(value any) Predicate { return cmp("<>", expr, value) } // Lt returns an 'expr < value' Predicate. func (expr Expression) Lt(value any) Predicate { return cmp("<", expr, value) } // Le returns an 'expr <= value' Predicate. func (expr Expression) Le(value any) Predicate { return cmp("<=", expr, value) } // Gt returns an 'expr > value' Predicate. func (expr Expression) Gt(value any) Predicate { return cmp(">", expr, value) } // Ge returns an 'expr >= value' Predicate. func (expr Expression) Ge(value any) Predicate { return cmp(">=", expr, value) } // GetAlias returns the alias of the Expression. func (expr Expression) GetAlias() string { return expr.alias } // IsTable implements the Table interface. func (expr Expression) IsTable() {} // IsField implements the Field interface. func (expr Expression) IsField() {} // IsArray implements the Array interface. func (expr Expression) IsArray() {} // IsBinary implements the Binary interface. func (expr Expression) IsBinary() {} // IsBoolean implements the Boolean interface. func (expr Expression) IsBoolean() {} // IsEnum implements the Enum interface. func (expr Expression) IsEnum() {} // IsJSON implements the JSON interface. func (expr Expression) IsJSON() {} // IsNumber implements the Number interface. func (expr Expression) IsNumber() {} // IsString implements the String interface. func (expr Expression) IsString() {} // IsTime implements the Time interface. func (expr Expression) IsTime() {} // IsUUID implements the UUID interface. func (expr Expression) IsUUID() {} func (e Expression) IsAssignment() {} // CustomQuery represents a user-defined query. type CustomQuery struct { Dialect string Format string Values []any fields []Field } var _ Query = (*CustomQuery)(nil) // Queryf creates a new query using Writef syntax. func Queryf(format string, values ...any) CustomQuery { return CustomQuery{Format: format, Values: values} } // Queryf creates a new SQLite query using Writef syntax. func (b sqliteQueryBuilder) Queryf(format string, values ...any) CustomQuery { return CustomQuery{Dialect: DialectSQLite, Format: format, Values: values} } // Queryf creates a new Postgres query using Writef syntax. func (b postgresQueryBuilder) Queryf(format string, values ...any) CustomQuery { return CustomQuery{Dialect: DialectPostgres, Format: format, Values: values} } // Queryf creates a new MySQL query using Writef syntax. func (b mysqlQueryBuilder) Queryf(format string, values ...any) CustomQuery { return CustomQuery{Dialect: DialectMySQL, Format: format, Values: values} } // Queryf creates a new SQL Server query using Writef syntax. func (b sqlserverQueryBuilder) Queryf(format string, values ...any) CustomQuery { return CustomQuery{Dialect: DialectSQLServer, Format: format, Values: values} } // Append returns a new CustomQuery with the format string and values slice // appended to the current CustomQuery. func (q CustomQuery) Append(format string, values ...any) CustomQuery { q.Format += " " + format q.Values = append(q.Values, values...) return q } // WriteSQL implements the SQLWriter interface. func (q CustomQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error format := q.Format splitAt := -1 for i := strings.IndexByte(format, '{'); i >= 0; i = strings.IndexByte(format, '{') { if i+2 <= len(format) && format[i:i+2] == "{{" { format = format[i+2:] continue } if i+3 <= len(format) && format[i:i+3] == "{*}" { splitAt = len(q.Format) - len(format[i:]) break } format = format[i+1:] } if splitAt < 0 { return Writef(ctx, dialect, buf, args, params, q.Format, q.Values) } runningValuesIndex := 0 ordinalIndices := make(map[int]int) err = writef(ctx, dialect, buf, args, params, q.Format[:splitAt], q.Values, &runningValuesIndex, ordinalIndices) if err != nil { return err } err = writeFields(ctx, dialect, buf, args, params, q.fields, true) if err != nil { return err } err = writef(ctx, dialect, buf, args, params, q.Format[splitAt+3:], q.Values, &runningValuesIndex, ordinalIndices) if err != nil { return err } return nil } // SetFetchableFields sets the fetchable fields of the query. func (q CustomQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { format := q.Format for i := strings.IndexByte(format, '{'); i >= 0; i = strings.IndexByte(format, '{') { if i+2 <= len(format) && format[i:i+2] == "{{" { format = format[i+2:] continue } if i+3 <= len(format) && format[i:i+3] == "{*}" { q.fields = fields return q, true } format = format[i+1:] } return q, false } // GetFetchableFields gets the fetchable fields of the query. func (q CustomQuery) GetFetchableFields() []Field { return q.fields } // GetDialect gets the dialect of the query. func (q CustomQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q CustomQuery) SetDialect(dialect string) CustomQuery { q.Dialect = dialect return q } // VariadicPredicate represents the 'x AND y AND z...' or 'x OR Y OR z...' SQL // construct. type VariadicPredicate struct { // Toplevel indicates if the VariadicPredicate can skip writing the // (surrounding brackets). Toplevel bool alias string // If IsDisjunction is true, the Predicates are joined using OR. If false, // the Predicates are joined using AND. The default is AND. IsDisjunction bool // Predicates holds the predicates inside the VariadicPredicate Predicates []Predicate } var _ Predicate = (*VariadicPredicate)(nil) // And joins the predicates together with the AND operator. func And(predicates ...Predicate) VariadicPredicate { return VariadicPredicate{IsDisjunction: false, Predicates: predicates} } // Or joins the predicates together with the OR operator. func Or(predicates ...Predicate) VariadicPredicate { return VariadicPredicate{IsDisjunction: true, Predicates: predicates} } // WriteSQL implements the SQLWriter interface. func (p VariadicPredicate) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error if len(p.Predicates) == 0 { return fmt.Errorf("VariadicPredicate empty") } if len(p.Predicates) == 1 { switch p1 := p.Predicates[0].(type) { case nil: return fmt.Errorf("predicate #1 is nil") case VariadicPredicate: p1.Toplevel = p.Toplevel err = p1.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } default: err = p.Predicates[0].WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } } return nil } if !p.Toplevel { buf.WriteString("(") } for i, predicate := range p.Predicates { if i > 0 { if p.IsDisjunction { buf.WriteString(" OR ") } else { buf.WriteString(" AND ") } } switch predicate := predicate.(type) { case nil: return fmt.Errorf("predicate #%d is nil", i+1) case VariadicPredicate: predicate.Toplevel = false err = predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("predicate #%d: %w", i+1, err) } default: err = predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("predicate #%d: %w", i+1, err) } } } if !p.Toplevel { buf.WriteString(")") } return nil } // As returns a new VariadicPredicate with the given alias. func (p VariadicPredicate) As(alias string) VariadicPredicate { p.alias = alias return p } // GetAlias returns the alias of the VariadicPredicate. func (p VariadicPredicate) GetAlias() string { return p.alias } // IsField implements the Field interface. func (p VariadicPredicate) IsField() {} // IsBooleanType implements the Predicate interface. func (p VariadicPredicate) IsBoolean() {} // assignment represents assigning a value to a Field. type assignment struct { field Field value any } var _ Assignment = (*assignment)(nil) // Set creates a new Assignment assigning the value to a field. func Set(field Field, value any) Assignment { return assignment{field: field, value: value} } // Setf creates a new Assignment assigning a custom expression to a Field. func Setf(field Field, format string, values ...any) Assignment { return assignment{field: field, value: Expr(format, values...)} } // WriteSQL implements the SQLWriter interface. func (a assignment) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { if a.field == nil { return fmt.Errorf("field is nil") } var err error if dialect == DialectMySQL { err = a.field.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } } else { err = withPrefix(a.field, "").WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } } buf.WriteString(" = ") _, isQuery := a.value.(Query) if isQuery { buf.WriteString("(") } err = WriteValue(ctx, dialect, buf, args, params, a.value) if err != nil { return err } if isQuery { buf.WriteString(")") } return nil } // IsAssignment implements the Assignment interface. func (a assignment) IsAssignment() {} // Assignments represents a list of Assignments e.g. x = 1, y = 2, z = 3. type Assignments []Assignment // WriteSQL implements the SQLWriter interface. func (as Assignments) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error for i, assignment := range as { if assignment == nil { return fmt.Errorf("assignment #%d is nil", i+1) } if i > 0 { buf.WriteString(", ") } err = assignment.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("assignment #%d: %w", i+1, err) } } return nil } // RowValue represents an SQL row value expression e.g. (x, y, z). type RowValue []any // WriteSQL implements the SQLWriter interface. func (r RowValue) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { buf.WriteString("(") var err error for i, value := range r { if i > 0 { buf.WriteString(", ") } _, isQuery := value.(Query) if isQuery { buf.WriteString("(") } err = WriteValue(ctx, dialect, buf, args, params, value) if err != nil { return fmt.Errorf("rowvalue #%d: %w", i+1, err) } if isQuery { buf.WriteString(")") } } buf.WriteString(")") return nil } // In returns an 'rowvalue IN (value)' Predicate. func (r RowValue) In(v any) Predicate { return In(r, v) } // NotIn returns an 'rowvalue NOT IN (value)' Predicate. func (r RowValue) NotIn(v any) Predicate { return NotIn(r, v) } // Eq returns an 'rowvalue = value' Predicate. func (r RowValue) Eq(v any) Predicate { return cmp("=", r, v) } // RowValues represents a list of RowValues e.g. (x, y, z), (a, b, c). type RowValues []RowValue // WriteSQL implements the SQLWriter interface. func (rs RowValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error for i, r := range rs { if i > 0 { buf.WriteString(", ") } err = r.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("rowvalues #%d: %w", i+1, err) } } return nil } // Fields represents a list of Fields e.g. tbl.field1, tbl.field2, tbl.field3. type Fields []Field // WriteSQL implements the SQLWriter interface. func (fs Fields) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error for i, field := range fs { if field == nil { return fmt.Errorf("field #%d is nil", i+1) } if i > 0 { buf.WriteString(", ") } _, isQuery := field.(Query) if isQuery { buf.WriteString("(") } err = field.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("field #%d: %w", i+1, err) } if isQuery { buf.WriteString(")") } } return nil } type ( sqliteQueryBuilder struct{ ctes []CTE } postgresQueryBuilder struct{ ctes []CTE } mysqlQueryBuilder struct{ ctes []CTE } sqlserverQueryBuilder struct{ ctes []CTE } ) // Dialect-specific query builder variables. var ( SQLite sqliteQueryBuilder Postgres postgresQueryBuilder MySQL mysqlQueryBuilder SQLServer sqlserverQueryBuilder ) // With sets the CTEs in the SQLiteQueryBuilder. func (b sqliteQueryBuilder) With(ctes ...CTE) sqliteQueryBuilder { b.ctes = ctes return b } // With sets the CTEs in the PostgresQueryBuilder. func (b postgresQueryBuilder) With(ctes ...CTE) postgresQueryBuilder { b.ctes = ctes return b } // With sets the CTEs in the MySQLQueryBuilder. func (b mysqlQueryBuilder) With(ctes ...CTE) mysqlQueryBuilder { b.ctes = ctes return b } // With sets the CTEs in the SQLServerQueryBuilder. func (b sqlserverQueryBuilder) With(ctes ...CTE) sqlserverQueryBuilder { b.ctes = ctes return b } // ToSQL converts an SQLWriter into a query string and args slice. // // The params map is used to hold the mappings between named parameters in the // query to the corresponding index in the args slice and is used for rebinding // args by their parameter name. If you don't need to track this, you can pass // in a nil map. func ToSQL(dialect string, w SQLWriter, params map[string][]int) (query string, args []any, err error) { return ToSQLContext(context.Background(), dialect, w, params) } // ToSQLContext is like ToSQL but additionally requires a context.Context. func ToSQLContext(ctx context.Context, dialect string, w SQLWriter, params map[string][]int) (query string, args []any, err error) { if w == nil { return "", nil, fmt.Errorf("SQLWriter is nil") } if dialect == "" { if q, ok := w.(Query); ok { dialect = q.GetDialect() } } buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) err = w.WriteSQL(ctx, dialect, buf, &args, params) query = buf.String() if err != nil { return query, args, err } return query, args, nil } // Eq returns an 'x = y' Predicate. func Eq(x, y any) Predicate { return cmp("=", x, y) } // Ne returns an 'x <> y' Predicate. func Ne(x, y any) Predicate { return cmp("<>", x, y) } // Lt returns an 'x < y' Predicate. func Lt(x, y any) Predicate { return cmp("<", x, y) } // Le returns an 'x <= y' Predicate. func Le(x, y any) Predicate { return cmp("<=", x, y) } // Gt returns an 'x > y' Predicate. func Gt(x, y any) Predicate { return cmp(">", x, y) } // Ge returns an 'x >= y' Predicate. func Ge(x, y any) Predicate { return cmp(">=", x, y) } // Exists returns an 'EXISTS (query)' Predicate. func Exists(query Query) Predicate { return Expr("EXISTS ({})", query) } // NotExists returns a 'NOT EXISTS (query)' Predicate. func NotExists(query Query) Predicate { return Expr("NOT EXISTS ({})", query) } // In returns an 'x IN (y)' Predicate. func In(x, y any) Predicate { _, isQueryA := x.(Query) _, isRowValueB := y.(RowValue) if !isQueryA && !isRowValueB { return Expr("{} IN ({})", x, y) } else if !isQueryA && isRowValueB { return Expr("{} IN {}", x, y) } else if isQueryA && !isRowValueB { return Expr("({}) IN ({})", x, y) } else { return Expr("({}) IN {}", x, y) } } // NotIn returns an 'x NOT IN (y)' Predicate. func NotIn(x, y any) Predicate { _, isQueryA := x.(Query) _, isRowValueB := y.(RowValue) if !isQueryA && !isRowValueB { return Expr("{} NOT IN ({})", x, y) } else if !isQueryA && isRowValueB { return Expr("{} NOT IN {}", x, y) } else if isQueryA && !isRowValueB { return Expr("({}) NOT IN ({})", x, y) } else { return Expr("({}) NOT IN {}", x, y) } } // cmp returns an 'x y' Predicate. func cmp(operator string, x, y any) Expression { _, isQueryA := x.(Query) _, isQueryB := y.(Query) if !isQueryA && !isQueryB { return Expr("{} "+operator+" {}", x, y) } else if !isQueryA && isQueryB { return Expr("{} "+operator+" ({})", x, y) } else if isQueryA && !isQueryB { return Expr("({}) "+operator+" {}", x, y) } else { return Expr("({}) "+operator+" ({})", x, y) } } // appendPolicy will append a policy from a Table (if it implements // PolicyTable) to a slice of policies. The resultant slice is returned. func appendPolicy(ctx context.Context, dialect string, policies []Predicate, table Table) ([]Predicate, error) { policyTable, ok := table.(PolicyTable) if !ok { return policies, nil } policy, err := policyTable.Policy(ctx, dialect) if err != nil { return nil, err } if policy != nil { policies = append(policies, policy) } return policies, nil } // appendPredicates will append a slices of predicates into a predicate. func appendPredicates(predicate Predicate, predicates []Predicate) VariadicPredicate { if predicate == nil { return And(predicates...) } if p1, ok := predicate.(VariadicPredicate); ok && !p1.IsDisjunction { p1.Predicates = append(p1.Predicates, predicates...) return p1 } p2 := VariadicPredicate{Predicates: make([]Predicate, 1+len(predicates))} p2.Predicates[0] = predicate copy(p2.Predicates[1:], predicates) return p2 } func writeTop(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, topLimit, topPercentLimit any, withTies bool) error { var err error if topLimit != nil { buf.WriteString("TOP (") err = WriteValue(ctx, dialect, buf, args, params, topLimit) if err != nil { return fmt.Errorf("TOP: %w", err) } buf.WriteString(") ") } else if topPercentLimit != nil { buf.WriteString("TOP (") err = WriteValue(ctx, dialect, buf, args, params, topPercentLimit) if err != nil { return fmt.Errorf("TOP PERCENT: %w", err) } buf.WriteString(") PERCENT ") } if (topLimit != nil || topPercentLimit != nil) && withTies { buf.WriteString("WITH TIES ") } return nil } ================================================ FILE: builtins_test.go ================================================ package sq import ( "bytes" "context" "database/sql" "errors" "testing" "github.com/bokwoon95/sq/internal/testutil" ) func TestExpression(t *testing.T) { t.Run("schema, name and alias", func(t *testing.T) { t.Parallel() expr := Expr("COUNT(*)").As("total") if diff := testutil.Diff(expr.GetAlias(), "total"); diff != "" { t.Error(testutil.Callers(), diff) } }) tests := []TestTable{{ description: "basic", dialect: DialectSQLServer, item: Expr("CONCAT(CONCAT(name, {}), {})", "abc", sql.Named("xyz", "def")), wantQuery: "CONCAT(CONCAT(name, @p1), @xyz)", wantArgs: []any{"abc", sql.Named("xyz", "def")}, wantParams: map[string][]int{"xyz": {1}}, }, { description: "In", item: Expr("age").In([]int{18, 21, 32}), wantQuery: "age IN (?, ?, ?)", wantArgs: []any{18, 21, 32}, }, { description: "NotIn", item: Expr("age").NotIn([]int{18, 21, 32}), wantQuery: "age NOT IN (?, ?, ?)", wantArgs: []any{18, 21, 32}, }, { description: "Eq", item: Expr("age").Eq(34), wantQuery: "age = ?", wantArgs: []any{34}, }, { description: "Ne", item: Expr("age").Ne(34), wantQuery: "age <> ?", wantArgs: []any{34}, }, { description: "Lt", item: Expr("age").Lt(34), wantQuery: "age < ?", wantArgs: []any{34}, }, { description: "Le", item: Expr("age").Le(34), wantQuery: "age <= ?", wantArgs: []any{34}, }, { description: "Gt", item: Expr("age").Gt(34), wantQuery: "age > ?", wantArgs: []any{34}, }, { description: "Ge", item: Expr("age").Ge(34), wantQuery: "age >= ?", wantArgs: []any{34}, }, { description: "Exists", item: Exists(Queryf("SELECT 1 FROM tbl WHERE 1 = 1")), wantQuery: "EXISTS (SELECT 1 FROM tbl WHERE 1 = 1)", }, { description: "NotExists", item: NotExists(Queryf("SELECT 1 FROM tbl WHERE 1 = 1")), wantQuery: "NOT EXISTS (SELECT 1 FROM tbl WHERE 1 = 1)", }, { description: "Count", item: Count(Expr("name")), wantQuery: "COUNT(name)", }, { description: "CountStar", item: CountStar(), wantQuery: "COUNT(*)", }, { description: "Sum", item: Sum(Expr("score")), wantQuery: "SUM(score)", }, { description: "Avg", item: Avg(Expr("score")), wantQuery: "AVG(score)", }, { description: "Min", item: Min(Expr("score")), wantQuery: "MIN(score)", }, { description: "Max", item: Max(Expr("score")), wantQuery: "MAX(score)", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } t.Run("FaultySQL", func(t *testing.T) { t.Parallel() TestTable{item: Expr("SELECT {}", FaultySQL{})}.assertNotOK(t) }) } func TestVariadicPredicate(t *testing.T) { t.Run("name and alias", func(t *testing.T) { t.Parallel() p := And(Expr("True"), Expr("FALSE")).As("is_false") if diff := testutil.Diff(p.GetAlias(), "is_false"); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("empty", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = VariadicPredicate{} tt.assertNotOK(t) }) t.Run("nil predicate", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = And(nil) tt.assertNotOK(t) }) t.Run("1 predicate", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = And(cmp("=", Expr("score"), 21)) tt.wantQuery = "score = ?" tt.wantArgs = []any{21} tt.assert(t) }) t.Run("2 predicate", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = And(cmp("=", Expr("score"), 21), cmp("=", Expr("name"), "bob")) tt.wantQuery = "(score = ? AND name = ?)" tt.wantArgs = []any{21, "bob"} tt.assert(t) }) t.Run("multiple nested VariadicPredicate collapses into one", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = And(And(And(And(cmp("=", Expr("score"), 21))))) tt.wantQuery = "score = ?" tt.wantArgs = []any{21} tt.assert(t) }) t.Run("multiple predicates", func(t *testing.T) { t.Parallel() var tt TestTable user_id, name, age := Expr("user_id"), Expr("name"), Expr("age") tt.item = Or( Expr("{} IS NULL", name), cmp("=", age, age), And(cmp("=", age, age)), And( cmp("=", user_id, 1), cmp("<>", user_id, 2), cmp("<", user_id, 3), cmp("<=", user_id, 4), cmp(">", user_id, 5), cmp(">=", user_id, 6), ), ) tt.wantQuery = "(name IS NULL" + " OR age = age" + " OR age = age" + " OR (" + "user_id = ?" + " AND user_id <> ?" + " AND user_id < ?" + " AND user_id <= ?" + " AND user_id > ?" + " AND user_id >= ?" + "))" tt.wantArgs = []any{1, 2, 3, 4, 5, 6} tt.assert(t) }) t.Run("multiple predicates with nil", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Or( Expr("1 = 1"), And(Expr("TRUE"), Predicate(nil)), ) tt.assertNotOK(t) }) t.Run("VariadicPredicate alias", func(t *testing.T) { t.Parallel() p1 := And(Expr("TRUE")).As("abc") if diff := testutil.Diff(p1.GetAlias(), "abc"); diff != "" { t.Error(testutil.Callers(), diff) } p2 := p1.As("def") if diff := testutil.Diff(p1.GetAlias(), "abc"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(p2.GetAlias(), "def"); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("VariadicPredicate FaultySQL", func(t *testing.T) { t.Parallel() var tt TestTable // AND, 1 predicate tt.item = And(FaultySQL{}) tt.assertErr(t, ErrFaultySQL) // AND, multiple predicates tt.item = And(Expr("FALSE"), FaultySQL{}) tt.assertErr(t, ErrFaultySQL) // nested AND tt.item = And(And(FaultySQL{})) tt.assertErr(t, ErrFaultySQL) }) } func TestQueryf(t *testing.T) { t.Run("basic", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Queryf("SELECT {field} FROM {table} WHERE {predicate}", sql.Named("field", Expr("name")), sql.Named("table", Expr("users")), sql.Named("predicate", Expr("user_id = {}", 5)), ) tt.wantQuery = "SELECT name FROM users WHERE user_id = ?" tt.wantArgs = []any{5} tt.assert(t) }) t.Run("select star", func(t *testing.T) { t.Parallel() var tt TestTable q := Queryf("SELECT {*} FROM {table} WHERE {predicate}", sql.Named("table", Expr("users")), sql.Named("predicate", Expr("user_id = {}", 5)), ) if diff := testutil.Diff(q.GetDialect(), ""); diff != "" { t.Error(testutil.Callers(), diff) } q2, ok := q.SetFetchableFields([]Field{Expr("name"), Expr("age")}) if !ok { t.Fatal(testutil.Callers(), "not ok") } tt.item = q2 tt.wantQuery = "SELECT name, age FROM users WHERE user_id = ?" tt.wantArgs = []any{5} tt.assert(t) }) t.Run("escape curly brace", func(t *testing.T) { t.Parallel() var tt TestTable q := Queryf(`WITH cte AS (SELECT '{{*}' AS name) SELECT {*} FROM cte`) q2, ok := q.SetFetchableFields([]Field{Expr("name")}) if !ok { t.Fatal(testutil.Callers(), "not ok") } tt.item = q2 tt.wantQuery = "WITH cte AS (SELECT '{*}' AS name) SELECT name FROM cte" tt.assert(t) }) t.Run("mixed", func(t *testing.T) { t.Parallel() var tt TestTable q := Queryf(`{1} {3} {name} {} SELECT {*} FROM {1} {} {name} {}`, 5, sql.Named("name", "bob"), 10) q2, ok := q.SetFetchableFields([]Field{Expr("alpha"), Expr("SUBSTR({}, {})", "apple", 77), Expr("beta")}) if !ok { t.Fatal(testutil.Callers(), "not ok") } tt.dialect = DialectPostgres tt.item = q2 tt.wantQuery = "$1 $2 $3 $4 SELECT alpha, SUBSTR($5, $6), beta FROM $1 $3 $3 $7" tt.wantArgs = []any{5, 10, "bob", 5, "apple", 77, 10} tt.wantParams = map[string][]int{"name": {2}} tt.assert(t) }) t.Run("append", func(t *testing.T) { t.Parallel() var tt TestTable q := Queryf("SELECT {*} FROM tbl WHERE 1 = 1") q = q.Append("AND name = {}", "bob") q = q.Append("AND email = {email}", sql.Named("email", "bob@email.com")) q = q.Append("AND age = {age}", sql.Named("age", 27)) q2, ok := q.SetFetchableFields([]Field{Expr("name"), Expr("email")}) if !ok { t.Fatal(testutil.Callers(), "not ok") } tt.item = q2 tt.wantQuery = "SELECT name, email FROM tbl WHERE 1 = 1 AND name = ? AND email = ? AND age = ?" tt.wantArgs = []any{"bob", "bob@email.com", 27} tt.wantParams = map[string][]int{"email": {1}, "age": {2}} tt.assert(t) }) } func TestAssign(t *testing.T) { t.Run("AssignValue nil field", func(t *testing.T) { t.Parallel() _, _, err := ToSQL("", Set(nil, 1), nil) if err == nil { t.Error(testutil.Callers(), "expected error but got nil") } }) t.Run("AssignValue", func(t *testing.T) { t.Parallel() TestTable{ item: Set(tmpfield("tbl.field"), 1), wantQuery: "field = ?", wantArgs: []any{1}, }.assert(t) }) t.Run("mysql AssignValue", func(t *testing.T) { t.Parallel() TestTable{ dialect: DialectMySQL, item: Set(tmpfield("tbl.field"), 1), wantQuery: "tbl.field = ?", wantArgs: []any{1}, }.assert(t) }) t.Run("AssignValue err", func(t *testing.T) { t.Parallel() TestTable{ item: Set(tmpfield("tbl.field"), FaultySQL{}), }.assertErr(t, ErrFaultySQL) }) t.Run("Assignf nil field", func(t *testing.T) { t.Parallel() _, _, err := ToSQL("", Setf(nil, ""), nil) if err == nil { t.Error(testutil.Callers(), "expected error but got nil") } }) t.Run("Assignf", func(t *testing.T) { t.Parallel() TestTable{ item: Setf(tmpfield("tbl.field"), "EXCLUDED.field"), wantQuery: "field = EXCLUDED.field", }.assert(t) }) t.Run("Assignf err", func(t *testing.T) { t.Parallel() TestTable{ item: Setf(tmpfield("tbl.field"), "EXCLUDED.{}", FaultySQL{}), }.assertErr(t, ErrFaultySQL) }) } func TestRowValuesFieldsAssignments(t *testing.T) { tests := []TestTable{{ description: "RowValue", item: RowValue{1, 2, 3}, wantQuery: "(?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "RowValue with query", item: RowValue{1, 2, Queryf("SELECT {}", 3)}, wantQuery: "(?, ?, (SELECT ?))", wantArgs: []any{1, 2, 3}, }, { description: "RowValue In", item: RowValue{1, 2, 3}.In(RowValues{{4, 5, 6}, {7, 8, 9}}), wantQuery: "(?, ?, ?) IN ((?, ?, ?), (?, ?, ?))", wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, }, { description: "RowValue NotIn", item: RowValue{1, 2, 3}.NotIn(RowValues{{4, 5, 6}, {7, 8, 9}}), wantQuery: "(?, ?, ?) NOT IN ((?, ?, ?), (?, ?, ?))", wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, }, { description: "RowValue Eq", item: RowValue{1, 2, 3}.Eq(RowValue{4, 5, 6}), wantQuery: "(?, ?, ?) = (?, ?, ?)", wantArgs: []any{1, 2, 3, 4, 5, 6}, }, { description: "Fields", item: Fields{Expr("tbl.f1"), Expr("tbl.f2"), Expr("tbl.f3")}, wantQuery: "tbl.f1, tbl.f2, tbl.f3", }, { description: "Assignments", item: Assignments{Set(Expr("f1"), 1), Set(Expr("f2"), 2), Set(Expr("f3"), 3)}, wantQuery: "f1 = ?, f2 = ?, f3 = ?", wantArgs: []any{1, 2, 3}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } t.Run("nil fields/assignments", func(t *testing.T) { t.Parallel() // Fields TestTable{item: Fields{Expr("f1"), Expr("f2"), nil}}.assertNotOK(t) // Assignments TestTable{item: Assignments{Set(Expr("f1"), 1), Set(Expr("f2"), 2), nil}}.assertNotOK(t) }) errTests := []TestTable{{ description: "RowValue err", item: RowValue{1, 2, FaultySQL{}}, }, { description: "RowValues err", item: RowValues{{1, 2, FaultySQL{}}}, }, { description: "Fields err", item: Fields{Expr("f1"), Expr("f2"), FaultySQL{}}, }, { description: "Assignments err", item: Assignments{Set(Expr("f1"), 1), Set(Expr("f2"), 2), FaultySQL{}}, }} for _, tt := range errTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertErr(t, ErrFaultySQL) }) } } func TestToSQL(t *testing.T) { t.Run("basic", func(t *testing.T) { gotQuery, _, err := ToSQL("", Queryf("SELECT {fields} FROM {table}", sql.Named("fields", Fields{Expr("f1"), Expr("f2"), Expr("f3")}), sql.Named("table", Expr("tbl")), ), nil) if err != nil { t.Error(testutil.Callers(), err) } wantQuery := "SELECT f1, f2, f3 FROM tbl" if diff := testutil.Diff(gotQuery, wantQuery); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("nil SQLWriter", func(t *testing.T) { _, _, err := ToSQL("", nil, nil) if err == nil { t.Error(testutil.Callers(), "expected error but got nil") } }) t.Run("err", func(t *testing.T) { _, _, err := ToSQL("", Queryf("SELECT {}", FaultySQL{}), nil) if !errors.Is(err, ErrFaultySQL) { t.Errorf(testutil.Callers()+"expected '%v' but got '%v'", ErrFaultySQL, err) } }) } func Test_in_cmp(t *testing.T) { tests := []TestTable{{ description: "!Query IN !RowValue", item: In(Expr("{}", "tom"), Queryf("SELECT name FROM users WHERE name LIKE {}", "t%")), wantQuery: "? IN (SELECT name FROM users WHERE name LIKE ?)", wantArgs: []any{"tom", "t%"}, }, { description: "!Query IN RowValue", item: In(Expr("name"), RowValue{"tom", "dick", "harry"}), wantQuery: "name IN (?, ?, ?)", wantArgs: []any{"tom", "dick", "harry"}, }, { description: "Query IN !RowValue", item: In(Queryf("SELECT {}", "tom"), []string{"tom", "dick", "harry"}), wantQuery: "(SELECT ?) IN (?, ?, ?)", wantArgs: []any{"tom", "tom", "dick", "harry"}, }, { description: "Query IN RowValue", item: In(Queryf("SELECT {}", "tom"), RowValue{"tom", "dick", "harry"}), wantQuery: "(SELECT ?) IN (?, ?, ?)", wantArgs: []any{"tom", "tom", "dick", "harry"}, }, { description: "!Query NOT IN !RowValue", item: NotIn(Expr("{}", "tom"), Queryf("SELECT name FROM users WHERE name LIKE {}", "t%")), wantQuery: "? NOT IN (SELECT name FROM users WHERE name LIKE ?)", wantArgs: []any{"tom", "t%"}, }, { description: "!Query NOT IN RowValue", item: NotIn(Expr("name"), RowValue{"tom", "dick", "harry"}), wantQuery: "name NOT IN (?, ?, ?)", wantArgs: []any{"tom", "dick", "harry"}, }, { description: "Query NOT IN !RowValue", item: NotIn(Queryf("SELECT {}", "tom"), []string{"tom", "dick", "harry"}), wantQuery: "(SELECT ?) NOT IN (?, ?, ?)", wantArgs: []any{"tom", "tom", "dick", "harry"}, }, { description: "Query NOT IN RowValue", item: NotIn(Queryf("SELECT {}", "tom"), RowValue{"tom", "dick", "harry"}), wantQuery: "(SELECT ?) NOT IN (?, ?, ?)", wantArgs: []any{"tom", "tom", "dick", "harry"}, }, { description: "!Query = !Query", item: cmp("=", 1, 1), wantQuery: "? = ?", wantArgs: []any{1, 1}, }, { description: "!Query = Query", item: cmp("=", Expr("score"), Queryf("SELECT score FROM users WHERE id = {}", 5)), wantQuery: "score = (SELECT score FROM users WHERE id = ?)", wantArgs: []any{5}, }, { description: "Query = !Query", item: cmp("=", Queryf("SELECT score FROM users WHERE id = {}", 5), Expr("{}", 7)), wantQuery: "(SELECT score FROM users WHERE id = ?) = ?", wantArgs: []any{5, 7}, }, { description: "Query = Query", item: cmp("=", Queryf("SELECT 1"), Queryf("SELECT 2")), wantQuery: "(SELECT 1) = (SELECT 2)", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } type policyTableStub struct { policy Predicate err error } var _ PolicyTable = (*policyTableStub)(nil) func (tbl policyTableStub) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { buf.WriteString("policy_table_stub") return nil } func (tbl policyTableStub) GetAlias() string { return "" } func (tbl policyTableStub) Policy(ctx context.Context, dialect string) (Predicate, error) { return tbl.policy, tbl.err } func (tbl policyTableStub) IsTable() {} func Test_appendPolicy(t *testing.T) { type TT struct { description string table Table wantPolicies []Predicate wantErr error } tests := []TT{{ description: "table doesn't implement PolicyTable", table: Expr("tbl"), }, { description: "PolicyTable returns err", table: policyTableStub{err: ErrFaultySQL}, wantErr: ErrFaultySQL, }, { description: "PolicyTable returns policy", table: policyTableStub{policy: Expr("TRUE")}, wantPolicies: []Predicate{Expr("TRUE")}, }, { description: "PolicyTable returns nil policy", table: policyTableStub{}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() policies, err := appendPolicy(context.Background(), "", nil, tt.table) if !errors.Is(err, tt.wantErr) { t.Errorf(testutil.Callers()+"expected error '%v' but got '%v'", tt.wantErr, err) } if diff := testutil.Diff(policies, tt.wantPolicies); diff != "" { t.Error(testutil.Callers(), diff) } }) } } func Test_appendPredicates(t *testing.T) { type TT struct { description string predicate Predicate predicates []Predicate wantPredicate VariadicPredicate } p1, p2, p3 := Expr("p1"), Expr("p2"), Expr("p3") tests := []TT{{ description: "nil predicate", predicate: nil, predicates: []Predicate{p2, p3}, wantPredicate: And(p2, p3), }, { description: "AND predicate", predicate: And(p1), predicates: []Predicate{p2, p3}, wantPredicate: And(p1, p2, p3), }, { description: "OR predicate", predicate: Or(p1), predicates: []Predicate{p2, p3}, wantPredicate: And(Or(p1), p2, p3), }, { description: "non-VariadicPredicate predicate", predicate: p1, predicates: []Predicate{p2, p3}, wantPredicate: And(p1, p2, p3), }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() gotPredicate := appendPredicates(tt.predicate, tt.predicates) if diff := testutil.Diff(gotPredicate, tt.wantPredicate); diff != "" { t.Error(testutil.Callers(), diff) } }) } } func Test_writeTop(t *testing.T) { type TT struct { description string topLimit any topPercentLimit any withTies bool wantQuery string wantArgs []any wantParams map[string][]int } t.Run("err", func(t *testing.T) { t.Parallel() ctx := context.Background() dialect := DialectSQLServer buf := &bytes.Buffer{} args := &[]any{} params := make(map[string][]int) // topLimit err := writeTop(ctx, dialect, buf, args, params, FaultySQL{}, nil, false) if !errors.Is(err, ErrFaultySQL) { t.Errorf(testutil.Callers()+"expected error '%v' but got '%v'", ErrFaultySQL, err) } // topPercentLimit err = writeTop(ctx, dialect, buf, args, params, nil, FaultySQL{}, false) if !errors.Is(err, ErrFaultySQL) { t.Errorf(testutil.Callers()+"expected error '%v' but got '%v'", ErrFaultySQL, err) } }) tests := []TT{{ description: "empty", topLimit: nil, topPercentLimit: nil, withTies: true, wantQuery: "", wantArgs: nil, wantParams: nil, }, { description: "TOP n", topLimit: 5, wantQuery: "TOP (@p1) ", wantArgs: []any{5}, }, { description: "TOP n PERCENT", topPercentLimit: 10, wantQuery: "TOP (@p1) PERCENT ", wantArgs: []any{10}, }, { description: "TOP n WITH TIES", topLimit: 5, withTies: true, wantQuery: "TOP (@p1) WITH TIES ", wantArgs: []any{5}, }, { description: "TOP expr", topLimit: Expr("5"), wantQuery: "TOP (5) ", }, { description: "TOP param", topLimit: IntParam("limit", 5), wantQuery: "TOP (@limit) ", wantArgs: []any{sql.Named("limit", 5)}, wantParams: map[string][]int{"limit": {0}}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() ctx := context.Background() dialect := DialectSQLServer buf := &bytes.Buffer{} args := &[]any{} params := make(map[string][]int) err := writeTop(ctx, dialect, buf, args, params, tt.topLimit, tt.topPercentLimit, tt.withTies) if err != nil { t.Error(testutil.Callers(), err) } if diff := testutil.Diff(buf.String(), tt.wantQuery); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(*args, tt.wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } if len(params) > 0 || len(tt.wantParams) > 0 { if diff := testutil.Diff(params, tt.wantParams); diff != "" { t.Error(testutil.Callers(), diff) } } }) } } // TODO: remove all uses of TestTable, manually write each assert. Fewer levels // of BS indirection when you need to debug something. type TestTable struct { description string ctx context.Context dialect string item any wantQuery string wantArgs []any wantParams map[string][]int } func (tt TestTable) assert(t *testing.T) { if tt.ctx == nil { tt.ctx = context.Background() } if tt.dialect == "" { if query, ok := tt.item.(Query); ok { tt.dialect = query.GetDialect() } } buf := &bytes.Buffer{} args := &[]any{} params := make(map[string][]int) err := WriteValue(tt.ctx, tt.dialect, buf, args, params, tt.item) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(buf.String(), tt.wantQuery); diff != "" { t.Error(testutil.Callers(), diff) } if len(*args) > 0 || len(tt.wantArgs) > 0 { if diff := testutil.Diff(*args, tt.wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } } if len(params) > 0 || len(tt.wantParams) > 0 { if diff := testutil.Diff(params, tt.wantParams); diff != "" { t.Error(testutil.Callers(), diff) } } } func (tt TestTable) assertErr(t *testing.T, wantErr error) { if tt.ctx == nil { tt.ctx = context.Background() } if tt.dialect == "" { if query, ok := tt.item.(Query); ok { tt.dialect = query.GetDialect() } } buf := &bytes.Buffer{} args := &[]any{} params := make(map[string][]int) gotErr := WriteValue(tt.ctx, tt.dialect, buf, args, params, tt.item) if !errors.Is(gotErr, wantErr) { t.Fatalf(testutil.Callers()+"expected error '%v' but got '%v'", wantErr, gotErr) } } func (tt TestTable) assertNotOK(t *testing.T) { if tt.ctx == nil { tt.ctx = context.Background() } if tt.dialect == "" { if query, ok := tt.item.(Query); ok { tt.dialect = query.GetDialect() } } buf := &bytes.Buffer{} args := &[]any{} params := make(map[string][]int) gotErr := WriteValue(tt.ctx, tt.dialect, buf, args, params, tt.item) if gotErr == nil { t.Fatal(testutil.Callers(), "expected error but got nil") } } ================================================ FILE: colors.go ================================================ //go:build windows package sq import ( "os" "syscall" ) func init() { // https://stackoverflow.com/a/69542231 const ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x4 var stderrMode uint32 stderr := syscall.Handle(os.Stderr.Fd()) syscall.GetConsoleMode(stderr, &stderrMode) syscall.MustLoadDLL("kernel32").MustFindProc("SetConsoleMode").Call(uintptr(stderr), uintptr(stderrMode|ENABLE_VIRTUAL_TERMINAL_PROCESSING)) var stdoutMode uint32 stdout := syscall.Handle(os.Stdout.Fd()) syscall.GetConsoleMode(stdout, &stdoutMode) syscall.MustLoadDLL("kernel32").MustFindProc("SetConsoleMode").Call(uintptr(stdout), uintptr(stdoutMode|ENABLE_VIRTUAL_TERMINAL_PROCESSING)) } ================================================ FILE: cte.go ================================================ package sq import ( "bytes" "context" "database/sql" "fmt" ) // CTE represents an SQL common table expression (CTE). type CTE struct { query Query columns []string recursive bool materialized sql.NullBool name string alias string } var _ Table = (*CTE)(nil) // NewCTE creates a new CTE. func NewCTE(name string, columns []string, query Query) CTE { return CTE{name: name, columns: columns, query: query} } // NewRecursiveCTE creates a new recursive CTE. func NewRecursiveCTE(name string, columns []string, query Query) CTE { return CTE{name: name, columns: columns, query: query, recursive: true} } // WriteSQL implements the SQLWriter interface. func (cte CTE) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { buf.WriteString(QuoteIdentifier(dialect, cte.name)) return nil } // As returns a new CTE with the given alias. func (cte CTE) As(alias string) CTE { cte.alias = alias return cte } // Materialized returns a new CTE marked as MATERIALIZED. This only works on // postgres. func (cte CTE) Materialized() CTE { cte.materialized.Valid = true cte.materialized.Bool = true return cte } // Materialized returns a new CTE marked as NOT MATERIALIZED. This only works // on postgres. func (cte CTE) NotMaterialized() CTE { cte.materialized.Valid = true cte.materialized.Bool = false return cte } // Field returns a Field from the CTE. func (cte CTE) Field(name string) AnyField { return NewAnyField(name, NewTableStruct("", cte.name, cte.alias)) } // GetAlias returns the alias of the CTE. func (cte CTE) GetAlias() string { return cte.alias } // AssertTable implements the Table interface. func (cte CTE) IsTable() {} func writeCTEs(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, ctes []CTE) error { var hasRecursiveCTE bool for _, cte := range ctes { if cte.recursive { hasRecursiveCTE = true break } } if hasRecursiveCTE { buf.WriteString("WITH RECURSIVE ") } else { buf.WriteString("WITH ") } for i, cte := range ctes { if i > 0 { buf.WriteString(", ") } if cte.name == "" { return fmt.Errorf("CTE #%d has no name", i+1) } buf.WriteString(QuoteIdentifier(dialect, cte.name)) if len(cte.columns) > 0 { buf.WriteString(" (") for j, column := range cte.columns { if j > 0 { buf.WriteString(", ") } buf.WriteString(QuoteIdentifier(dialect, column)) } buf.WriteString(")") } buf.WriteString(" AS ") if dialect == DialectPostgres && cte.materialized.Valid { if cte.materialized.Bool { buf.WriteString("MATERIALIZED ") } else { buf.WriteString("NOT MATERIALIZED ") } } buf.WriteString("(") switch query := cte.query.(type) { case nil: return fmt.Errorf("CTE #%d query is nil", i+1) case VariadicQuery: query.Toplevel = true err := query.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("CTE #%d failed to build query: %w", i+1, err) } default: err := query.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("CTE #%d failed to build query: %w", i+1, err) } } buf.WriteString(")") } buf.WriteString(" ") return nil } // VariadicQueryOperator represents a variadic query operator. type VariadicQueryOperator string // VariadicQuery operators. const ( QueryUnion VariadicQueryOperator = "UNION" QueryUnionAll VariadicQueryOperator = "UNION ALL" QueryIntersect VariadicQueryOperator = "INTERSECT" QueryIntersectAll VariadicQueryOperator = "INTERSECT ALL" QueryExcept VariadicQueryOperator = "EXCEPT" QueryExceptAll VariadicQueryOperator = "EXCEPT ALL" ) // VariadicQuery represents the 'x UNION y UNION z...' etc SQL constructs. type VariadicQuery struct { Toplevel bool Operator VariadicQueryOperator Queries []Query } var _ Query = (*VariadicQuery)(nil) // Union joins the queries together with the UNION operator. func Union(queries ...Query) VariadicQuery { return VariadicQuery{Operator: QueryUnion, Queries: queries} } // UnionAll joins the queries together with the UNION ALL operator. func UnionAll(queries ...Query) VariadicQuery { return VariadicQuery{Operator: QueryUnionAll, Queries: queries} } // Intersect joins the queries together with the INTERSECT operator. func Intersect(queries ...Query) VariadicQuery { return VariadicQuery{Operator: QueryIntersect, Queries: queries} } // IntersectAll joins the queries together with the INTERSECT ALL operator. func IntersectAll(queries ...Query) VariadicQuery { return VariadicQuery{Operator: QueryIntersectAll, Queries: queries} } // Except joins the queries together with the EXCEPT operator. func Except(queries ...Query) VariadicQuery { return VariadicQuery{Operator: QueryExcept, Queries: queries} } // ExceptAll joins the queries together with the EXCEPT ALL operator. func ExceptAll(queries ...Query) VariadicQuery { return VariadicQuery{Operator: QueryExceptAll, Queries: queries} } // WriteSQL implements the SQLWriter interface. func (q VariadicQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error if q.Operator == "" { q.Operator = QueryUnion } if len(q.Queries) == 0 { return fmt.Errorf("VariadicQuery empty") } if len(q.Queries) == 1 { switch q1 := q.Queries[0].(type) { case nil: return fmt.Errorf("query #1 is nil") case VariadicQuery: q1.Toplevel = q.Toplevel err = q1.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } default: err = q.Queries[0].WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } } return nil } if !q.Toplevel { buf.WriteString("(") } for i, query := range q.Queries { if i > 0 { buf.WriteString(" " + string(q.Operator) + " ") } if query == nil { return fmt.Errorf("query #%d is nil", i+1) } err = query.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("query #%d: %w", i+1, err) } } if !q.Toplevel { buf.WriteString(")") } return nil } // SetFetchableFields implements the Query interface. func (q VariadicQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return q, false } // GetFetchableFields implements the Query interface. func (q VariadicQuery) GetFetchableFields() []Field { return nil } // GetDialect returns the SQL dialect of the VariadicQuery. func (q VariadicQuery) GetDialect() string { if len(q.Queries) == 0 { return "" } q1 := q.Queries[0] if q1 == nil { return "" } return q1.GetDialect() } ================================================ FILE: cte_test.go ================================================ package sq import ( "bytes" "context" "database/sql" "errors" "testing" "github.com/bokwoon95/sq/internal/testutil" ) func TestCTE(t *testing.T) { t.Run("basic", func(t *testing.T) { cte := NewCTE("cte", []string{"n"}, Queryf("SELECT 1")).Materialized().NotMaterialized().As("c") TestTable{item: cte, wantQuery: "cte"}.assert(t) field := NewAnyField("ff", TableStruct{name: "cte", alias: "c"}) if diff := testutil.Diff(cte.Field("ff"), field); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(cte.materialized, sql.NullBool{Valid: true, Bool: false}); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(cte.GetAlias(), "c"); diff != "" { t.Error(testutil.Callers(), diff) } }) } func TestCTEs(t *testing.T) { type TT struct { description string dialect string ctes []CTE wantQuery string wantArgs []any wantParams map[string][]int } tests := []TT{{ description: "basic", ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1"))}, wantQuery: "WITH cte AS (SELECT 1) ", }, { description: "recursive", ctes: []CTE{ NewCTE("cte", nil, Queryf("SELECT 1")), NewRecursiveCTE("nums", []string{"n"}, Union( Queryf("SELECT 1"), Queryf("SELECT n+1 FROM nums WHERE n < 10"), )), }, wantQuery: "WITH RECURSIVE cte AS (SELECT 1)" + ", nums (n) AS (SELECT 1 UNION SELECT n+1 FROM nums WHERE n < 10) ", }, { description: "mysql materialized", dialect: DialectMySQL, ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).Materialized()}, wantQuery: "WITH cte AS (SELECT 1) ", }, { description: "postgres materialized", dialect: DialectPostgres, ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).Materialized()}, wantQuery: "WITH cte AS MATERIALIZED (SELECT 1) ", }, { description: "postgres not materialized", dialect: DialectPostgres, ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).NotMaterialized()}, wantQuery: "WITH cte AS NOT MATERIALIZED (SELECT 1) ", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int) buf.Reset() defer bufpool.Put(buf) err := writeCTEs(context.Background(), tt.dialect, buf, args, params, tt.ctes) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(buf.String(), tt.wantQuery); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(*args, tt.wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(params, tt.wantParams); diff != "" { t.Error(testutil.Callers(), diff) } }) } t.Run("invalid cte", func(t *testing.T) { t.Parallel() buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int) buf.Reset() defer bufpool.Put(buf) // no name err := writeCTEs(context.Background(), "", buf, args, params, []CTE{ NewCTE("", nil, Queryf("SELECT 1")), }) if err == nil { t.Fatal(testutil.Callers(), "expected error but got nil") } // no query err = writeCTEs(context.Background(), "", buf, args, params, []CTE{ NewCTE("cte", nil, nil), }) if err == nil { t.Fatal(testutil.Callers(), "expected error but got nil") } }) t.Run("err", func(t *testing.T) { t.Parallel() buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int) buf.Reset() defer bufpool.Put(buf) // VariadicQuery err := writeCTEs(context.Background(), "", buf, args, params, []CTE{ NewCTE("cte", nil, Union( Queryf("SELECT 1"), Queryf("SELECT {}", FaultySQL{}), )), }) if !errors.Is(err, ErrFaultySQL) { t.Errorf(testutil.Callers()+"expected error %q but got %q", ErrFaultySQL, err) } // Query err = writeCTEs(context.Background(), "", buf, args, params, []CTE{ NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{})), }) if !errors.Is(err, ErrFaultySQL) { t.Errorf(testutil.Callers()+"expected error %q but got %q", ErrFaultySQL, err) } }) } func TestVariadicQuery(t *testing.T) { q1, q2, q3 := Queryf("SELECT 1"), Queryf("SELECT 2"), Queryf("SELECT 3") tests := []TestTable{{ description: "Union", item: Union(q1, q2, q3), wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)", }, { description: "UnionAll", item: UnionAll(q1, q2, q3), wantQuery: "(SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3)", }, { description: "Intersect", item: Intersect(q1, q2, q3), wantQuery: "(SELECT 1 INTERSECT SELECT 2 INTERSECT SELECT 3)", }, { description: "IntersectAll", item: IntersectAll(q1, q2, q3), wantQuery: "(SELECT 1 INTERSECT ALL SELECT 2 INTERSECT ALL SELECT 3)", }, { description: "Except", item: Except(q1, q2, q3), wantQuery: "(SELECT 1 EXCEPT SELECT 2 EXCEPT SELECT 3)", }, { description: "ExceptAll", item: ExceptAll(q1, q2, q3), wantQuery: "(SELECT 1 EXCEPT ALL SELECT 2 EXCEPT ALL SELECT 3)", }, { description: "No operator specified", item: VariadicQuery{Queries: []Query{q1, q2, q3}}, wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)", }, { description: "nested VariadicQuery", item: Union(Union(Union(q1, q2, q3))), wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)", }, { description: "1 query", item: Union(q1), wantQuery: "SELECT 1", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } t.Run("invalid VariadicQuery", func(t *testing.T) { t.Parallel() // empty TestTable{item: Union()}.assertNotOK(t) // nil query TestTable{item: Union(nil)}.assertNotOK(t) // nil query TestTable{item: Union(q1, q2, nil)}.assertNotOK(t) }) t.Run("err", func(t *testing.T) { t.Parallel() // VariadicQuery TestTable{ item: Union( Union( Queryf("SELECT 1"), Queryf("SELECT {}", FaultySQL{}), ), ), }.assertErr(t, ErrFaultySQL) // Query TestTable{ item: Union(Queryf("SELECT {}", FaultySQL{})), }.assertErr(t, ErrFaultySQL) }) t.Run("SetFetchableFields", func(t *testing.T) { t.Parallel() _, ok := Union().SetFetchableFields([]Field{Expr("f1")}) if ok { t.Error(testutil.Callers(), "expected not ok but got ok") } }) t.Run("GetDialect", func(t *testing.T) { // empty VariadicQuery if diff := testutil.Diff(Union().GetDialect(), ""); diff != "" { t.Error(testutil.Callers(), diff) } // nil query if diff := testutil.Diff(Union(nil).GetDialect(), ""); diff != "" { t.Error(testutil.Callers(), diff) } // empty dialect propagated if diff := testutil.Diff(Union(Queryf("SELECT 1")).GetDialect(), ""); diff != "" { t.Error(testutil.Callers(), diff) } }) } ================================================ FILE: delete_query.go ================================================ package sq import ( "bytes" "context" "fmt" ) // DeleteQuery represents an SQL DELETE query. type DeleteQuery struct { Dialect string // WITH CTEs []CTE // DELETE FROM DeleteTable Table DeleteTables []Table // USING UsingTable Table JoinTables []JoinTable // WHERE WherePredicate Predicate // ORDER BY OrderByFields Fields // LIMIT LimitRows any // OFFSET OffsetRows any // RETURNING ReturningFields []Field } var _ Query = (*DeleteQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q DeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error // Table Policies var policies []Predicate policies, err = appendPolicy(ctx, dialect, policies, q.DeleteTable) if err != nil { return fmt.Errorf("DELETE FROM %s Policy: %w", toString(q.Dialect, q.DeleteTable), err) } policies, err = appendPolicy(ctx, dialect, policies, q.UsingTable) if err != nil { return fmt.Errorf("USING %s Policy: %w", toString(q.Dialect, q.UsingTable), err) } for _, joinTable := range q.JoinTables { policies, err = appendPolicy(ctx, dialect, policies, joinTable.Table) if err != nil { return fmt.Errorf("%s %s Policy: %w", joinTable.JoinOperator, joinTable.Table, err) } } if len(policies) > 0 { if q.WherePredicate != nil { policies = append(policies, q.WherePredicate) } q.WherePredicate = And(policies...) } // WITH if len(q.CTEs) > 0 { err = writeCTEs(ctx, dialect, buf, args, params, q.CTEs) if err != nil { return fmt.Errorf("WITH: %w", err) } } // DELETE FROM if (dialect == DialectMySQL || dialect == DialectSQLServer) && len(q.DeleteTables) > 0 { buf.WriteString("DELETE ") if len(q.DeleteTables) > 1 && dialect != DialectMySQL { return fmt.Errorf("dialect %q does not support multi-table DELETE", dialect) } for i, table := range q.DeleteTables { if i > 0 { buf.WriteString(", ") } if alias := getAlias(table); alias != "" { buf.WriteString(alias) } else { err = table.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("table #%d: %w", i+1, err) } } } } else { buf.WriteString("DELETE FROM ") if q.DeleteTable == nil { return fmt.Errorf("no table provided to DELETE FROM") } err = q.DeleteTable.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("DELETE FROM: %w", err) } if dialect != DialectSQLServer { if alias := getAlias(q.DeleteTable); alias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) } } } if q.UsingTable != nil || len(q.JoinTables) > 0 { if dialect != DialectPostgres && dialect != DialectMySQL && dialect != DialectSQLServer { return fmt.Errorf("%s DELETE does not support JOIN", dialect) } } // OUTPUT if len(q.ReturningFields) > 0 && dialect == DialectSQLServer { buf.WriteString(" OUTPUT ") err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, q.ReturningFields, "DELETED", true) if err != nil { return err } } // USING/FROM if q.UsingTable != nil { switch dialect { case DialectPostgres: buf.WriteString(" USING ") err = q.UsingTable.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("USING: %w", err) } case DialectMySQL, DialectSQLServer: buf.WriteString(" FROM ") err = q.UsingTable.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("FROM: %w", err) } } if alias := getAlias(q.UsingTable); alias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) } } // JOIN if len(q.JoinTables) > 0 { if q.UsingTable == nil { return fmt.Errorf("%s can't JOIN without a USING/FROM table", dialect) } buf.WriteString(" ") err = writeJoinTables(ctx, dialect, buf, args, params, q.JoinTables) if err != nil { return fmt.Errorf("JOIN: %w", err) } } // WHERE if q.WherePredicate != nil { buf.WriteString(" WHERE ") switch predicate := q.WherePredicate.(type) { case VariadicPredicate: predicate.Toplevel = true err = predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("WHERE: %w", err) } default: err = q.WherePredicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("WHERE: %w", err) } } } // ORDER BY if len(q.OrderByFields) > 0 { if dialect != DialectMySQL { return fmt.Errorf("%s UPDATE does not support ORDER BY", dialect) } buf.WriteString(" ORDER BY ") err = q.OrderByFields.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("ORDER BY: %w", err) } } // LIMIT if q.LimitRows != nil { if dialect != DialectMySQL { return fmt.Errorf("%s UPDATE does not support LIMIT", dialect) } buf.WriteString(" LIMIT ") err = WriteValue(ctx, dialect, buf, args, params, q.LimitRows) if err != nil { return fmt.Errorf("LIMIT: %w", err) } } // RETURNING if len(q.ReturningFields) > 0 && dialect != DialectSQLServer { if dialect != DialectPostgres && dialect != DialectSQLite && dialect != DialectMySQL { return fmt.Errorf("%s UPDATE does not support RETURNING", dialect) } buf.WriteString(" RETURNING ") err = writeFields(ctx, dialect, buf, args, params, q.ReturningFields, true) if err != nil { return fmt.Errorf("RETURNING: %w", err) } } return nil } // DeleteFrom returns a new DeleteQuery. func DeleteFrom(table Table) DeleteQuery { return DeleteQuery{DeleteTable: table} } // Where appends to the WherePredicate field of the DeleteQuery. func (q DeleteQuery) Where(predicates ...Predicate) DeleteQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // SetFetchableFields implements the Query interface. func (q DeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { switch q.Dialect { case DialectPostgres, DialectSQLite: if len(q.ReturningFields) == 0 { q.ReturningFields = fields return q, true } return q, false default: return q, false } } // GetFetchableFields returns the fetchable fields of the query. func (q DeleteQuery) GetFetchableFields() []Field { switch q.Dialect { case DialectPostgres, DialectSQLite: return q.ReturningFields default: return nil } } // GetDialect implements the Query interface. func (q DeleteQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q DeleteQuery) SetDialect(dialect string) DeleteQuery { q.Dialect = dialect return q } // SQLiteDeleteQuery represents an SQLite DELETE query. type SQLiteDeleteQuery DeleteQuery var _ Query = (*SQLiteDeleteQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SQLiteDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // DeleteFrom returns a new SQLiteDeleteQuery. func (b sqliteQueryBuilder) DeleteFrom(table Table) SQLiteDeleteQuery { return SQLiteDeleteQuery{ Dialect: DialectSQLite, CTEs: b.ctes, DeleteTable: table, } } // Where appends to the WherePredicate field of the SQLiteDeleteQuery. func (q SQLiteDeleteQuery) Where(predicates ...Predicate) SQLiteDeleteQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // Returning appends fields to the RETURNING clause of the SQLiteDeleteQuery. func (q SQLiteDeleteQuery) Returning(fields ...Field) SQLiteDeleteQuery { q.ReturningFields = append(q.ReturningFields, fields...) return q } // SetFetchableFields implements the Query interface. func (q SQLiteDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return DeleteQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the query. func (q SQLiteDeleteQuery) GetFetchableFields() []Field { return DeleteQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q SQLiteDeleteQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q SQLiteDeleteQuery) SetDialect(dialect string) SQLiteDeleteQuery { q.Dialect = dialect return q } // PostgresDeleteQuery represents a Postgres DELETE query. type PostgresDeleteQuery DeleteQuery var _ Query = (*PostgresDeleteQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q PostgresDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // DeleteFrom returns a new PostgresDeleteQuery. func (b postgresQueryBuilder) DeleteFrom(table Table) PostgresDeleteQuery { return PostgresDeleteQuery{ Dialect: DialectPostgres, CTEs: b.ctes, DeleteTable: table, } } // Using sets the UsingTable field of the PostgresDeleteQuery. func (q PostgresDeleteQuery) Using(table Table) PostgresDeleteQuery { q.UsingTable = table return q } // Join joins a new Table to the PostgresDeleteQuery. func (q PostgresDeleteQuery) Join(table Table, predicates ...Predicate) PostgresDeleteQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the PostgresDeleteQuery. func (q PostgresDeleteQuery) LeftJoin(table Table, predicates ...Predicate) PostgresDeleteQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the PostgresDeleteQuery. func (q PostgresDeleteQuery) FullJoin(table Table, predicates ...Predicate) PostgresDeleteQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the PostgresDeleteQuery. func (q PostgresDeleteQuery) CrossJoin(table Table) PostgresDeleteQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the PostgresDeleteQuery with a custom join // operator. func (q PostgresDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) PostgresDeleteQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the PostgresDeleteQuery with the USING operator. func (q PostgresDeleteQuery) JoinUsing(table Table, fields ...Field) PostgresDeleteQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Where appends to the WherePredicate field of the PostgresDeleteQuery. func (q PostgresDeleteQuery) Where(predicates ...Predicate) PostgresDeleteQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // Returning appends fields to the RETURNING clause of the PostgresDeleteQuery. func (q PostgresDeleteQuery) Returning(fields ...Field) PostgresDeleteQuery { q.ReturningFields = append(q.ReturningFields, fields...) return q } // SetFetchableFields implements the Query interface. func (q PostgresDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return DeleteQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the query. func (q PostgresDeleteQuery) GetFetchableFields() []Field { return DeleteQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q PostgresDeleteQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q PostgresDeleteQuery) SetDialect(dialect string) PostgresDeleteQuery { q.Dialect = dialect return q } // MySQLDeleteQuery represents a MySQL DELETE query. type MySQLDeleteQuery DeleteQuery var _ Query = (*MySQLDeleteQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q MySQLDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // DeleteFrom returns a new MySQLDeleteQuery. func (b mysqlQueryBuilder) DeleteFrom(table Table) MySQLDeleteQuery { return MySQLDeleteQuery{ Dialect: DialectMySQL, CTEs: b.ctes, DeleteTable: table, } } // Delete returns a new MySQLDeleteQuery. func (b mysqlQueryBuilder) Delete(tables ...Table) MySQLDeleteQuery { return MySQLDeleteQuery{ Dialect: DialectMySQL, CTEs: b.ctes, DeleteTables: tables, } } // From sets the UsingTable of the MySQLDeleteQuery. func (q MySQLDeleteQuery) From(table Table) MySQLDeleteQuery { q.UsingTable = table return q } // Join joins a new Table to the MySQLDeleteQuery. func (q MySQLDeleteQuery) Join(table Table, predicates ...Predicate) MySQLDeleteQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the MySQLDeleteQuery. func (q MySQLDeleteQuery) LeftJoin(table Table, predicates ...Predicate) MySQLDeleteQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the MySQLDeleteQuery. func (q MySQLDeleteQuery) FullJoin(table Table, predicates ...Predicate) MySQLDeleteQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the MySQLDeleteQuery. func (q MySQLDeleteQuery) CrossJoin(table Table) MySQLDeleteQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the MySQLDeleteQuery with a custom join // operator. func (q MySQLDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) MySQLDeleteQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the MySQLDeleteQuery with the USING operator. func (q MySQLDeleteQuery) JoinUsing(table Table, fields ...Field) MySQLDeleteQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Where appends to the WherePredicate field of the MySQLDeleteQuery. func (q MySQLDeleteQuery) Where(predicates ...Predicate) MySQLDeleteQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // OrderBy sets the OrderByFields field of the MySQLDeleteQuery. func (q MySQLDeleteQuery) OrderBy(fields ...Field) MySQLDeleteQuery { q.OrderByFields = append(q.OrderByFields, fields...) return q } // Limit sets the LimitRows field of the MySQLDeleteQuery. func (q MySQLDeleteQuery) Limit(limit any) MySQLDeleteQuery { q.LimitRows = limit return q } // Returning appends fields to the RETURNING clause of the MySQLDeleteQuery. func (q MySQLDeleteQuery) Returning(fields ...Field) MySQLDeleteQuery { q.ReturningFields = append(q.ReturningFields, fields...) return q } // SetFetchableFields implements the Query interface. func (q MySQLDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return DeleteQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the query. func (q MySQLDeleteQuery) GetFetchableFields() []Field { return DeleteQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q MySQLDeleteQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q MySQLDeleteQuery) SetDialect(dialect string) MySQLDeleteQuery { q.Dialect = dialect return q } // SQLServerDeleteQuery represents an SQL Server DELETE query. type SQLServerDeleteQuery DeleteQuery var _ Query = (*SQLServerDeleteQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SQLServerDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // DeleteFrom returns a new SQLServerDeleteQuery. func (b sqlserverQueryBuilder) DeleteFrom(table Table) SQLServerDeleteQuery { return SQLServerDeleteQuery{ Dialect: DialectSQLServer, CTEs: b.ctes, DeleteTable: table, } } // Delete returns a new SQLServerDeleteQuery. func (b sqlserverQueryBuilder) Delete(table Table) SQLServerDeleteQuery { return SQLServerDeleteQuery{ Dialect: DialectSQLServer, CTEs: b.ctes, DeleteTables: []Table{table}, } } // From sets the UsingTable of the SQLServerDeleteQuery. func (q SQLServerDeleteQuery) From(table Table) SQLServerDeleteQuery { q.UsingTable = table return q } // Join joins a new Table to the SQLServerDeleteQuery. func (q SQLServerDeleteQuery) Join(table Table, predicates ...Predicate) SQLServerDeleteQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the SQLServerDeleteQuery. func (q SQLServerDeleteQuery) LeftJoin(table Table, predicates ...Predicate) SQLServerDeleteQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the SQLServerDeleteQuery. func (q SQLServerDeleteQuery) FullJoin(table Table, predicates ...Predicate) SQLServerDeleteQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the SQLServerDeleteQuery. func (q SQLServerDeleteQuery) CrossJoin(table Table) SQLServerDeleteQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the SQLServerDeleteQuery with a custom join // operator. func (q SQLServerDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) SQLServerDeleteQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // Where appends to the WherePredicate field of the SQLServerDeleteQuery. func (q SQLServerDeleteQuery) Where(predicates ...Predicate) SQLServerDeleteQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // SetFetchableFields implements the Query interface. func (q SQLServerDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return DeleteQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the query. func (q SQLServerDeleteQuery) GetFetchableFields() []Field { return DeleteQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q SQLServerDeleteQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q SQLServerDeleteQuery) SetDialect(dialect string) SQLServerDeleteQuery { q.Dialect = dialect return q } ================================================ FILE: delete_query_test.go ================================================ package sq import ( "testing" "github.com/bokwoon95/sq/internal/testutil" ) func TestSQLiteDeleteQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SQLite.DeleteFrom(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectSQLite) fields := q1.GetFetchableFields() if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("Delete Returning", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. With(NewCTE("cte", nil, Queryf("SELECT 1"))). DeleteFrom(a). Where(a.ACTOR_ID.EqInt(1)). Returning(a.FIRST_NAME, a.LAST_NAME) tt.wantQuery = "WITH cte AS (SELECT 1)" + " DELETE FROM actor AS a" + " WHERE a.actor_id = $1" + " RETURNING a.first_name, a.last_name" tt.wantArgs = []any{1} tt.assert(t) }) } func TestPostgresDeleteQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := Postgres.DeleteFrom(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectPostgres) fields := q1.GetFetchableFields() if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("Delete Returning", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. With(NewCTE("cte", nil, Queryf("SELECT 1"))). DeleteFrom(a). Where(a.ACTOR_ID.EqInt(1)). Returning(a.FIRST_NAME, a.LAST_NAME) tt.wantQuery = "WITH cte AS (SELECT 1)" + " DELETE FROM actor AS a" + " WHERE a.actor_id = $1" + " RETURNING a.first_name, a.last_name" tt.wantArgs = []any{1} tt.assert(t) }) t.Run("Join", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. DeleteFrom(a). Using(a). Join(a, Expr("1 = 1")). LeftJoin(a, Expr("1 = 1")). FullJoin(a, Expr("1 = 1")). CrossJoin(a). CustomJoin(",", a). JoinUsing(a, a.FIRST_NAME, a.LAST_NAME) tt.wantQuery = "DELETE FROM actor AS a" + " USING actor AS a" + " JOIN actor AS a ON 1 = 1" + " LEFT JOIN actor AS a ON 1 = 1" + " FULL JOIN actor AS a ON 1 = 1" + " CROSS JOIN actor AS a" + " , actor AS a" + " JOIN actor AS a USING (first_name, last_name)" tt.assert(t) }) } func TestMySQLDeleteQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := MySQL.DeleteFrom(a).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectMySQL) fields := q1.GetFetchableFields() if len(fields) != 0 { t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } }) t.Run("Where", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. With(NewCTE("cte", nil, Queryf("SELECT 1"))). DeleteFrom(a). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " DELETE FROM actor" + " WHERE actor.actor_id = ?" tt.wantArgs = []any{1} tt.assert(t) }) t.Run("OrderBy Limit", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. DeleteFrom(a). OrderBy(a.ACTOR_ID). Limit(5) tt.wantQuery = "DELETE FROM actor" + " ORDER BY actor.actor_id" + " LIMIT ?" tt.wantArgs = []any{5} tt.assert(t) }) t.Run("Delete Returning", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. With(NewCTE("cte", nil, Queryf("SELECT 1"))). DeleteFrom(a). Where(a.ACTOR_ID.EqInt(1)). Returning(a.FIRST_NAME, a.LAST_NAME) tt.wantQuery = "WITH cte AS (SELECT 1)" + " DELETE FROM actor" + " WHERE actor.actor_id = ?" + " RETURNING actor.first_name, actor.last_name" tt.wantArgs = []any{1} tt.assert(t) }) t.Run("Join", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. Delete(a). From(a). Join(a, Expr("1 = 1")). LeftJoin(a, Expr("1 = 1")). FullJoin(a, Expr("1 = 1")). CrossJoin(a). CustomJoin(",", a). JoinUsing(a, a.FIRST_NAME, a.LAST_NAME) tt.wantQuery = "DELETE actor" + " FROM actor" + " JOIN actor ON 1 = 1" + " LEFT JOIN actor ON 1 = 1" + " FULL JOIN actor ON 1 = 1" + " CROSS JOIN actor" + " , actor" + " JOIN actor USING (first_name, last_name)" tt.assert(t) }) } func TestSQLServerDeleteQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SQLServer.DeleteFrom(a).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectSQLServer) q1 = q1.SetDialect(DialectMySQL) fields := q1.GetFetchableFields() if len(fields) != 0 { t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } }) t.Run("Where", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. With(NewCTE("cte", nil, Queryf("SELECT 1"))). DeleteFrom(a). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " DELETE FROM actor" + " WHERE actor.actor_id = @p1" tt.wantArgs = []any{1} tt.assert(t) }) t.Run("Join", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. DeleteFrom(a). From(a). Join(a, Expr("1 = 1")). LeftJoin(a, Expr("1 = 1")). FullJoin(a, Expr("1 = 1")). CrossJoin(a). CustomJoin(",", a) tt.wantQuery = "DELETE FROM actor" + " FROM actor" + " JOIN actor ON 1 = 1" + " LEFT JOIN actor ON 1 = 1" + " FULL JOIN actor ON 1 = 1" + " CROSS JOIN actor" + " , actor" tt.assert(t) }) } func TestDeleteQuery(t *testing.T) { t.Run("basic", func(t *testing.T) { t.Parallel() q1 := DeleteQuery{DeleteTable: Expr("tbl"), Dialect: "lorem ipsum"} if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("PolicyTable", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = DeleteQuery{ DeleteTable: policyTableStub{policy: And(Expr("1 = 1"), Expr("2 = 2"))}, WherePredicate: Expr("3 = 3"), } tt.wantQuery = "DELETE FROM policy_table_stub WHERE (1 = 1 AND 2 = 2) AND 3 = 3" tt.assert(t) }) notOKTests := []TestTable{{ description: "nil FromTable not allowed", item: DeleteQuery{ DeleteTable: nil, }, }, { description: "sqlite does not support JOIN", item: DeleteQuery{ Dialect: DialectSQLite, DeleteTable: Expr("tbl"), UsingTable: Expr("tbl"), JoinTables: []JoinTable{ Join(Expr("tbl"), Expr("1 = 1")), }, }, }, { description: "postgres does not allow JOIN without USING", item: DeleteQuery{ Dialect: DialectPostgres, DeleteTable: Expr("tbl"), JoinTables: []JoinTable{ Join(Expr("tbl"), Expr("1 = 1")), }, }, }, { description: "dialect does not support ORDER BY", item: DeleteQuery{ Dialect: DialectPostgres, DeleteTable: Expr("tbl"), OrderByFields: Fields{Expr("f1")}, }, }, { description: "dialect does not support LIMIT", item: DeleteQuery{ Dialect: DialectPostgres, DeleteTable: Expr("tbl"), LimitRows: 5, }, }} for _, tt := range notOKTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertNotOK(t) }) } errTests := []TestTable{{ description: "FromTable Policy err", item: DeleteQuery{ DeleteTable: policyTableStub{err: ErrFaultySQL}, }, }, { description: "UsingTable Policy err", item: DeleteQuery{ DeleteTable: Expr("tbl"), UsingTable: policyTableStub{err: ErrFaultySQL}, }, }, { description: "JoinTables Policy err", item: DeleteQuery{ DeleteTable: Expr("tbl"), UsingTable: Expr("tbl"), JoinTables: []JoinTable{ Join(policyTableStub{err: ErrFaultySQL}, Expr("1 = 1")), }, }, }, { description: "CTEs err", item: DeleteQuery{ CTEs: []CTE{NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{}))}, DeleteTable: Expr("tbl"), }, }, { description: "FromTable err", item: DeleteQuery{ DeleteTable: FaultySQL{}, }, }, { description: "postgres UsingTable err", item: DeleteQuery{ Dialect: DialectPostgres, DeleteTable: Expr("tbl"), UsingTable: FaultySQL{}, }, }, { description: "sqlserver UsingTable err", item: DeleteQuery{ Dialect: DialectSQLServer, DeleteTable: Expr("tbl"), UsingTable: FaultySQL{}, }, }, { description: "JoinTables err", item: DeleteQuery{ Dialect: DialectPostgres, DeleteTable: Expr("tbl"), UsingTable: Expr("tbl"), JoinTables: []JoinTable{ Join(Expr("tbl"), FaultySQL{}), }, }, }, { description: "WherePredicate Variadic err", item: DeleteQuery{ DeleteTable: Expr("tbl"), WherePredicate: And(FaultySQL{}), }, }, { description: "WherePredicate err", item: DeleteQuery{ DeleteTable: Expr("tbl"), WherePredicate: FaultySQL{}, }, }, { description: "OrderByFields err", item: DeleteQuery{ Dialect: DialectMySQL, DeleteTable: Expr("tbl"), OrderByFields: Fields{FaultySQL{}}, }, }, { description: "LimitRows err", item: DeleteQuery{ Dialect: DialectMySQL, DeleteTable: Expr("tbl"), OrderByFields: Fields{Expr("f1")}, LimitRows: FaultySQL{}, }, }, { description: "ReturningFields err", item: DeleteQuery{ Dialect: DialectPostgres, DeleteTable: Expr("tbl"), ReturningFields: Fields{FaultySQL{}}, }, }} for _, tt := range errTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertErr(t, ErrFaultySQL) }) } } ================================================ FILE: fetch_exec.go ================================================ package sq import ( "bytes" "context" "database/sql" "fmt" "reflect" "runtime" "strconv" "strings" "sync/atomic" "time" ) // Default dialect used by all queries (if no dialect is explicitly provided). var DefaultDialect atomic.Pointer[string] // A Cursor represents a database cursor. type Cursor[T any] struct { ctx context.Context row *Row rowmapper func(*Row) T queryStats QueryStats logSettings LogSettings logger SqLogger logged int32 fieldNames []string resultsBuffer *bytes.Buffer } // FetchCursor returns a new cursor. func FetchCursor[T any](db DB, query Query, rowmapper func(*Row) T) (*Cursor[T], error) { return fetchCursor(context.Background(), db, query, rowmapper, 1) } // FetchCursorContext is like FetchCursor but additionally requires a context.Context. func FetchCursorContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) (*Cursor[T], error) { return fetchCursor(ctx, db, query, rowmapper, 1) } func fetchCursor[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T, skip int) (cursor *Cursor[T], err error) { if db == nil { return nil, fmt.Errorf("db is nil") } if query == nil { return nil, fmt.Errorf("query is nil") } if rowmapper == nil { return nil, fmt.Errorf("rowmapper is nil") } dialect := query.GetDialect() if dialect == "" { defaultDialect := DefaultDialect.Load() if defaultDialect != nil { dialect = *defaultDialect } } // If we can't set the fetchable fields, the query is static. _, ok := query.SetFetchableFields(nil) cursor = &Cursor[T]{ ctx: ctx, rowmapper: rowmapper, row: &Row{ dialect: dialect, queryIsStatic: !ok, }, queryStats: QueryStats{ Dialect: dialect, Params: make(map[string][]int), RowCount: sql.NullInt64{Valid: true}, }, } // If the query is dynamic, call the rowmapper to populate row.fields and // row.scanDest. Then, insert those fields back into the query. if !cursor.row.queryIsStatic { defer mapperFunctionPanicked(&err) _ = cursor.rowmapper(cursor.row) query, _ = query.SetFetchableFields(cursor.row.fields) } // Build query. buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) err = query.WriteSQL(ctx, dialect, buf, &cursor.queryStats.Args, cursor.queryStats.Params) cursor.queryStats.Query = buf.String() if err != nil { return nil, err } // Setup logger. cursor.logger, _ = db.(SqLogger) if cursor.logger == nil { logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats)) if logQuery != nil { logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings)) cursor.logger = &sqLogStruct{ logSettings: logSettings, logQuery: logQuery, } } } if cursor.logger != nil { cursor.logger.SqLogSettings(ctx, &cursor.logSettings) if cursor.logSettings.IncludeCaller { cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1) } } // Run query. if cursor.logSettings.IncludeTime { cursor.queryStats.StartedAt = time.Now() } cursor.row.sqlRows, cursor.queryStats.Err = db.QueryContext(ctx, cursor.queryStats.Query, cursor.queryStats.Args...) if cursor.logSettings.IncludeTime { cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt) } if cursor.queryStats.Err != nil { cursor.log() return nil, cursor.queryStats.Err } // If the query is static, we now know the number of columns returned by // the query and can allocate the values slice and scanDest slice for // scanning later. if cursor.row.queryIsStatic { cursor.row.columns, err = cursor.row.sqlRows.Columns() if err != nil { return nil, err } cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes() if err != nil { return nil, err } cursor.row.columnIndex = make(map[string]int) for index, column := range cursor.row.columns { cursor.row.columnIndex[column] = index } cursor.row.values = make([]any, len(cursor.row.columns)) cursor.row.scanDest = make([]any, len(cursor.row.columns)) for index := range cursor.row.values { cursor.row.scanDest[index] = &cursor.row.values[index] } } // Allocate the resultsBuffer. if cursor.logSettings.IncludeResults > 0 { cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer) cursor.resultsBuffer.Reset() } return cursor, nil } // Next advances the cursor to the next result. func (cursor *Cursor[T]) Next() bool { hasNext := cursor.row.sqlRows.Next() if hasNext { cursor.queryStats.RowCount.Int64++ } else { cursor.log() } return hasNext } // RowCount returns the current row number so far. func (cursor *Cursor[T]) RowCount() int64 { return cursor.queryStats.RowCount.Int64 } // Result returns the cursor result. func (cursor *Cursor[T]) Result() (result T, err error) { err = cursor.row.sqlRows.Scan(cursor.row.scanDest...) if err != nil { cursor.log() fieldMappings := getFieldMappings(cursor.queryStats.Dialect, cursor.row.fields, cursor.row.scanDest) return result, fmt.Errorf("please check if your mapper function is correct:%s\n%w", fieldMappings, err) } // If results should be logged, write the row into the resultsBuffer. if cursor.resultsBuffer != nil && cursor.queryStats.RowCount.Int64 <= int64(cursor.logSettings.IncludeResults) { if len(cursor.fieldNames) == 0 { cursor.fieldNames = getFieldNames(cursor.ctx, cursor.row) } cursor.resultsBuffer.WriteString("\n----[ Row " + strconv.FormatInt(cursor.queryStats.RowCount.Int64, 10) + " ]----") for i := range cursor.row.scanDest { cursor.resultsBuffer.WriteString("\n") if i < len(cursor.fieldNames) { cursor.resultsBuffer.WriteString(cursor.fieldNames[i]) } cursor.resultsBuffer.WriteString(": ") scanDest := cursor.row.scanDest[i] rhs, err := Sprint(cursor.queryStats.Dialect, scanDest) if err != nil { cursor.resultsBuffer.WriteString("%!(error=" + err.Error() + ")") continue } cursor.resultsBuffer.WriteString(rhs) } } cursor.row.runningIndex = 0 defer mapperFunctionPanicked(&err) result = cursor.rowmapper(cursor.row) return result, nil } func (cursor *Cursor[T]) log() { if !atomic.CompareAndSwapInt32(&cursor.logged, 0, 1) { return } if cursor.resultsBuffer != nil { cursor.queryStats.Results = cursor.resultsBuffer.String() bufpool.Put(cursor.resultsBuffer) } if cursor.logger == nil { return } if cursor.logSettings.LogAsynchronously { go cursor.logger.SqLogQuery(cursor.ctx, cursor.queryStats) } else { cursor.logger.SqLogQuery(cursor.ctx, cursor.queryStats) } } // Close closes the cursor. func (cursor *Cursor[T]) Close() error { cursor.log() if err := cursor.row.sqlRows.Close(); err != nil { return err } if err := cursor.row.sqlRows.Err(); err != nil { return err } return nil } // FetchOne returns the first result from running the given Query on the given // DB. func FetchOne[T any](db DB, query Query, rowmapper func(*Row) T) (T, error) { cursor, err := fetchCursor(context.Background(), db, query, rowmapper, 1) if err != nil { return *new(T), err } defer cursor.Close() return cursorResult(cursor) } // FetchOneContext is like FetchOne but additionally requires a context.Context. func FetchOneContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) (T, error) { cursor, err := fetchCursor(ctx, db, query, rowmapper, 1) if err != nil { return *new(T), err } defer cursor.Close() return cursorResult(cursor) } // FetchAll returns all results from running the given Query on the given DB. func FetchAll[T any](db DB, query Query, rowmapper func(*Row) T) ([]T, error) { cursor, err := fetchCursor(context.Background(), db, query, rowmapper, 1) if err != nil { return nil, err } defer cursor.Close() return cursorResults(cursor) } // FetchAllContext is like FetchAll but additionally requires a context.Context. func FetchAllContext[T any](ctx context.Context, db DB, query Query, rowmapper func(*Row) T) ([]T, error) { cursor, err := fetchCursor(ctx, db, query, rowmapper, 1) if err != nil { return nil, err } defer cursor.Close() return cursorResults(cursor) } // CompiledFetch is the result of compiling a Query down into a query string // and args slice. A CompiledFetch can be safely executed in parallel. type CompiledFetch[T any] struct { dialect string query string args []any params map[string][]int rowmapper func(*Row) T // if queryIsStatic is true, the rowmapper doesn't actually know what // columns are in the query and it must be determined at runtime after // running the query. queryIsStatic bool } // NewCompiledFetch returns a new CompiledFetch. func NewCompiledFetch[T any](dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) T) *CompiledFetch[T] { return &CompiledFetch[T]{ dialect: dialect, query: query, args: args, params: params, rowmapper: rowmapper, } } // CompileFetch returns a new CompileFetch. func CompileFetch[T any](q Query, rowmapper func(*Row) T) (*CompiledFetch[T], error) { return CompileFetchContext(context.Background(), q, rowmapper) } // CompileFetchContext is like CompileFetch but accepts a context.Context. func CompileFetchContext[T any](ctx context.Context, query Query, rowmapper func(*Row) T) (compiledFetch *CompiledFetch[T], err error) { if query == nil { return nil, fmt.Errorf("query is nil") } if rowmapper == nil { return nil, fmt.Errorf("rowmapper is nil") } dialect := query.GetDialect() if dialect == "" { defaultDialect := DefaultDialect.Load() if defaultDialect != nil { dialect = *defaultDialect } } // If we can't set the fetchable fields, the query is static. _, ok := query.SetFetchableFields(nil) compiledFetch = &CompiledFetch[T]{ dialect: dialect, params: make(map[string][]int), rowmapper: rowmapper, queryIsStatic: !ok, } row := &Row{ dialect: dialect, queryIsStatic: !ok, } // If the query is dynamic, call the rowmapper to populate row.fields. // Then, insert those fields back into the query. if !row.queryIsStatic { defer mapperFunctionPanicked(&err) _ = rowmapper(row) query, _ = query.SetFetchableFields(row.fields) } // Build query. buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) err = query.WriteSQL(ctx, dialect, buf, &compiledFetch.args, compiledFetch.params) compiledFetch.query = buf.String() if err != nil { return nil, err } return compiledFetch, nil } // FetchCursor returns a new cursor. func (compiledFetch *CompiledFetch[T]) FetchCursor(db DB, params Params) (*Cursor[T], error) { return compiledFetch.fetchCursor(context.Background(), db, params, 1) } // FetchCursorContext is like FetchCursor but additionally requires a context.Context. func (compiledFetch *CompiledFetch[T]) FetchCursorContext(ctx context.Context, db DB, params Params) (*Cursor[T], error) { return compiledFetch.fetchCursor(ctx, db, params, 1) } func (compiledFetch *CompiledFetch[T]) fetchCursor(ctx context.Context, db DB, params Params, skip int) (cursor *Cursor[T], err error) { if db == nil { return nil, fmt.Errorf("db is nil") } cursor = &Cursor[T]{ ctx: ctx, rowmapper: compiledFetch.rowmapper, row: &Row{ dialect: compiledFetch.dialect, queryIsStatic: compiledFetch.queryIsStatic, }, queryStats: QueryStats{ Dialect: compiledFetch.dialect, Query: compiledFetch.query, Args: compiledFetch.args, Params: compiledFetch.params, }, } // Call the rowmapper to populate row.scanDest. if !cursor.row.queryIsStatic { defer mapperFunctionPanicked(&err) _ = cursor.rowmapper(cursor.row) } // Substitute params. cursor.queryStats.Args, err = substituteParams(cursor.queryStats.Dialect, cursor.queryStats.Args, cursor.queryStats.Params, params) if err != nil { return nil, err } // Setup logger. cursor.queryStats.RowCount.Valid = true cursor.logger, _ = db.(SqLogger) if cursor.logger == nil { logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats)) if logQuery != nil { logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings)) cursor.logger = &sqLogStruct{ logSettings: logSettings, logQuery: logQuery, } } } if cursor.logger != nil { cursor.logger.SqLogSettings(ctx, &cursor.logSettings) if cursor.logSettings.IncludeCaller { cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1) } } // Run query. if cursor.logSettings.IncludeTime { cursor.queryStats.StartedAt = time.Now() } cursor.row.sqlRows, cursor.queryStats.Err = db.QueryContext(ctx, cursor.queryStats.Query, cursor.queryStats.Args...) if cursor.logSettings.IncludeTime { cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt) } if cursor.queryStats.Err != nil { return nil, cursor.queryStats.Err } // If the query is static, we now know the number of columns returned by // the query and can allocate the values slice and scanDest slice for // scanning later. if cursor.row.queryIsStatic { cursor.row.columns, err = cursor.row.sqlRows.Columns() if err != nil { return nil, err } cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes() if err != nil { return nil, err } cursor.row.columnIndex = make(map[string]int) for index, column := range cursor.row.columns { cursor.row.columnIndex[column] = index } cursor.row.values = make([]any, len(cursor.row.columns)) cursor.row.scanDest = make([]any, len(cursor.row.columns)) for index := range cursor.row.values { cursor.row.scanDest[index] = &cursor.row.values[index] } } // Allocate the resultsBuffer. if cursor.logSettings.IncludeResults > 0 { cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer) cursor.resultsBuffer.Reset() } return cursor, nil } // FetchOne returns the first result from running the CompiledFetch on the // given DB with the give params. func (compiledFetch *CompiledFetch[T]) FetchOne(db DB, params Params) (T, error) { cursor, err := compiledFetch.fetchCursor(context.Background(), db, params, 1) if err != nil { return *new(T), err } defer cursor.Close() return cursorResult(cursor) } // FetchOneContext is like FetchOne but additionally requires a context.Context. func (compiledFetch *CompiledFetch[T]) FetchOneContext(ctx context.Context, db DB, params Params) (T, error) { cursor, err := compiledFetch.fetchCursor(ctx, db, params, 1) if err != nil { return *new(T), err } defer cursor.Close() return cursorResult(cursor) } // FetchAll returns all the results from running the CompiledFetch on the given // DB with the give params. func (compiledFetch *CompiledFetch[T]) FetchAll(db DB, params Params) ([]T, error) { cursor, err := compiledFetch.fetchCursor(context.Background(), db, params, 1) if err != nil { return nil, err } defer cursor.Close() return cursorResults(cursor) } // FetchAllContext is like FetchAll but additionally requires a context.Context. func (compiledFetch *CompiledFetch[T]) FetchAllContext(ctx context.Context, db DB, params Params) ([]T, error) { cursor, err := compiledFetch.fetchCursor(ctx, db, params, 1) if err != nil { return nil, err } defer cursor.Close() return cursorResults(cursor) } // GetSQL returns a copy of the dialect, query, args, params and rowmapper that // make up the CompiledFetch. func (compiledFetch *CompiledFetch[T]) GetSQL() (dialect string, query string, args []any, params map[string][]int, rowmapper func(*Row) T) { dialect = compiledFetch.dialect query = compiledFetch.query args = make([]any, len(compiledFetch.args)) params = make(map[string][]int) copy(args, compiledFetch.args) for name, indexes := range compiledFetch.params { indexes2 := make([]int, len(indexes)) copy(indexes2, indexes) params[name] = indexes2 } return dialect, query, args, params, compiledFetch.rowmapper } // Prepare creates a PreparedFetch from a CompiledFetch by preparing it on // the given DB. func (compiledFetch *CompiledFetch[T]) Prepare(db DB) (*PreparedFetch[T], error) { return compiledFetch.PrepareContext(context.Background(), db) } // PrepareContext is like Prepare but additionally requires a context.Context. func (compiledFetch *CompiledFetch[T]) PrepareContext(ctx context.Context, db DB) (*PreparedFetch[T], error) { var err error preparedFetch := &PreparedFetch[T]{ compiledFetch: NewCompiledFetch(compiledFetch.GetSQL()), } preparedFetch.compiledFetch.queryIsStatic = compiledFetch.queryIsStatic if db == nil { return nil, fmt.Errorf("db is nil") } preparedFetch.stmt, err = db.PrepareContext(ctx, compiledFetch.query) if err != nil { return nil, err } preparedFetch.logger, _ = db.(SqLogger) if preparedFetch.logger == nil { logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats)) if logQuery != nil { logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings)) preparedFetch.logger = &sqLogStruct{ logSettings: logSettings, logQuery: logQuery, } } } return preparedFetch, nil } // PreparedFetch is the result of preparing a CompiledFetch on a DB. type PreparedFetch[T any] struct { compiledFetch *CompiledFetch[T] stmt *sql.Stmt logger SqLogger } // PrepareFetch returns a new PreparedFetch. func PrepareFetch[T any](db DB, q Query, rowmapper func(*Row) T) (*PreparedFetch[T], error) { return PrepareFetchContext(context.Background(), db, q, rowmapper) } // PrepareFetchContext is like PrepareFetch but additionally requires a context.Context. func PrepareFetchContext[T any](ctx context.Context, db DB, q Query, rowmapper func(*Row) T) (*PreparedFetch[T], error) { compiledFetch, err := CompileFetchContext(ctx, q, rowmapper) if err != nil { return nil, err } return compiledFetch.PrepareContext(ctx, db) } // FetchCursor returns a new cursor. func (preparedFetch PreparedFetch[T]) FetchCursor(params Params) (*Cursor[T], error) { return preparedFetch.fetchCursor(context.Background(), params, 1) } // FetchCursorContext is like FetchCursor but additionally requires a context.Context. func (preparedFetch PreparedFetch[T]) FetchCursorContext(ctx context.Context, params Params) (*Cursor[T], error) { return preparedFetch.fetchCursor(ctx, params, 1) } func (preparedFetch *PreparedFetch[T]) fetchCursor(ctx context.Context, params Params, skip int) (cursor *Cursor[T], err error) { cursor = &Cursor[T]{ ctx: ctx, rowmapper: preparedFetch.compiledFetch.rowmapper, row: &Row{ dialect: preparedFetch.compiledFetch.dialect, queryIsStatic: preparedFetch.compiledFetch.queryIsStatic, }, queryStats: QueryStats{ Dialect: preparedFetch.compiledFetch.dialect, Query: preparedFetch.compiledFetch.query, Args: preparedFetch.compiledFetch.args, Params: preparedFetch.compiledFetch.params, RowCount: sql.NullInt64{Valid: true}, }, logger: preparedFetch.logger, } // If the query is dynamic, call the rowmapper to populate row.scanDest. if !cursor.row.queryIsStatic { defer mapperFunctionPanicked(&err) _ = cursor.rowmapper(cursor.row) } // Substitute params. cursor.queryStats.Args, err = substituteParams(cursor.queryStats.Dialect, cursor.queryStats.Args, cursor.queryStats.Params, params) if err != nil { return nil, err } // Setup logger. if cursor.logger != nil { cursor.logger.SqLogSettings(ctx, &cursor.logSettings) if cursor.logSettings.IncludeCaller { cursor.queryStats.CallerFile, cursor.queryStats.CallerLine, cursor.queryStats.CallerFunction = caller(skip + 1) } } // Run query. if cursor.logSettings.IncludeTime { cursor.queryStats.StartedAt = time.Now() } cursor.row.sqlRows, cursor.queryStats.Err = preparedFetch.stmt.QueryContext(ctx, cursor.queryStats.Args...) if cursor.logSettings.IncludeTime { cursor.queryStats.TimeTaken = time.Since(cursor.queryStats.StartedAt) } if cursor.queryStats.Err != nil { return nil, cursor.queryStats.Err } // If the query is static, we now know the number of columns returned by // the query and can allocate the values slice and scanDest slice for // scanning later. if cursor.row.queryIsStatic { cursor.row.columns, err = cursor.row.sqlRows.Columns() if err != nil { return nil, err } cursor.row.columnTypes, err = cursor.row.sqlRows.ColumnTypes() if err != nil { return nil, err } cursor.row.columnIndex = make(map[string]int) for index, column := range cursor.row.columns { cursor.row.columnIndex[column] = index } cursor.row.values = make([]any, len(cursor.row.columns)) cursor.row.scanDest = make([]any, len(cursor.row.columns)) for index := range cursor.row.values { cursor.row.scanDest[index] = &cursor.row.values[index] } } // Allocate the resultsBuffer. if cursor.logSettings.IncludeResults > 0 { cursor.resultsBuffer = bufpool.Get().(*bytes.Buffer) cursor.resultsBuffer.Reset() } return cursor, nil } // FetchOne returns the first result from running the PreparedFetch with the // give params. func (preparedFetch *PreparedFetch[T]) FetchOne(params Params) (T, error) { cursor, err := preparedFetch.fetchCursor(context.Background(), params, 1) if err != nil { return *new(T), err } defer cursor.Close() return cursorResult(cursor) } // FetchOneContext is like FetchOne but additionally requires a context.Context. func (preparedFetch *PreparedFetch[T]) FetchOneContext(ctx context.Context, params Params) (T, error) { cursor, err := preparedFetch.fetchCursor(ctx, params, 1) if err != nil { return *new(T), err } defer cursor.Close() return cursorResult(cursor) } // FetchAll returns all the results from running the PreparedFetch with the // give params. func (preparedFetch *PreparedFetch[T]) FetchAll(params Params) ([]T, error) { cursor, err := preparedFetch.fetchCursor(context.Background(), params, 1) if err != nil { return nil, err } defer cursor.Close() return cursorResults(cursor) } // FetchAllContext is like FetchAll but additionally requires a context.Context. func (preparedFetch *PreparedFetch[T]) FetchAllContext(ctx context.Context, params Params) ([]T, error) { cursor, err := preparedFetch.fetchCursor(ctx, params, 1) if err != nil { return nil, err } defer cursor.Close() return cursorResults(cursor) } // GetCompiled returns a copy of the underlying CompiledFetch. func (preparedFetch *PreparedFetch[T]) GetCompiled() *CompiledFetch[T] { compiledFetch := NewCompiledFetch(preparedFetch.compiledFetch.GetSQL()) compiledFetch.queryIsStatic = preparedFetch.compiledFetch.queryIsStatic return compiledFetch } // Close closes the PreparedFetch. func (preparedFetch *PreparedFetch[T]) Close() error { if preparedFetch.stmt == nil { return nil } return preparedFetch.stmt.Close() } // Exec executes the given Query on the given DB. func Exec(db DB, query Query) (Result, error) { return exec(context.Background(), db, query, 1) } // ExecContext is like Exec but additionally requires a context.Context. func ExecContext(ctx context.Context, db DB, query Query) (Result, error) { return exec(ctx, db, query, 1) } func exec(ctx context.Context, db DB, query Query, skip int) (result Result, err error) { if db == nil { return result, fmt.Errorf("db is nil") } if query == nil { return result, fmt.Errorf("query is nil") } dialect := query.GetDialect() if dialect == "" { defaultDialect := DefaultDialect.Load() if defaultDialect != nil { dialect = *defaultDialect } } queryStats := QueryStats{ Dialect: dialect, Params: make(map[string][]int), } // Build query. buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) err = query.WriteSQL(ctx, dialect, buf, &queryStats.Args, queryStats.Params) queryStats.Query = buf.String() if err != nil { return result, err } // Setup logger. var logSettings LogSettings logger, _ := db.(SqLogger) if logger == nil { logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats)) if logQuery != nil { logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings)) logger = &sqLogStruct{ logSettings: logSettings, logQuery: logQuery, } } } if logger != nil { logger.SqLogSettings(ctx, &logSettings) if logSettings.IncludeCaller { queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1) } defer func() { if logSettings.LogAsynchronously { go logger.SqLogQuery(ctx, queryStats) } else { logger.SqLogQuery(ctx, queryStats) } }() } // Run query. if logSettings.IncludeTime { queryStats.StartedAt = time.Now() } var sqlResult sql.Result sqlResult, queryStats.Err = db.ExecContext(ctx, queryStats.Query, queryStats.Args...) if logSettings.IncludeTime { queryStats.TimeTaken = time.Since(queryStats.StartedAt) } if queryStats.Err != nil { return result, queryStats.Err } return execResult(sqlResult, &queryStats) } // CompiledExec is the result of compiling a Query down into a query string and // args slice. A CompiledExec can be safely executed in parallel. type CompiledExec struct { dialect string query string args []any params map[string][]int } // NewCompiledExec returns a new CompiledExec. func NewCompiledExec(dialect string, query string, args []any, params map[string][]int) *CompiledExec { return &CompiledExec{ dialect: dialect, query: query, args: args, params: params, } } // CompileExec returns a new CompiledExec. func CompileExec(query Query) (*CompiledExec, error) { return CompileExecContext(context.Background(), query) } // CompileExecContext is like CompileExec but additionally requires a context.Context. func CompileExecContext(ctx context.Context, query Query) (*CompiledExec, error) { if query == nil { return nil, fmt.Errorf("query is nil") } dialect := query.GetDialect() if dialect == "" { defaultDialect := DefaultDialect.Load() if defaultDialect != nil { dialect = *defaultDialect } } compiledExec := &CompiledExec{ dialect: dialect, params: make(map[string][]int), } // Build query. buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) err := query.WriteSQL(ctx, dialect, buf, &compiledExec.args, compiledExec.params) compiledExec.query = buf.String() if err != nil { return nil, err } return compiledExec, nil } // Exec executes the CompiledExec on the given DB with the given params. func (compiledExec *CompiledExec) Exec(db DB, params Params) (Result, error) { return compiledExec.exec(context.Background(), db, params, 1) } // ExecContext is like Exec but additionally requires a context.Context. func (compiledExec *CompiledExec) ExecContext(ctx context.Context, db DB, params Params) (Result, error) { return compiledExec.exec(ctx, db, params, 1) } func (compiledExec *CompiledExec) exec(ctx context.Context, db DB, params Params, skip int) (result Result, err error) { if db == nil { return result, fmt.Errorf("db is nil") } queryStats := QueryStats{ Dialect: compiledExec.dialect, Query: compiledExec.query, Args: compiledExec.args, Params: compiledExec.params, } // Setup logger. var logSettings LogSettings logger, _ := db.(SqLogger) if logger == nil { logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats)) if logQuery != nil { logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings)) logger = &sqLogStruct{ logSettings: logSettings, logQuery: logQuery, } } } if logger != nil { logger.SqLogSettings(ctx, &logSettings) if logSettings.IncludeCaller { queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1) } defer func() { if logSettings.LogAsynchronously { go logger.SqLogQuery(ctx, queryStats) } else { logger.SqLogQuery(ctx, queryStats) } }() } // Substitute params. queryStats.Args, err = substituteParams(queryStats.Dialect, queryStats.Args, queryStats.Params, params) if err != nil { return result, err } // Run query. if logSettings.IncludeTime { queryStats.StartedAt = time.Now() } var sqlResult sql.Result sqlResult, queryStats.Err = db.ExecContext(ctx, queryStats.Query, queryStats.Args...) if logSettings.IncludeTime { queryStats.TimeTaken = time.Since(queryStats.StartedAt) } if queryStats.Err != nil { return result, queryStats.Err } return execResult(sqlResult, &queryStats) } // GetSQL returns a copy of the dialect, query, args, params and rowmapper that // make up the CompiledExec. func (compiledExec *CompiledExec) GetSQL() (dialect string, query string, args []any, params map[string][]int) { dialect = compiledExec.dialect query = compiledExec.query args = make([]any, len(compiledExec.args)) params = make(map[string][]int) copy(args, compiledExec.args) for name, indexes := range compiledExec.params { indexes2 := make([]int, len(indexes)) copy(indexes2, indexes) params[name] = indexes2 } return dialect, query, args, params } // Prepare creates a PreparedExec from a CompiledExec by preparing it on the // given DB. func (compiledExec *CompiledExec) Prepare(db DB) (*PreparedExec, error) { return compiledExec.PrepareContext(context.Background(), db) } // PrepareContext is like Prepare but additionally requires a context.Context. func (compiledExec *CompiledExec) PrepareContext(ctx context.Context, db DB) (*PreparedExec, error) { var err error preparedExec := &PreparedExec{ compiledExec: NewCompiledExec(compiledExec.GetSQL()), } preparedExec.stmt, err = db.PrepareContext(ctx, compiledExec.query) if err != nil { return nil, err } preparedExec.logger, _ = db.(SqLogger) if preparedExec.logger == nil { logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats)) if logQuery != nil { logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings)) preparedExec.logger = &sqLogStruct{ logSettings: logSettings, logQuery: logQuery, } } } return preparedExec, nil } // PrepareExec is the result of preparing a CompiledExec on a DB. type PreparedExec struct { compiledExec *CompiledExec stmt *sql.Stmt logger SqLogger } // PrepareExec returns a new PreparedExec. func PrepareExec(db DB, q Query) (*PreparedExec, error) { return PrepareExecContext(context.Background(), db, q) } // PrepareExecContext is like PrepareExec but additionally requires a // context.Context. func PrepareExecContext(ctx context.Context, db DB, q Query) (*PreparedExec, error) { compiledExec, err := CompileExecContext(ctx, q) if err != nil { return nil, err } return compiledExec.PrepareContext(ctx, db) } // Close closes the PreparedExec. func (preparedExec *PreparedExec) Close() error { if preparedExec.stmt == nil { return nil } return preparedExec.stmt.Close() } // Exec executes the PreparedExec with the given params. func (preparedExec *PreparedExec) Exec(params Params) (Result, error) { return preparedExec.exec(context.Background(), params, 1) } // ExecContext is like Exec but additionally requires a context.Context. func (preparedExec *PreparedExec) ExecContext(ctx context.Context, params Params) (Result, error) { return preparedExec.exec(ctx, params, 1) } func (preparedExec *PreparedExec) exec(ctx context.Context, params Params, skip int) (result Result, err error) { queryStats := QueryStats{ Dialect: preparedExec.compiledExec.dialect, Query: preparedExec.compiledExec.query, Args: preparedExec.compiledExec.args, Params: preparedExec.compiledExec.params, } // Setup logger. var logSettings LogSettings if preparedExec.logger != nil { preparedExec.logger.SqLogSettings(ctx, &logSettings) if logSettings.IncludeCaller { queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1) } defer func() { if logSettings.LogAsynchronously { go preparedExec.logger.SqLogQuery(ctx, queryStats) } else { preparedExec.logger.SqLogQuery(ctx, queryStats) } }() } // Substitute params. queryStats.Args, err = substituteParams(queryStats.Dialect, queryStats.Args, queryStats.Params, params) if err != nil { return result, err } // Run query. if logSettings.IncludeTime { queryStats.StartedAt = time.Now() } var sqlResult sql.Result sqlResult, queryStats.Err = preparedExec.stmt.ExecContext(ctx, queryStats.Args...) if logSettings.IncludeTime { queryStats.TimeTaken = time.Since(queryStats.StartedAt) } if queryStats.Err != nil { return result, queryStats.Err } return execResult(sqlResult, &queryStats) } func getFieldNames(ctx context.Context, row *Row) []string { if len(row.fields) == 0 { columns, _ := row.sqlRows.Columns() return columns } buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) var args []any fieldNames := make([]string, 0, len(row.fields)) for _, field := range row.fields { if alias := getAlias(field); alias != "" { fieldNames = append(fieldNames, alias) continue } buf.Reset() args = args[:0] err := field.WriteSQL(ctx, row.dialect, buf, &args, nil) if err != nil { fieldNames = append(fieldNames, "%!(error="+err.Error()+")") continue } fieldName, err := Sprintf(row.dialect, buf.String(), args) if err != nil { fieldNames = append(fieldNames, "%!(error="+err.Error()+")") continue } fieldNames = append(fieldNames, fieldName) } return fieldNames } func getFieldMappings(dialect string, fields []Field, scanDest []any) string { var buf bytes.Buffer var args []any var b strings.Builder for i, field := range fields { b.WriteString(fmt.Sprintf("\n %02d. ", i+1)) buf.Reset() args = args[:0] err := field.WriteSQL(context.Background(), dialect, &buf, &args, nil) if err != nil { buf.WriteString("%!(error=" + err.Error() + ")") continue } fieldName, err := Sprintf(dialect, buf.String(), args) if err != nil { b.WriteString("%!(error=" + err.Error() + ")") continue } b.WriteString(fieldName + " => " + reflect.TypeOf(scanDest[i]).String()) } return b.String() } // TODO: inline cursorResult, cursorResults and execResult. func cursorResult[T any](cursor *Cursor[T]) (result T, err error) { for cursor.Next() { result, err = cursor.Result() if err != nil { return result, err } break } if cursor.RowCount() == 0 { return result, sql.ErrNoRows } return result, cursor.Close() } func cursorResults[T any](cursor *Cursor[T]) (results []T, err error) { var result T for cursor.Next() { result, err = cursor.Result() if err != nil { return results, err } results = append(results, result) } return results, cursor.Close() } func execResult(sqlResult sql.Result, queryStats *QueryStats) (Result, error) { var err error var result Result if queryStats.Dialect == DialectSQLite || queryStats.Dialect == DialectMySQL { result.LastInsertId, err = sqlResult.LastInsertId() if err != nil { return result, err } queryStats.LastInsertId.Valid = true queryStats.LastInsertId.Int64 = result.LastInsertId } result.RowsAffected, err = sqlResult.RowsAffected() if err != nil { return result, err } queryStats.RowsAffected.Valid = true queryStats.RowsAffected.Int64 = result.RowsAffected return result, nil } // FetchExists returns a boolean indicating if running the given Query on the // given DB returned any results. func FetchExists(db DB, query Query) (exists bool, err error) { return fetchExists(context.Background(), db, query, 1) } // FetchExistsContext is like FetchExists but additionally requires a // context.Context. func FetchExistsContext(ctx context.Context, db DB, query Query) (exists bool, err error) { return fetchExists(ctx, db, query, 1) } func fetchExists(ctx context.Context, db DB, query Query, skip int) (exists bool, err error) { dialect := query.GetDialect() if dialect == "" { defaultDialect := DefaultDialect.Load() if defaultDialect != nil { dialect = *defaultDialect } } queryStats := QueryStats{ Dialect: dialect, Params: make(map[string][]int), Exists: sql.NullBool{Valid: true}, } // Build query. buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) if dialect == DialectSQLServer { query = Queryf("SELECT CASE WHEN EXISTS ({}) THEN 1 ELSE 0 END", query) } else { query = Queryf("SELECT EXISTS ({})", query) } err = query.WriteSQL(ctx, dialect, buf, &queryStats.Args, queryStats.Params) queryStats.Query = buf.String() if err != nil { return false, err } // Setup logger. var logSettings LogSettings logger, _ := db.(SqLogger) if logger == nil { logQuery, _ := defaultLogQuery.Load().(func(context.Context, QueryStats)) if logQuery != nil { logSettings, _ := defaultLogSettings.Load().(func(context.Context, *LogSettings)) logger = &sqLogStruct{ logSettings: logSettings, logQuery: logQuery, } } } if logger != nil { logger.SqLogSettings(ctx, &logSettings) if logSettings.IncludeCaller { queryStats.CallerFile, queryStats.CallerLine, queryStats.CallerFunction = caller(skip + 1) } defer func() { if logSettings.LogAsynchronously { go logger.SqLogQuery(ctx, queryStats) } else { logger.SqLogQuery(ctx, queryStats) } }() } // Run query. if logSettings.IncludeTime { queryStats.StartedAt = time.Now() } var sqlRows *sql.Rows sqlRows, queryStats.Err = db.QueryContext(ctx, queryStats.Query, queryStats.Args...) if logSettings.IncludeTime { queryStats.TimeTaken = time.Since(queryStats.StartedAt) } if queryStats.Err != nil { return false, queryStats.Err } for sqlRows.Next() { err = sqlRows.Scan(&exists) if err != nil { return false, err } break } queryStats.Exists.Bool = exists if err := sqlRows.Close(); err != nil { return exists, err } if err := sqlRows.Err(); err != nil { return exists, err } return exists, nil } // substituteParams will return a new args slice by substituting values from // the given paramValues. The input args slice is untouched. func substituteParams(dialect string, args []any, paramIndexes map[string][]int, paramValues map[string]any) ([]any, error) { if len(paramValues) == 0 { return args, nil } newArgs := make([]any, len(args)) copy(newArgs, args) var err error for name, value := range paramValues { indexes := paramIndexes[name] for _, index := range indexes { switch arg := newArgs[index].(type) { case sql.NamedArg: arg.Value, err = preprocessValue(dialect, value) if err != nil { return nil, err } newArgs[index] = arg default: value, err = preprocessValue(dialect, value) if err != nil { return nil, err } newArgs[index] = value } } } return newArgs, nil } func caller(skip int) (file string, line int, function string) { pc, file, line, _ := runtime.Caller(skip + 1) fn := runtime.FuncForPC(pc) function = fn.Name() return file, line, function } ================================================ FILE: fetch_exec_test.go ================================================ package sq import ( "database/sql" "testing" "time" "github.com/bokwoon95/sq/internal/testutil" _ "github.com/mattn/go-sqlite3" ) var ACTOR = New[struct { TableStruct `sq:"actor"` ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField }]("") type Actor struct { ActorID int FirstName string LastName string LastUpdate time.Time } func actorRowMapper(row *Row) Actor { var actor Actor actorID, _ := row.Value("actor.actor_id").(int64) actor.ActorID = int(actorID) actor.FirstName = row.StringField(ACTOR.FIRST_NAME) actor.LastName = row.StringField(ACTOR.LAST_NAME) actor.LastUpdate, _ = row.Value("actor.last_update").(time.Time) return actor } func actorRowMapperRawSQL(row *Row) Actor { result := make(map[string]any) values := row.Values() for i, column := range row.Columns() { result[column] = values[i] } var actor Actor actorID, _ := result["actor_id"].(int64) actor.ActorID = int(actorID) actor.FirstName, _ = result["first_name"].(string) actor.LastName, _ = result["last_name"].(string) actor.LastUpdate, _ = result["last_update"].(time.Time) return actor } func Test_substituteParams(t *testing.T) { t.Run("no params provided", func(t *testing.T) { t.Parallel() args := []any{1, 2, 3} params := map[string][]int{"one": {0}, "two": {1}, "three": {2}} gotArgs, err := substituteParams("", args, params, nil) if err != nil { t.Fatal(testutil.Callers(), err) } wantArgs := []any{1, 2, 3} if diff := testutil.Diff(gotArgs, wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("not all params provided", func(t *testing.T) { t.Parallel() args := []any{1, 2, 3} params := map[string][]int{"one": {0}, "two": {1}, "three": {2}} paramValues := Params{"one": "One", "two": "Two"} gotArgs, err := substituteParams("", args, params, paramValues) if err != nil { t.Fatal(testutil.Callers(), err) } wantArgs := []any{"One", "Two", 3} if diff := testutil.Diff(gotArgs, wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("params substituted", func(t *testing.T) { t.Parallel() type Data struct { id int name string } args := []any{ 0, sql.Named("one", 1), sql.Named("two", 2), 3, } params := map[string][]int{ "zero": {0}, "one": {1}, "two": {2}, "three": {3}, } paramValues := Params{ "one": "[one]", "two": "[two]", "three": "[three]", } wantArgs := []any{ 0, sql.Named("one", "[one]"), sql.Named("two", "[two]"), "[three]", } gotArgs, err := substituteParams("", args, params, paramValues) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(gotArgs, wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } }) } func Test_getFieldMappings(t *testing.T) { type TestTable struct { description string dialect string fields []Field scanDest []any wantFieldMappings string } var tests = []TestTable{{ description: "empty", wantFieldMappings: "", }, { description: "basic", fields: []Field{ Expr("actor_id"), Expr("first_name || {} || last_name", " "), Expr("last_update"), }, scanDest: []any{ &sql.NullInt64{}, &sql.NullString{}, &sql.NullTime{}, }, wantFieldMappings: "" + "\n 01. actor_id => *sql.NullInt64" + "\n 02. first_name || ' ' || last_name => *sql.NullString" + "\n 03. last_update => *sql.NullTime", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() gotFieldMappings := getFieldMappings(tt.dialect, tt.fields, tt.scanDest) if diff := testutil.Diff(gotFieldMappings, tt.wantFieldMappings); diff != "" { t.Error(testutil.Callers(), diff) } }) } } func TestFetchExec(t *testing.T) { t.Parallel() db := newDB(t) var referenceActors = []Actor{ {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()}, } // Exec. res, err := Exec(Log(db), SQLite. InsertInto(ACTOR). ColumnValues(func(col *Column) { for _, actor := range referenceActors { col.SetInt(ACTOR.ACTOR_ID, actor.ActorID) col.SetString(ACTOR.FIRST_NAME, actor.FirstName) col.SetString(ACTOR.LAST_NAME, actor.LastName) col.SetTime(ACTOR.LAST_UPDATE, actor.LastUpdate) } }), ) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(res.RowsAffected, int64(len(referenceActors))); diff != "" { t.Fatal(testutil.Callers(), diff) } // FetchOne. actor, err := FetchOne(Log(db), SQLite. From(ACTOR). Where(ACTOR.ACTOR_ID.EqInt(1)), actorRowMapper, ) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { t.Fatal(testutil.Callers(), diff) } // FetchOne (Raw SQL). actor, err = FetchOne(Log(db), SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {}", 1), actorRowMapperRawSQL, ) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { t.Fatal(testutil.Callers(), diff) } // FetchAll. actors, err := FetchAll(VerboseLog(db), SQLite. From(ACTOR). OrderBy(ACTOR.ACTOR_ID), actorRowMapper, ) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actors, referenceActors); diff != "" { t.Fatal(testutil.Callers(), err) } // FetchAll (RawSQL). actors, err = FetchAll(VerboseLog(db), SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"), actorRowMapperRawSQL, ) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actors, referenceActors); diff != "" { t.Fatal(testutil.Callers(), err) } } func TestCompiledFetchExec(t *testing.T) { t.Parallel() db := newDB(t) var referenceActors = []Actor{ {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()}, } // CompiledExec. compiledExec, err := CompileExec(SQLite. InsertInto(ACTOR). ColumnValues(func(col *Column) { col.Set(ACTOR.ACTOR_ID, IntParam("actor_id", 0)) col.Set(ACTOR.FIRST_NAME, StringParam("first_name", "")) col.Set(ACTOR.LAST_NAME, StringParam("last_name", "")) col.Set(ACTOR.LAST_UPDATE, TimeParam("last_update", time.Time{})) }), ) if err != nil { t.Fatal(testutil.Callers(), err) } for _, actor := range referenceActors { _, err = compiledExec.Exec(Log(db), Params{ "actor_id": actor.ActorID, "first_name": actor.FirstName, "last_name": actor.LastName, "last_update": actor.LastUpdate, }) if err != nil { t.Fatal(testutil.Callers(), err) } } // CompiledFetch FetchOne. compiledFetch, err := CompileFetch(SQLite. From(ACTOR). Where(ACTOR.ACTOR_ID.Eq(IntParam("actor_id", 0))), actorRowMapper, ) if err != nil { t.Fatal(testutil.Callers(), err) } actor, err := compiledFetch.FetchOne(Log(db), Params{"actor_id": 1}) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { t.Fatal(testutil.Callers(), diff) } // CompiledFetch FetchOne (Raw SQL). compiledFetch, err = CompileFetch( SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {actor_id}", IntParam("actor_id", 0)), actorRowMapperRawSQL, ) if err != nil { t.Fatal(testutil.Callers(), err) } actor, err = compiledFetch.FetchOne(Log(db), Params{"actor_id": 1}) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { t.Fatal(testutil.Callers(), diff) } // CompiledFetch FetchAll. compiledFetch, err = CompileFetch(SQLite. From(ACTOR). OrderBy(ACTOR.ACTOR_ID), actorRowMapper, ) if err != nil { t.Fatal(testutil.Callers(), err) } actors, err := compiledFetch.FetchAll(VerboseLog(db), nil) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actors, referenceActors); diff != "" { t.Fatal(testutil.Callers(), diff) } // CompiledFetch FetchAll (Raw SQL). compiledFetch, err = CompileFetch( SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"), actorRowMapperRawSQL, ) if err != nil { t.Fatal(testutil.Callers(), err) } actors, err = compiledFetch.FetchAll(VerboseLog(db), nil) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actors, referenceActors); diff != "" { t.Fatal(testutil.Callers(), diff) } } func TestPreparedFetchExec(t *testing.T) { t.Parallel() db := newDB(t) var referenceActors = []Actor{ {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()}, {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()}, } // PreparedExec. preparedExec, err := PrepareExec(Log(db), SQLite. InsertInto(ACTOR). ColumnValues(func(col *Column) { col.Set(ACTOR.ACTOR_ID, IntParam("actor_id", 0)) col.Set(ACTOR.FIRST_NAME, StringParam("first_name", "")) col.Set(ACTOR.LAST_NAME, StringParam("last_name", "")) col.Set(ACTOR.LAST_UPDATE, TimeParam("last_update", time.Time{})) }), ) if err != nil { t.Fatal(testutil.Callers(), err) } for _, actor := range referenceActors { _, err = preparedExec.Exec(Params{ "actor_id": actor.ActorID, "first_name": actor.FirstName, "last_name": actor.LastName, "last_update": actor.LastUpdate, }) if err != nil { t.Fatal(testutil.Callers(), err) } } // PreparedFetch FetchOne. preparedFetch, err := PrepareFetch(Log(db), SQLite. From(ACTOR). Where(ACTOR.ACTOR_ID.Eq(IntParam("actor_id", 0))), actorRowMapper, ) if err != nil { t.Fatal(testutil.Callers(), err) } actor, err := preparedFetch.FetchOne(Params{"actor_id": 1}) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { t.Fatal(testutil.Callers(), diff) } // PreparedFetch FetchOne (Raw SQL). preparedFetch, err = PrepareFetch( Log(db), SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {actor_id}", IntParam("actor_id", 0)), actorRowMapperRawSQL, ) if err != nil { t.Fatal(testutil.Callers(), err) } actor, err = preparedFetch.FetchOne(Params{"actor_id": 1}) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { t.Fatal(testutil.Callers(), diff) } // PreparedFetch FetchAll. preparedFetch, err = PrepareFetch(VerboseLog(db), SQLite. From(ACTOR). OrderBy(ACTOR.ACTOR_ID), actorRowMapper, ) if err != nil { t.Fatal(testutil.Callers(), err) } actors, err := preparedFetch.FetchAll(nil) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actors, referenceActors); diff != "" { t.Fatal(testutil.Callers(), diff) } // PreparedFetch FetchAll (Raw SQL). preparedFetch, err = PrepareFetch( VerboseLog(db), SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"), actorRowMapperRawSQL, ) if err != nil { t.Fatal(testutil.Callers(), err) } actors, err = preparedFetch.FetchAll(nil) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(actors, referenceActors); diff != "" { t.Fatal(testutil.Callers(), diff) } } func newDB(t *testing.T) *sql.DB { db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatal(testutil.Callers(), err) } _, err = db.Exec(`CREATE TABLE actor ( actor_id INTEGER PRIMARY KEY AUTOINCREMENT ,first_name TEXT NOT NULL ,last_name TEXT NOT NULL ,last_update DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP )`) if err != nil { t.Fatal(testutil.Callers(), err) } return db } ================================================ FILE: fields.go ================================================ package sq import ( "bytes" "context" "database/sql" "database/sql/driver" "fmt" "reflect" "strconv" "strings" "time" ) // Identifier represents an SQL identifier. If necessary, it will be quoted // according to the dialect. type Identifier string var _ Field = (*Identifier)(nil) // WriteSQL implements the SQLWriter interface. func (id Identifier) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { buf.WriteString(QuoteIdentifier(dialect, string(id))) return nil } // IsField implements the Field interface. func (id Identifier) IsField() {} // AnyField is a catch-all field type that satisfies the Any interface. type AnyField struct { table TableStruct name string alias string desc sql.NullBool nullsfirst sql.NullBool } var _ interface { Field Any WithPrefix(string) Field } = (*AnyField)(nil) // NewAnyField returns a new AnyField. func NewAnyField(name string, tbl TableStruct) AnyField { return AnyField{table: tbl, name: name} } // WriteSQL implements the SQLWriter interface. func (field AnyField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) writeFieldOrder(ctx, dialect, buf, args, params, field.desc, field.nullsfirst) return nil } // As returns a new AnyField with the given alias. func (field AnyField) As(alias string) AnyField { field.alias = alias return field } // Asc returns a new AnyField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field ASC'. func (field AnyField) Asc() AnyField { field.desc.Valid = true field.desc.Bool = false return field } // Desc returns a new AnyField indicating that it should be ordered in descending // order i.e. 'ORDER BY field DESC'. func (field AnyField) Desc() AnyField { field.desc.Valid = true field.desc.Bool = true return field } // NullsLast returns a new NumberField indicating that it should be ordered // with nulls last i.e. 'ORDER BY field NULLS LAST'. func (field AnyField) NullsLast() AnyField { field.nullsfirst.Valid = true field.nullsfirst.Bool = false return field } // NullsFirst returns a new NumberField indicating that it should be ordered // with nulls first i.e. 'ORDER BY field NULLS FIRST'. func (field AnyField) NullsFirst() AnyField { field.nullsfirst.Valid = true field.nullsfirst.Bool = true return field } // WithPrefix returns a new Field that with the given prefix. func (field AnyField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field AnyField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field AnyField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // In returns a 'field IN (value)' Predicate. The value can be a slice, which // corresponds to the expression 'field IN (x, y, z)'. func (field AnyField) In(value any) Predicate { return In(field, value) } // In returns a 'field NOT IN (value)' Predicate. The value can be a slice, which // corresponds to the expression 'field NOT IN (x, y, z)'. func (field AnyField) NotIn(value any) Predicate { return NotIn(field, value) } // Eq returns a 'field = value' Predicate. func (field AnyField) Eq(value any) Predicate { return Eq(field, value) } // Ne returns a 'field <> value' Predicate. func (field AnyField) Ne(value any) Predicate { return Ne(field, value) } // Lt returns a 'field < value' Predicate. func (field AnyField) Lt(value any) Predicate { return Lt(field, value) } // Le returns a 'field <= value' Predicate. func (field AnyField) Le(value any) Predicate { return Le(field, value) } // Gt returns a 'field > value' Predicate. func (field AnyField) Gt(value any) Predicate { return Gt(field, value) } // Ge returns a 'field >= value' Predicate. func (field AnyField) Ge(value any) Predicate { return Ge(field, value) } // Expr returns an expression where the field is prepended to the front of the // expression. func (field AnyField) Expr(format string, values ...any) Expression { values = append(values, field) ordinal := len(values) return Expr("{"+strconv.Itoa(ordinal)+"} "+format, values...) } // Set returns an Assignment assigning the value to the field. func (field AnyField) Set(value any) Assignment { return Set(field, value) } // Setf returns an Assignment assigning an expression to the field. func (field AnyField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // GetAlias returns the alias of the AnyField. func (field AnyField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field AnyField) IsField() {} // IsArray implements the Array interface. func (field AnyField) IsArray() {} // IsBinary implements the Binary interface. func (field AnyField) IsBinary() {} // IsBoolean implements the Boolean interface. func (field AnyField) IsBoolean() {} // IsEnum implements the Enum interface. func (field AnyField) IsEnum() {} // IsJSON implements the JSONValue interface. func (field AnyField) IsJSON() {} // IsNumber implements the Number interface. func (field AnyField) IsNumber() {} // IsString implements the String interface. func (field AnyField) IsString() {} // IsTime implements the Time interface. func (field AnyField) IsTime() {} // IsUUIDType implements the UUID interface. func (field AnyField) IsUUID() {} // ArrayField represents an SQL array field. type ArrayField struct { table TableStruct name string alias string } var _ interface { Field Array WithPrefix(string) Field } = (*ArrayField)(nil) // NewArrayField returns a new ArrayField. func NewArrayField(fieldName string, tableName TableStruct) ArrayField { return ArrayField{table: tableName, name: fieldName} } // WriteSQL implements the SQLWriter interface. func (field ArrayField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) return nil } // As returns a new ArrayField with the given alias. func (field ArrayField) As(alias string) ArrayField { field.alias = alias return field } // WithPrefix returns a new Field that with the given prefix. func (field ArrayField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field ArrayField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNull returns a 'field IS NOT NULL' Predicate. func (field ArrayField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // Set returns an Assignment assigning the value to the field. func (field ArrayField) Set(value any) Assignment { switch value.(type) { case SQLWriter: return Set(field, value) case []string, []int, []int64, []int32, []float64, []float32, []bool: return Set(field, ArrayValue(value)) } return Set(field, value) } // SetArray returns an Assignment assigning the value to the field. It wraps // the value with ArrayValue(). func (field ArrayField) SetArray(value any) Assignment { return Set(field, ArrayValue(value)) } // Setf returns an Assignment assigning an expression to the field. func (field ArrayField) Setf(format string, values ...any) Assignment { return Set(field, Expr(format, values...)) } // GetAlias returns the alias of the ArrayField. func (field ArrayField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field ArrayField) IsField() {} // IsArray implements the Array interface. func (field ArrayField) IsArray() {} // BinaryField represents an SQL binary field. type BinaryField struct { table TableStruct name string alias string desc sql.NullBool nullsfirst sql.NullBool } var _ interface { Field Binary WithPrefix(string) Field } = (*BinaryField)(nil) // NewBinaryField returns a new BinaryField. func NewBinaryField(fieldName string, tableName TableStruct) BinaryField { return BinaryField{table: tableName, name: fieldName} } // WriteSQL implements the SQLWriter interface. func (field BinaryField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) writeFieldOrder(ctx, dialect, buf, args, params, field.desc, field.nullsfirst) return nil } // As returns a new BinaryField with the given alias. func (field BinaryField) As(alias string) BinaryField { field.alias = alias return field } // Asc returns a new BinaryField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field ASC'. func (field BinaryField) Asc() BinaryField { field.desc.Valid = true field.desc.Bool = false return field } // Desc returns a new BinaryField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field DESC'. func (field BinaryField) Desc() BinaryField { field.desc.Valid = true field.desc.Bool = true return field } // NullsLast returns a new BinaryField indicating that it should be ordered // with nulls last i.e. 'ORDER BY field NULLS LAST'. func (field BinaryField) NullsLast() BinaryField { field.nullsfirst.Valid = true field.nullsfirst.Bool = false return field } // NullsFirst returns a new BinaryField indicating that it should be ordered // with nulls first i.e. 'ORDER BY field NULLS FIRST'. func (field BinaryField) NullsFirst() BinaryField { field.nullsfirst.Valid = true field.nullsfirst.Bool = true return field } // WithPrefix returns a new Field that with the given prefix. func (field BinaryField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field BinaryField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field BinaryField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // Eq returns a 'field = value' Predicate. func (field BinaryField) Eq(value Binary) Predicate { return Eq(field, value) } // Ne returns a 'field <> value' Predicate. func (field BinaryField) Ne(value Binary) Predicate { return Ne(field, value) } // EqBytes returns a 'field = b' Predicate. func (field BinaryField) EqBytes(b []byte) Predicate { return Eq(field, b) } // NeBytes returns a 'field <> b' Predicate. func (field BinaryField) NeBytes(b []byte) Predicate { return Ne(field, b) } // Set returns an Assignment assigning the value to the field. func (field BinaryField) Set(value any) Assignment { return Set(field, value) } // Setf returns an Assignment assigning an expression to the field. func (field BinaryField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // SetBytes returns an Assignment assigning a []byte to the field. func (field BinaryField) SetBytes(b []byte) Assignment { return Set(field, b) } // GetAlias returns the alias of the BinaryField. func (field BinaryField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field BinaryField) IsField() {} // IsBinary implements the Binary interface. func (field BinaryField) IsBinary() {} // BooleanField represents an SQL boolean field. type BooleanField struct { table TableStruct name string alias string desc sql.NullBool nullsfirst sql.NullBool } var _ interface { Field Boolean Predicate WithPrefix(string) Field } = (*BooleanField)(nil) // NewBooleanField returns a new BooleanField. func NewBooleanField(fieldName string, tableName TableStruct) BooleanField { return BooleanField{table: tableName, name: fieldName} } // WriteSQL implements the SQLWriter interface. func (field BooleanField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) writeFieldOrder(ctx, dialect, buf, args, params, field.desc, field.nullsfirst) return nil } // As returns a new BooleanField with the given alias. func (field BooleanField) As(alias string) BooleanField { field.alias = alias return field } // Asc returns a new BooleanField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field ASC'. func (field BooleanField) Asc() BooleanField { field.desc.Valid = true field.desc.Bool = false return field } // Desc returns a new BooleanField indicating that it should be ordered in // descending order i.e. 'ORDER BY field DESC'. func (f BooleanField) Desc() BooleanField { f.desc.Valid = true f.desc.Bool = true return f } // NullsLast returns a new BooleanField indicating that it should be ordered // with nulls last i.e. 'ORDER BY field NULLS LAST'. func (field BooleanField) NullsLast() BooleanField { field.nullsfirst.Valid = true field.nullsfirst.Bool = false return field } // NullsFirst returns a new BooleanField indicating that it should be ordered // with nulls first i.e. 'ORDER BY field NULLS FIRST'. func (field BooleanField) NullsFirst() BooleanField { field.nullsfirst.Valid = true field.nullsfirst.Bool = true return field } // WithPrefix returns a new Field that with the given prefix. func (field BooleanField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field BooleanField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field BooleanField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // Eq returns a 'field = value' Predicate. func (field BooleanField) Eq(value Boolean) Predicate { return Eq(field, value) } // Ne returns a 'field <> value' Predicate. func (field BooleanField) Ne(value Boolean) Predicate { return Ne(field, value) } // EqBool returns a 'field = b' Predicate. func (field BooleanField) EqBool(b bool) Predicate { return Eq(field, b) } // NeBool returns a 'field <> b' Predicate. func (field BooleanField) NeBool(b bool) Predicate { return Ne(field, b) } // Set returns an Assignment assigning the value to the field. func (field BooleanField) Set(value any) Assignment { return Set(field, value) } // Setf returns an Assignment assigning an expression to the field. func (field BooleanField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // SetBool returns an Assignment assigning a bool to the field i.e. 'field = // b'. func (field BooleanField) SetBool(b bool) Assignment { return Set(field, b) } // GetAlias returns the alias of the BooleanField. func (field BooleanField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field BooleanField) IsField() {} // IsBoolean implements the Boolean interface. func (field BooleanField) IsBoolean() {} // EnumField represents an SQL enum field. type EnumField struct { table TableStruct name string alias string } var _ interface { Field Enum WithPrefix(string) Field } = (*EnumField)(nil) // NewEnumField returns a new EnumField. func NewEnumField(name string, tbl TableStruct) EnumField { return EnumField{table: tbl, name: name} } // WriteSQL implements the SQLWriter interface. func (field EnumField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) return nil } // As returns a new EnumField with the given alias. func (field EnumField) As(alias string) EnumField { field.alias = alias return field } // WithPrefix returns a new Field that with the given prefix. func (field EnumField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field EnumField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field EnumField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // In returns a 'field IN (value)' Predicate. The value can be a slice, which // corresponds to the expression 'field IN (x, y, z)'. func (field EnumField) In(value any) Predicate { return In(field, value) } // NotIn returns a 'field NOT IN (value)' Predicate. The value can be a slice, which // corresponds to the expression 'field NOT IN (x, y, z)'. func (field EnumField) NotIn(value any) Predicate { return NotIn(field, value) } // Eq returns a 'field = value' Predicate. func (field EnumField) Eq(value any) Predicate { return Eq(field, value) } // Ne returns a 'field <> value' Predicate. func (field EnumField) Ne(value any) Predicate { return Ne(field, value) } // EqEnum returns a 'field = value' Predicate. It wraps the value with // EnumValue(). func (field EnumField) EqEnum(value Enumeration) Predicate { return Eq(field, EnumValue(value)) } // NeEnum returns a 'field <> value' Predicate. it wraps the value with // EnumValue(). func (field EnumField) NeEnum(value Enumeration) Predicate { return Ne(field, EnumValue(value)) } // Set returns an Assignment assigning the value to the field. func (field EnumField) Set(value any) Assignment { return Set(field, value) } // SetEnum returns an Assignment assigning the value to the field. It wraps the // value with EnumValue(). func (field EnumField) SetEnum(value Enumeration) Assignment { return Set(field, EnumValue(value)) } // Setf returns an Assignment assigning an expression to the field. func (field EnumField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // GetAlias returns the alias of the EnumField. func (field EnumField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field EnumField) IsField() {} // IsEnum implements the Enum interface. func (field EnumField) IsEnum() {} // JSONField represents an SQL JSON field. type JSONField struct { table TableStruct name string alias string } var _ interface { Field Binary JSON String WithPrefix(string) Field } = (*JSONField)(nil) // NewJSONField returns a new JSONField. func NewJSONField(name string, tbl TableStruct) JSONField { return JSONField{table: tbl, name: name} } // WriteSQL implements the SQLWriter interface. func (field JSONField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) return nil } // As returns a new JSONField with the given alias. func (field JSONField) As(alias string) JSONField { field.alias = alias return field } // WithPrefix returns a new Field that with the given prefix. func (field JSONField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field JSONField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field JSONField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // Set returns an Assignment assigning the value to the field. func (field JSONField) Set(value any) Assignment { switch value.(type) { case []byte, driver.Valuer, SQLWriter: return Set(field, value) } switch reflect.TypeOf(value).Kind() { case reflect.Map, reflect.Struct, reflect.Slice, reflect.Array: return Set(field, JSONValue(value)) } return Set(field, value) } // SetJSON returns an Assignment assigning the value to the field. It wraps the // value in JSONValue(). func (field JSONField) SetJSON(value any) Assignment { return Set(field, JSONValue(value)) } // Setf returns an Assignment assigning an expression to the field. func (field JSONField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // GetAlias returns the alias of the JSONField. func (field JSONField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field JSONField) IsField() {} // IsBinary implements the Binary interface. func (field JSONField) IsBinary() {} // IsJSON implements the JSON interface. func (field JSONField) IsJSON() {} // IsString implements the String interface. func (field JSONField) IsString() {} // NumberField represents an SQL number field. type NumberField struct { table TableStruct name string alias string desc sql.NullBool nullsfirst sql.NullBool } var _ interface { Field Number WithPrefix(string) Field } = (*NumberField)(nil) // NewNumberField returns a new NumberField. func NewNumberField(name string, tbl TableStruct) NumberField { return NumberField{table: tbl, name: name} } // WriteSQL implements the SQLWriter interface. func (field NumberField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) writeFieldOrder(ctx, dialect, buf, args, params, field.desc, field.nullsfirst) return nil } // As returns a new NumberField with the given alias. func (field NumberField) As(alias string) NumberField { field.alias = alias return field } // Asc returns a new NumberField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field ASC'. func (field NumberField) Asc() NumberField { field.desc.Valid = true field.desc.Bool = false return field } // Desc returns a new NumberField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field DESC'. func (field NumberField) Desc() NumberField { field.desc.Valid = true field.desc.Bool = true return field } // NullsLast returns a new NumberField indicating that it should be ordered // with nulls last i.e. 'ORDER BY field NULLS LAST'. func (field NumberField) NullsLast() NumberField { field.nullsfirst.Valid = true field.nullsfirst.Bool = false return field } // NullsFirst returns a new NumberField indicating that it should be ordered // with nulls first i.e. 'ORDER BY field NULLS FIRST'. func (field NumberField) NullsFirst() NumberField { field.nullsfirst.Valid = true field.nullsfirst.Bool = true return field } // WithPrefix returns a new Field that with the given prefix. func (field NumberField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field NumberField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field NumberField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // In returns a 'field IN (value)' Predicate. The value can be a slice, which // corresponds to the expression 'field IN (x, y, z)'. func (field NumberField) In(value any) Predicate { return In(field, value) } // NotIn returns a 'field NOT IN (value)' Predicate. The value can be a slice, // which corresponds to the expression 'field IN (x, y, z)'. func (field NumberField) NotIn(value any) Predicate { return NotIn(field, value) } // Eq returns a 'field = value' Predicate. func (field NumberField) Eq(value Number) Predicate { return Eq(field, value) } // Ne returns a 'field <> value' Predicate. func (field NumberField) Ne(value Number) Predicate { return Ne(field, value) } // Lt returns a 'field < value' Predicate. func (field NumberField) Lt(value Number) Predicate { return Lt(field, value) } // Le returns a 'field <= value' Predicate. func (field NumberField) Le(value Number) Predicate { return Le(field, value) } // Gt returns a 'field > value' Predicate. func (field NumberField) Gt(value Number) Predicate { return Gt(field, value) } // Ge returns a 'field >= value' Predicate. func (field NumberField) Ge(value Number) Predicate { return Ge(field, value) } // EqInt returns a 'field = num' Predicate. func (field NumberField) EqInt(num int) Predicate { return Eq(field, num) } // NeInt returns a 'field <> num' Predicate. func (field NumberField) NeInt(num int) Predicate { return Ne(field, num) } // LtInt returns a 'field < num' Predicate. func (field NumberField) LtInt(num int) Predicate { return Lt(field, num) } // LeInt returns a 'field <= num' Predicate. func (field NumberField) LeInt(num int) Predicate { return Le(field, num) } // GtInt returns a 'field > num' Predicate. func (field NumberField) GtInt(num int) Predicate { return Gt(field, num) } // GeInt returns a 'field >= num' Predicate. func (field NumberField) GeInt(num int) Predicate { return Ge(field, num) } // EqInt64 returns a 'field = num' Predicate. func (field NumberField) EqInt64(num int64) Predicate { return Eq(field, num) } // NeInt64 returns a 'field <> num' Predicate. func (field NumberField) NeInt64(num int64) Predicate { return Ne(field, num) } // LtInt64 returns a 'field < num' Predicate. func (field NumberField) LtInt64(num int64) Predicate { return Lt(field, num) } // LeInt64 returns a 'field <= num' Predicate. func (field NumberField) LeInt64(num int64) Predicate { return Le(field, num) } // GtInt64 returns a 'field > num' Predicate. func (field NumberField) GtInt64(num int64) Predicate { return Gt(field, num) } // GeInt64 returns a 'field >= num' Predicate. func (field NumberField) GeInt64(num int64) Predicate { return Ge(field, num) } // EqFloat64 returns a 'field = num' Predicate. func (field NumberField) EqFloat64(num float64) Predicate { return Eq(field, num) } // NeFloat64 returns a 'field <> num' Predicate. func (field NumberField) NeFloat64(num float64) Predicate { return Ne(field, num) } // LtFloat64 returns a 'field < num' Predicate. func (field NumberField) LtFloat64(num float64) Predicate { return Lt(field, num) } // LeFloat64 returns a 'field <= num' Predicate. func (field NumberField) LeFloat64(num float64) Predicate { return Le(field, num) } // GtFloat64 returns a 'field > num' Predicate. func (field NumberField) GtFloat64(num float64) Predicate { return Gt(field, num) } // GeFloat64 returns a 'field >= num' Predicate. func (field NumberField) GeFloat64(num float64) Predicate { return Ge(field, num) } // Set returns an Assignment assigning the value to the field. func (field NumberField) Set(value any) Assignment { return Set(field, value) } // Setf returns an Assignment assigning an expression to the field. func (field NumberField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // SetBytes returns an Assignment assigning an int to the field. func (field NumberField) SetInt(num int) Assignment { return Set(field, num) } // SetBytes returns an Assignment assigning an int64 to the field. func (field NumberField) SetInt64(num int64) Assignment { return Set(field, num) } // SetBytes returns an Assignment assigning an float64 to the field. func (field NumberField) SetFloat64(num float64) Assignment { return Set(field, num) } // GetAlias returns the alias of the NumberField. func (field NumberField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field NumberField) IsField() {} // IsNumber implements the Number interface. func (field NumberField) IsNumber() {} // StringField represents an SQL string field. type StringField struct { table TableStruct name string alias string collation string desc sql.NullBool nullsfirst sql.NullBool } var _ interface { Field String WithPrefix(string) Field } = (*StringField)(nil) // NewStringField returns a new StringField. func NewStringField(name string, tbl TableStruct) StringField { return StringField{table: tbl, name: name} } // WriteSQL implements the SQLWriter interface. func (field StringField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) if field.collation != "" { buf.WriteString(" COLLATE ") if dialect == DialectPostgres { buf.WriteString(`"` + EscapeQuote(field.collation, '"') + `"`) } else { buf.WriteString(QuoteIdentifier(dialect, field.collation)) } } writeFieldOrder(ctx, dialect, buf, args, params, field.desc, field.nullsfirst) return nil } // As returns a new StringField with the given alias. func (field StringField) As(alias string) StringField { field.alias = alias return field } // Collate returns a new StringField using the given collation. func (field StringField) Collate(collation string) StringField { field.collation = collation return field } // Asc returns a new StringField indicating that it should be ordered in // ascending order i.e. 'ORDER BY field ASC'. func (field StringField) Asc() StringField { field.desc.Valid = true field.desc.Bool = false return field } // Desc returns a new StringField indicating that it should be ordered in // descending order i.e. 'ORDER BY field DESC'. func (field StringField) Desc() StringField { field.desc.Valid = true field.desc.Bool = true return field } // NullsLast returns a new StringField indicating that it should be ordered // with nulls last i.e. 'ORDER BY field NULLS LAST'. func (field StringField) NullsLast() StringField { field.nullsfirst.Valid = true field.nullsfirst.Bool = false return field } // NullsFirst returns a new StringField indicating that it should be ordered // with nulls first i.e. 'ORDER BY field NULLS FIRST'. func (field StringField) NullsFirst() StringField { field.nullsfirst.Valid = true field.nullsfirst.Bool = true return field } // WithPrefix returns a new Field that with the given prefix. func (field StringField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field StringField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field StringField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // In returns a 'field IN (value)' Predicate. The value can be a slice, which // corresponds to the expression 'field IN (x, y, z)'. func (field StringField) In(value any) Predicate { return In(field, value) } // In returns a 'field NOT IN (value)' Predicate. The value can be a slice, // which corresponds to the expression 'field NOT IN (x, y, z)'. func (field StringField) NotIn(value any) Predicate { return NotIn(field, value) } // Eq returns a 'field = value' Predicate. func (field StringField) Eq(value String) Predicate { return Eq(field, value) } // Ne returns a 'field <> value' Predicate. func (field StringField) Ne(value String) Predicate { return Ne(field, value) } // Lt returns a 'field < value' Predicate. func (field StringField) Lt(value String) Predicate { return Lt(field, value) } // Le returns a 'field <= value' Predicate. func (field StringField) Le(value String) Predicate { return Le(field, value) } // Gt returns a 'field > value' Predicate. func (field StringField) Gt(value String) Predicate { return Gt(field, value) } // Ge returns a 'field >= value' Predicate. func (field StringField) Ge(value String) Predicate { return Ge(field, value) } // EqString returns a 'field = str' Predicate. func (field StringField) EqString(str string) Predicate { return Eq(field, str) } // NeString returns a 'field <> str' Predicate. func (field StringField) NeString(str string) Predicate { return Ne(field, str) } // LtString returns a 'field < str' Predicate. func (field StringField) LtString(str string) Predicate { return Lt(field, str) } // LeString returns a 'field <= str' Predicate. func (field StringField) LeString(str string) Predicate { return Le(field, str) } // GtString returns a 'field > str' Predicate. func (field StringField) GtString(str string) Predicate { return Gt(field, str) } // GeString returns a 'field >= str' Predicate. func (field StringField) GeString(str string) Predicate { return Ge(field, str) } // LikeString returns a 'field LIKE str' Predicate. func (field StringField) LikeString(str string) Predicate { return Expr("{} LIKE {}", field, str) } // NotLikeString returns a 'field NOT LIKE str' Predicate. func (field StringField) NotLikeString(str string) Predicate { return Expr("{} NOT LIKE {}", field, str) } // ILikeString returns a 'field ILIKE str' Predicate. func (field StringField) ILikeString(str string) Predicate { return Expr("{} ILIKE {}", field, str) } // NotILikeString returns a 'field NOT ILIKE str' Predicate. func (field StringField) NotILikeString(str string) Predicate { return Expr("{} NOT ILIKE {}", field, str) } // Set returns an Assignment assigning the value to the field. func (field StringField) Set(value any) Assignment { return Set(field, value) } // Setf returns an Assignment assigning an expression to the field. func (field StringField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // SetString returns an Assignment assigning a string to the field. func (field StringField) SetString(str string) Assignment { return Set(field, str) } // GetAlias returns the alias of the StringField. func (field StringField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field StringField) IsField() {} // IsString implements the String interface. func (field StringField) IsString() {} // TimeField represents an SQL time field. type TimeField struct { table TableStruct name string alias string desc sql.NullBool nullsfirst sql.NullBool } var _ interface { Field Time WithPrefix(string) Field } = (*TimeField)(nil) // NewTimeField returns a new TimeField. func NewTimeField(name string, tbl TableStruct) TimeField { return TimeField{table: tbl, name: name} } // WriteSQL implements the SQLWriter interface. func (field TimeField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) writeFieldOrder(ctx, dialect, buf, args, params, field.desc, field.nullsfirst) return nil } // As returns a new TimeField with the given alias. func (field TimeField) As(alias string) TimeField { field.alias = alias return field } // Asc returns a new TimeField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field ASC'. func (field TimeField) Asc() TimeField { field.desc.Valid = true field.desc.Bool = false return field } // Desc returns a new TimeField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field DESC'. func (field TimeField) Desc() TimeField { field.desc.Valid = true field.desc.Bool = true return field } // NullsLast returns a new TimeField indicating that it should be ordered // with nulls last i.e. 'ORDER BY field NULLS LAST'. func (field TimeField) NullsLast() TimeField { field.nullsfirst.Valid = true field.nullsfirst.Bool = false return field } // NullsFirst returns a new TimeField indicating that it should be ordered // with nulls first i.e. 'ORDER BY field NULLS FIRST'. func (field TimeField) NullsFirst() TimeField { field.nullsfirst.Valid = true field.nullsfirst.Bool = true return field } // WithPrefix returns a new Field that with the given prefix. func (field TimeField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field TimeField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field TimeField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // In returns a 'field IN (value)' Predicate. The value can be a slice, which // corresponds to the expression 'field IN (x, y, z)'. func (field TimeField) In(value any) Predicate { return In(field, value) } // NotIn returns a 'field NOT IN (value)' Predicate. The value can be a slice, // which corresponds to the expression 'field NOT IN (x, y, z)'. func (field TimeField) NotIn(value any) Predicate { return NotIn(field, value) } // Eq returns a 'field = value' Predicate. func (field TimeField) Eq(value Time) Predicate { return Eq(field, value) } // Ne returns a 'field <> value' Predicate. func (field TimeField) Ne(value Time) Predicate { return Ne(field, value) } // Lt returns a 'field < value' Predicate. func (field TimeField) Lt(value Time) Predicate { return Lt(field, value) } // Le returns a 'field <= value' Predicate. func (field TimeField) Le(value Time) Predicate { return Le(field, value) } // Gt returns a 'field > value' Predicate. func (field TimeField) Gt(value Time) Predicate { return Gt(field, value) } // Ge returns a 'field >= value' Predicate. func (field TimeField) Ge(value Time) Predicate { return Ge(field, value) } // EqTime returns a 'field = t' Predicate. func (field TimeField) EqTime(t time.Time) Predicate { return Eq(field, t) } // NeTime returns a 'field <> t' Predicate. func (field TimeField) NeTime(t time.Time) Predicate { return Ne(field, t) } // LtTime returns a 'field < t' Predicate. func (field TimeField) LtTime(t time.Time) Predicate { return Lt(field, t) } // LeTime returns a 'field <= t' Predicate. func (field TimeField) LeTime(t time.Time) Predicate { return Le(field, t) } // GtTime returns a 'field > t' Predicate. func (field TimeField) GtTime(t time.Time) Predicate { return Gt(field, t) } // GeTime returns a 'field >= t' Predicate. func (field TimeField) GeTime(t time.Time) Predicate { return Ge(field, t) } // Set returns an Assignment assigning the value to the field. func (field TimeField) Set(value any) Assignment { return Set(field, value) } // Setf returns an Assignment assigning an expression to the field. func (field TimeField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // SetTime returns an Assignment assigning a time.Time to the field. func (field TimeField) SetTime(t time.Time) Assignment { return Set(field, t) } // GetAlias returns the alias of the TimeField. func (field TimeField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field TimeField) IsField() {} // IsTime implements the Time interface. func (field TimeField) IsTime() {} // Timestamp is as a replacement for sql.NullTime but with the following // enhancements: // // 1. Timestamp.Value() returns an int64 unix timestamp if the dialect is // SQLite, otherwise it returns a time.Time (similar to sql.NullTime). // // 2. Timestamp.Scan() additionally supports scanning from int64 and text // (string/[]byte) values on top of what sql.NullTime already supports. The // following text timestamp formats are supported: // // var timestampFormats = []string{ // "2006-01-02 15:04:05.999999999-07:00", // "2006-01-02T15:04:05.999999999-07:00", // "2006-01-02 15:04:05.999999999", // "2006-01-02T15:04:05.999999999", // "2006-01-02 15:04:05", // "2006-01-02T15:04:05", // "2006-01-02 15:04", // "2006-01-02T15:04", // "2006-01-02", // } type Timestamp struct { time.Time Valid bool dialect string } // NewTimestamp creates a new Timestamp from a time.Time. func NewTimestamp(t time.Time) Timestamp { return Timestamp{Time: t, Valid: true} } // copied from https://pkg.go.dev/github.com/mattn/go-sqlite3#pkg-variables var timestampFormats = []string{ "2006-01-02 15:04:05.999999999-07:00", "2006-01-02T15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999", "2006-01-02 15:04:05", "2006-01-02T15:04:05", "2006-01-02 15:04", "2006-01-02T15:04", "2006-01-02", } // Scan implements the sql.Scanner interface. It additionally supports scanning // from int64 and text (string/[]byte) values on top of what sql.NullTime // already supports. The following text timestamp formats are supported: // // var timestampFormats = []string{ // "2006-01-02 15:04:05.999999999-07:00", // "2006-01-02T15:04:05.999999999-07:00", // "2006-01-02 15:04:05.999999999", // "2006-01-02T15:04:05.999999999", // "2006-01-02 15:04:05", // "2006-01-02T15:04:05", // "2006-01-02 15:04", // "2006-01-02T15:04", // "2006-01-02", // } func (ts *Timestamp) Scan(value any) error { if value == nil { ts.Time, ts.Valid = time.Time{}, false return nil } // int64 and string handling copied from // https://github.com/mattn/go-sqlite3/issues/748#issuecomment-538643131 switch value := value.(type) { case int64: // Assume a millisecond unix timestamp if it's 13 digits -- too // large to be a reasonable timestamp in seconds. if value > 1e12 || value < -1e12 { value *= int64(time.Millisecond) // convert ms to nsec ts.Time = time.Unix(0, value) } else { ts.Time = time.Unix(value, 0) } ts.Valid = true return nil case string: if len(value) == 0 { ts.Time, ts.Valid = time.Time{}, false return nil } var err error var timeVal time.Time value = strings.TrimSuffix(value, "Z") for _, format := range timestampFormats { if timeVal, err = time.ParseInLocation(format, value, time.UTC); err == nil { ts.Time, ts.Valid = timeVal, true return nil } } return fmt.Errorf("could not convert %q into time", value) case []byte: if len(value) == 0 { ts.Time, ts.Valid = time.Time{}, false return nil } var err error var timeVal time.Time value = bytes.TrimSuffix(value, []byte("Z")) for _, format := range timestampFormats { if timeVal, err = time.ParseInLocation(format, string(value), time.UTC); err == nil { ts.Time, ts.Valid = timeVal, true return nil } } return fmt.Errorf("could not convert %q into time", value) default: var nulltime sql.NullTime err := nulltime.Scan(value) if err != nil { return err } ts.Time, ts.Valid = nulltime.Time, nulltime.Valid return nil } } // Value implements the driver.Valuer interface. It returns an int64 unix // timestamp if the dialect is SQLite, otherwise it returns a time.Time // (similar to sql.NullTime). func (ts Timestamp) Value() (driver.Value, error) { if !ts.Valid { return nil, nil } if ts.dialect == DialectSQLite { return ts.Time.UTC().Unix(), nil } return ts.Time, nil } // DialectValuer implements the DialectValuer interface. func (ts Timestamp) DialectValuer(dialect string) (driver.Valuer, error) { ts.dialect = dialect return ts, nil } // UUIDField represents an SQL UUID field. type UUIDField struct { table TableStruct name string alias string desc sql.NullBool nullsfirst sql.NullBool } var _ interface { Field UUID WithPrefix(string) Field } = (*UUIDField)(nil) // NewUUIDField returns a new UUIDField. func NewUUIDField(name string, tbl TableStruct) UUIDField { return UUIDField{table: tbl, name: name} } // WriteSQL implements the SQLWriter interface. func (field UUIDField) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { writeFieldIdentifier(ctx, dialect, buf, args, params, field.table, field.name) writeFieldOrder(ctx, dialect, buf, args, params, field.desc, field.nullsfirst) return nil } // As returns a new UUIDField with the given alias. func (field UUIDField) As(alias string) UUIDField { field.alias = alias return field } // Asc returns a new UUIDField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field ASC'. func (field UUIDField) Asc() UUIDField { field.desc.Valid = true field.desc.Bool = false return field } // Desc returns a new UUIDField indicating that it should be ordered in ascending // order i.e. 'ORDER BY field DESC'. func (field UUIDField) Desc() UUIDField { field.desc.Valid = true field.desc.Bool = true return field } // NullsLast returns a new UUIDField indicating that it should be ordered // with nulls last i.e. 'ORDER BY field NULLS LAST'. func (field UUIDField) NullsLast() UUIDField { field.nullsfirst.Valid = true field.nullsfirst.Bool = false return field } // NullsFirst returns a new UUIDField indicating that it should be ordered // with nulls first i.e. 'ORDER BY field NULLS FIRST'. func (field UUIDField) NullsFirst() UUIDField { field.nullsfirst.Valid = true field.nullsfirst.Bool = true return field } // WithPrefix returns a new Field that with the given prefix. func (field UUIDField) WithPrefix(prefix string) Field { field.table.alias = "" field.table.name = prefix return field } // IsNull returns a 'field IS NULL' Predicate. func (field UUIDField) IsNull() Predicate { return Expr("{} IS NULL", field) } // IsNotNull returns a 'field IS NOT NULL' Predicate. func (field UUIDField) IsNotNull() Predicate { return Expr("{} IS NOT NULL", field) } // In returns a 'field IN (value)' Predicate. The value can be a slice, which // corresponds to the expression 'field IN (x, y, z)'. func (field UUIDField) In(value any) Predicate { return In(field, value) } // NotIn returns a 'field NOT IN (value)' Predicate. The value can be a slice, // which corresponds to the expression 'field NOT IN (x, y, z)'. func (field UUIDField) NotIn(value any) Predicate { return NotIn(field, value) } // Eq returns a 'field = value' Predicate. func (field UUIDField) Eq(value any) Predicate { return Eq(field, value) } // Ne returns a 'field <> value' Predicate. func (field UUIDField) Ne(value any) Predicate { return Ne(field, value) } // EqUUID returns a 'field = value' Predicate. The value is wrapped in // UUIDValue(). func (field UUIDField) EqUUID(value any) Predicate { return Eq(field, UUIDValue(value)) } // NeUUID returns a 'field <> value' Predicate. The value is wrapped in // UUIDValue(). func (field UUIDField) NeUUID(value any) Predicate { return Ne(field, UUIDValue(value)) } // Set returns an Assignment assigning the value to the field. func (field UUIDField) Set(value any) Assignment { return Set(field, value) } // SetUUID returns an Assignment assigning the value to the field. It wraps the // value in UUIDValue(). func (field UUIDField) SetUUID(value any) Assignment { return Set(field, UUIDValue(value)) } // Set returns an Assignment assigning the value to the field. func (field UUIDField) Setf(format string, values ...any) Assignment { return Setf(field, format, values...) } // GetAlias returns the alias of the UUIDField. func (field UUIDField) GetAlias() string { return field.alias } // IsField implements the Field interface. func (field UUIDField) IsField() {} // IsUUID implements the UUID interface. func (field UUIDField) IsUUID() {} // New instantiates a new table struct with the given alias. Passing in an // empty string is equivalent to giving no alias to the table. func New[T Table](alias string) T { var tbl T ptrvalue := reflect.ValueOf(&tbl) value := reflect.Indirect(ptrvalue) typ := value.Type() if typ.Kind() != reflect.Struct { return tbl } if value.NumField() == 0 { return tbl } firstfield := value.Field(0) firstfieldType := typ.Field(0) if !firstfield.CanInterface() { return tbl } _, ok := firstfield.Interface().(TableStruct) if !ok { return tbl } if !firstfield.CanSet() { return tbl } tag := firstfieldType.Tag.Get("sq") tableSchema, tableName, ok := strings.Cut(tag, ".") if !ok { tableSchema, tableName = "", tableSchema } if tableName == "" { tableName = strings.ToLower(typ.Name()) } tableStruct := NewTableStruct(tableSchema, tableName, alias) firstfield.Set(reflect.ValueOf(tableStruct)) for i := 1; i < value.NumField(); i++ { v := value.Field(i) if !v.CanInterface() { continue } if !v.CanSet() { continue } fieldType := typ.Field(i) name := fieldType.Tag.Get("sq") if name == "" { name = strings.ToLower(fieldType.Name) } switch v.Interface().(type) { case AnyField: v.Set(reflect.ValueOf(NewAnyField(name, tableStruct))) case ArrayField: v.Set(reflect.ValueOf(NewArrayField(name, tableStruct))) case BinaryField: v.Set(reflect.ValueOf(NewBinaryField(name, tableStruct))) case BooleanField: v.Set(reflect.ValueOf(NewBooleanField(name, tableStruct))) case EnumField: v.Set(reflect.ValueOf(NewEnumField(name, tableStruct))) case JSONField: v.Set(reflect.ValueOf(NewJSONField(name, tableStruct))) case NumberField: v.Set(reflect.ValueOf(NewNumberField(name, tableStruct))) case StringField: v.Set(reflect.ValueOf(NewStringField(name, tableStruct))) case TimeField: v.Set(reflect.ValueOf(NewTimeField(name, tableStruct))) case UUIDField: v.Set(reflect.ValueOf(NewUUIDField(name, tableStruct))) } } return tbl } func writeFieldIdentifier(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, table TableStruct, fieldName string) { tableQualifier, _, _ := strings.Cut(table.alias, "(") tableQualifier = strings.TrimRight(tableQualifier, " ") if tableQualifier == "" { tableQualifier = table.name } if tableQualifier != "" { buf.WriteString(QuoteIdentifier(dialect, tableQualifier) + ".") } buf.WriteString(QuoteIdentifier(dialect, fieldName)) } func writeFieldOrder(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, desc, nullsfirst sql.NullBool) { if desc.Valid { if desc.Bool { buf.WriteString(" DESC") } else { buf.WriteString(" ASC") } } if nullsfirst.Valid { if nullsfirst.Bool { buf.WriteString(" NULLS FIRST") } else { buf.WriteString(" NULLS LAST") } } } ================================================ FILE: fields_test.go ================================================ package sq import ( "bytes" "context" "database/sql/driver" "strings" "testing" "time" "github.com/bokwoon95/sq/internal/testutil" "github.com/google/uuid" ) func TestTableStruct(t *testing.T) { t.Parallel() tbl := NewTableStruct("public", "users", "u") if diff := testutil.Diff(tbl.GetAlias(), "u"); diff != "" { t.Error(testutil.Callers(), diff) } gotQuery, _, err := ToSQL("", tbl, nil) if err != nil { t.Error(testutil.Callers(), err) } if diff := testutil.Diff(gotQuery, "public.users"); diff != "" { t.Error(testutil.Callers(), diff) } } func TestArrayField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewArrayField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("alias brackets removed", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "t (id, name, email)") f1 := NewArrayField("field", tbl) gotQuery, _, err := ToSQL("", f1, nil) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(tbl.GetAlias(), "t (id, name, email)"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(gotQuery, "t.field"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewArrayField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "Set", item: field.SetArray([]int{1, 2, 3}), wantQuery: "field = ?", wantArgs: []any{`[1,2,3]`}, }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestBinaryField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewBinaryField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewBinaryField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Asc NullsLast", item: field.Asc().NullsLast(), wantQuery: "tbl.field ASC NULLS LAST", }, { description: "Desc NullsFirst", item: field.Desc().NullsFirst(), wantQuery: "tbl.field DESC NULLS FIRST", }, { description: "Eq", item: field.Eq(field), wantQuery: "tbl.field = tbl.field", }, { description: "Ne", item: field.Ne(field), wantQuery: "tbl.field <> tbl.field", }, { description: "EqBytes", item: field.EqBytes([]byte{0xff, 0xff}), wantQuery: "tbl.field = ?", wantArgs: []any{[]byte{0xff, 0xff}}, }, { description: "NeBytes", item: field.NeBytes([]byte{0xff, 0xff}), wantQuery: "tbl.field <> ?", wantArgs: []any{[]byte{0xff, 0xff}}, }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }, { description: "SetBytes", item: field.SetBytes([]byte{0xff, 0xff}), wantQuery: "field = ?", wantArgs: []any{[]byte{0xff, 0xff}}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestBooleanField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewBooleanField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewBooleanField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Asc NullsLast", item: field.Asc().NullsLast(), wantQuery: "tbl.field ASC NULLS LAST", }, { description: "Desc NullsFirst", item: field.Desc().NullsFirst(), wantQuery: "tbl.field DESC NULLS FIRST", }, { description: "Eq", item: field.Eq(field), wantQuery: "tbl.field = tbl.field", }, { description: "Ne", item: field.Ne(field), wantQuery: "tbl.field <> tbl.field", }, { description: "EqBytes", item: field.EqBool(true), wantQuery: "tbl.field = ?", wantArgs: []any{true}, }, { description: "NeBytes", item: field.NeBool(true), wantQuery: "tbl.field <> ?", wantArgs: []any{true}, }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }, { description: "SetBool", item: field.SetBool(true), wantQuery: "field = ?", wantArgs: []any{true}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestCustomField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewAnyField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewAnyField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Asc NullsLast", item: field.Asc().NullsLast(), wantQuery: "tbl.field ASC NULLS LAST", }, { description: "Desc NullsFirst", item: field.Desc().NullsFirst(), wantQuery: "tbl.field DESC NULLS FIRST", }, { description: "In", item: field.In(RowValue{1, 2, 3}), wantQuery: "tbl.field IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "NotIn", item: field.NotIn(RowValue{1, 2, 3}), wantQuery: "tbl.field NOT IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "Eq", item: field.Eq(field), wantQuery: "tbl.field = tbl.field", }, { description: "Ne", item: field.Ne(field), wantQuery: "tbl.field <> tbl.field", }, { description: "Lt", item: field.Lt(field), wantQuery: "tbl.field < tbl.field", }, { description: "Le", item: field.Le(field), wantQuery: "tbl.field <= tbl.field", }, { description: "Gt", item: field.Gt(field), wantQuery: "tbl.field > tbl.field", }, { description: "Ge", item: field.Ge(field), wantQuery: "tbl.field >= tbl.field", }, { description: "Expr", item: field.Expr("&& ARRAY[1, 2, 3]"), wantQuery: "tbl.field && ARRAY[1, 2, 3]", }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set Self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestEnumField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewEnumField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewEnumField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "In", item: field.In(RowValue{1, 2, 3}), wantQuery: "tbl.field IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "NotIn", item: field.NotIn(RowValue{1, 2, 3}), wantQuery: "tbl.field NOT IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "Eq", item: field.Eq(field), wantQuery: "tbl.field = tbl.field", }, { description: "Ne", item: field.Ne(field), wantQuery: "tbl.field <> tbl.field", }, { description: "EqEnum", item: field.Eq(Monday), wantQuery: "tbl.field = ?", wantArgs: []any{"Monday"}, }, { description: "NeEnum", item: field.Ne(Monday), wantQuery: "tbl.field <> ?", wantArgs: []any{"Monday"}, }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "SetEnum", item: field.Set(Monday), wantQuery: "field = ?", wantArgs: []any{"Monday"}, }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestJSONField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewJSONField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) type Data struct { ID int `json:"id"` Name string `json:"name"` } field := NewJSONField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "Set", item: field.SetJSON([]int{1, 2, 3}), wantQuery: "field = ?", wantArgs: []any{`[1,2,3]`}, }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestNumberField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewNumberField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewNumberField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Asc NullsLast", item: field.Asc().NullsLast(), wantQuery: "tbl.field ASC NULLS LAST", }, { description: "Desc NullsFirst", item: field.Desc().NullsFirst(), wantQuery: "tbl.field DESC NULLS FIRST", }, { description: "In", item: field.In(RowValue{1, 2, 3}), wantQuery: "tbl.field IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "NotIn", item: field.NotIn(RowValue{1, 2, 3}), wantQuery: "tbl.field NOT IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "Eq", item: field.Eq(field), wantQuery: "tbl.field = tbl.field", }, { description: "Ne", item: field.Ne(field), wantQuery: "tbl.field <> tbl.field", }, { description: "Lt", item: field.Lt(field), wantQuery: "tbl.field < tbl.field", }, { description: "Le", item: field.Le(field), wantQuery: "tbl.field <= tbl.field", }, { description: "Gt", item: field.Gt(field), wantQuery: "tbl.field > tbl.field", }, { description: "Ge", item: field.Ge(field), wantQuery: "tbl.field >= tbl.field", }, { description: "EqInt", item: field.EqInt(3), wantQuery: "tbl.field = ?", wantArgs: []any{3}, }, { description: "NeInt", item: field.NeInt(3), wantQuery: "tbl.field <> ?", wantArgs: []any{3}, }, { description: "LtInt", item: field.LtInt(3), wantQuery: "tbl.field < ?", wantArgs: []any{3}, }, { description: "LeInt", item: field.LeInt(3), wantQuery: "tbl.field <= ?", wantArgs: []any{3}, }, { description: "GtInt", item: field.GtInt(3), wantQuery: "tbl.field > ?", wantArgs: []any{3}, }, { description: "GeInt", item: field.GeInt(3), wantQuery: "tbl.field >= ?", wantArgs: []any{3}, }, { description: "EqInt64", item: field.EqInt64(5), wantQuery: "tbl.field = ?", wantArgs: []any{int64(5)}, }, { description: "NeInt64", item: field.NeInt64(5), wantQuery: "tbl.field <> ?", wantArgs: []any{int64(5)}, }, { description: "LtInt64", item: field.LtInt64(5), wantQuery: "tbl.field < ?", wantArgs: []any{int64(5)}, }, { description: "LeInt64", item: field.LeInt64(5), wantQuery: "tbl.field <= ?", wantArgs: []any{int64(5)}, }, { description: "GtInt64", item: field.GtInt64(5), wantQuery: "tbl.field > ?", wantArgs: []any{int64(5)}, }, { description: "GeInt64", item: field.GeInt64(5), wantQuery: "tbl.field >= ?", wantArgs: []any{int64(5)}, }, { description: "EqFloat64", item: field.EqFloat64(7.11), wantQuery: "tbl.field = ?", wantArgs: []any{float64(7.11)}, }, { description: "NeFloat64", item: field.NeFloat64(7.11), wantQuery: "tbl.field <> ?", wantArgs: []any{float64(7.11)}, }, { description: "LtFloat64", item: field.LtFloat64(7.11), wantQuery: "tbl.field < ?", wantArgs: []any{float64(7.11)}, }, { description: "LeFloat64", item: field.LeFloat64(7.11), wantQuery: "tbl.field <= ?", wantArgs: []any{float64(7.11)}, }, { description: "GtFloat64", item: field.GtFloat64(7.11), wantQuery: "tbl.field > ?", wantArgs: []any{float64(7.11)}, }, { description: "GeFloat64", item: field.GeFloat64(7.11), wantQuery: "tbl.field >= ?", wantArgs: []any{float64(7.11)}, }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }, { description: "SetInt", item: field.SetInt(3), wantQuery: "field = ?", wantArgs: []any{3}, }, { description: "SetInt64", item: field.SetInt64(5), wantQuery: "field = ?", wantArgs: []any{int64(5)}, }, { description: "SetFloat64", item: field.SetFloat64(7.11), wantQuery: "field = ?", wantArgs: []any{float64(7.11)}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestStringField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewStringField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewStringField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Asc NullsLast", item: field.Asc().NullsLast(), wantQuery: "tbl.field ASC NULLS LAST", }, { description: "Desc NullsFirst", item: field.Desc().NullsFirst(), wantQuery: "tbl.field DESC NULLS FIRST", }, { description: "In", item: field.In(RowValue{1, 2, 3}), wantQuery: "tbl.field IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "NotIn", item: field.NotIn(RowValue{1, 2, 3}), wantQuery: "tbl.field NOT IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "Eq", item: field.Eq(field), wantQuery: "tbl.field = tbl.field", }, { description: "Ne", item: field.Ne(field), wantQuery: "tbl.field <> tbl.field", }, { description: "Lt", item: field.Lt(field), wantQuery: "tbl.field < tbl.field", }, { description: "Le", item: field.Le(field), wantQuery: "tbl.field <= tbl.field", }, { description: "Gt", item: field.Gt(field), wantQuery: "tbl.field > tbl.field", }, { description: "Ge", item: field.Ge(field), wantQuery: "tbl.field >= tbl.field", }, { description: "EqString", item: field.EqString("lorem ipsum"), wantQuery: "tbl.field = ?", wantArgs: []any{"lorem ipsum"}, }, { description: "NeString", item: field.NeString("lorem ipsum"), wantQuery: "tbl.field <> ?", wantArgs: []any{"lorem ipsum"}, }, { description: "LtString", item: field.LtString("lorem ipsum"), wantQuery: "tbl.field < ?", wantArgs: []any{"lorem ipsum"}, }, { description: "LeString", item: field.LeString("lorem ipsum"), wantQuery: "tbl.field <= ?", wantArgs: []any{"lorem ipsum"}, }, { description: "GtString", item: field.GtString("lorem ipsum"), wantQuery: "tbl.field > ?", wantArgs: []any{"lorem ipsum"}, }, { description: "GeString", item: field.GeString("lorem ipsum"), wantQuery: "tbl.field >= ?", wantArgs: []any{"lorem ipsum"}, }, { description: "LikeString", item: field.LikeString("lorem%"), wantQuery: "tbl.field LIKE ?", wantArgs: []any{"lorem%"}, }, { description: "NotLikeString", item: field.NotLikeString("lorem%"), wantQuery: "tbl.field NOT LIKE ?", wantArgs: []any{"lorem%"}, }, { description: "ILikeString", item: field.ILikeString("lorem%"), wantQuery: "tbl.field ILIKE ?", wantArgs: []any{"lorem%"}, }, { description: "NotILikeString", item: field.NotILikeString("lorem%"), wantQuery: "tbl.field NOT ILIKE ?", wantArgs: []any{"lorem%"}, }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }, { description: "SetString", item: field.SetString("lorem ipsum"), wantQuery: "field = ?", wantArgs: []any{"lorem ipsum"}, }, { description: "postgres Collate", item: field.Collate("c").LtString("lorem ipsum"), dialect: DialectPostgres, wantQuery: `tbl.field COLLATE "c" < $1`, wantArgs: []any{"lorem ipsum"}, }, { description: "mysql Collate", item: field.Collate("latin1_swedish_ci").LtString("lorem ipsum"), wantQuery: "tbl.field COLLATE latin1_swedish_ci < ?", wantArgs: []any{"lorem ipsum"}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestTimeField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewTimeField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewTimeField("field", NewTableStruct("", "tbl", "")) zeroTime := time.Unix(0, 0).UTC() tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Asc NullsLast", item: field.Asc().NullsLast(), wantQuery: "tbl.field ASC NULLS LAST", }, { description: "Desc NullsFirst", item: field.Desc().NullsFirst(), wantQuery: "tbl.field DESC NULLS FIRST", }, { description: "In", item: field.In(RowValue{1, 2, 3}), wantQuery: "tbl.field IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "NotIn", item: field.NotIn(RowValue{1, 2, 3}), wantQuery: "tbl.field NOT IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "Eq", item: field.Eq(field), wantQuery: "tbl.field = tbl.field", }, { description: "Ne", item: field.Ne(field), wantQuery: "tbl.field <> tbl.field", }, { description: "Lt", item: field.Lt(field), wantQuery: "tbl.field < tbl.field", }, { description: "Le", item: field.Le(field), wantQuery: "tbl.field <= tbl.field", }, { description: "Gt", item: field.Gt(field), wantQuery: "tbl.field > tbl.field", }, { description: "Ge", item: field.Ge(field), wantQuery: "tbl.field >= tbl.field", }, { description: "EqTime", item: field.EqTime(zeroTime), wantQuery: "tbl.field = ?", wantArgs: []any{zeroTime}, }, { description: "NeTime", item: field.NeTime(zeroTime), wantQuery: "tbl.field <> ?", wantArgs: []any{zeroTime}, }, { description: "LtTime", item: field.LtTime(zeroTime), wantQuery: "tbl.field < ?", wantArgs: []any{zeroTime}, }, { description: "LeTime", item: field.LeTime(zeroTime), wantQuery: "tbl.field <= ?", wantArgs: []any{zeroTime}, }, { description: "GtTime", item: field.GtTime(zeroTime), wantQuery: "tbl.field > ?", wantArgs: []any{zeroTime}, }, { description: "GeTime", item: field.GeTime(zeroTime), wantQuery: "tbl.field >= ?", wantArgs: []any{zeroTime}, }, { description: "Set", item: field.Set(Expr("NULL")), wantQuery: "field = NULL", }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }, { description: "SetTime", item: field.SetTime(zeroTime), wantQuery: "field = ?", wantArgs: []any{zeroTime}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestTimestamp(t *testing.T) { t.Run("Value", func(t *testing.T) { type TestTable struct { description string timestamp Timestamp wantValue driver.Value } tests := []TestTable{{ description: "empty", timestamp: Timestamp{}, wantValue: nil, }, { description: "sqlite", timestamp: Timestamp{ Valid: true, Time: time.Unix(1, 0), dialect: DialectSQLite, }, wantValue: time.Unix(1, 0).Unix(), }, { description: "non-sqlite", timestamp: Timestamp{ Valid: true, Time: time.Unix(1, 0), dialect: DialectPostgres, }, wantValue: time.Unix(1, 0), }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() gotValue, err := tt.timestamp.Value() if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(gotValue, tt.wantValue); diff != "" { t.Error(testutil.Callers(), diff) } }) } }) t.Run("Scan", func(t *testing.T) { type TestTable struct { description string value any wantTimestamp Timestamp } parseTime := func(value string) time.Time { value = strings.TrimSuffix(value, "Z") for _, format := range timestampFormats { if timeVal, err := time.ParseInLocation(format, value, time.UTC); err == nil { return timeVal } } t.Fatalf(testutil.Callers()+" could not convert %q into time", value) return time.Time{} } tests := []TestTable{{ description: "empty", value: nil, wantTimestamp: Timestamp{}, }, { description: "empty string", value: "", wantTimestamp: Timestamp{}, }, { description: "empty bytes", value: []byte{}, wantTimestamp: Timestamp{}, }, { description: "time.Time", value: time.Unix(1, 0), wantTimestamp: NewTimestamp(time.Unix(1, 0)), }, { description: "time.Time", value: time.Unix(1, 0), wantTimestamp: NewTimestamp(time.Unix(1, 0)), }, { description: "2006-01-02 15:04:05", value: "2006-01-02 15:04:05", wantTimestamp: NewTimestamp(parseTime("2006-01-02 15:04:05")), }, { description: "2006-01-02 15:04:05Z", value: []byte("2006-01-02 15:04:05Z"), wantTimestamp: NewTimestamp(parseTime("2006-01-02 15:04:05Z")), }, { description: "2006-01-02 15:04:05-07:00", value: "2006-01-02 15:04:05-07:00", wantTimestamp: NewTimestamp(parseTime("2006-01-02 15:04:05-07:00")), }, { description: "2006-01-02 15:04:05.9999", value: []byte("2006-01-02 15:04:05.9999"), wantTimestamp: NewTimestamp(parseTime("2006-01-02 15:04:05.9999")), }, { description: "int64 seconds", value: int64(123456), wantTimestamp: NewTimestamp(time.Unix(123456, 0)), }, { description: "int64 milliseconds", value: int64(1e12 + 1), wantTimestamp: NewTimestamp(time.Unix(0, int64(1e12+1)*int64(time.Millisecond))), }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() var gotTimestamp Timestamp err := gotTimestamp.Scan(tt.value) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(gotTimestamp, tt.wantTimestamp); diff != "" { t.Error(testutil.Callers(), diff) } }) } }) } func TestUUIDField(t *testing.T) { t.Run("basic", func(t *testing.T) { tbl := NewTableStruct("", "tbl", "") f1 := NewUUIDField("field", tbl).As("f") if diff := testutil.Diff(f1.GetAlias(), "f"); diff != "" { t.Error(testutil.Callers(), diff) } }) field := NewUUIDField("field", NewTableStruct("", "tbl", "")) tests := []TestTable{{ description: "IsNull", item: field.IsNull(), wantQuery: "tbl.field IS NULL", }, { description: "IsNotNull", item: field.IsNotNull(), wantQuery: "tbl.field IS NOT NULL", }, { description: "Asc NullsLast", item: field.Asc().NullsLast(), wantQuery: "tbl.field ASC NULLS LAST", }, { description: "Desc NullsFirst", item: field.Desc().NullsFirst(), wantQuery: "tbl.field DESC NULLS FIRST", }, { description: "In", item: field.In(RowValue{1, 2, 3}), wantQuery: "tbl.field IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "NotIn", item: field.NotIn(RowValue{1, 2, 3}), wantQuery: "tbl.field NOT IN (?, ?, ?)", wantArgs: []any{1, 2, 3}, }, { description: "Eq", item: field.Eq(field), wantQuery: "tbl.field = tbl.field", }, { description: "Ne", item: field.Ne(field), wantQuery: "tbl.field <> tbl.field", }, { description: "EqUUID", item: func() any { id, err := uuid.Parse("ffffffff-ffff-ffff-ffff-ffffffffffff") if err != nil { t.Fatal(testutil.Callers(), err) } return field.EqUUID(id) }(), wantQuery: "tbl.field = ?", wantArgs: []any{[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, }, { description: "NeUUID", item: func() any { id, err := uuid.Parse("ffffffff-ffff-ffff-ffff-ffffffffffff") if err != nil { t.Fatal(testutil.Callers(), err) } return field.NeUUID(id) }(), wantQuery: "tbl.field <> ?", wantArgs: []any{[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, }, { description: "Eq id", item: func() any { id, err := uuid.Parse("ffffffff-ffff-ffff-ffff-ffffffffffff") if err != nil { t.Fatal(testutil.Callers(), err) } return field.Eq(id) }(), wantQuery: "tbl.field = ?", wantArgs: []any{"ffffffff-ffff-ffff-ffff-ffffffffffff"}, }, { description: "Ne id", item: func() any { id, err := uuid.Parse("ffffffff-ffff-ffff-ffff-ffffffffffff") if err != nil { t.Fatal(testutil.Callers(), err) } return field.Ne(id) }(), wantQuery: "tbl.field <> ?", wantArgs: []any{"ffffffff-ffff-ffff-ffff-ffffffffffff"}, }, { description: "Set", item: func() any { id, err := uuid.Parse("ffffffff-ffff-ffff-ffff-ffffffffffff") if err != nil { t.Fatal(testutil.Callers(), err) } return field.Set(id) }(), wantQuery: "field = ?", wantArgs: []any{"ffffffff-ffff-ffff-ffff-ffffffffffff"}, }, { description: "SetUUID", item: func() any { id, err := uuid.Parse("ffffffff-ffff-ffff-ffff-ffffffffffff") if err != nil { t.Fatal(testutil.Callers(), err) } return field.SetUUID(id) }(), wantQuery: "field = ?", wantArgs: []any{[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, }, { description: "Setf", item: field.Setf("VALUES({})", field.WithPrefix("")), wantQuery: "field = VALUES(field)", }, { description: "Set EXCLUDED", item: field.Set(field.WithPrefix("EXCLUDED")), wantQuery: "field = EXCLUDED.field", }, { description: "Set self", item: field.Set(field), dialect: DialectMySQL, wantQuery: "tbl.field = tbl.field", }, { description: "Set with alias", item: field.Set(field.WithPrefix("new")), wantQuery: "field = new.field", }, { description: "Set id", item: func() any { id, err := uuid.Parse("ffffffff-ffff-ffff-ffff-ffffffffffff") if err != nil { t.Fatal(testutil.Callers(), err) } return field.Set(id) }(), wantQuery: "field = ?", wantArgs: []any{"ffffffff-ffff-ffff-ffff-ffffffffffff"}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } type DummyTable struct{} var _ Table = (*DummyTable)(nil) func (t DummyTable) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return nil } func (t DummyTable) GetAlias() string { return "" } func (t DummyTable) IsTable() {} func TestNew(t *testing.T) { t.Run("basic", func(t *testing.T) { type USER struct { TableStruct USER_ID NumberField NAME StringField EMAIL StringField } u := New[USER]("u") TestTable{ item: Queryf("SELECT {} FROM {} AS {}", Fields{u.USER_ID, u.NAME, u.EMAIL}, u, Expr(u.GetAlias())), wantQuery: "SELECT u.user_id, u.name, u.email FROM user AS u", }.assert(t) }) t.Run("name", func(t *testing.T) { type USER struct { TableStruct `sq:"User"` USER_ID NumberField `sq:"UserId"` NAME StringField `sq:"Name"` EMAIL StringField `sq:"Email"` private int } u := New[USER]("u") TestTable{ item: Queryf("SELECT {} FROM {} AS {}", Fields{u.USER_ID, u.NAME, u.EMAIL}, u, Expr(u.GetAlias())), wantQuery: `SELECT u."UserId", u."Name", u."Email" FROM "User" AS u`, }.assert(t) }) t.Run("schema", func(t *testing.T) { type USER struct { TableStruct `sq:"public.user"` USER_ID NumberField NAME StringField EMAIL StringField Public int } u := New[USER]("u") TestTable{ item: Queryf("SELECT {} FROM {} AS {}", Fields{u.USER_ID, u.NAME, u.EMAIL}, u, Expr(u.GetAlias())), wantQuery: "SELECT u.user_id, u.name, u.email FROM public.user AS u", }.assert(t) }) t.Run("first field not a struct", func(t *testing.T) { tbl := New[tmptable]("") if diff := testutil.Diff(tbl, tmptable("")); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("struct has no fields", func(t *testing.T) { tbl := New[DummyTable]("") if diff := testutil.Diff(tbl, DummyTable{}); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("first field is unexported", func(t *testing.T) { tbl := New[Expression]("") if diff := testutil.Diff(tbl, Expression{}); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("first field is not TableStruct", func(t *testing.T) { tbl := New[struct{ DummyTable }]("") if diff := testutil.Diff(tbl, struct{ DummyTable }{}); diff != "" { t.Error(testutil.Callers(), diff) } }) } ================================================ FILE: fmt.go ================================================ package sq import ( "bytes" "context" "database/sql" "database/sql/driver" "encoding/hex" "fmt" "reflect" "sort" "strconv" "strings" "time" "unicode" ) // Writef is a fmt.Sprintf-style function that will write a format string and // values slice into an Output. The only recognized placeholder is '{}'. // Placeholders can be anonymous (e.g. {}), ordinal (e.g. {1}, {2}, {3}) or // named (e.g. {name}, {email}, {age}). // // - Anonymous placeholders refer to successive values in the values slice. // Anonymous placeholders are treated like a series of incrementing ordinal // placeholders. // // - Ordinal placeholders refer to a specific value in the values slice using // 1-based indexing. // // - Named placeholders refer to their corresponding sql.NamedArg value in the // values slice. If there are multiple sql.NamedArg values with the same name, // the last one wins. // // If a value is an SQLWriter, its WriteSQL method will be called. Else if a // value is a slice, it will undergo slice expansion // (https://bokwoon.neocities.org/sq.html#value-expansion). Otherwise, the // value is added to the query args slice. func Writef(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, format string, values []any) error { return writef(ctx, dialect, buf, args, params, format, values, nil, nil) } func writef(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, format string, values []any, runningValuesIndex *int, ordinalIndex map[int]int) error { // optimized case when the format string does not contain any '{}' // placeholders if i := strings.IndexByte(format, '{'); i < 0 { buf.WriteString(format) return nil } // namedIndex tracks the indexes of the namedArgs that are inside the // values slice namedIndex := make(map[string]int) for i, value := range values { var name string switch arg := value.(type) { case sql.NamedArg: name = arg.Name case Parameter: name = arg.Name case ArrayParameter: name = arg.Name case BinaryParameter: name = arg.Name case BooleanParameter: name = arg.Name case EnumParameter: name = arg.Name case JSONParameter: name = arg.Name case NumberParameter: name = arg.Name case StringParameter: name = arg.Name case TimeParameter: name = arg.Name case UUIDParameter: name = arg.Name } if name != "" { if _, ok := namedIndex[name]; ok { return fmt.Errorf("named parameter {%s} provided more than once", name) } namedIndex[name] = i } } buf.Grow(len(format)) if runningValuesIndex == nil { n := 0 runningValuesIndex = &n } // ordinalIndex tracks the index of the ordinals that have already been // written into the args slice if ordinalIndex == nil { ordinalIndex = make(map[int]int) } // jump to each '{' character in the format string for i := strings.IndexByte(format, '{'); i >= 0; i = strings.IndexByte(format, '{') { // Unescape '{{' to '{' if i+1 <= len(format) && format[i+1] == '{' { buf.WriteString(format[:i]) buf.WriteByte('{') format = format[i+2:] continue } buf.WriteString(format[:i]) format = format[i:] // If we can't find the terminating '}' return an error j := strings.IndexByte(format, '}') if j < 0 { return fmt.Errorf("no '}' found") } paramName := format[1:j] format = format[j+1:] for _, char := range paramName { if char != '_' && !unicode.IsLetter(char) && !unicode.IsDigit(char) { return fmt.Errorf("%q is not a valid param name (only letters, digits and '_' are allowed)", paramName) } } // is it an anonymous placeholder? e.g. {} if paramName == "" { if *runningValuesIndex >= len(values) { return fmt.Errorf("too few values passed in to Writef, expected more than %d", runningValuesIndex) } value := values[*runningValuesIndex] *runningValuesIndex++ err := WriteValue(ctx, dialect, buf, args, params, value) if err != nil { return err } continue } // is it an ordinal placeholder? e.g. {1}, {2}, {3} ordinal, err := strconv.Atoi(paramName) if err == nil { err = writeOrdinalValue(ctx, dialect, buf, args, params, values, ordinal, ordinalIndex) if err != nil { return err } continue } // is it a named placeholder? e.g. {name}, {age}, {email} index, ok := namedIndex[paramName] if !ok { availableParams := make([]string, 0, len(namedIndex)) for name := range namedIndex { availableParams = append(availableParams, name) } sort.Strings(availableParams) return fmt.Errorf("named parameter {%s} not provided (available params: %s)", paramName, strings.Join(availableParams, ", ")) } value := values[index] err = WriteValue(ctx, dialect, buf, args, params, value) if err != nil { return err } } buf.WriteString(format) return nil } // WriteValue is the equivalent of Writef but for writing a single value into // the Output. func WriteValue(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, value any) error { if namedArg, ok := value.(sql.NamedArg); ok { return writeNamedArg(ctx, dialect, buf, args, params, namedArg) } if w, ok := value.(SQLWriter); ok { return w.WriteSQL(ctx, dialect, buf, args, params) } if isExpandableSlice(value) { return expandSlice(ctx, dialect, buf, args, params, value) } value, err := preprocessValue(dialect, value) if err != nil { return err } *args = append(*args, value) index := len(*args) - 1 switch dialect { case DialectPostgres, DialectSQLite: buf.WriteString("$" + strconv.Itoa(index+1)) case DialectSQLServer: buf.WriteString("@p" + strconv.Itoa(index+1)) default: buf.WriteString("?") } return nil } // QuoteIdentifier quotes an identifier if necessary using dialect-specific // quoting rules. func QuoteIdentifier(dialect string, identifier string) string { var needsQuoting bool switch identifier { case "": needsQuoting = true case "EXCLUDED", "INSERTED", "DELETED", "NEW", "OLD": needsQuoting = false default: for i, char := range identifier { if i == 0 && (char >= '0' && char <= '9') { // first character cannot be a number needsQuoting = true break } if char == '_' || (char >= '0' && char <= '9') || (char >= 'a' && char <= 'z') { continue } // If there are capital letters, the identifier is quoted to preserve // capitalization information (because databases treat capital letters // differently based on their dialect or configuration). // If the character is anything else, we also quote. In general there // may be some special characters that are allowed in unquoted // identifiers (e.g. '$'), but different databases allow different // things. We only recognize _a-z0-9 as the true standard. needsQuoting = true break } if !needsQuoting && dialect != "" { switch dialect { case DialectSQLite: _, needsQuoting = sqliteKeywords[strings.ToLower(identifier)] case DialectPostgres: _, needsQuoting = postgresKeywords[strings.ToLower(identifier)] case DialectMySQL: _, needsQuoting = mysqlKeywords[strings.ToLower(identifier)] case DialectSQLServer: _, needsQuoting = sqlserverKeywords[strings.ToLower(identifier)] } } } if !needsQuoting { return identifier } switch dialect { case DialectMySQL: return "`" + EscapeQuote(identifier, '`') + "`" case DialectSQLServer: return "[" + EscapeQuote(identifier, ']') + "]" default: return `"` + EscapeQuote(identifier, '"') + `"` } } // EscapeQuote will escape the relevant quote in a string by doubling up on it // (as per SQL rules). func EscapeQuote(str string, quote byte) string { i := strings.IndexByte(str, quote) if i < 0 { return str } var b strings.Builder b.Grow(len(str) + strings.Count(str, string(quote))) for i >= 0 { b.WriteString(str[:i]) b.WriteByte(quote) b.WriteByte(quote) if len(str[i:]) > 2 && str[i] == quote && str[i+1] == quote { str = str[i+2:] } else { str = str[i+1:] } i = strings.IndexByte(str, quote) } b.WriteString(str) return b.String() } // Sprintf will interpolate SQL args into a query string containing prepared // statement parameters. It returns an error if an argument cannot be properly // represented in SQL. This function may be vulnerable to SQL injection and // should be used for logging purposes only. func Sprintf(dialect string, query string, args []any) (string, error) { if len(args) == 0 { return query, nil } buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) buf.Grow(len(query)) namedIndices := make(map[string]int) for i, arg := range args { switch arg := arg.(type) { case sql.NamedArg: namedIndices[arg.Name] = i } } runningArgsIndex := 0 mustWriteCharAt := -1 insideStringOrIdentifier := false var openingQuote rune var paramName []rune for i, char := range query { // do we unconditionally write in the current char? if mustWriteCharAt == i { buf.WriteRune(char) continue } // are we currently inside a string or identifier? if insideStringOrIdentifier { buf.WriteRune(char) switch openingQuote { case '\'', '"', '`': // does the current char terminate the current string or identifier? if char == openingQuote { // is the next char the same as the current char, which // escapes it and prevents it from terminating the current // string or identifier? if i+1 < len(query) && rune(query[i+1]) == openingQuote { mustWriteCharAt = i + 1 } else { insideStringOrIdentifier = false } } case '[': // does the current char terminate the current string or identifier? if char == ']' { // is the next char the same as the current char, which // escapes it and prevents it from terminating the current // string or identifier? if i+1 < len(query) && query[i+1] == ']' { mustWriteCharAt = i + 1 } else { insideStringOrIdentifier = false } } } continue } // does the current char mark the start of a new string or identifier? if char == '\'' || char == '"' || (char == '`' && dialect == DialectMySQL) || (char == '[' && dialect == DialectSQLServer) { insideStringOrIdentifier = true openingQuote = char buf.WriteRune(char) continue } // are we currently inside a parameter name? if len(paramName) > 0 { // does the current char terminate the current parameter name? if char != '_' && !unicode.IsLetter(char) && !unicode.IsDigit(char) { paramValue, err := lookupParam(dialect, args, paramName, namedIndices, runningArgsIndex) if err != nil { return buf.String(), err } buf.WriteString(paramValue) buf.WriteRune(char) if len(paramName) == 1 && paramName[0] == '?' { runningArgsIndex++ } paramName = paramName[:0] } else { paramName = append(paramName, char) } continue } // does the current char mark the start of a new parameter name? if (char == '$' && (dialect == DialectSQLite || dialect == DialectPostgres)) || (char == ':' && dialect == DialectSQLite) || (char == '@' && (dialect == DialectSQLite || dialect == DialectSQLServer)) { paramName = append(paramName, char) continue } // is the current char the anonymous '?' parameter? if char == '?' && dialect != DialectPostgres { // for sqlite, just because we encounter a '?' doesn't mean it // is an anonymous param. sqlite also supports using '?' for // ordinal params (e.g. ?1, ?2, ?3) or named params (?foo, // ?bar, ?baz). Hence we treat it as an ordinal/named param // first, and handle the edge case later when it isn't. if dialect == DialectSQLite { paramName = append(paramName, char) continue } if runningArgsIndex >= len(args) { return buf.String(), fmt.Errorf("too few args provided, expected more than %d", runningArgsIndex+1) } paramValue, err := Sprint(dialect, args[runningArgsIndex]) if err != nil { return buf.String(), err } buf.WriteString(paramValue) runningArgsIndex++ continue } // if all the above questions answer false, we just write the current // char in and continue buf.WriteRune(char) } // flush the paramName buffer (to handle edge case where the query ends with a parameter name) if len(paramName) > 0 { paramValue, err := lookupParam(dialect, args, paramName, namedIndices, runningArgsIndex) if err != nil { return buf.String(), err } buf.WriteString(paramValue) } if insideStringOrIdentifier { return buf.String(), fmt.Errorf("unclosed string or identifier") } return buf.String(), nil } // Sprint is the equivalent of Sprintf but for converting a single value into // its SQL representation. func Sprint(dialect string, v any) (string, error) { const ( timestamp = "2006-01-02 15:04:05" timestampWithTimezone = "2006-01-02 15:04:05.9999999-07:00" ) switch v := v.(type) { case nil: return "NULL", nil case bool: if v { if dialect == DialectSQLServer { return "1", nil } return "TRUE", nil } if dialect == DialectSQLServer { return "0", nil } return "FALSE", nil case []byte: switch dialect { case DialectPostgres: // https://www.postgresql.org/docs/current/datatype-binary.html // (see 8.4.1. bytea Hex Format) return `'\x` + hex.EncodeToString(v) + `'`, nil case DialectSQLServer: return `0x` + hex.EncodeToString(v), nil default: return `x'` + hex.EncodeToString(v) + `'`, nil } case string: str := v i := strings.IndexAny(str, "\r\n") if i < 0 { return `'` + strings.ReplaceAll(str, `'`, `''`) + `'`, nil } var b strings.Builder if dialect == DialectMySQL || dialect == DialectSQLServer { b.WriteString("CONCAT(") } for i >= 0 { if str[:i] != "" { b.WriteString(`'` + strings.ReplaceAll(str[:i], `'`, `''`) + `'`) if dialect == DialectMySQL || dialect == DialectSQLServer { b.WriteString(", ") } else { b.WriteString(" || ") } } switch str[i] { case '\r': if dialect == DialectPostgres { b.WriteString("CHR(13)") } else { b.WriteString("CHAR(13)") } case '\n': if dialect == DialectPostgres { b.WriteString("CHR(10)") } else { b.WriteString("CHAR(10)") } } if str[i+1:] != "" { if dialect == DialectMySQL || dialect == DialectSQLServer { b.WriteString(", ") } else { b.WriteString(" || ") } } str = str[i+1:] i = strings.IndexAny(str, "\r\n") } if str != "" { b.WriteString(`'` + strings.ReplaceAll(str, `'`, `''`) + `'`) } if dialect == DialectMySQL || dialect == DialectSQLServer { b.WriteString(")") } return b.String(), nil case time.Time: if dialect == DialectPostgres || dialect == DialectSQLServer { return `'` + v.Format(timestampWithTimezone) + `'`, nil } return `'` + v.UTC().Format(timestamp) + `'`, nil case int: return strconv.FormatInt(int64(v), 10), nil case int8: return strconv.FormatInt(int64(v), 10), nil case int16: return strconv.FormatInt(int64(v), 10), nil case int32: return strconv.FormatInt(int64(v), 10), nil case int64: return strconv.FormatInt(v, 10), nil case uint: return strconv.FormatUint(uint64(v), 10), nil case uint8: return strconv.FormatUint(uint64(v), 10), nil case uint16: return strconv.FormatUint(uint64(v), 10), nil case uint32: return strconv.FormatUint(uint64(v), 10), nil case uint64: return strconv.FormatUint(v, 10), nil case float32: return strconv.FormatFloat(float64(v), 'g', -1, 64), nil case float64: return strconv.FormatFloat(v, 'g', -1, 64), nil case sql.NamedArg: return Sprint(dialect, v.Value) case sql.NullBool: if !v.Valid { return "NULL", nil } if v.Bool { if dialect == DialectSQLServer { return "1", nil } return "TRUE", nil } if dialect == DialectSQLServer { return "0", nil } return "FALSE", nil case sql.NullFloat64: if !v.Valid { return "NULL", nil } return strconv.FormatFloat(v.Float64, 'g', -1, 64), nil case sql.NullInt64: if !v.Valid { return "NULL", nil } return strconv.FormatInt(v.Int64, 10), nil case sql.NullInt32: if !v.Valid { return "NULL", nil } return strconv.FormatInt(int64(v.Int32), 10), nil case sql.NullString: if !v.Valid { return "NULL", nil } return Sprint(dialect, v.String) case sql.NullTime: if !v.Valid { return "NULL", nil } if dialect == DialectPostgres || dialect == DialectSQLServer { return `'` + v.Time.Format(timestampWithTimezone) + `'`, nil } return `'` + v.Time.UTC().Format(timestamp) + `'`, nil case driver.Valuer: vv, err := v.Value() if err != nil { return "", fmt.Errorf("error when calling Value(): %w", err) } switch vv.(type) { case int64, float64, bool, []byte, string, time.Time, nil: return Sprint(dialect, vv) default: return "", fmt.Errorf("invalid driver.Value type %T (must be one of int64, float64, bool, []byte, string, time.Time, nil)", vv) } } rv := reflect.ValueOf(v) if rv.Kind() == reflect.Pointer { rv = rv.Elem() if !rv.IsValid() { return "NULL", nil } } switch v := rv.Interface().(type) { case bool, []byte, string, time.Time, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, sql.NamedArg, sql.NullBool, sql.NullFloat64, sql.NullInt64, sql.NullInt32, sql.NullString, sql.NullTime, driver.Valuer: return Sprint(dialect, v) default: return "", fmt.Errorf("%T has no SQL representation", v) } } // isExpandableSlice checks if a value is an expandable slice. func isExpandableSlice(value any) bool { // treat byte slices as a special case that we never want to expand if _, ok := value.([]byte); ok { return false } valueType := reflect.TypeOf(value) if valueType == nil { return false } return valueType.Kind() == reflect.Slice } // expandSlice expands a slice value into Output. Make sure the value is an // expandable slice first by checking it with isExpandableSlice(). func expandSlice(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, value any) error { slice := reflect.ValueOf(value) var err error for i := 0; i < slice.Len(); i++ { if i > 0 { buf.WriteString(", ") } arg := slice.Index(i).Interface() if v, ok := arg.(SQLWriter); ok { err = v.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } continue } switch dialect { case DialectPostgres, DialectSQLite: buf.WriteString("$" + strconv.Itoa(len(*args)+1)) case DialectSQLServer: buf.WriteString("@p" + strconv.Itoa(len(*args)+1)) default: buf.WriteString("?") } arg, err = preprocessValue(dialect, arg) if err != nil { return err } *args = append(*args, arg) } return nil } // writeNamedArg writes an sql.NamedArg into the Output. func writeNamedArg(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, namedArg sql.NamedArg) error { if w, ok := namedArg.Value.(SQLWriter); ok { return w.WriteSQL(ctx, dialect, buf, args, params) } if isExpandableSlice(namedArg.Value) { return expandSlice(ctx, dialect, buf, args, params, namedArg.Value) } var err error namedArg.Value, err = preprocessValue(dialect, namedArg.Value) if err != nil { return err } paramIndices := params[namedArg.Name] if len(paramIndices) > 0 { index := paramIndices[0] switch dialect { case DialectSQLite: (*args)[index] = namedArg buf.WriteString("$" + namedArg.Name) return nil case DialectPostgres: (*args)[index] = namedArg.Value buf.WriteString("$" + strconv.Itoa(index+1)) return nil case DialectSQLServer: (*args)[index] = namedArg buf.WriteString("@" + namedArg.Name) return nil default: for _, index := range paramIndices { (*args)[index] = namedArg.Value } } } switch dialect { case DialectSQLite: *args = append(*args, namedArg) if params != nil { index := len(*args) - 1 params[namedArg.Name] = []int{index} } buf.WriteString("$" + namedArg.Name) case DialectPostgres: *args = append(*args, namedArg.Value) index := len(*args) - 1 if params != nil { params[namedArg.Name] = []int{index} } buf.WriteString("$" + strconv.Itoa(index+1)) case DialectSQLServer: *args = append(*args, namedArg) if params != nil { index := len(*args) - 1 params[namedArg.Name] = []int{index} } buf.WriteString("@" + namedArg.Name) default: *args = append(*args, namedArg.Value) if params != nil { index := len(*args) - 1 params[namedArg.Name] = append(paramIndices, index) } buf.WriteString("?") } return nil } // writeOrdinalValue writes an ordinal value into the Output. The // ordinalIndices map is there to keep track of which ordinal values we have // already appended to args (which we do not want to append again). func writeOrdinalValue(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, values []any, ordinal int, ordinalIndices map[int]int) error { index := ordinal - 1 if index < 0 || index >= len(values) { return fmt.Errorf("ordinal parameter {%d} is out of bounds", ordinal) } value := values[index] if namedArg, ok := value.(sql.NamedArg); ok { return writeNamedArg(ctx, dialect, buf, args, params, namedArg) } if w, ok := value.(SQLWriter); ok { return w.WriteSQL(ctx, dialect, buf, args, params) } if isExpandableSlice(value) { return expandSlice(ctx, dialect, buf, args, params, value) } var err error value, err = preprocessValue(dialect, value) if err != nil { return err } switch dialect { case DialectSQLite, DialectPostgres, DialectSQLServer: index, ok := ordinalIndices[ordinal] if !ok { *args = append(*args, value) index = len(*args) - 1 ordinalIndices[ordinal] = index } switch dialect { case DialectSQLite, DialectPostgres: buf.WriteString("$" + strconv.Itoa(index+1)) case DialectSQLServer: buf.WriteString("@p" + strconv.Itoa(index+1)) } default: err := WriteValue(ctx, dialect, buf, args, params, value) if err != nil { return err } } return nil } // lookupParam returns the SQL representation of a paramName (inside the args // slice). func lookupParam(dialect string, args []any, paramName []rune, namedIndices map[string]int, runningArgsIndex int) (paramValue string, err error) { var maybeNum string if paramName[0] == '@' && dialect == DialectSQLServer && len(paramName) >= 2 && (paramName[1] == 'p' || paramName[1] == 'P') { maybeNum = string(paramName[2:]) } else { maybeNum = string(paramName[1:]) } // is paramName an anonymous parameter? if maybeNum == "" { if paramName[0] != '?' { return "", fmt.Errorf("parameter name missing") } paramValue, err = Sprint(dialect, args[runningArgsIndex]) if err != nil { return "", err } return paramValue, nil } // is paramName an ordinal paramater? ordinal, err := strconv.Atoi(maybeNum) if err == nil { index := ordinal - 1 if index < 0 || index >= len(args) { return "", fmt.Errorf("args index %d out of bounds", ordinal) } paramValue, err = Sprint(dialect, args[index]) if err != nil { return "", err } return paramValue, nil } // if we reach here, we know that the paramName is not an ordinal parameter // i.e. it is a named parameter if dialect == DialectPostgres || dialect == DialectMySQL { return "", fmt.Errorf("%s does not support %s named parameter", dialect, string(paramName)) } index, ok := namedIndices[string(paramName[1:])] if !ok { return "", fmt.Errorf("named parameter %s not provided", string(paramName)) } if index < 0 || index >= len(args) { return "", fmt.Errorf("args index %d out of bounds", ordinal) } paramValue, err = Sprint(dialect, args[index]) if err != nil { return "", err } return paramValue, nil } func quoteTableColumns(dialect string, table Table) string { tableWithColumns, ok := table.(interface{ GetColumns() []string }) if !ok { return "" } columns := tableWithColumns.GetColumns() if len(columns) == 0 { return "" } var b strings.Builder b.WriteString(" (") for i, column := range columns { if i > 0 { b.WriteString(", ") } b.WriteString(QuoteIdentifier(dialect, column)) } b.WriteString(")") return b.String() } // Params is a shortcut for typing map[string]interface{}. type Params = map[string]any // Parameter is identical to sql.NamedArg, but implements the Field interface. type Parameter sql.NamedArg var _ Field = (*Parameter)(nil) // Param creates a new Parameter. func Param(name string, value any) Parameter { return Parameter{Name: name, Value: value} } // WriteSQL implements the SQLWriter interface. func (p Parameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p Parameter) IsField() {} // ArrayParameter is identical to sql.NamedArg, but implements the Array interface. type ArrayParameter sql.NamedArg var _ Field = (*ArrayParameter)(nil) // ArrayParam creates a new ArrayParameter. It wraps the value with // ArrayValue(). func ArrayParam(name string, value any) ArrayParameter { return ArrayParameter{Name: name, Value: ArrayValue(value)} } // WriteSQL implements the SQLWriter interface. func (p ArrayParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p ArrayParameter) IsField() {} // IsArray implements the Array interface. func (p ArrayParameter) IsArray() {} // BinaryParameter is identical to sql.NamedArg, but implements the Binary // interface. type BinaryParameter sql.NamedArg var _ Binary = (*BinaryParameter)(nil) // BytesParam creates a new BinaryParameter using a []byte value. func BytesParam(name string, b []byte) BinaryParameter { return BinaryParameter{Name: name, Value: b} } // WriteSQL implements the SQLWriter interface. func (p BinaryParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p BinaryParameter) IsField() {} // IsBinary implements the Binary interface. func (p BinaryParameter) IsBinary() {} // BooleanParameter is identical to sql.NamedArg, but implements the Boolean // interface. type BooleanParameter sql.NamedArg var _ Boolean = (*BooleanParameter)(nil) // BoolParam creates a new BooleanParameter from a bool value. func BoolParam(name string, b bool) BooleanParameter { return BooleanParameter{Name: name, Value: b} } // WriteSQL implements the SQLWriter interface. func (p BooleanParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p BooleanParameter) IsField() {} // IsBoolean implements the Boolean interface. func (p BooleanParameter) IsBoolean() {} // EnumParameter is identical to sql.NamedArg, but implements the Enum // interface. type EnumParameter sql.NamedArg var _ Field = (*EnumParameter)(nil) // EnumParam creates a new EnumParameter. It wraps the value with EnumValue(). func EnumParam(name string, value Enumeration) EnumParameter { return EnumParameter{Name: name, Value: EnumValue(value)} } // WriteSQL implements the SQLWriter interface. func (p EnumParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p EnumParameter) IsField() {} // IsEnum implements the Enum interface. func (p EnumParameter) IsEnum() {} // JSONParameter is identical to sql.NamedArg, but implements the JSON // interface. type JSONParameter sql.NamedArg var _ Field = (*JSONParameter)(nil) // JSONParam creates a new JSONParameter. It wraps the value with JSONValue(). func JSONParam(name string, value any) JSONParameter { return JSONParameter{Name: name, Value: JSONValue(value)} } // WriteSQL implements the SQLWriter interface. func (p JSONParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p JSONParameter) IsField() {} // IsJSON implements the JSON interface. func (p JSONParameter) IsJSON() {} // NumberParameter is identical to sql.NamedArg, but implements the Number // interface. type NumberParameter sql.NamedArg var _ Number = (*NumberParameter)(nil) // IntParam creates a new NumberParameter from an int value. func IntParam(name string, num int) NumberParameter { return NumberParameter{Name: name, Value: num} } // Int64Param creates a new NumberParameter from an int64 value. func Int64Param(name string, num int64) NumberParameter { return NumberParameter{Name: name, Value: num} } // Float64Param creates a new NumberParameter from an float64 value. func Float64Param(name string, num float64) NumberParameter { return NumberParameter{Name: name, Value: num} } // WriteSQL implements the SQLWriter interface. func (p NumberParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p NumberParameter) IsField() {} // IsNumber implements the Number interface. func (p NumberParameter) IsNumber() {} // StringParameter is identical to sql.NamedArg, but implements the String // interface. type StringParameter sql.NamedArg var _ String = (*StringParameter)(nil) // StringParam creates a new StringParameter from a string value. func StringParam(name string, s string) StringParameter { return StringParameter{Name: name, Value: s} } // WriteSQL implements the SQLWriter interface. func (p StringParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p StringParameter) IsField() {} // IsString implements the String interface. func (p StringParameter) IsString() {} // TimeParameter is identical to sql.NamedArg, but implements the Time // interface. type TimeParameter sql.NamedArg var _ Time = (*TimeParameter)(nil) // TimeParam creates a new TimeParameter from a time.Time value. func TimeParam(name string, t time.Time) TimeParameter { return TimeParameter{Name: name, Value: t} } // WriteSQL implements the SQLWriter interface. func (p TimeParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p TimeParameter) IsField() {} // IsTime implements the Time interface. func (p TimeParameter) IsTime() {} // UUIDParameter is identical to sql.NamedArg, but implements the UUID // interface. type UUIDParameter sql.NamedArg var _ Field = (*UUIDParameter)(nil) // UUIDParam creates a new UUIDParameter. It wraps the value with UUIDValue(). func UUIDParam(name string, value any) UUIDParameter { return UUIDParameter{Name: name, Value: UUIDValue(value)} } // WriteSQL implements the SQLWriter interface. func (p UUIDParameter) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return writeNamedArg(ctx, dialect, buf, args, params, sql.NamedArg(p)) } // IsField implements the Field interface. func (p UUIDParameter) IsField() {} // IsUUID implements the UUID interface. func (p UUIDParameter) IsUUID() {} // SQLite keyword reference: https://www.sqlite.org/lang_keywords.html var sqliteKeywords = map[string]struct{}{ "abort": {}, "action": {}, "add": {}, "after": {}, "all": {}, "alter": {}, "always": {}, "analyze": {}, "and": {}, "as": {}, "asc": {}, "attach": {}, "autoincrement": {}, "before": {}, "begin": {}, "between": {}, "by": {}, "cascade": {}, "case": {}, "cast": {}, "check": {}, "collate": {}, "column": {}, "commit": {}, "conflict": {}, "constraint": {}, "create": {}, "cross": {}, "current": {}, "current_date": {}, "current_time": {}, "current_timestamp": {}, "database": {}, "default": {}, "deferrable": {}, "deferred": {}, "delete": {}, "desc": {}, "detach": {}, "distinct": {}, "do": {}, "drop": {}, "each": {}, "else": {}, "end": {}, "escape": {}, "except": {}, "exclude": {}, "exclusive": {}, "exists": {}, "explain": {}, "fail": {}, "filter": {}, "first": {}, "following": {}, "for": {}, "foreign": {}, "from": {}, "full": {}, "generated": {}, "glob": {}, "group": {}, "groups": {}, "having": {}, "if": {}, "ignore": {}, "immediate": {}, "in": {}, "index": {}, "indexed": {}, "initially": {}, "inner": {}, "insert": {}, "instead": {}, "intersect": {}, "into": {}, "is": {}, "isnull": {}, "join": {}, "key": {}, "last": {}, "left": {}, "like": {}, "limit": {}, "match": {}, "materialized": {}, "natural": {}, "no": {}, "not": {}, "nothing": {}, "notnull": {}, "null": {}, "nulls": {}, "of": {}, "offset": {}, "on": {}, "or": {}, "order": {}, "others": {}, "outer": {}, "over": {}, "partition": {}, "plan": {}, "pragma": {}, "preceding": {}, "primary": {}, "query": {}, "raise": {}, "range": {}, "recursive": {}, "references": {}, "regexp": {}, "reindex": {}, "release": {}, "rename": {}, "replace": {}, "restrict": {}, "returning": {}, "right": {}, "rollback": {}, "row": {}, "rows": {}, "savepoint": {}, "select": {}, "set": {}, "table": {}, "temp": {}, "temporary": {}, "then": {}, "ties": {}, "to": {}, "transaction": {}, "trigger": {}, "unbounded": {}, "union": {}, "unique": {}, "update": {}, "using": {}, "vacuum": {}, "values": {}, "view": {}, "virtual": {}, "when": {}, "where": {}, "window": {}, "with": {}, "without": {}, } // Postgres keyword reference: // https://www.postgresql.org/docs/current/sql-keywords-appendix.html var postgresKeywords = map[string]struct{}{ "all": {}, "analyse": {}, "analyze": {}, "and": {}, "any": {}, "array": {}, "as": {}, "asc": {}, "asymmetric": {}, "authorization": {}, "binary": {}, "both": {}, "case": {}, "cast": {}, "check": {}, "collate": {}, "collation": {}, "column": {}, "concurrently": {}, "constraint": {}, "create": {}, "cross": {}, "current_catalog": {}, "current_date": {}, "current_role": {}, "current_schema": {}, "current_time": {}, "current_timestamp": {}, "current_user": {}, "default": {}, "deferrable": {}, "desc": {}, "distinct": {}, "do": {}, "else": {}, "end": {}, "except": {}, "false": {}, "fetch": {}, "for": {}, "foreign": {}, "freeze": {}, "from": {}, "full": {}, "grant": {}, "group": {}, "having": {}, "ilike": {}, "in": {}, "initially": {}, "inner": {}, "intersect": {}, "into": {}, "is": {}, "isnull": {}, "join": {}, "lateral": {}, "leading": {}, "left": {}, "like": {}, "limit": {}, "localtime": {}, "localtimestamp": {}, "natural": {}, "not": {}, "notnull": {}, "null": {}, "offset": {}, "on": {}, "only": {}, "or": {}, "order": {}, "outer": {}, "overlaps": {}, "placing": {}, "primary": {}, "references": {}, "returning": {}, "right": {}, "select": {}, "session_user": {}, "similar": {}, "some": {}, "symmetric": {}, "table": {}, "tablesample": {}, "then": {}, "to": {}, "trailing": {}, "true": {}, "union": {}, "unique": {}, "user": {}, "using": {}, "variadic": {}, "verbose": {}, "when": {}, "where": {}, "window": {}, "with": {}, } // MySQL keyword reference: // https://dev.mysql.com/doc/refman/8.0/en/keywords.html var mysqlKeywords = map[string]struct{}{ "accessible": {}, "add": {}, "all": {}, "alter": {}, "analyze": {}, "and": {}, "as": {}, "asc": {}, "asensitive": {}, "before": {}, "between": {}, "bigint": {}, "binary": {}, "blob": {}, "both": {}, "by": {}, "call": {}, "cascade": {}, "case": {}, "change": {}, "char": {}, "character": {}, "check": {}, "collate": {}, "column": {}, "condition": {}, "constraint": {}, "continue": {}, "convert": {}, "create": {}, "cross": {}, "cube": {}, "cume_dist": {}, "current_date": {}, "current_time": {}, "current_timestamp": {}, "current_user": {}, "cursor": {}, "database": {}, "databases": {}, "day_hour": {}, "day_microsecond": {}, "day_minute": {}, "day_second": {}, "dec": {}, "decimal": {}, "declare": {}, "default": {}, "delayed": {}, "delete": {}, "dense_rank": {}, "desc": {}, "describe": {}, "deterministic": {}, "distinct": {}, "distinctrow": {}, "div": {}, "double": {}, "drop": {}, "dual": {}, "each": {}, "else": {}, "elseif": {}, "empty": {}, "enclosed": {}, "escaped": {}, "except": {}, "exists": {}, "exit": {}, "explain": {}, "false": {}, "fetch": {}, "first_value": {}, "float": {}, "float4": {}, "float8": {}, "for": {}, "force": {}, "foreign": {}, "from": {}, "fulltext": {}, "function": {}, "generated": {}, "get": {}, "grant": {}, "group": {}, "grouping": {}, "groups": {}, "having": {}, "high_priority": {}, "hour_microsecond": {}, "hour_minute": {}, "hour_second": {}, "if": {}, "ignore": {}, "in": {}, "index": {}, "infile": {}, "inner": {}, "inout": {}, "insensitive": {}, "insert": {}, "int": {}, "int1": {}, "int2": {}, "int3": {}, "int4": {}, "int8": {}, "integer": {}, "intersect": {}, "interval": {}, "into": {}, "io_after_gtids": {}, "io_before_gtids": {}, "is": {}, "iterate": {}, "join": {}, "json_table": {}, "key": {}, "keys": {}, "kill": {}, "lag": {}, "last_value": {}, "lateral": {}, "lead": {}, "leading": {}, "leave": {}, "left": {}, "like": {}, "limit": {}, "linear": {}, "lines": {}, "load": {}, "localtime": {}, "localtimestamp": {}, "lock": {}, "long": {}, "longblob": {}, "longtext": {}, "loop": {}, "low_priority": {}, "master_bind": {}, "master_ssl_verify_server_cert": {}, "match": {}, "maxvalue": {}, "mediumblob": {}, "mediumint": {}, "mediumtext": {}, "middleint": {}, "minute_microsecond": {}, "minute_second": {}, "mod": {}, "modifies": {}, "natural": {}, "not": {}, "no_write_to_binlog": {}, "nth_value": {}, "ntile": {}, "null": {}, "numeric": {}, "of": {}, "on": {}, "optimize": {}, "optimizer_costs": {}, "option": {}, "optionally": {}, "or": {}, "order": {}, "out": {}, "outer": {}, "outfile": {}, "over": {}, "partition": {}, "percent_rank": {}, "precision": {}, "primary": {}, "procedure": {}, "purge": {}, "range": {}, "rank": {}, "read": {}, "reads": {}, "read_write": {}, "real": {}, "recursive": {}, "references": {}, "regexp": {}, "release": {}, "rename": {}, "repeat": {}, "replace": {}, "require": {}, "resignal": {}, "restrict": {}, "return": {}, "revoke": {}, "right": {}, "rlike": {}, "row": {}, "rows": {}, "row_number": {}, "schema": {}, "schemas": {}, "second_microsecond": {}, "select": {}, "sensitive": {}, "separator": {}, "set": {}, "show": {}, "signal": {}, "smallint": {}, "spatial": {}, "specific": {}, "sql": {}, "sqlexception": {}, "sqlstate": {}, "sqlwarning": {}, "sql_big_result": {}, "sql_calc_found_rows": {}, "sql_small_result": {}, "ssl": {}, "starting": {}, "stored": {}, "straight_join": {}, "system": {}, "table": {}, "terminated": {}, "then": {}, "tinyblob": {}, "tinyint": {}, "tinytext": {}, "to": {}, "trailing": {}, "trigger": {}, "true": {}, "undo": {}, "union": {}, "unique": {}, "unlock": {}, "unsigned": {}, "update": {}, "usage": {}, "use": {}, "using": {}, "utc_date": {}, "utc_time": {}, "utc_timestamp": {}, "values": {}, "varbinary": {}, "varchar": {}, "varcharacter": {}, "varying": {}, "virtual": {}, "when": {}, "where": {}, "while": {}, "window": {}, "with": {}, "write": {}, "xor": {}, "year_month": {}, "zerofill": {}, } // SQLServer keyword reference: // https://learn.microsoft.com/en-us/sql/t-sql/language-elements/reserved-keywords-transact-sql?view=sql-server-ver16 var sqlserverKeywords = map[string]struct{}{ "add": {}, "external": {}, "procedure": {}, "all": {}, "fetch": {}, "public": {}, "alter": {}, "file": {}, "raiserror": {}, "and": {}, "fillfactor": {}, "read": {}, "any": {}, "for": {}, "readtext": {}, "as": {}, "foreign": {}, "reconfigure": {}, "asc": {}, "freetext": {}, "references": {}, "authorization": {}, "freetexttable": {}, "replication": {}, "backup": {}, "from": {}, "restore": {}, "begin": {}, "full": {}, "restrict": {}, "between": {}, "function": {}, "return": {}, "break": {}, "goto": {}, "revert": {}, "browse": {}, "grant": {}, "revoke": {}, "bulk": {}, "group": {}, "right": {}, "by": {}, "having": {}, "rollback": {}, "cascade": {}, "holdlock": {}, "rowcount": {}, "case": {}, "identity": {}, "rowguidcol": {}, "check": {}, "identity_insert": {}, "rule": {}, "checkpoint": {}, "identitycol": {}, "save": {}, "close": {}, "if": {}, "schema": {}, "clustered": {}, "in": {}, "securityaudit": {}, "coalesce": {}, "index": {}, "select": {}, "collate": {}, "inner": {}, "semantickeyphrasetable": {}, "column": {}, "insert": {}, "semanticsimilaritydetailstable": {}, "commit": {}, "intersect": {}, "semanticsimilaritytable": {}, "compute": {}, "into": {}, "session_user": {}, "constraint": {}, "is": {}, "set": {}, "contains": {}, "join": {}, "setuser": {}, "containstable": {}, "key": {}, "shutdown": {}, "continue": {}, "kill": {}, "some": {}, "convert": {}, "left": {}, "statistics": {}, "create": {}, "like": {}, "system_user": {}, "cross": {}, "lineno": {}, "table": {}, "current": {}, "load": {}, "tablesample": {}, "current_date": {}, "merge": {}, "textsize": {}, "current_time": {}, "national": {}, "then": {}, "current_timestamp": {}, "nocheck": {}, "to": {}, "current_user": {}, "nonclustered": {}, "top": {}, "cursor": {}, "not": {}, "tran": {}, "database": {}, "null": {}, "transaction": {}, "dbcc": {}, "nullif": {}, "trigger": {}, "deallocate": {}, "of": {}, "truncate": {}, "declare": {}, "off": {}, "try_convert": {}, "default": {}, "offsets": {}, "tsequal": {}, "delete": {}, "on": {}, "union": {}, "deny": {}, "open": {}, "unique": {}, "desc": {}, "opendatasource": {}, "unpivot": {}, "disk": {}, "openquery": {}, "update": {}, "distinct": {}, "openrowset": {}, "updatetext": {}, "distributed": {}, "openxml": {}, "use": {}, "double": {}, "option": {}, "user": {}, "drop": {}, "or": {}, "values": {}, "dump": {}, "order": {}, "varying": {}, "else": {}, "outer": {}, "view": {}, "end": {}, "over": {}, "waitfor": {}, "errlvl": {}, "percent": {}, "when": {}, "escape": {}, "pivot": {}, "where": {}, "except": {}, "plan": {}, "while": {}, "exec": {}, "precision": {}, "with": {}, "execute": {}, "primary": {}, "within group": {}, "exists": {}, "print": {}, "writetext": {}, "exit": {}, "proc": {}, } ================================================ FILE: fmt_test.go ================================================ package sq import ( "bytes" "context" "database/sql" "database/sql/driver" "errors" "flag" "strings" "testing" "time" "github.com/bokwoon95/sq/internal/testutil" ) var ( postgresDSN = flag.String("postgres", "", "") mysqlDSN = flag.String("mysql", "", "") sqlserverDSN = flag.String("sqlserver", "", "") ) func TestWritef(t *testing.T) { type TT struct { ctx context.Context dialect string format string values []any wantQuery string wantArgs []any wantParams map[string][]int } assert := func(t *testing.T, tt TT) { if tt.ctx == nil { tt.ctx = context.Background() } buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(buf.String(), tt.wantQuery); diff != "" { t.Error(testutil.Callers(), diff) } if len(*args) > 0 || len(tt.wantArgs) > 0 { if diff := testutil.Diff(*args, tt.wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } } if len(params) > 0 || len(tt.wantParams) > 0 { if diff := testutil.Diff(params, tt.wantParams); diff != "" { t.Error(testutil.Callers(), diff) } } } t.Run("empty", func(t *testing.T) { t.Parallel() var tt TT tt.format = "" tt.values = []any{} tt.wantQuery = "" tt.wantArgs = []any{} assert(t, tt) }) t.Run("escape curly bracket {{", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {} = '{{}'" tt.values = []any{"{}"} tt.wantQuery = `SELECT ? = '{}'` tt.wantArgs = []any{"{}"} assert(t, tt) }) t.Run("expr", func(t *testing.T) { t.Parallel() var tt TT tt.format = "(MAX(AVG({one}), AVG({two}), SUM({three})) + {incr}) IN ({slice})" tt.values = []any{ sql.Named("one", tmpfield("user_id")), sql.Named("two", tmpfield("age")), sql.Named("three", tmpfield("age")), sql.Named("incr", 1), sql.Named("slice", []int{1, 2, 3}), } tt.wantQuery = "(MAX(AVG(user_id), AVG(age), SUM(age)) + ?) IN (?, ?, ?)" tt.wantArgs = []any{1, 1, 2, 3} tt.wantParams = map[string][]int{"incr": {0}} assert(t, tt) }) t.Run("Field slice expansion", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {} FROM {}" tt.values = []any{ []Field{ tmpfield("111.aaa"), tmpfield("222.bbb"), tmpfield("333.ccc"), }, tmptable("public.222"), } tt.wantQuery = `SELECT "111".aaa, "222".bbb, "333".ccc FROM public."222"` assert(t, tt) }) t.Run("params", func(t *testing.T) { t.Parallel() var tt TT tt.format = "{param}, {param}" + ", {array}, {array}" + ", {bytes}, {bytes}" + ", {bool}, {bool}" + ", {enum}, {enum}" + ", {json}, {json}" + ", {int}, {int}" + ", {int64}, {float64}" + ", {string}, {string}" + ", {time}, {time}" + ", {uuid}, {uuid}" tt.values = []any{ Param("param", nil), ArrayParam("array", []int{1, 2, 3}), BytesParam("bytes", []byte{0xFF, 0xFF, 0xFF}), BoolParam("bool", true), EnumParam("enum", Monday), JSONParam("json", map[string]string{"lorem": "ipsum"}), IntParam("int", 5), Int64Param("int64", 7), Float64Param("float64", 11.0), StringParam("string", "lorem ipsum"), TimeParam("time", time.Unix(0, 0)), UUIDParam("uuid", [16]byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}), } tt.wantQuery = "?, ?" + ", ?, ?" + ", ?, ?" + ", ?, ?" + ", ?, ?" + ", ?, ?" + ", ?, ?" + ", ?, ?" + ", ?, ?" + ", ?, ?" + ", ?, ?" tt.wantArgs = []any{ nil, nil, "[1,2,3]", "[1,2,3]", []byte{0xFF, 0xFF, 0xFF}, []byte{0xFF, 0xFF, 0xFF}, true, true, "Monday", "Monday", `{"lorem":"ipsum"}`, `{"lorem":"ipsum"}`, 5, 5, int64(7), float64(11.0), "lorem ipsum", "lorem ipsum", time.Unix(0, 0), time.Unix(0, 0), []byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, []byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, } tt.wantParams = map[string][]int{ "param": {0, 1}, "array": {2, 3}, "bytes": {4, 5}, "bool": {6, 7}, "enum": {8, 9}, "json": {10, 11}, "int": {12, 13}, "int64": {14}, "float64": {15}, "string": {16, 17}, "time": {18, 19}, "uuid": {20, 21}, } assert(t, tt) }) t.Run("duplicate params should error", func(t *testing.T) { t.Parallel() var tt TT tt.format = "{param}, {param}" tt.values = []any{ Param("param", 1), Param("param", 1), } var buf bytes.Buffer var args []any params := make(map[string][]int) format := "{param}, {param}" values := []any{ Param("param", 1), Param("param", 1), } err := Writef(context.Background(), "", &buf, &args, params, format, values) if err == nil { t.Errorf(testutil.Callers() + " expected error but got nil") } }) t.Run("sqlite,postgres QuoteIdentifier", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.format = "SELECT {}" tt.values = []any{ tmpfield(`"; ""; DROP TABLE users --`), } tt.wantQuery = `SELECT """; ""; DROP TABLE users --"` assert(t, tt) tt.dialect = DialectPostgres assert(t, tt) }) t.Run("sqlite,postgres anonymous params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.format = "SELECT {}" + " FROM {}" + " WHERE {} = {}" + " AND {} <> {}" + " AND {} IN ({})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), 5, tmpfield("email"), "bob@email.com", tmpfield("name"), []string{"tom", "dick", "harry"}, } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = $1" + " AND email <> $2" + " AND name IN ($3, $4, $5)" tt.wantArgs = []any{5, "bob@email.com", "tom", "dick", "harry"} assert(t, tt) tt.dialect = DialectPostgres assert(t, tt) }) t.Run("sqlite,postgres ordinal params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.format = "SELECT {}" + " FROM {}" + " WHERE {} = {5}" + " AND {} <> {5}" + " AND {1} IN ({6})" + " AND {4} IN ({6})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), tmpfield("email"), "bob@email.com", []string{"tom", "dick", "harry"}, } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = $1" + " AND email <> $1" + " AND name IN ($2, $3, $4)" + " AND email IN ($5, $6, $7)" tt.wantArgs = []any{ "bob@email.com", "tom", "dick", "harry", "tom", "dick", "harry", } assert(t, tt) tt.dialect = DialectPostgres assert(t, tt) }) t.Run("sqlite named params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.format = "SELECT {}" + " FROM {}" + " WHERE {3} = {age}" + " AND {3} > {6}" + " AND {4} <> {email}" + " AND {1} IN ({names})" + " AND {4} IN ({names})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), tmpfield("email"), sql.Named("email", "bob@email.com"), sql.Named("age", 5), sql.Named("names", []string{"tom", "dick", "harry"}), } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = $age" + " AND age > $age" + " AND email <> $email" + " AND name IN ($3, $4, $5)" + " AND email IN ($6, $7, $8)" tt.wantArgs = []any{ sql.Named("age", 5), sql.Named("email", "bob@email.com"), "tom", "dick", "harry", "tom", "dick", "harry", } tt.wantParams = map[string][]int{"age": {0}, "email": {1}} assert(t, tt) }) t.Run("postgres named params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectPostgres tt.format = "SELECT {}" + " FROM {}" + " WHERE {3} = {age}" + " AND {3} > {6}" + " AND {4} <> {email}" + " AND {1} IN ({names})" + " AND {4} IN ({names})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), tmpfield("email"), sql.Named("email", "bob@email.com"), sql.Named("age", 5), sql.Named("names", []string{"tom", "dick", "harry"}), } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = $1" + " AND age > $1" + " AND email <> $2" + " AND name IN ($3, $4, $5)" + " AND email IN ($6, $7, $8)" tt.wantArgs = []any{ 5, "bob@email.com", "tom", "dick", "harry", "tom", "dick", "harry", } tt.wantParams = map[string][]int{"age": {0}, "email": {1}} assert(t, tt) }) t.Run("sqlite,postgres SQLWriter in named param", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.format = "SELECT {field} FROM {tbl} WHERE {field} IN ({nums})" tt.values = []any{ sql.Named("nums", []int{1, 2, 3}), sql.Named("tbl", tmptable("public.tbl")), sql.Named("field", tmpfield("tbl.field")), } tt.wantQuery = `SELECT tbl.field FROM public.tbl WHERE tbl.field IN ($1, $2, $3)` tt.wantArgs = []any{1, 2, 3} assert(t, tt) tt.dialect = DialectPostgres assert(t, tt) }) t.Run("mysql QuoteIdentifier", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectMySQL tt.format = "SELECT {}" tt.values = []any{ tmpfield("`; ``; DROP TABLE users --"), } tt.wantQuery = "SELECT ```; ``; DROP TABLE users --`" assert(t, tt) }) t.Run("mysql anonymous params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectMySQL tt.format = "SELECT {}" + " FROM {}" + " WHERE {} = {}" + " AND {} <> {}" + " AND {} IN ({})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), 5, tmpfield("email"), "bob@email.com", tmpfield("name"), []string{"tom", "dick", "harry"}, } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = ?" + " AND email <> ?" + " AND name IN (?, ?, ?)" tt.wantArgs = []any{5, "bob@email.com", "tom", "dick", "harry"} assert(t, tt) }) t.Run("mysql ordinal params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectMySQL tt.format = "SELECT {}" + " FROM {}" + " WHERE {} = {5}" + " AND {} <> {5}" + " AND {1} IN ({6})" + " AND {4} IN ({6})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), tmpfield("email"), "bob@email.com", []string{"tom", "dick", "harry"}, } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = ?" + " AND email <> ?" + " AND name IN (?, ?, ?)" + " AND email IN (?, ?, ?)" tt.wantArgs = []any{ "bob@email.com", "bob@email.com", "tom", "dick", "harry", "tom", "dick", "harry", } assert(t, tt) }) t.Run("mysql named params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectMySQL tt.format = "SELECT {}" + " FROM {}" + " WHERE {3} = {age}" + " AND {3} > {6}" + " AND {4} <> {email}" + " AND {1} IN ({names})" + " AND {4} IN ({names})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), tmpfield("email"), sql.Named("email", "bob@email.com"), sql.Named("age", 5), sql.Named("names", []string{"tom", "dick", "harry"}), } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = ?" + " AND age > ?" + " AND email <> ?" + " AND name IN (?, ?, ?)" + " AND email IN (?, ?, ?)" tt.wantArgs = []any{ 5, 5, "bob@email.com", "tom", "dick", "harry", "tom", "dick", "harry", } tt.wantParams = map[string][]int{"age": {0, 1}, "email": {2}} assert(t, tt) }) t.Run("mysql SQLWriter in named param", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectMySQL tt.format = "SELECT {field} FROM {tbl} WHERE {field} IN ({nums})" tt.values = []any{ sql.Named("nums", []int{1, 2, 3}), sql.Named("tbl", tmptable("public.tbl")), sql.Named("field", tmpfield("tbl.field")), } tt.wantQuery = `SELECT tbl.field FROM public.tbl WHERE tbl.field IN (?, ?, ?)` tt.wantArgs = []any{1, 2, 3} assert(t, tt) }) t.Run("sqlserver QuoteIdentifier", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.format = "SELECT {}" tt.values = []any{ tmpfield("]; ]]; DROP TABLE users --"), } tt.wantQuery = "SELECT []]; ]]; DROP TABLE users --]" assert(t, tt) }) t.Run("sqlserver anonymous params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.format = "SELECT {}" + " FROM {}" + " WHERE {} = {}" + " AND {} <> {}" + " AND {} IN ({})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), 5, tmpfield("email"), "bob@email.com", tmpfield("name"), []string{"tom", "dick", "harry"}, } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = @p1" + " AND email <> @p2" + " AND name IN (@p3, @p4, @p5)" tt.wantArgs = []any{5, "bob@email.com", "tom", "dick", "harry"} assert(t, tt) }) t.Run("sqlserver ordinal params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.format = "SELECT {}" + " FROM {}" + " WHERE {} = {5}" + " AND {} <> {5}" + " AND {1} IN ({6})" + " AND {4} IN ({6})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), tmpfield("email"), "bob@email.com", []string{"tom", "dick", "harry"}, } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = @p1" + " AND email <> @p1" + " AND name IN (@p2, @p3, @p4)" + " AND email IN (@p5, @p6, @p7)" tt.wantArgs = []any{ "bob@email.com", "tom", "dick", "harry", "tom", "dick", "harry", } assert(t, tt) }) t.Run("sqlserver named params", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.format = "SELECT {}" + " FROM {}" + " WHERE {3} = {age}" + " AND {3} > {6}" + " AND {4} <> {email}" + " AND {1} IN ({names})" + " AND {4} IN ({names})" tt.values = []any{ tmpfield("name"), tmptable("users"), tmpfield("age"), tmpfield("email"), sql.Named("email", "bob@email.com"), sql.Named("age", 5), sql.Named("names", []string{"tom", "dick", "harry"}), } tt.wantQuery = "SELECT name" + " FROM users" + " WHERE age = @age" + " AND age > @age" + " AND email <> @email" + " AND name IN (@p3, @p4, @p5)" + " AND email IN (@p6, @p7, @p8)" tt.wantArgs = []any{ sql.Named("age", 5), sql.Named("email", "bob@email.com"), "tom", "dick", "harry", "tom", "dick", "harry", } tt.wantParams = map[string][]int{"age": {0}, "email": {1}} assert(t, tt) }) t.Run("sqlserver SQLWriter in named param", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.format = "SELECT {field} FROM {tbl} WHERE {field} IN ({nums})" tt.values = []any{ sql.Named("nums", []int{1, 2, 3}), sql.Named("tbl", tmptable("dbo.tbl")), sql.Named("field", tmpfield("tbl.field")), } tt.wantQuery = `SELECT tbl.field FROM dbo.tbl WHERE tbl.field IN (@p1, @p2, @p3)` tt.wantArgs = []any{1, 2, 3} assert(t, tt) }) t.Run("preprocessValue kicks in for anonymous, ordinal params, named params and slices", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.format = "SELECT {}, {2}, {foo}, {3}, {bar}" tt.values = []any{ Monday, sql.Named("foo", Tuesday), Wednesday, sql.Named("bar", []Weekday{Thursday, Friday, Saturday}), } tt.wantQuery = "SELECT $1, $foo, $foo, $3, $4, $5, $6" tt.wantArgs = []any{ "Monday", sql.NamedArg{Name: "foo", Value: "Tuesday"}, "Wednesday", "Thursday", "Friday", "Saturday", } tt.wantParams = map[string][]int{"foo": {1}} assert(t, tt) }) t.Run("no closing curly brace }", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {field" buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if err == nil { t.Error(testutil.Callers(), "expected error but got nil") } }) t.Run("too few values passed in", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {}, {}, {}, {}" tt.values = []any{1, 2} buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if err == nil { t.Error(testutil.Callers(), "expected error but got nil") } }) t.Run("anonymous param faulty SQL", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {}" tt.values = []any{FaultySQL{}} buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if !errors.Is(err, ErrFaultySQL) { t.Error(testutil.Callers(), "expected ErrFaultySQL but got %v", err) } }) t.Run("ordinal param faulty SQL", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {1}" tt.values = []any{FaultySQL{}} buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if !errors.Is(err, ErrFaultySQL) { t.Error(testutil.Callers(), "expected ErrFaultySQL but got %v", err) } }) t.Run("named param faulty SQL", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {field}" tt.values = []any{sql.Named("field", FaultySQL{})} buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if !errors.Is(err, ErrFaultySQL) { t.Error(testutil.Callers(), "expected ErrFaultySQL but got %v", err) } }) t.Run("ordinal param out of bounds", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {1}, {2}, {99}" tt.values = []any{1, 2, 3} buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if err == nil { t.Error(testutil.Callers(), "expected error but got nil") } }) t.Run("nonexistent named param", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {A}, {B}, {C}" tt.values = []any{ sql.Named("A", 1), sql.Named("B", 2), sql.Named("E", 5), } buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if err == nil { t.Error(testutil.Callers(), "expected error but got nil") } }) t.Run("expandSlice faulty SQL", func(t *testing.T) { t.Parallel() var tt TT tt.format = "SELECT {}" tt.values = []any{ []Field{tmpfield("name"), tmpfield("age"), FaultySQL{}}, } buf := new(bytes.Buffer) args := new([]any) params := make(map[string][]int) err := Writef(tt.ctx, tt.dialect, buf, args, params, tt.format, tt.values) if !errors.Is(err, ErrFaultySQL) { t.Error(testutil.Callers(), "expected ErrFaultySQL but got %v", err) } }) } func TestSprintf(t *testing.T) { type TT struct { dialect string query string args []any wantString string } assert := func(t *testing.T, tt TT) { gotString, err := Sprintf(tt.dialect, tt.query, tt.args) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(gotString, tt.wantString); diff != "" { t.Error(testutil.Callers(), diff) } } assertNotOK := func(t *testing.T, tt TT) { _, err := Sprintf(tt.dialect, tt.query, tt.args) if err == nil { t.Fatal(testutil.Callers(), "expected error but got nil") } } t.Run("empty", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = "" tt.query = "" tt.args = []any{} tt.wantString = "" assert(t, tt) }) t.Run("insideString, insideIdentifier and escaping single quotes", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = "" tt.query = `SELECT ?` + `, 'do not "rebind" ? ? ?'` + // string `, "do not 'rebind' ? ? ?"` + // identifier `, ?` + `, ?` tt.args = []any{ "normal string", "string with 'quotes' must be escaped", "string with already escaped ''quotes'' except for 'this'", } tt.wantString = `SELECT 'normal string'` + `, 'do not "rebind" ? ? ?'` + `, "do not 'rebind' ? ? ?"` + `, 'string with ''quotes'' must be escaped'` + `, 'string with already escaped ''''quotes'''' except for ''this'''` assert(t, tt) }) t.Run("insideString, insideIdentifier and escaping single quotes (dialect == mysql)", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectMySQL tt.query = `SELECT ?` + `, 'do not "rebind" ? ? ?'` + // string ", `do not \" 'rebind' ? ? ?`" + // identifier ", \"do not ``` 'rebind' ? ? ?\"" + // identifier `, ?` + `, ?` tt.args = []any{ "normal string", "string with 'quotes' must be escaped", "string with already escaped ''quotes'' except for 'this'", } tt.wantString = `SELECT 'normal string'` + `, 'do not "rebind" ? ? ?'` + ", `do not \" 'rebind' ? ? ?`" + ", \"do not ``` 'rebind' ? ? ?\"" + `, 'string with ''quotes'' must be escaped'` + `, 'string with already escaped ''''quotes'''' except for ''this'''` assert(t, tt) }) t.Run("insideString, insideIdentifier and escaping single quotes (dialect == sqlserver)", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.query = `SELECT ?` + `, 'do not [[rebind] @p1 @p2 @name'` + // string ", [do not \" 'rebind' [[[[[@pp]] @p3 @p1]" + // identifier ", \"do not [[[ 'rebind' [[[[[@pp]] @p3 @p1\"" + // identifier `, ?` + `, @p3` tt.args = []any{ "normal string", "string with 'quotes' must be escaped", "string with already escaped ''quotes'' except for 'this'", } tt.wantString = `SELECT 'normal string'` + `, 'do not [[rebind] @p1 @p2 @name'` + ", [do not \" 'rebind' [[[[[@pp]] @p3 @p1]" + ", \"do not [[[ 'rebind' [[[[[@pp]] @p3 @p1\"" + `, 'string with ''quotes'' must be escaped'` + `, 'string with already escaped ''''quotes'''' except for ''this'''` assert(t, tt) }) t.Run("mysql", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectMySQL tt.query = "SELECT name FROM users WHERE age = ? AND email <> ? AND name IN (?, ?, ?)" tt.args = []any{5, "bob@email.com", "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> 'bob@email.com' AND name IN ('tom', 'dick', 'harry')" assert(t, tt) }) t.Run("mysql insideString", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectMySQL tt.query = "SELECT name FROM users WHERE age = ? AND email <> '? ? ? ? ''bruh ?' AND name IN (?, ?) ?" tt.args = []any{5, "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> '? ? ? ? ''bruh ?' AND name IN ('tom', 'dick') 'harry'" assert(t, tt) }) t.Run("omitted dialect insideString", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = "" tt.query = "SELECT name FROM users WHERE age = ? AND email <> '? ? ? ? ''bruh ?' AND name IN (?, ?) ?" tt.args = []any{5, "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> '? ? ? ? ''bruh ?' AND name IN ('tom', 'dick') 'harry'" assert(t, tt) }) t.Run("postgres", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectPostgres tt.query = "SELECT name FROM users WHERE age = $1 AND email <> $2 AND name IN ($2, $3, $4, $1)" tt.args = []any{5, "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> 'tom' AND name IN ('tom', 'dick', 'harry', 5)" assert(t, tt) }) t.Run("postgres insideString", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectPostgres tt.query = "SELECT name FROM users WHERE age = $1 AND email <> '$2 $2 $3 $4 ''bruh $1' AND name IN ($2, $3) $4" tt.args = []any{5, "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> '$2 $2 $3 $4 ''bruh $1' AND name IN ('tom', 'dick') 'harry'" assert(t, tt) }) t.Run("sqlite", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.query = "SELECT name FROM users WHERE age = $1 AND email <> $2 AND name IN ($2, $3, $4, $1)" tt.args = []any{5, "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> 'tom' AND name IN ('tom', 'dick', 'harry', 5)" assert(t, tt) }) t.Run("sqlite insideString", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.query = "SELECT name FROM users WHERE age = $1 AND email <> '$2 $2 $3 $4 ''bruh $1' AND name IN ($2, $3) $4" tt.args = []any{5, "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> '$2 $2 $3 $4 ''bruh $1' AND name IN ('tom', 'dick') 'harry'" assert(t, tt) }) t.Run("sqlite mixing ordinal param and named param", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.query = "SELECT name FROM users WHERE age = $age AND age > $1 AND email <> $email" tt.args = []any{sql.Named("age", 5), sql.Named("email", "bob@email.com")} tt.wantString = "SELECT name FROM users WHERE age = 5 AND age > 5 AND email <> 'bob@email.com'" assert(t, tt) }) t.Run("sqlite supports everything", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.query = "SELECT name FROM users WHERE age = ?age AND email <> :email AND name IN (@3, ?4, $5, :5) ? ?" tt.args = []any{sql.Named("age", 5), sql.Named("email", "bob@email.com"), "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> 'bob@email.com' AND name IN ('tom', 'dick', 'harry', 'harry') 5 'bob@email.com'" assert(t, tt) }) t.Run("sqlserver", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.query = "SELECT name FROM users WHERE age = @p1 AND email <> @P2 AND name IN (@p2, @p3, @p4, @P1)" tt.args = []any{5, "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> 'tom' AND name IN ('tom', 'dick', 'harry', 5)" assert(t, tt) }) t.Run("sqlserver insideString", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.query = "SELECT name FROM users WHERE age = @p1 AND email <> '@p2 @p2 @p3 @p4 ''bruh @p1' AND name IN (@p2, @p3) @p4" tt.args = []any{5, "tom", "dick", "harry"} tt.wantString = "SELECT name FROM users WHERE age = 5 AND email <> '@p2 @p2 @p3 @p4 ''bruh @p1' AND name IN ('tom', 'dick') 'harry'" assert(t, tt) }) t.Run("sqlserver mixing ordinal param and named param", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLServer tt.query = "SELECT name FROM users WHERE age = @age AND age > @p1 AND email <> @email" tt.args = []any{sql.Named("age", 5), sql.Named("email", "bob@email.com")} tt.wantString = "SELECT name FROM users WHERE age = 5 AND age > 5 AND email <> 'bob@email.com'" assert(t, tt) }) t.Run("unclosed string and identifier", func(t *testing.T) { t.Parallel() var tt TT // unclosed string tt.query = `SELECT ?, 'mary had a little', 'lamb` tt.args = []any{1} assertNotOK(t, tt) // unclosed identifier tt.query = `SELECT ?, "one", "two", "three` tt.args = []any{2} assertNotOK(t, tt) }) t.Run("sqlite invalid anonymous param", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.args = []any{23} tt.wantString = "SELECT 23" // ?1 is valid tt.query = "SELECT ?1" assert(t, tt) // ? is valid tt.query = "SELECT ?" assert(t, tt) // $1 is valid tt.query = "SELECT $1" assert(t, tt) // $ is invalid tt.query = "SELECT $" assertNotOK(t, tt) }) t.Run("not enough params", func(t *testing.T) { t.Parallel() var tt TT tt.query = "SELECT ?, ?, ?" tt.args = []any{1, 2} assertNotOK(t, tt) }) t.Run("functions cannot be printed", func(t *testing.T) { t.Parallel() var tt TT tt.query = "SELECT ?, ?" tt.args = []any{1, func() {}} assertNotOK(t, tt) }) t.Run("channels cannot be printed", func(t *testing.T) { t.Parallel() var tt TT tt.query = "SELECT ?, ?" tt.args = []any{make(chan int), 2} assertNotOK(t, tt) }) t.Run("non driver.Valuer types cannot be printed", func(t *testing.T) { t.Parallel() var tt TT tt.query = "SELECT ?, ?" tt.args = []any{struct{}{}, any(nil)} assertNotOK(t, tt) }) t.Run("ordinal param out of bounds", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.query = "SELECT @1, @2, @3" tt.args = []any{1, 2} assertNotOK(t, tt) }) t.Run("dialect that does not support sql.NamedArg", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectPostgres tt.query = "SELECT $test" tt.args = []any{sql.Named("test", 123)} assertNotOK(t, tt) }) t.Run("sql.NamedArg not provided", func(t *testing.T) { t.Parallel() var tt TT tt.dialect = DialectSQLite tt.query = "SELECT :one, :two, :three" tt.args = []any{ sql.Named("one", 1), sql.Named("two", 2), sql.Named("four", 4), } assertNotOK(t, tt) }) } func TestSprint(t *testing.T) { type TT struct { description string dialect string value any wantString string } singaporeLocation, _ := time.LoadLocation("Asia/Singapore") tests := []TT{{ description: "nil", value: nil, wantString: "NULL", }, { description: "true", value: true, wantString: "TRUE", }, { description: "false", value: false, wantString: "FALSE", }, { description: "sqlserver true", dialect: DialectSQLServer, value: true, wantString: "1", }, { description: "sqlserver false", dialect: DialectSQLServer, value: false, wantString: "0", }, { description: "postgres []byte", dialect: DialectPostgres, value: []byte{0xff, 0xff}, wantString: `'\xffff'`, }, { description: "[]byte", value: []byte{0xff, 0xff}, wantString: `x'ffff'`, }, { description: "string", value: "' OR ''test' = '; DROP TABLE users; -- ", wantString: `''' OR ''''test'' = ''; DROP TABLE users; -- '`, }, { description: "time.Time", value: time.Unix(0, 0).UTC(), wantString: `'1970-01-01 00:00:00'`, }, { description: "time.Time (SQLServer)", dialect: DialectSQLServer, value: time.Unix(0, 0).UTC(), wantString: `'1970-01-01 00:00:00+00:00'`, }, { description: "int", value: int(0), wantString: `0`, }, { description: "int8", value: int8(8), wantString: `8`, }, { description: "int16", value: int16(16), wantString: `16`, }, { description: "int32", value: int32(32), wantString: `32`, }, { description: "int64", value: int64(64), wantString: `64`, }, { description: "uint", value: uint(0), wantString: `0`, }, { description: "uint8", value: uint8(8), wantString: `8`, }, { description: "uint16", value: uint16(16), wantString: `16`, }, { description: "uint32", value: uint32(32), wantString: `32`, }, { description: "uint64", value: uint64(64), wantString: `64`, }, { description: "float32", value: float32(32.32), wantString: `32.31999969482422`, }, { description: "float64", value: float64(64.6464), wantString: `64.6464`, }, { description: "sql.NamedArg", value: sql.Named("test", 7), wantString: `7`, }, { description: "sql.NullBool NULL", value: sql.NullBool{}, wantString: `NULL`, }, { description: "sql.NullBool true", value: sql.NullBool{Valid: true, Bool: true}, wantString: `TRUE`, }, { description: "sql.NullBool false", value: sql.NullBool{Valid: true, Bool: false}, wantString: `FALSE`, }, { description: "sqlserver sql.NullBool NULL", dialect: DialectSQLServer, value: sql.NullBool{}, wantString: `NULL`, }, { description: "sqlserver sql.NullBool true", dialect: DialectSQLServer, value: sql.NullBool{Valid: true, Bool: true}, wantString: `1`, }, { description: "sqlserver sql.NullBool false", dialect: DialectSQLServer, value: sql.NullBool{Valid: true, Bool: false}, wantString: `0`, }, { description: "sql.NullFloat64 NULL", value: sql.NullFloat64{}, wantString: `NULL`, }, { description: "sql.NullFloat64", value: sql.NullFloat64{Valid: true, Float64: 3.0}, wantString: `3`, }, { description: "sql.NullInt64Field NULL", value: sql.NullInt64{}, wantString: `NULL`, }, { description: "sql.NullInt64Field", value: sql.NullInt64{Valid: true, Int64: 5}, wantString: `5`, }, { description: "sql.NullInt32 NULL", value: sql.NullInt32{}, wantString: `NULL`, }, { description: "sql.NullInt32", value: sql.NullInt32{Valid: true, Int32: 7}, wantString: `7`, }, { description: "sql.NullStringField NULL", value: sql.NullString{}, wantString: `NULL`, }, { description: "sql.NullStringField", value: sql.NullString{Valid: true, String: "pp"}, wantString: `'pp'`, }, { description: "sql.NullTimeField NULL", value: sql.NullTime{}, wantString: `NULL`, }, { description: "sql.NullTime", value: sql.NullTime{ Valid: true, Time: time.Unix(0, 0).UTC(), }, wantString: `'1970-01-01 00:00:00'`, }, { description: "sql.NullTime (Postgres)", dialect: DialectPostgres, value: sql.NullTime{ Valid: true, Time: time.Unix(0, 0).UTC(), }, wantString: `'1970-01-01 00:00:00+00:00'`, }, { description: "int64 Valuer", value: driverValuer{int64(3)}, wantString: `3`, }, { description: "float64 Valuer", value: driverValuer{64.6464}, wantString: `64.6464`, }, { description: "bool Valuer 1", value: driverValuer{true}, wantString: `TRUE`, }, { description: "bool Valuer 0", value: driverValuer{false}, wantString: `FALSE`, }, { description: "bytes Valuer", value: driverValuer{[]byte{0xab, 0xba}}, wantString: `x'abba'`, }, { description: "string Valuer", value: driverValuer{`'' ha '; DROP TABLE users; --`}, wantString: `''''' ha ''; DROP TABLE users; --'`, }, { description: "time.Time Valuer", value: driverValuer{time.Unix(0, 0).UTC()}, wantString: `'1970-01-01 00:00:00'`, }, { description: "time.Time Valuer (Postgres)", dialect: DialectPostgres, value: driverValuer{time.Unix(0, 0).UTC()}, wantString: `'1970-01-01 00:00:00+00:00'`, }, { description: "time.Time Valuer (Postgres)", dialect: DialectPostgres, value: driverValuer{time.Unix(22, 330000000).In(singaporeLocation)}, wantString: `'1970-01-01 07:30:22.33+07:30'`, }, { description: "string Valuer ptr", value: &driverValuer{`'' ha '; DROP TABLE users; --`}, wantString: `''''' ha ''; DROP TABLE users; --'`, }, { description: "int ptr", value: func() *int { num := 33 return &num }(), wantString: `33`, }, { description: "nil int ptr", value: func() *int { var num *int return num }(), wantString: `NULL`, }, { description: "string ptr", value: func() *string { str := "test string" return &str }(), wantString: `'test string'`, }, { description: "nil string ptr", value: func() *string { var str *string return str }(), wantString: `NULL`, }, { description: "sql.NullInt64 ptr", value: &sql.NullInt64{ Valid: true, Int64: 33, }, wantString: `33`, }, { description: "sql.NullString ptr", value: &sql.NullString{ Valid: true, String: "test string", }, wantString: `'test string'`, }, { description: "mysql string", dialect: DialectMySQL, value: "the quick brown fox", wantString: `'the quick brown fox'`, }, { description: "mysql string newlines in middle", dialect: DialectMySQL, value: "the quick\nbrown\r\nfox", wantString: `CONCAT('the quick', CHAR(10), 'brown', CHAR(13), CHAR(10), 'fox')`, }, { description: "mysql string newlines at end", dialect: DialectMySQL, value: "\nthe quick brown fox\r\n", wantString: `CONCAT(CHAR(10), 'the quick brown fox', CHAR(13), CHAR(10))`, }, { description: "postgres string", dialect: DialectPostgres, value: "the quick brown fox", wantString: `'the quick brown fox'`, }, { description: "postgres string newlines in middle", dialect: DialectPostgres, value: "the quick\nbrown\r\nfox", wantString: `'the quick' || CHR(10) || 'brown' || CHR(13) || CHR(10) || 'fox'`, }, { description: "postgres string newlines at end", dialect: DialectPostgres, value: "\nthe quick brown fox\r\n", wantString: `CHR(10) || 'the quick brown fox' || CHR(13) || CHR(10)`, }, { description: "sql.NullString with newlines", dialect: DialectPostgres, value: sql.NullString{ Valid: true, String: "\rthe quick\nbrown fox\r\n", }, wantString: `CHR(13) || 'the quick' || CHR(10) || 'brown fox' || CHR(13) || CHR(10)`, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() gotString, err := Sprint(tt.dialect, tt.value) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(gotString, tt.wantString); diff != "" { t.Error(testutil.Callers(), diff) } }) } } type tmptable string var _ Table = (*tmptable)(nil) func (t tmptable) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { schema, name := "", string(t) if i := strings.IndexByte(name, '.'); i >= 0 { schema, name = name[:i], name[i+1:] } if schema != "" { buf.WriteString(QuoteIdentifier(dialect, schema) + ".") } buf.WriteString(QuoteIdentifier(dialect, name)) return nil } func (t tmptable) GetAlias() string { return "" } func (t tmptable) IsTable() {} type tmpfield string var _ Field = (*tmpfield)(nil) func (f tmpfield) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { table, name := "", string(f) if i := strings.IndexByte(name, '.'); i >= 0 { table, name = name[:i], name[i+1:] } if table != "" { buf.WriteString(QuoteIdentifier(dialect, table) + ".") } buf.WriteString(QuoteIdentifier(dialect, name)) return nil } func (f tmpfield) WithPrefix(prefix string) Field { body := f if i := strings.IndexByte(string(f), '.'); i >= 0 { body = f[i+1:] } if prefix == "" { return body } return tmpfield(prefix + "." + string(body)) } func (f tmpfield) GetAlias() string { return "" } func (f tmpfield) IsField() {} type FaultySQLError struct{} func (e FaultySQLError) Error() string { return "sql broke" } var ErrFaultySQL error = FaultySQLError{} var _ interface { Query Table Field Predicate Assignment } = (*FaultySQL)(nil) type FaultySQL struct{} func (q FaultySQL) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return ErrFaultySQL } func (q FaultySQL) SetFetchableFields([]Field) (Query, bool) { return nil, false } func (q FaultySQL) GetFetchableFields() ([]Field, bool) { return nil, false } func (q FaultySQL) GetAlias() string { return "" } func (q FaultySQL) GetDialect() string { return "" } func (q FaultySQL) IsBoolean() {} func (q FaultySQL) IsTable() {} func (q FaultySQL) IsField() {} func (q FaultySQL) IsAssignment() {} type driverValuer struct{ value any } func (v driverValuer) Value() (driver.Value, error) { return v.value, nil } type dialectValuer struct { mysqlValuer driver.Valuer valuer driver.Valuer } func (v dialectValuer) DialectValuer(dialect string) (driver.Valuer, error) { if dialect == DialectMySQL { return v.mysqlValuer, nil } return v.valuer, nil } ================================================ FILE: go.mod ================================================ module github.com/bokwoon95/sq go 1.19 require ( github.com/denisenkom/go-mssqldb v0.12.3 github.com/go-sql-driver/mysql v1.7.1 github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.16 ) require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect golang.org/x/crypto v0.9.0 // indirect ) ================================================ FILE: go.sum ================================================ github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0= github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw= github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: insert_query.go ================================================ package sq import ( "bytes" "context" "fmt" ) // InsertQuery represents an SQL INSERT query. type InsertQuery struct { Dialect string ColumnMapper func(*Column) // WITH CTEs []CTE // INSERT INTO InsertIgnore bool InsertTable Table InsertColumns []Field // VALUES RowValues []RowValue RowAlias string // SELECT SelectQuery Query // ON CONFLICT Conflict ConflictClause // RETURNING ReturningFields []Field } var _ Query = (*InsertQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q InsertQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) (err error) { if q.ColumnMapper != nil { col := &Column{ dialect: q.Dialect, isUpdate: false, } defer mapperFunctionPanicked(&err) q.ColumnMapper(col) if err != nil { return err } q.InsertColumns, q.RowValues = col.insertColumns, col.rowValues } // WITH if len(q.CTEs) > 0 { if dialect == DialectMySQL { return fmt.Errorf("mysql does not support CTEs with INSERT") } err = writeCTEs(ctx, dialect, buf, args, params, q.CTEs) if err != nil { return fmt.Errorf("WITH: %w", err) } } // INSERT INTO if q.InsertIgnore { if dialect != DialectMySQL { return fmt.Errorf("%s does not support INSERT IGNORE", dialect) } buf.WriteString("INSERT IGNORE INTO ") } else { buf.WriteString("INSERT INTO ") } if q.InsertTable == nil { return fmt.Errorf("no table provided to INSERT") } err = q.InsertTable.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("INSERT INTO: %w", err) } if alias := getAlias(q.InsertTable); alias != "" { if dialect == DialectMySQL || dialect == DialectSQLServer { return fmt.Errorf("%s does not allow an alias for the INSERT table", dialect) } buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) } // Columns if len(q.InsertColumns) > 0 { buf.WriteString(" (") err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, q.InsertColumns, "", false) if err != nil { return fmt.Errorf("INSERT INTO: %w", err) } buf.WriteString(")") } // OUTPUT if len(q.ReturningFields) > 0 && dialect == DialectSQLServer { buf.WriteString(" OUTPUT ") for i, field := range q.ReturningFields { if i > 0 { buf.WriteString(", ") } err = WriteValue(ctx, dialect, buf, args, params, withPrefix(field, "INSERTED")) if err != nil { return err } if alias := getAlias(field); alias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) } } } // VALUES if len(q.RowValues) > 0 { buf.WriteString(" VALUES ") err = RowValues(q.RowValues).WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("VALUES: %w", err) } if q.RowAlias != "" { if dialect != DialectMySQL { return fmt.Errorf("%s does not support row aliases", dialect) } buf.WriteString(" AS " + q.RowAlias) } } else if q.SelectQuery != nil { // SELECT buf.WriteString(" ") err = q.SelectQuery.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("SELECT: %w", err) } } else { return fmt.Errorf("InsertQuery missing RowValues and SelectQuery (either one is required)") } // ON CONFLICT err = q.Conflict.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } // RETURNING if len(q.ReturningFields) > 0 && dialect != DialectSQLServer { if dialect != DialectPostgres && dialect != DialectSQLite && dialect != DialectMySQL { return fmt.Errorf("%s INSERT does not support RETURNING", dialect) } buf.WriteString(" RETURNING ") err = writeFields(ctx, dialect, buf, args, params, q.ReturningFields, true) if err != nil { return fmt.Errorf("RETURNING: %w", err) } } return nil } // InsertInto creates a new InsertQuery. func InsertInto(table Table) InsertQuery { return InsertQuery{InsertTable: table} } // Columns sets the InsertColumns field of the InsertQuery. func (q InsertQuery) Columns(fields ...Field) InsertQuery { q.InsertColumns = fields return q } // Values sets the RowValues field of the InsertQuery. func (q InsertQuery) Values(values ...any) InsertQuery { q.RowValues = append(q.RowValues, values) return q } // ColumnValues sets the ColumnMapper field of the InsertQuery. func (q InsertQuery) ColumnValues(colmapper func(*Column)) InsertQuery { q.ColumnMapper = colmapper return q } // Select sets the SelectQuery field of the InsertQuery. func (q InsertQuery) Select(query Query) InsertQuery { q.SelectQuery = query return q } // ConflictClause represents an SQL conflict clause e.g. ON CONFLICT DO // NOTHING/DO UPDATE or ON DUPLICATE KEY UPDATE. type ConflictClause struct { ConstraintName string Fields []Field Predicate Predicate DoNothing bool Resolution []Assignment ResolutionPredicate Predicate } // WriteSQL implements the SQLWriter interface. func (c ConflictClause) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error if c.ConstraintName == "" && len(c.Fields) == 0 && len(c.Resolution) == 0 && !c.DoNothing { return nil } if dialect != DialectSQLite && dialect != DialectPostgres && dialect != DialectMySQL { return nil } if dialect == DialectMySQL { if len(c.Resolution) > 0 { buf.WriteString(" ON DUPLICATE KEY UPDATE ") err = Assignments(c.Resolution).WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("ON DUPLICATE KEY UPDATE: %w", err) } } return nil } buf.WriteString(" ON CONFLICT") if c.ConstraintName != "" { buf.WriteString(" ON CONSTRAINT " + QuoteIdentifier(dialect, c.ConstraintName)) } else if len(c.Fields) > 0 { buf.WriteString(" (") err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, c.Fields, "", false) if err != nil { return fmt.Errorf("ON CONFLICT: %w", err) } buf.WriteString(")") if c.Predicate != nil { buf.WriteString(" WHERE ") switch predicate := c.Predicate.(type) { case VariadicPredicate: predicate.Toplevel = true err = predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("ON CONFLICT ... WHERE: %w", err) } default: err = c.Predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("ON CONFLICT ... WHERE: %w", err) } } } } if len(c.Resolution) == 0 || c.DoNothing { buf.WriteString(" DO NOTHING") return nil } buf.WriteString(" DO UPDATE SET ") err = Assignments(c.Resolution).WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("DO UPDATE SET: %w", err) } if c.ResolutionPredicate != nil { buf.WriteString(" WHERE ") switch predicate := c.ResolutionPredicate.(type) { case VariadicPredicate: predicate.Toplevel = true err = predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("DO UPDATE SET ... WHERE: %w", err) } default: err = c.ResolutionPredicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("DO UPDATE SET ... WHERE: %w", err) } } } return nil } // SetFetchableFields implements the Query interface. func (q InsertQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { switch q.Dialect { case DialectPostgres, DialectSQLite: if len(q.ReturningFields) == 0 { q.ReturningFields = fields return q, true } return q, false default: return q, false } } // GetFetchableFields returns the fetchable fields of the query. func (q InsertQuery) GetFetchableFields() []Field { switch q.Dialect { case DialectPostgres, DialectSQLite: return q.ReturningFields default: return nil } } // GetDialect implements the Query interface. func (q InsertQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q InsertQuery) SetDialect(dialect string) InsertQuery { q.Dialect = dialect return q } // SQLiteInsertQuery represents an SQLite INSERT query. type SQLiteInsertQuery InsertQuery var _ Query = (*SQLiteInsertQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SQLiteInsertQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return InsertQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // InsertInto creates a new SQLiteInsertQuery. func (b sqliteQueryBuilder) InsertInto(table Table) SQLiteInsertQuery { return SQLiteInsertQuery{ Dialect: DialectSQLite, CTEs: b.ctes, InsertTable: table, } } // Columns sets the InsertColumns field of the SQLiteInsertQuery. func (q SQLiteInsertQuery) Columns(fields ...Field) SQLiteInsertQuery { q.InsertColumns = fields return q } // Values sets the RowValues field of the SQLiteInsertQuery. func (q SQLiteInsertQuery) Values(values ...any) SQLiteInsertQuery { q.RowValues = append(q.RowValues, values) return q } // ColumnValues sets the ColumnMapper field of the SQLiteInsertQuery. func (q SQLiteInsertQuery) ColumnValues(colmapper func(*Column)) SQLiteInsertQuery { q.ColumnMapper = colmapper return q } // Select sets the SelectQuery field of the SQLiteInsertQuery. func (q SQLiteInsertQuery) Select(query Query) SQLiteInsertQuery { q.SelectQuery = query return q } type sqliteInsertConflict struct{ q *SQLiteInsertQuery } // OnConflict starts the ON CONFLICT clause of the SQLiteInsertQuery. func (q SQLiteInsertQuery) OnConflict(fields ...Field) sqliteInsertConflict { q.Conflict.Fields = fields return sqliteInsertConflict{q: &q} } // Where adds predicates to the ON CONFLICT clause of the SQLiteInsertQuery. func (c sqliteInsertConflict) Where(predicates ...Predicate) sqliteInsertConflict { c.q.Conflict.Predicate = appendPredicates(c.q.Conflict.Predicate, predicates) return c } // DoNothing resolves the ON CONFLICT clause of the SQLiteInsertQuery with DO // NOTHING. func (c sqliteInsertConflict) DoNothing() SQLiteInsertQuery { c.q.Conflict.DoNothing = true return *c.q } // DoUpdateSet resolves the ON CONFLICT CLAUSE of the SQLiteInsertQuery with DO UPDATE SET. func (c sqliteInsertConflict) DoUpdateSet(assignments ...Assignment) SQLiteInsertQuery { c.q.Conflict.Resolution = assignments return *c.q } // Where adds predicates to the DO UPDATE SET clause of the SQLiteInsertQuery. func (q SQLiteInsertQuery) Where(predicates ...Predicate) SQLiteInsertQuery { q.Conflict.ResolutionPredicate = appendPredicates(q.Conflict.ResolutionPredicate, predicates) return q } // Returning adds fields to the RETURNING clause of the SQLiteInsertQuery. func (q SQLiteInsertQuery) Returning(fields ...Field) SQLiteInsertQuery { q.ReturningFields = append(q.ReturningFields, fields...) return q } // SetFetchableFields implements the Query interface. func (q SQLiteInsertQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return InsertQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the query. func (q SQLiteInsertQuery) GetFetchableFields() []Field { return InsertQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q SQLiteInsertQuery) GetDialect() string { return q.Dialect } // SetDialect returns the dialect of the query. func (q SQLiteInsertQuery) SetDialect(dialect string) SQLiteInsertQuery { q.Dialect = dialect return q } // PostgresInsertQuery represents a Postgres INSERT query. type PostgresInsertQuery InsertQuery var _ Query = (*PostgresInsertQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q PostgresInsertQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return InsertQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // InsertInto creates a new PostgresInsertQuery. func (b postgresQueryBuilder) InsertInto(table Table) PostgresInsertQuery { return PostgresInsertQuery{ Dialect: DialectPostgres, CTEs: b.ctes, InsertTable: table, } } // Columns sets the InsertColumns field of the PostgresInsertQuery. func (q PostgresInsertQuery) Columns(fields ...Field) PostgresInsertQuery { q.InsertColumns = fields return q } // Values sets the RowValues field of the PostgresInsertQuery. func (q PostgresInsertQuery) Values(values ...any) PostgresInsertQuery { q.RowValues = append(q.RowValues, values) return q } // ColumnValues sets the ColumnMapper field of the PostgresInsertQuery. func (q PostgresInsertQuery) ColumnValues(colmapper func(*Column)) PostgresInsertQuery { q.ColumnMapper = colmapper return q } // Select sets the SelectQuery field of the PostgresInsertQuery. func (q PostgresInsertQuery) Select(query Query) PostgresInsertQuery { q.SelectQuery = query return q } type postgresInsertConflict struct{ q *PostgresInsertQuery } // OnConflict starts the ON CONFLICT clause of the PostgresInsertQuery. func (q PostgresInsertQuery) OnConflict(fields ...Field) postgresInsertConflict { q.Conflict.Fields = fields return postgresInsertConflict{q: &q} } // OnConflict starts the ON CONFLICT clause of the PostgresInsertQuery. func (q PostgresInsertQuery) OnConflictOnConstraint(constraintName string) postgresInsertConflict { q.Conflict.ConstraintName = constraintName return postgresInsertConflict{q: &q} } // Where adds predicates to the ON CONFLICT clause of the PostgresInsertQuery. func (c postgresInsertConflict) Where(predicates ...Predicate) postgresInsertConflict { c.q.Conflict.Predicate = appendPredicates(c.q.Conflict.Predicate, predicates) return c } // DoNothing resolves the ON CONFLICT clause of the PostgresInsertQuery with DO // NOTHING. func (c postgresInsertConflict) DoNothing() PostgresInsertQuery { c.q.Conflict.DoNothing = true return *c.q } // DoUpdateSet resolves the ON CONFLICT CLAUSE of the PostgresInsertQuery with DO UPDATE SET. func (c postgresInsertConflict) DoUpdateSet(assignments ...Assignment) PostgresInsertQuery { c.q.Conflict.Resolution = assignments return *c.q } // Where adds predicates to the DO UPDATE SET clause of the PostgresInsertQuery. func (q PostgresInsertQuery) Where(predicates ...Predicate) PostgresInsertQuery { q.Conflict.ResolutionPredicate = appendPredicates(q.Conflict.ResolutionPredicate, predicates) return q } // Returning adds fields to the RETURNING clause of the PostgresInsertQuery. func (q PostgresInsertQuery) Returning(fields ...Field) PostgresInsertQuery { q.ReturningFields = append(q.ReturningFields, fields...) return q } // SetFetchableFields implements the Query interface. func (q PostgresInsertQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return InsertQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the query. func (q PostgresInsertQuery) GetFetchableFields() []Field { return InsertQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q PostgresInsertQuery) GetDialect() string { return q.Dialect } // SetDialect returns the dialect of the query. func (q PostgresInsertQuery) SetDialect(dialect string) PostgresInsertQuery { q.Dialect = dialect return q } // MySQLInsertQuery represents a MySQL INSERT query. type MySQLInsertQuery InsertQuery var _ Query = (*MySQLInsertQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q MySQLInsertQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return InsertQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // InsertInto creates a new MySQLInsertQuery. func (b mysqlQueryBuilder) InsertInto(table Table) MySQLInsertQuery { return MySQLInsertQuery{ Dialect: DialectMySQL, CTEs: b.ctes, InsertTable: table, } } // InsertInto creates a new MySQLInsertQuery. func (b mysqlQueryBuilder) InsertIgnoreInto(table Table) MySQLInsertQuery { return MySQLInsertQuery{ Dialect: DialectMySQL, CTEs: b.ctes, InsertTable: table, InsertIgnore: true, } } // Columns sets the InsertColumns field of the MySQLInsertQuery. func (q MySQLInsertQuery) Columns(fields ...Field) MySQLInsertQuery { q.InsertColumns = fields return q } // Values sets the RowValues field of the MySQLInsertQuery. func (q MySQLInsertQuery) Values(values ...any) MySQLInsertQuery { q.RowValues = append(q.RowValues, values) return q } // As sets the RowAlias field of the MySQLInsertQuery. func (q MySQLInsertQuery) As(rowAlias string) MySQLInsertQuery { q.RowAlias = rowAlias return q } // ColumnValues sets the ColumnMapper field of the MySQLInsertQuery. func (q MySQLInsertQuery) ColumnValues(colmapper func(*Column)) MySQLInsertQuery { q.ColumnMapper = colmapper return q } // Select sets the SelectQuery field of the MySQLInsertQuery. func (q MySQLInsertQuery) Select(query Query) MySQLInsertQuery { q.SelectQuery = query return q } // OnDuplicateKeyUpdate sets the ON DUPLICATE KEY UPDATE clause of the // MySQLInsertQuery. func (q MySQLInsertQuery) OnDuplicateKeyUpdate(assignments ...Assignment) MySQLInsertQuery { q.Conflict.Resolution = assignments return q } // Returning adds fields to the RETURNING clause of the MySQLInsertQuery. func (q MySQLInsertQuery) Returning(fields ...Field) MySQLInsertQuery { q.ReturningFields = append(q.ReturningFields, fields...) return q } // SetFetchableFields implements the Query interface. func (q MySQLInsertQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return InsertQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the query. func (q MySQLInsertQuery) GetFetchableFields() []Field { return InsertQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q MySQLInsertQuery) GetDialect() string { return q.Dialect } // SetDialect returns the dialect of the query. func (q MySQLInsertQuery) SetDialect(dialect string) MySQLInsertQuery { q.Dialect = dialect return q } // SQLServerInsertQuery represents an SQL Server INSERT query. type SQLServerInsertQuery InsertQuery var _ Query = (*SQLServerInsertQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SQLServerInsertQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return InsertQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // InsertInto creates a new SQLServerInsertQuery. func (b sqlserverQueryBuilder) InsertInto(table Table) SQLServerInsertQuery { return SQLServerInsertQuery{ Dialect: DialectSQLServer, CTEs: b.ctes, InsertTable: table, } } // Columns sets the InsertColumns field of the SQLServerInsertQuery. func (q SQLServerInsertQuery) Columns(fields ...Field) SQLServerInsertQuery { q.InsertColumns = fields return q } // Values sets the RowValues field of the SQLServerInsertQuery. func (q SQLServerInsertQuery) Values(values ...any) SQLServerInsertQuery { q.RowValues = append(q.RowValues, values) return q } // ColumnValues sets the ColumnMapper field of the SQLServerInsertQuery. func (q SQLServerInsertQuery) ColumnValues(colmapper func(*Column)) SQLServerInsertQuery { q.ColumnMapper = colmapper return q } // Select sets the SelectQuery field of the SQLServerInsertQuery. func (q SQLServerInsertQuery) Select(query Query) SQLServerInsertQuery { q.SelectQuery = query return q } // SetFetchableFields implements the Query interface. func (q SQLServerInsertQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return InsertQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the query. func (q SQLServerInsertQuery) GetFetchableFields() []Field { return InsertQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q SQLServerInsertQuery) GetDialect() string { return q.Dialect } // SetDialect returns the dialect of the query. func (q SQLServerInsertQuery) SetDialect(dialect string) SQLServerInsertQuery { q.Dialect = dialect return q } ================================================ FILE: insert_query_test.go ================================================ package sq import ( "testing" "github.com/bokwoon95/sq/internal/testutil" ) func TestSQLiteInsertQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SQLite.InsertInto(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectSQLite) fields := q1.GetFetchableFields() if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("Columns Values", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland") tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " VALUES ($1, $2), ($3, $4)" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("ColumnValues", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). ColumnValues(func(col *Column) { // bob col.SetString(a.FIRST_NAME, "bob") col.SetString(a.LAST_NAME, "the builder") // alice col.SetString(a.FIRST_NAME, "alice") col.SetString(a.LAST_NAME, "in wonderland") }) tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " VALUES ($1, $2), ($3, $4)" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("Insert Returning", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Select(SQLite.Select(a.FIRST_NAME, a.LAST_NAME).From(a)). Returning(a.ACTOR_ID) tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " SELECT a.first_name, a.last_name FROM actor AS a" + " RETURNING a.actor_id" tt.assert(t) }) t.Run("OnConflict DoNothing", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland"). OnConflict(a.FIRST_NAME, a.LAST_NAME). Where(And(a.FIRST_NAME.IsNotNull(), a.LAST_NAME.IsNotNull())). DoNothing() tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " VALUES ($1, $2), ($3, $4)" + " ON CONFLICT (first_name, last_name)" + " WHERE a.first_name IS NOT NULL AND a.last_name IS NOT NULL" + " DO NOTHING" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("OnConflict DoUpdateSet", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland"). OnConflict(a.FIRST_NAME, a.LAST_NAME). DoUpdateSet( a.FIRST_NAME.Set(a.FIRST_NAME.WithPrefix("EXCLUDED")), a.LAST_NAME.Set(a.LAST_NAME.WithPrefix("EXCLUDED")), ). Where(And(a.FIRST_NAME.IsNotNull(), a.LAST_NAME.IsNotNull())) tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " VALUES ($1, $2), ($3, $4)" + " ON CONFLICT (first_name, last_name)" + " DO UPDATE SET first_name = EXCLUDED.first_name, last_name = EXCLUDED.last_name" + " WHERE a.first_name IS NOT NULL AND a.last_name IS NOT NULL" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) } func TestPostgresInsertQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := Postgres.InsertInto(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectPostgres) fields := q1.GetFetchableFields() if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("Columns Values", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland") tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " VALUES ($1, $2), ($3, $4)" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("ColumnValues", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). ColumnValues(func(col *Column) { // bob col.SetString(a.FIRST_NAME, "bob") col.SetString(a.LAST_NAME, "the builder") // alice col.SetString(a.FIRST_NAME, "alice") col.SetString(a.LAST_NAME, "in wonderland") }) tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " VALUES ($1, $2), ($3, $4)" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("Insert Returning", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Select(Postgres.Select(a.FIRST_NAME, a.LAST_NAME).From(a)). Returning(a.ACTOR_ID) tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " SELECT a.first_name, a.last_name FROM actor AS a" + " RETURNING a.actor_id" tt.assert(t) }) t.Run("OnConflict DoNothing", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland"). OnConflict(a.FIRST_NAME, a.LAST_NAME). Where(And(a.FIRST_NAME.IsNotNull(), a.LAST_NAME.IsNotNull())). DoNothing() tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " VALUES ($1, $2), ($3, $4)" + " ON CONFLICT (first_name, last_name)" + " WHERE a.first_name IS NOT NULL AND a.last_name IS NOT NULL" + " DO NOTHING" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("OnConflictOnConstraint DoUpdateSet", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland"). OnConflictOnConstraint("actor_first_name_last_name_key"). DoUpdateSet( a.FIRST_NAME.Set(a.FIRST_NAME.WithPrefix("EXCLUDED")), a.LAST_NAME.Set(a.LAST_NAME.WithPrefix("EXCLUDED")), ). Where(And(a.FIRST_NAME.IsNotNull(), a.LAST_NAME.IsNotNull())) tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor AS a (first_name, last_name)" + " VALUES ($1, $2), ($3, $4)" + " ON CONFLICT ON CONSTRAINT actor_first_name_last_name_key" + " DO UPDATE SET first_name = EXCLUDED.first_name, last_name = EXCLUDED.last_name" + " WHERE a.first_name IS NOT NULL AND a.last_name IS NOT NULL" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) } func TestMySQLInsertQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := MySQL.InsertInto(a).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectMySQL) fields := q1.GetFetchableFields() if len(fields) != 0 { t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } }) t.Run("Columns Values", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland") tt.wantQuery = "INSERT INTO actor (first_name, last_name)" + " VALUES (?, ?), (?, ?)" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("ColumnValues", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. InsertInto(a). ColumnValues(func(col *Column) { // bob col.SetString(a.FIRST_NAME, "bob") col.SetString(a.LAST_NAME, "the builder") // alice col.SetString(a.FIRST_NAME, "alice") col.SetString(a.LAST_NAME, "in wonderland") }) tt.wantQuery = "INSERT INTO actor (first_name, last_name)" + " VALUES (?, ?), (?, ?)" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("Insert Returning", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Select(SQLite.Select(a.FIRST_NAME, a.LAST_NAME).From(a)). Returning(a.ACTOR_ID) tt.wantQuery = "INSERT INTO actor (first_name, last_name)" + " SELECT actor.first_name, actor.last_name FROM actor" + " RETURNING actor.actor_id" tt.assert(t) }) t.Run("Select InsertIgnore", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. InsertIgnoreInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Select(MySQL.Select(a.FIRST_NAME, a.LAST_NAME).From(a)) tt.wantQuery = "INSERT IGNORE INTO actor (first_name, last_name)" + " SELECT actor.first_name, actor.last_name FROM actor" tt.assert(t) }) t.Run("OnDuplicateKey DoNothing", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland"). OnDuplicateKeyUpdate(a.FIRST_NAME.Set(a.FIRST_NAME), a.LAST_NAME.Set(a.LAST_NAME)) tt.wantQuery = "INSERT INTO actor (first_name, last_name)" + " VALUES (?, ?), (?, ?)" + " ON DUPLICATE KEY UPDATE actor.first_name = actor.first_name, actor.last_name = actor.last_name" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("OnDuplicateKey", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland"). As("new"). OnDuplicateKeyUpdate( a.FIRST_NAME.Set(a.FIRST_NAME.WithPrefix("new")), a.LAST_NAME.Set(a.LAST_NAME.WithPrefix("new")), ) tt.wantQuery = "INSERT INTO actor (first_name, last_name)" + " VALUES (?, ?), (?, ?) AS new" + " ON DUPLICATE KEY UPDATE actor.first_name = new.first_name, actor.last_name = new.last_name" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) } func TestSQLServerInsertQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SQLServer.InsertInto(a).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectSQLServer) fields := q1.GetFetchableFields() if len(fields) != 0 { t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } }) t.Run("Columns Values", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("bob", "the builder"). Values("alice", "in wonderland") tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor (first_name, last_name)" + " VALUES (@p1, @p2), (@p3, @p4)" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("ColumnValues", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). ColumnValues(func(col *Column) { // bob col.SetString(a.FIRST_NAME, "bob") col.SetString(a.LAST_NAME, "the builder") // alice col.SetString(a.FIRST_NAME, "alice") col.SetString(a.LAST_NAME, "in wonderland") }) tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor (first_name, last_name)" + " VALUES (@p1, @p2), (@p3, @p4)" tt.wantArgs = []any{"bob", "the builder", "alice", "in wonderland"} tt.assert(t) }) t.Run("Select", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. With(NewCTE("cte", nil, Queryf("SELECT 1"))). InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Select(SQLServer.Select(a.FIRST_NAME, a.LAST_NAME).From(a)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " INSERT INTO actor (first_name, last_name)" + " SELECT actor.first_name, actor.last_name FROM actor" tt.assert(t) }) } func TestInsertQuery(t *testing.T) { t.Run("basic", func(t *testing.T) { t.Parallel() q1 := InsertQuery{InsertTable: Expr("tbl"), Dialect: "lorem ipsum"} if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } }) f1, f2, f3 := Expr("f1"), Expr("f2"), Expr("f3") colmapper := func(col *Column) { col.Set(f1, 1) col.Set(f2, 2) col.Set(f3, 3) } notOKTests := []TestTable{{ description: "mysql does not support CTEs with INSERT", item: InsertQuery{ Dialect: DialectMySQL, CTEs: []CTE{NewCTE("cte", nil, Queryf("SELECT 1"))}, InsertTable: Expr("tbl"), ColumnMapper: colmapper, }, }, { description: "dialect does not support INSERT IGNORE", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), InsertIgnore: true, ColumnMapper: colmapper, }, }, { description: "nil IntoTable not allowed", item: InsertQuery{ InsertTable: nil, ColumnMapper: colmapper, }, }, { description: "dialect does not support IntoTable alias", item: InsertQuery{ Dialect: DialectMySQL, InsertTable: Expr("tbl").As("t"), ColumnMapper: colmapper, }, }, { description: "nil Field in InsertColumns not allowed", item: InsertQuery{ Dialect: DialectMySQL, InsertTable: Expr("tbl"), InsertColumns: Fields{nil}, RowValues: RowValues{{1}}, }, }, { description: "dialect does not support row alias", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), ColumnMapper: colmapper, RowAlias: "new", }, }, { description: "missing both Values and Select not allowed (either one is required)", item: InsertQuery{ InsertTable: Expr("tbl"), }, }, { description: "missing both Values and Select not allowed (either one is required)", item: InsertQuery{ InsertTable: Expr("tbl"), }, }, { description: "nil Field in ConflictFields not allowed", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), ColumnMapper: colmapper, Conflict: ConflictClause{Fields: Fields{nil}}, }, }} for _, tt := range notOKTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertNotOK(t) }) } errTests := []TestTable{{ description: "ColumnMapper err", item: InsertQuery{ InsertTable: Expr("tbl"), ColumnMapper: func(*Column) { panic(ErrFaultySQL) }, }, }, { description: "CTEs err", item: InsertQuery{ CTEs: []CTE{NewCTE("cte", nil, FaultySQL{})}, InsertTable: Expr("tbl"), ColumnMapper: colmapper, }, }, { description: "IntoTable err", item: InsertQuery{ InsertTable: FaultySQL{}, ColumnMapper: colmapper, }, }, { description: "RowValues err", item: InsertQuery{ InsertTable: Expr("tbl"), InsertColumns: Fields{f1, f2, f3}, RowValues: RowValues{{1, 2, FaultySQL{}}}, }, }, { description: "SelectQuery err", item: InsertQuery{ InsertTable: Expr("tbl"), InsertColumns: Fields{f1, f2, f3}, SelectQuery: Queryf("SELECT 1, 2, {}", FaultySQL{}), }, }, { description: "ConflictPredicate VariadicPredicate err", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), ColumnMapper: colmapper, Conflict: ConflictClause{ Fields: Fields{f1, f2}, Predicate: And(FaultySQL{}), }, }, }, { description: "ConflictPredicate err", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), ColumnMapper: colmapper, Conflict: ConflictClause{ Fields: Fields{f1, f2}, Predicate: FaultySQL{}, }, }, }, { description: "Resolution err", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), ColumnMapper: colmapper, Conflict: ConflictClause{ Fields: Fields{f1, f2}, Resolution: Assignments{FaultySQL{}}, }, }, }, { description: "ResolutionPredicate VariadicPredicate err", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), ColumnMapper: colmapper, Conflict: ConflictClause{ Fields: Fields{f1, f2}, Resolution: Assignments{FaultySQL{}}, ResolutionPredicate: And(FaultySQL{}), }, }, }, { description: "ResolutionPredicate err", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), ColumnMapper: colmapper, Conflict: ConflictClause{ Fields: Fields{f1, f2}, Resolution: Assignments{Set(f1, f1)}, ResolutionPredicate: FaultySQL{}, }, }, }, { description: "mysql Resolution err", item: InsertQuery{ Dialect: DialectMySQL, InsertTable: Expr("tbl"), ColumnMapper: colmapper, Conflict: ConflictClause{ Resolution: Assignments{Set(f1, FaultySQL{})}, }, }, }, { description: "ReturningFields err", item: InsertQuery{ Dialect: DialectPostgres, InsertTable: Expr("tbl"), ColumnMapper: colmapper, ReturningFields: Fields{FaultySQL{}}, }, }} for _, tt := range errTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertErr(t, ErrFaultySQL) }) } } ================================================ FILE: integration_test.go ================================================ package sq import ( "database/sql" "fmt" "net/url" "strings" "testing" "time" "github.com/bokwoon95/sq/internal/testutil" _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" "github.com/google/uuid" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) type Color int const ( ColorInvalid Color = iota ColorRed ColorGreen ColorBlue ) var colorNames = [...]string{ ColorInvalid: "", ColorRed: "red", ColorGreen: "green", ColorBlue: "blue", } func (c Color) Enumerate() []string { return colorNames[:] } type Direction string const ( DirectionInvalid = Direction("") DirectionNorth = Direction("north") DirectionSouth = Direction("south") DirectionEast = Direction("east") DirectionWest = Direction("west") ) func (d Direction) Enumerate() []string { return []string{ string(DirectionInvalid), string(DirectionNorth), string(DirectionSouth), string(DirectionEast), string(DirectionWest), } } func TestRow(t *testing.T) { type TestTable struct { dialect string driver string dsn string teardown string setup string } tests := []TestTable{{ dialect: DialectSQLite, driver: "sqlite3", dsn: "file:/TestRow/sqlite?vfs=memdb&_foreign_keys=true", teardown: "DROP TABLE IF EXISTS table00;", setup: "CREATE TABLE table00 (" + "\n uuid UUID" + "\n ,data JSON" + "\n ,color TEXT" + "\n ,direction TEXT" + "\n ,weekday TEXT" + "\n ,text_array JSON" + "\n ,int_array JSON" + "\n ,int64_array JSON" + "\n ,int32_array JSON" + "\n ,float64_array JSON" + "\n ,float32_array JSON" + "\n ,bool_array JSON" + "\n ,bytes BLOB" + "\n ,is_active BOOLEAN" + "\n ,price REAL" + "\n ,score BIGINT" + "\n ,name TEXT" + "\n ,updated_at DATETIME" + "\n);", }, { dialect: DialectPostgres, driver: "postgres", dsn: *postgresDSN, teardown: "DROP TABLE IF EXISTS table00;" + "\nDROP TYPE IF EXISTS direction;" + "\nDROP TYPE IF EXISTS color;" + "\nDROP TYPE IF EXISTS weekday;", setup: "CREATE TYPE color AS ENUM ('red', 'green', 'blue');" + "\nCREATE TYPE direction AS ENUM ('north', 'south', 'east', 'west');" + "\nCREATE TYPE weekday AS ENUM ('Sunday', 'Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday');" + "\nCREATE TABLE table00 (" + "\n uuid UUID" + "\n ,data JSONB" + "\n ,color color" + "\n ,direction direction" + "\n ,weekday weekday" + "\n ,text_array TEXT[]" + "\n ,int_array BIGINT[]" + "\n ,int64_array BIGINT[]" + "\n ,int32_array INT[]" + "\n ,float64_array DOUBLE PRECISION[]" + "\n ,float32_array REAL[]" + "\n ,bool_array BOOLEAN[]" + "\n ,is_active BOOLEAN" + "\n ,bytes BYTEA" + "\n ,price DOUBLE PRECISION" + "\n ,score BIGINT" + "\n ,name TEXT" + "\n ,updated_at TIMESTAMPTZ" + "\n);", }, { dialect: DialectMySQL, driver: "mysql", dsn: *mysqlDSN, teardown: "DROP TABLE IF EXISTS table00;", setup: "CREATE TABLE table00 (" + "\n uuid BINARY(16)" + "\n ,data JSON" + "\n ,color VARCHAR(255)" + "\n ,direction VARCHAR(255)" + "\n ,weekday VARCHAR(255)" + "\n ,text_array JSON" + "\n ,int_array JSON" + "\n ,int64_array JSON" + "\n ,int32_array JSON" + "\n ,float64_array JSON" + "\n ,float32_array JSON" + "\n ,bool_array JSON" + "\n ,is_active BOOLEAN" + "\n ,bytes LONGBLOB" + "\n ,price DOUBLE PRECISION" + "\n ,score BIGINT" + "\n ,name TEXT" + "\n ,updated_at DATETIME" + "\n);", }, { dialect: DialectSQLServer, driver: "sqlserver", dsn: *sqlserverDSN, teardown: "DROP TABLE IF EXISTS table00;", setup: "CREATE TABLE table00 (" + "\n uuid BINARY(16)" + "\n ,data NVARCHAR(MAX)" + "\n ,color NVARCHAR(255)" + "\n ,direction NVARCHAR(255)" + "\n ,weekday NVARCHAR(255)" + "\n ,text_array NVARCHAR(MAX)" + "\n ,int_array NVARCHAR(MAX)" + "\n ,int64_array NVARCHAR(MAX)" + "\n ,int32_array NVARCHAR(MAX)" + "\n ,float64_array NVARCHAR(MAX)" + "\n ,float32_array NVARCHAR(MAX)" + "\n ,bool_array NVARCHAR(MAX)" + "\n ,is_active BIT" + "\n ,bytes VARBINARY(MAX)" + "\n ,price DOUBLE PRECISION" + "\n ,score BIGINT" + "\n ,name NVARCHAR(255)" + "\n ,updated_at DATETIME" + "\n);", }} var TABLE00 = New[struct { TableStruct `sq:"table00"` UUID UUIDField DATA JSONField COLOR EnumField DIRECTION EnumField WEEKDAY EnumField TEXT_ARRAY ArrayField INT_ARRAY ArrayField INT64_ARRAY ArrayField INT32_ARRAY ArrayField FLOAT64_ARRAY ArrayField FLOAT32_ARRAY ArrayField BOOL_ARRAY ArrayField BYTES BinaryField IS_ACTIVE BooleanField PRICE NumberField SCORE NumberField NAME StringField UPDATED_AT TimeField }]("") type Table00 struct { uuid uuid.UUID data any color Color direction Direction weekday Weekday textArray []string intArray []int int64Array []int64 int32Array []int32 float64Array []float64 float32Array []float32 boolArray []bool bytes []byte isActive bool price float64 score int64 name string updatedAt time.Time } var table00Values = []Table00{{ uuid: uuid.UUID([16]byte{15: 1}), data: map[string]any{"lorem ipsum": "dolor sit amet"}, color: ColorRed, direction: DirectionNorth, weekday: Monday, textArray: []string{"one", "two", "three"}, intArray: []int{1, 2, 3}, int64Array: []int64{1, 2, 3}, int32Array: []int32{1, 2, 3}, float64Array: []float64{1, 2, 3}, float32Array: []float32{1, 2, 3}, boolArray: []bool{true, false, false}, bytes: []byte{1, 2, 3}, isActive: true, price: 123, score: 123, name: "one two three", updatedAt: time.Unix(123, 0).UTC(), }, { uuid: uuid.UUID([16]byte{15: 2}), data: map[string]any{"lorem ipsum": "dolor sit amet"}, color: ColorGreen, direction: DirectionSouth, weekday: Tuesday, textArray: []string{"four", "five", "six"}, intArray: []int{4, 5, 6}, int64Array: []int64{4, 5, 6}, int32Array: []int32{4, 5, 6}, float64Array: []float64{4, 5, 6}, float32Array: []float32{4, 5, 6}, boolArray: []bool{false, true, false}, bytes: []byte{4, 5, 6}, isActive: true, price: 456, score: 456, name: "four five six", updatedAt: time.Unix(456, 0).UTC(), }, { uuid: uuid.UUID([16]byte{15: 3}), data: map[string]any{"lorem ipsum": "dolor sit amet"}, color: ColorBlue, direction: DirectionEast, weekday: Wednesday, textArray: []string{"seven", "eight", "nine"}, intArray: []int{7, 8, 9}, float64Array: []float64{7, 8, 9}, boolArray: []bool{false, false, true}, bytes: []byte{7, 8, 9}, isActive: true, price: 789, score: 789, name: "seven eight nine", updatedAt: time.Unix(789, 0).UTC(), }} for _, tt := range tests { tt := tt t.Run(tt.dialect, func(t *testing.T) { if tt.dsn == "" { return } t.Parallel() dsn := preprocessDSN(tt.dialect, tt.dsn) db, err := sql.Open(tt.driver, dsn) if err != nil { t.Fatal(testutil.Callers(), err) } _, err = db.Exec(tt.teardown) if err != nil { t.Fatal(testutil.Callers(), err) } _, err = db.Exec(tt.setup) if err != nil { t.Fatal(testutil.Callers(), err) } defer func() { db.Exec(tt.teardown) }() // Insert the data. result, err := Exec(Log(db), InsertInto(TABLE00). ColumnValues(func(col *Column) { for _, value := range table00Values { col.SetUUID(TABLE00.UUID, value.uuid) col.SetJSON(TABLE00.DATA, value.data) col.SetEnum(TABLE00.COLOR, value.color) col.SetEnum(TABLE00.DIRECTION, value.direction) col.SetEnum(TABLE00.WEEKDAY, value.weekday) col.SetArray(TABLE00.TEXT_ARRAY, value.textArray) col.SetArray(TABLE00.INT_ARRAY, value.intArray) col.SetArray(TABLE00.INT64_ARRAY, value.int64Array) col.SetArray(TABLE00.INT32_ARRAY, value.int32Array) col.SetArray(TABLE00.FLOAT64_ARRAY, value.float64Array) col.SetArray(TABLE00.FLOAT32_ARRAY, value.float32Array) col.SetArray(TABLE00.BOOL_ARRAY, value.boolArray) col.SetBytes(TABLE00.BYTES, value.bytes) col.SetBool(TABLE00.IS_ACTIVE, value.isActive) col.SetFloat64(TABLE00.PRICE, value.price) col.SetInt64(TABLE00.SCORE, value.score) col.SetString(TABLE00.NAME, value.name) col.SetTime(TABLE00.UPDATED_AT, value.updatedAt) } }). SetDialect(tt.dialect), ) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(result.RowsAffected, int64(len(table00Values))); diff != "" { t.Error(testutil.Callers(), diff) } // Fetch the data. values, err := FetchAll(VerboseLog(db), From(TABLE00). OrderBy(TABLE00.UUID). SetDialect(tt.dialect), func(row *Row) Table00 { var value Table00 row.UUIDField(&value.uuid, TABLE00.UUID) row.JSONField(&value.data, TABLE00.DATA) row.EnumField(&value.color, TABLE00.COLOR) row.EnumField(&value.direction, TABLE00.DIRECTION) row.EnumField(&value.weekday, TABLE00.WEEKDAY) row.ArrayField(&value.textArray, TABLE00.TEXT_ARRAY) row.ArrayField(&value.intArray, TABLE00.INT_ARRAY) row.ArrayField(&value.int64Array, TABLE00.INT64_ARRAY) row.ArrayField(&value.int32Array, TABLE00.INT32_ARRAY) row.ArrayField(&value.float64Array, TABLE00.FLOAT64_ARRAY) row.ArrayField(&value.float32Array, TABLE00.FLOAT32_ARRAY) row.ArrayField(&value.boolArray, TABLE00.BOOL_ARRAY) value.bytes = row.BytesField(TABLE00.BYTES) value.isActive = row.BoolField(TABLE00.IS_ACTIVE) value.price = row.Float64Field(TABLE00.PRICE) value.score = row.Int64Field(TABLE00.SCORE) value.name = row.StringField(TABLE00.NAME) value.updatedAt = row.TimeField(TABLE00.UPDATED_AT) // make sure Columns, ColumnTypes and Values are all // callable inside the rowmapper even for dynamic queries. fmt.Println(row.Columns()) fmt.Println(row.ColumnTypes()) fmt.Println(row.Values()) return value }, ) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(values, table00Values); diff != "" { t.Error(testutil.Callers(), diff) } exists, err := FetchExists(Log(db), SelectOne(). From(TABLE00). Where(TABLE00.UUID.EqUUID(table00Values[0].uuid)). SetDialect(tt.dialect), ) if err != nil { t.Fatal(testutil.Callers(), err) } if !exists { t.Errorf(testutil.Callers()+" expected row with uuid = %q to exist, got false", table00Values[0].uuid.String()) } // SQLServer driver *still* doesn't support NULL UUIDs 🙄, skip // NULL testing for SQL Server. // https://github.com/denisenkom/go-mssqldb/issues/196 if tt.dialect == "sqlserver" { return } // Insert NULLs. _, err = Exec(Log(db), InsertInto(TABLE00). ColumnValues(func(col *Column) { col.Set(TABLE00.UUID, nil) col.Set(TABLE00.DATA, nil) col.Set(TABLE00.COLOR, nil) col.Set(TABLE00.DIRECTION, nil) col.Set(TABLE00.WEEKDAY, nil) col.Set(TABLE00.TEXT_ARRAY, nil) col.Set(TABLE00.INT_ARRAY, nil) col.Set(TABLE00.INT64_ARRAY, nil) col.Set(TABLE00.INT32_ARRAY, nil) col.Set(TABLE00.FLOAT64_ARRAY, nil) col.Set(TABLE00.FLOAT32_ARRAY, nil) col.Set(TABLE00.BOOL_ARRAY, nil) col.Set(TABLE00.BYTES, nil) col.Set(TABLE00.IS_ACTIVE, nil) col.Set(TABLE00.PRICE, nil) col.Set(TABLE00.SCORE, nil) col.Set(TABLE00.NAME, nil) col.Set(TABLE00.UPDATED_AT, nil) }). SetDialect(tt.dialect), ) if err != nil { t.Fatal(testutil.Callers(), err) } // Fetch NULLs. _, err = FetchAll(VerboseLog(db), From(TABLE00). Where(TABLE00.UUID.IsNull()). OrderBy(TABLE00.UUID). SetDialect(tt.dialect), func(row *Row) Table00 { var value Table00 row.UUIDField(&value.uuid, TABLE00.UUID) row.JSONField(&value.data, TABLE00.DATA) row.EnumField(&value.color, TABLE00.COLOR) row.EnumField(&value.direction, TABLE00.DIRECTION) row.EnumField(&value.weekday, TABLE00.WEEKDAY) row.ArrayField(&value.textArray, TABLE00.TEXT_ARRAY) row.ArrayField(&value.intArray, TABLE00.INT_ARRAY) row.ArrayField(&value.int64Array, TABLE00.INT64_ARRAY) row.ArrayField(&value.int32Array, TABLE00.INT32_ARRAY) row.ArrayField(&value.float64Array, TABLE00.FLOAT64_ARRAY) row.ArrayField(&value.float32Array, TABLE00.FLOAT32_ARRAY) row.ArrayField(&value.boolArray, TABLE00.BOOL_ARRAY) value.bytes = row.BytesField(TABLE00.BYTES) value.isActive = row.BoolField(TABLE00.IS_ACTIVE) value.price = row.Float64Field(TABLE00.PRICE) value.score = row.Int64Field(TABLE00.SCORE) value.name = row.StringField(TABLE00.NAME) value.updatedAt = row.TimeField(TABLE00.UPDATED_AT) return value }, ) if err != nil { t.Fatal(testutil.Callers(), err) } }) } } func TestRowScan(t *testing.T) { table01Values := [][]any{ {nil, nil, nil, nil, nil, nil}, {123, int64(123), float64(123), "abc", true, time.Unix(123, 0).UTC()}, {456, int64(456), float64(456), "def", true, time.Unix(456, 0).UTC()}, {789, int64(789), float64(789), "ghi", true, time.Unix(789, 0).UTC()}, } type TestTable struct { dialect string driver string dsn string teardown string setup string } tests := []TestTable{{ dialect: DialectSQLite, driver: "sqlite3", dsn: "file:/TestRowScan/sqlite?vfs=memdb&_foreign_keys=true", teardown: "DROP TABLE IF EXISTS table01;", setup: "CREATE TABLE table01 (" + "\n id INT" + "\n ,score BIGINT" + "\n ,price REAL" + "\n ,name TEXT" + "\n ,is_active BOOLEAN" + "\n ,updated_at DATETIME" + "\n);", }, { dialect: DialectPostgres, driver: "postgres", dsn: *postgresDSN, teardown: "DROP TABLE IF EXISTS table01;", setup: "CREATE TABLE table01 (" + "\n id INT" + "\n ,score BIGINT" + "\n ,price DOUBLE PRECISION" + "\n ,name TEXT" + "\n ,is_active BOOLEAN" + "\n ,updated_at TIMESTAMPTZ" + "\n);", }, { dialect: DialectMySQL, driver: "mysql", dsn: *mysqlDSN, teardown: "DROP TABLE IF EXISTS table01;", setup: "CREATE TABLE table01 (" + "\n id INT" + "\n ,score BIGINT" + "\n ,price DOUBLE PRECISION" + "\n ,name VARCHAR(255)" + "\n ,is_active BOOLEAN" + "\n ,updated_at DATETIME" + "\n);", }, { dialect: DialectSQLServer, driver: "sqlserver", dsn: *sqlserverDSN, teardown: "DROP TABLE IF EXISTS table01;", setup: "CREATE TABLE table01 (" + "\n id INT" + "\n ,score BIGINT" + "\n ,price DOUBLE PRECISION" + "\n ,name NVARCHAR(255)" + "\n ,is_active BIT" + "\n ,updated_at DATETIME2" + "\n);", }} for _, tt := range tests { tt := tt t.Run(tt.dialect, func(t *testing.T) { if tt.dsn == "" { return } t.Parallel() dsn := preprocessDSN(tt.dialect, tt.dsn) db, err := sql.Open(tt.driver, dsn) if err != nil { t.Fatal(testutil.Callers(), err) } _, err = db.Exec(tt.teardown) if err != nil { t.Fatal(testutil.Callers(), err) } _, err = db.Exec(tt.setup) if err != nil { t.Fatal(testutil.Callers(), err) } defer func() { db.Exec(tt.teardown) }() // Insert values. result, err := Exec(Log(db), InsertQuery{ Dialect: tt.dialect, InsertTable: Expr("table01"), ColumnMapper: func(col *Column) { for _, value := range table01Values { col.Set(Expr("id"), value[0]) col.Set(Expr("score"), value[1]) col.Set(Expr("price"), value[2]) col.Set(Expr("name"), value[3]) col.Set(Expr("is_active"), value[4]) col.Set(Expr("updated_at"), value[5]) } }, }) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(result.RowsAffected, int64(len(table01Values))); diff != "" { t.Error(testutil.Callers(), diff) } t.Run("dynamic SQL query", func(t *testing.T) { gotValues, err := FetchAll(db, Queryf("SELECT {*} FROM table01 WHERE id IS NOT NULL ORDER BY id").SetDialect(tt.dialect), func(row *Row) []any { var id int var score1 int64 var score2 int32 var price float64 var name string var isActive bool var updatedAt time.Time row.Scan(&id, "id") row.Scan(&score1, "score") row.Scan(&score2, "score") if diff := testutil.Diff(score1, int64(score2)); diff != "" { panic(fmt.Errorf(testutil.Callers() + diff)) } row.Scan(&price, "price") row.Scan(&name, "name") row.Scan(&isActive, "is_active") row.Scan(&updatedAt, "updated_at") return []any{id, score1, price, name, isActive, updatedAt} }, ) if err != nil { t.Fatal(testutil.Callers(), err) } wantValues := [][]any{ {123, int64(123), float64(123), "abc", true, time.Unix(123, 0).UTC()}, {456, int64(456), float64(456), "def", true, time.Unix(456, 0).UTC()}, {789, int64(789), float64(789), "ghi", true, time.Unix(789, 0).UTC()}, } if diff := testutil.Diff(gotValues, wantValues); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("dynamic SQL query (null values)", func(t *testing.T) { gotValue, err := FetchOne(db, Queryf("SELECT {*} FROM table01 WHERE id IS NULL").SetDialect(tt.dialect), func(row *Row) []any { var id int var score1 int64 var score2 int32 var price float64 var name string var isActive bool var updatedAt time.Time row.Scan(&id, "id") row.Scan(&score1, "score") row.Scan(&score2, "score") if diff := testutil.Diff(score1, int64(score2)); diff != "" { panic(fmt.Errorf(testutil.Callers() + diff)) } row.Scan(&price, "price") row.Scan(&name, "name") row.Scan(&isActive, "is_active") row.Scan(&updatedAt, "updated_at") return []any{id, score1, price, name, isActive, updatedAt} }, ) if err != nil { t.Fatal(testutil.Callers(), err) } wantValue := []any{int(0), int64(0), float64(0), "", false, time.Time{}} if diff := testutil.Diff(gotValue, wantValue); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("dynamic SQL query (null values) (using sql.Null structs)", func(t *testing.T) { gotValue, err := FetchOne(db, Queryf("SELECT {*} FROM table01 WHERE id IS NULL").SetDialect(tt.dialect), func(row *Row) []any { var id sql.NullInt64 var score1 sql.NullInt64 var score2 sql.NullInt32 var price sql.NullFloat64 var name sql.NullString var isActive sql.NullBool var updatedAt sql.NullTime row.Scan(&id, "id") row.Scan(&score1, "score") row.Scan(&score2, "score") if diff := testutil.Diff(score1.Int64, int64(score2.Int32)); diff != "" { panic(fmt.Errorf(testutil.Callers() + diff)) } row.Scan(&price, "price") row.Scan(&name, "name") row.Scan(&isActive, "is_active") row.Scan(&updatedAt, "updated_at") return []any{int(id.Int64), score1.Int64, price.Float64, name.String, isActive.Bool, updatedAt.Time} }, ) if err != nil { t.Fatal(testutil.Callers(), err) } wantValue := []any{int(0), int64(0), float64(0), "", false, time.Time{}} if diff := testutil.Diff(gotValue, wantValue); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("static SQL query", func(t *testing.T) { // Raw SQL query with. gotValues, err := FetchAll(Log(db), Queryf("SELECT id, score, price, name, is_active, updated_at FROM table01 WHERE id IS NOT NULL ORDER BY id").SetDialect(tt.dialect), func(row *Row) []any { return []any{ row.Int("id"), row.Int64("score"), row.Float64("price"), row.String("name"), row.Bool("is_active"), row.Time("updated_at"), } }, ) if err != nil { t.Fatal(testutil.Callers(), err) } wantValues := [][]any{ {123, int64(123), float64(123), "abc", true, time.Unix(123, 0).UTC()}, {456, int64(456), float64(456), "def", true, time.Unix(456, 0).UTC()}, {789, int64(789), float64(789), "ghi", true, time.Unix(789, 0).UTC()}, } if diff := testutil.Diff(gotValues, wantValues); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("static SQL query (raw Values)", func(t *testing.T) { gotValues, err := FetchAll(db, Queryf("SELECT id, score, price, name, is_active, updated_at FROM table01 WHERE id IS NOT NULL ORDER BY id").SetDialect(tt.dialect), func(row *Row) []any { columns := row.Columns() columnTypes := row.ColumnTypes() values := row.Values() if len(columns) != len(columnTypes) || len(columnTypes) != len(values) { panic(fmt.Errorf(testutil.Callers()+" length of columns/columnTypes/values don't match: %v %v %v", columns, columnTypes, values)) } return values }, ) if err != nil { t.Fatal(testutil.Callers(), err) } // We need to tweak wantValues depending on the dialect because // we are at the mercy of whatever that dialect's database // driver decides to return. var wantValues [][]any switch tt.dialect { case DialectSQLite, DialectPostgres, DialectSQLServer: wantValues = [][]any{ {int64(123), int64(123), float64(123), "abc", true, time.Unix(123, 0).UTC()}, {int64(456), int64(456), float64(456), "def", true, time.Unix(456, 0).UTC()}, {int64(789), int64(789), float64(789), "ghi", true, time.Unix(789, 0).UTC()}, } case DialectMySQL: wantValues = [][]any{ {[]byte("123"), []byte("123"), []byte("123"), []byte("abc"), []byte("1"), time.Unix(123, 0).UTC()}, {[]byte("456"), []byte("456"), []byte("456"), []byte("def"), []byte("1"), time.Unix(456, 0).UTC()}, {[]byte("789"), []byte("789"), []byte("789"), []byte("ghi"), []byte("1"), time.Unix(789, 0).UTC()}, } } if diff := testutil.Diff(gotValues, wantValues); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("static SQL query (null values)", func(t *testing.T) { gotValue, err := FetchOne(db, Queryf("SELECT id, score, price, name, is_active, updated_at FROM table01 WHERE id IS NULL").SetDialect(tt.dialect), func(row *Row) []any { columns := row.Columns() columnTypes := row.ColumnTypes() values := row.Values() if len(columns) != len(columnTypes) || len(columnTypes) != len(values) { panic(fmt.Errorf(testutil.Callers()+" length of columns/columnTypes/values don't match: %v %v %v", columns, columnTypes, values)) } return values }, ) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(gotValue, []any{nil, nil, nil, nil, nil, nil}); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("static SQL query (null values) (using sql.Null structs)", func(t *testing.T) { gotValue, err := FetchOne(db, Queryf("SELECT id, score, price, name, is_active, updated_at FROM table01 WHERE id IS NULL").SetDialect(tt.dialect), func(row *Row) []any { return []any{ row.NullInt64("score"), row.NullFloat64("price"), row.NullString("name"), row.NullBool("is_active"), row.NullTime("updated_at"), } }, ) if err != nil { t.Fatal(testutil.Callers(), err) } wantValues := []any{sql.NullInt64{}, sql.NullFloat64{}, sql.NullString{}, sql.NullBool{}, sql.NullTime{}} if diff := testutil.Diff(gotValue, wantValues); diff != "" { t.Error(testutil.Callers(), diff) } }) }) } } func preprocessDSN(dialect string, dsn string) string { switch dialect { case DialectPostgres: before, after, _ := strings.Cut(dsn, "?") q, err := url.ParseQuery(after) if err != nil { return dsn } if !q.Has("sslmode") { q.Set("sslmode", "disable") } if !q.Has("binary_parameters") { q.Set("binary_parameters", "yes") } return before + "?" + q.Encode() case DialectMySQL: before, after, _ := strings.Cut(strings.TrimPrefix(dsn, "mysql://"), "?") q, err := url.ParseQuery(after) if err != nil { return dsn } if !q.Has("allowAllFiles") { q.Set("allowAllFiles", "true") } if !q.Has("multiStatements") { q.Set("multiStatements", "true") } if !q.Has("parseTime") { q.Set("parseTime", "true") } return before + "?" + q.Encode() default: return dsn } } ================================================ FILE: internal/googleuuid/googleuuid.go ================================================ // Copyright (c) 2009,2014 Google Inc. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. package googleuuid import ( "bytes" "encoding/hex" "errors" "fmt" "strings" ) // ParseBytes decodes b into a UUID or returns an error. Both the UUID form of // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded. func ParseBytes(b []byte) (uuid [16]byte, err error) { switch len(b) { case 36: // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx case 36 + 9: // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx if !bytes.Equal(bytes.ToLower(b[:9]), []byte("urn:uuid:")) { return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9]) } b = b[9:] case 36 + 2: // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} b = b[1:] case 32: // xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx var ok bool for i := 0; i < 32; i += 2 { uuid[i/2], ok = xtob(b[i], b[i+1]) if !ok { return uuid, errors.New("invalid UUID format") } } return uuid, nil default: return uuid, fmt.Errorf("invalid UUID length: %d", len(b)) } // s is now at least 36 bytes long // it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' { return uuid, errors.New("invalid UUID format") } for i, x := range [16]int{ 0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34, } { v, ok := xtob(b[x], b[x+1]) if !ok { return uuid, errors.New("invalid UUID format") } uuid[i] = v } return uuid, nil } func Parse(s string) (uuid [16]byte, err error) { if len(s) != 36 { if len(s) != 36+9 { return uuid, fmt.Errorf("invalid UUID length: %d", len(s)) } if strings.ToLower(s[:9]) != "urn:uuid:" { return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9]) } s = s[9:] } if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' { return uuid, errors.New("invalid UUID format") } for i, x := range [16]int{ 0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34, } { v, ok := xtob(s[x], s[x+1]) if !ok { return uuid, errors.New("invalid UUID format") } uuid[i] = v } return uuid, nil } // xvalues returns the value of a byte as a hexadecimal digit or 255. var xvalues = [256]byte{ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255, 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, } // xtob converts hex characters x1 and x2 into a byte. func xtob(x1, x2 byte) (byte, bool) { b1 := xvalues[x1] b2 := xvalues[x2] return (b1 << 4) | b2, b1 != 255 && b2 != 255 } // var buf [36]byte; encodeHex(buf[:], [16]byte{...}); string(buf[:]) func EncodeHex(dst []byte, uuid [16]byte) { hex.Encode(dst[:], uuid[:4]) dst[8] = '-' hex.Encode(dst[9:13], uuid[4:6]) dst[13] = '-' hex.Encode(dst[14:18], uuid[6:8]) dst[18] = '-' hex.Encode(dst[19:23], uuid[8:10]) dst[23] = '-' hex.Encode(dst[24:], uuid[10:]) } ================================================ FILE: internal/pqarray/pqarray.go ================================================ // Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany // // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. package pqarray import ( "bytes" "database/sql" "database/sql/driver" "encoding/hex" "fmt" "reflect" "strconv" "strings" "time" ) var typeByteSlice = reflect.TypeOf([]byte{}) var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // Array returns the optimal driver.Valuer and sql.Scanner for an array or // slice of any dimension. // // For example: // // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) // // var x []sql.NullInt64 // db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) // // Scanning multi-dimensional arrays is not supported. Arrays where the lower // bound is not one (such as `[0:0]={1}') are not supported. func Array(a interface{}) interface { driver.Valuer sql.Scanner } { switch a := a.(type) { case []bool: return (*BoolArray)(&a) case []float64: return (*Float64Array)(&a) case []float32: return (*Float32Array)(&a) case []int64: return (*Int64Array)(&a) case []int32: return (*Int32Array)(&a) case []string: return (*StringArray)(&a) case [][]byte: return (*ByteaArray)(&a) case *[]bool: return (*BoolArray)(a) case *[]float64: return (*Float64Array)(a) case *[]float32: return (*Float32Array)(a) case *[]int64: return (*Int64Array)(a) case *[]int32: return (*Int32Array)(a) case *[]string: return (*StringArray)(a) case *[][]byte: return (*ByteaArray)(a) } return GenericArray{a} } // ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner // to override the array delimiter used by GenericArray. type ArrayDelimiter interface { // ArrayDelimiter returns the delimiter character(s) for this element's type. ArrayDelimiter() string } // BoolArray represents a one-dimensional array of the PostgreSQL boolean type. type BoolArray []bool // Scan implements the sql.Scanner interface. func (a *BoolArray) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) case nil: *a = nil return nil } return fmt.Errorf("pq: cannot convert %T to BoolArray", src) } func (a *BoolArray) scanBytes(src []byte) error { elems, err := scanLinearArray(src, []byte{','}, "BoolArray") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(BoolArray, len(elems)) for i, v := range elems { if len(v) != 1 { return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) } switch v[0] { case 't': b[i] = true case 'f': b[i] = false default: return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) } } *a = b } return nil } // Value implements the driver.Valuer interface. func (a BoolArray) Value() (driver.Value, error) { if a == nil { return nil, nil } if n := len(a); n > 0 { // There will be exactly two curly brackets, N bytes of values, // and N-1 bytes of delimiters. b := make([]byte, 1+2*n) for i := 0; i < n; i++ { b[2*i] = ',' if a[i] { b[1+2*i] = 't' } else { b[1+2*i] = 'f' } } b[0] = '{' b[2*n] = '}' return string(b), nil } return "{}", nil } // ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. type ByteaArray [][]byte // Scan implements the sql.Scanner interface. func (a *ByteaArray) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) case nil: *a = nil return nil } return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) } func (a *ByteaArray) scanBytes(src []byte) error { elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(ByteaArray, len(elems)) for i, v := range elems { b[i], err = parseBytea(v) if err != nil { return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) } } *a = b } return nil } // Value implements the driver.Valuer interface. It uses the "hex" format which // is only supported on PostgreSQL 9.0 or newer. func (a ByteaArray) Value() (driver.Value, error) { if a == nil { return nil, nil } if n := len(a); n > 0 { // There will be at least two curly brackets, 2*N bytes of quotes, // 3*N bytes of hex formatting, and N-1 bytes of delimiters. size := 1 + 6*n for _, x := range a { size += hex.EncodedLen(len(x)) } b := make([]byte, size) for i, s := 0, b; i < n; i++ { o := copy(s, `,"\\x`) o += hex.Encode(s[o:], a[i]) s[o] = '"' s = s[o+1:] } b[0] = '{' b[size-1] = '}' return string(b), nil } return "{}", nil } // Float64Array represents a one-dimensional array of the PostgreSQL double // precision type. type Float64Array []float64 // Scan implements the sql.Scanner interface. func (a *Float64Array) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) case nil: *a = nil return nil } return fmt.Errorf("pq: cannot convert %T to Float64Array", src) } func (a *Float64Array) scanBytes(src []byte) error { elems, err := scanLinearArray(src, []byte{','}, "Float64Array") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(Float64Array, len(elems)) for i, v := range elems { if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { return fmt.Errorf("pq: parsing array element index %d: %v", i, err) } } *a = b } return nil } // Value implements the driver.Valuer interface. func (a Float64Array) Value() (driver.Value, error) { if a == nil { return nil, nil } if n := len(a); n > 0 { // There will be at least two curly brackets, N bytes of values, // and N-1 bytes of delimiters. b := make([]byte, 1, 1+2*n) b[0] = '{' b = strconv.AppendFloat(b, a[0], 'f', -1, 64) for i := 1; i < n; i++ { b = append(b, ',') b = strconv.AppendFloat(b, a[i], 'f', -1, 64) } return string(append(b, '}')), nil } return "{}", nil } // Float32Array represents a one-dimensional array of the PostgreSQL double // precision type. type Float32Array []float32 // Scan implements the sql.Scanner interface. func (a *Float32Array) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) case nil: *a = nil return nil } return fmt.Errorf("pq: cannot convert %T to Float32Array", src) } func (a *Float32Array) scanBytes(src []byte) error { elems, err := scanLinearArray(src, []byte{','}, "Float32Array") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(Float32Array, len(elems)) for i, v := range elems { var x float64 if x, err = strconv.ParseFloat(string(v), 32); err != nil { return fmt.Errorf("pq: parsing array element index %d: %v", i, err) } b[i] = float32(x) } *a = b } return nil } // Value implements the driver.Valuer interface. func (a Float32Array) Value() (driver.Value, error) { if a == nil { return nil, nil } if n := len(a); n > 0 { // There will be at least two curly brackets, N bytes of values, // and N-1 bytes of delimiters. b := make([]byte, 1, 1+2*n) b[0] = '{' b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) for i := 1; i < n; i++ { b = append(b, ',') b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) } return string(append(b, '}')), nil } return "{}", nil } // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. type GenericArray struct{ A interface{} } func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { var assign func([]byte, reflect.Value) error var del = "," // TODO calculate the assign function for other types // TODO repeat this section on the element type of arrays or slices (multidimensional) { if reflect.PtrTo(rt).Implements(typeSQLScanner) { // dest is always addressable because it is an element of a slice. assign = func(src []byte, dest reflect.Value) (err error) { ss := dest.Addr().Interface().(sql.Scanner) if src == nil { err = ss.Scan(nil) } else { err = ss.Scan(src) } return } goto FoundType } assign = func([]byte, reflect.Value) error { return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) } } FoundType: if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { del = ad.ArrayDelimiter() } return rt, assign, del } // Scan implements the sql.Scanner interface. func (a GenericArray) Scan(src interface{}) error { dpv := reflect.ValueOf(a.A) switch { case dpv.Kind() != reflect.Ptr: return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) case dpv.IsNil(): return fmt.Errorf("pq: destination %T is nil", a.A) } dv := dpv.Elem() switch dv.Kind() { case reflect.Slice: case reflect.Array: default: return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) } switch src := src.(type) { case []byte: return a.scanBytes(src, dv) case string: return a.scanBytes([]byte(src), dv) case nil: if dv.Kind() == reflect.Slice { dv.Set(reflect.Zero(dv.Type())) return nil } } return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) } func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) dims, elems, err := parseArray(src, []byte(del)) if err != nil { return err } // TODO allow multidimensional if len(dims) > 1 { return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", strings.Replace(fmt.Sprint(dims), " ", "][", -1)) } // Treat a zero-dimensional array like an array with a single dimension of zero. if len(dims) == 0 { dims = append(dims, 0) } for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { switch rt.Kind() { case reflect.Slice: case reflect.Array: if rt.Len() != dims[i] { return fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) } default: // TODO handle multidimensional } } values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) for i, e := range elems { if err := assign(e, values.Index(i)); err != nil { return fmt.Errorf("pq: parsing array element index %d: %v", i, err) } } // TODO handle multidimensional switch dv.Kind() { case reflect.Slice: dv.Set(values.Slice(0, dims[0])) case reflect.Array: for i := 0; i < dims[0]; i++ { dv.Index(i).Set(values.Index(i)) } } return nil } // Value implements the driver.Valuer interface. func (a GenericArray) Value() (driver.Value, error) { if a.A == nil { return nil, nil } rv := reflect.ValueOf(a.A) switch rv.Kind() { case reflect.Slice: if rv.IsNil() { return nil, nil } case reflect.Array: default: return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) } if n := rv.Len(); n > 0 { // There will be at least two curly brackets, N bytes of values, // and N-1 bytes of delimiters. b := make([]byte, 0, 1+2*n) b, _, err := appendArray(b, rv, n) return string(b), err } return "{}", nil } // Int64Array represents a one-dimensional array of the PostgreSQL integer types. type Int64Array []int64 // Scan implements the sql.Scanner interface. func (a *Int64Array) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) case nil: *a = nil return nil } return fmt.Errorf("pq: cannot convert %T to Int64Array", src) } func (a *Int64Array) scanBytes(src []byte) error { elems, err := scanLinearArray(src, []byte{','}, "Int64Array") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(Int64Array, len(elems)) for i, v := range elems { if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { return fmt.Errorf("pq: parsing array element index %d: %v", i, err) } } *a = b } return nil } // Value implements the driver.Valuer interface. func (a Int64Array) Value() (driver.Value, error) { if a == nil { return nil, nil } if n := len(a); n > 0 { // There will be at least two curly brackets, N bytes of values, // and N-1 bytes of delimiters. b := make([]byte, 1, 1+2*n) b[0] = '{' b = strconv.AppendInt(b, a[0], 10) for i := 1; i < n; i++ { b = append(b, ',') b = strconv.AppendInt(b, a[i], 10) } return string(append(b, '}')), nil } return "{}", nil } // Int32Array represents a one-dimensional array of the PostgreSQL integer types. type Int32Array []int32 // Scan implements the sql.Scanner interface. func (a *Int32Array) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) case nil: *a = nil return nil } return fmt.Errorf("pq: cannot convert %T to Int32Array", src) } func (a *Int32Array) scanBytes(src []byte) error { elems, err := scanLinearArray(src, []byte{','}, "Int32Array") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(Int32Array, len(elems)) for i, v := range elems { x, err := strconv.ParseInt(string(v), 10, 32) if err != nil { return fmt.Errorf("pq: parsing array element index %d: %v", i, err) } b[i] = int32(x) } *a = b } return nil } // Value implements the driver.Valuer interface. func (a Int32Array) Value() (driver.Value, error) { if a == nil { return nil, nil } if n := len(a); n > 0 { // There will be at least two curly brackets, N bytes of values, // and N-1 bytes of delimiters. b := make([]byte, 1, 1+2*n) b[0] = '{' b = strconv.AppendInt(b, int64(a[0]), 10) for i := 1; i < n; i++ { b = append(b, ',') b = strconv.AppendInt(b, int64(a[i]), 10) } return string(append(b, '}')), nil } return "{}", nil } // StringArray represents a one-dimensional array of the PostgreSQL character types. type StringArray []string // Scan implements the sql.Scanner interface. func (a *StringArray) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) case nil: *a = nil return nil } return fmt.Errorf("pq: cannot convert %T to StringArray", src) } func (a *StringArray) scanBytes(src []byte) error { elems, err := scanLinearArray(src, []byte{','}, "StringArray") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(StringArray, len(elems)) for i, v := range elems { if b[i] = string(v); v == nil { return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) } } *a = b } return nil } // Value implements the driver.Valuer interface. func (a StringArray) Value() (driver.Value, error) { if a == nil { return nil, nil } if n := len(a); n > 0 { // There will be at least two curly brackets, 2*N bytes of quotes, // and N-1 bytes of delimiters. b := make([]byte, 1, 1+3*n) b[0] = '{' b = appendArrayQuotedBytes(b, []byte(a[0])) for i := 1; i < n; i++ { b = append(b, ',') b = appendArrayQuotedBytes(b, []byte(a[i])) } return string(append(b, '}')), nil } return "{}", nil } // appendArray appends rv to the buffer, returning the extended buffer and // the delimiter used between elements. // // It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { var del string var err error b = append(b, '{') if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { return b, del, err } for i := 1; i < n; i++ { b = append(b, del...) if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { return b, del, err } } return append(b, '}'), del, nil } // appendArrayElement appends rv to the buffer, returning the extended buffer // and the delimiter to use before the next element. // // When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted // using driver.DefaultParameterConverter and the resulting []byte or string // is double-quoted. // // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { if n := rv.Len(); n > 0 { return appendArray(b, rv, n) } return b, "", nil } } var del = "," var err error var iv interface{} = rv.Interface() if ad, ok := iv.(ArrayDelimiter); ok { del = ad.ArrayDelimiter() } if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { return b, del, err } switch v := iv.(type) { case nil: return append(b, "NULL"...), del, nil case []byte: return appendArrayQuotedBytes(b, v), del, nil case string: return appendArrayQuotedBytes(b, []byte(v)), del, nil } b, err = appendValue(b, iv) return b, del, err } func appendArrayQuotedBytes(b, v []byte) []byte { b = append(b, '"') for { i := bytes.IndexAny(v, `"\`) if i < 0 { b = append(b, v...) break } if i > 0 { b = append(b, v[:i]...) } b = append(b, '\\', v[i]) v = v[i+1:] } return append(b, '"') } func appendValue(b []byte, v driver.Value) ([]byte, error) { return append(b, encode(nil, v, 0)...), nil } // parseArray extracts the dimensions and elements of an array represented in // text format. Only representations emitted by the backend are supported. // Notably, whitespace around brackets and delimiters is significant, and NULL // is case-sensitive. // // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { var depth, i int if len(src) < 1 || src[0] != '{' { return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) } Open: for i < len(src) { switch src[i] { case '{': depth++ i++ case '}': elems = make([][]byte, 0) goto Close default: break Open } } dims = make([]int, i) Element: for i < len(src) { switch src[i] { case '{': if depth == len(dims) { break Element } depth++ dims[depth-1] = 0 i++ case '"': var elem = []byte{} var escape bool for i++; i < len(src); i++ { if escape { elem = append(elem, src[i]) escape = false } else { switch src[i] { default: elem = append(elem, src[i]) case '\\': escape = true case '"': elems = append(elems, elem) i++ break Element } } } default: for start := i; i < len(src); i++ { if bytes.HasPrefix(src[i:], del) || src[i] == '}' { elem := src[start:i] if len(elem) == 0 { return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) } if bytes.Equal(elem, []byte("NULL")) { elem = nil } elems = append(elems, elem) break Element } } } } for i < len(src) { if bytes.HasPrefix(src[i:], del) && depth > 0 { dims[depth-1]++ i += len(del) goto Element } else if src[i] == '}' && depth > 0 { dims[depth-1]++ depth-- i++ } else { return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) } } Close: for i < len(src) { if src[i] == '}' && depth > 0 { depth-- i++ } else { return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) } } if depth > 0 { err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) } if err == nil { for _, d := range dims { if (len(elems) % d) != 0 { err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") } } } return } func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { dims, elems, err := parseArray(src, del) if err != nil { return nil, err } if len(dims) > 1 { return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) } return elems, err } func parseBytea(s []byte) (result []byte, err error) { if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { // bytea_output = hex s = s[2:] // trim off leading "\\x" result = make([]byte, hex.DecodedLen(len(s))) _, err := hex.Decode(result, s) if err != nil { return nil, err } } else { // bytea_output = escape for len(s) > 0 { if s[0] == '\\' { // escaped '\\' if len(s) >= 2 && s[1] == '\\' { result = append(result, '\\') s = s[2:] continue } // '\\' followed by an octal number if len(s) < 4 { return nil, fmt.Errorf("invalid bytea sequence %v", s) } r, err := strconv.ParseUint(string(s[1:4]), 8, 8) if err != nil { return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) } result = append(result, byte(r)) s = s[4:] } else { // We hit an unescaped, raw byte. Try to read in as many as // possible in one go. i := bytes.IndexByte(s, '\\') if i == -1 { result = append(result, s...) break } result = append(result, s[:i]...) s = s[i:] } } } return result, nil } func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid Oid) []byte { switch v := x.(type) { case int64: return strconv.AppendInt(nil, v, 10) case float64: return strconv.AppendFloat(nil, v, 'f', -1, 64) case []byte: if pgtypOid == T_bytea { return encodeBytea(parameterStatus.serverVersion, v) } return v case string: if pgtypOid == T_bytea { return encodeBytea(parameterStatus.serverVersion, []byte(v)) } return []byte(v) case bool: return strconv.AppendBool(nil, v) case time.Time: return formatTs(v) default: panic(fmt.Errorf("pq: %s", fmt.Sprintf("encode: unknown type for %T", v))) } } type parameterStatus struct { // server version in the same format as server_version_num, or 0 if // unavailable serverVersion int // the current location based on the TimeZone value of the session, if // available currentLocation *time.Location } func encodeBytea(serverVersion int, v []byte) (result []byte) { if serverVersion >= 90000 { // Use the hex format if we know that the server supports it result = make([]byte, 2+hex.EncodedLen(len(v))) result[0] = '\\' result[1] = 'x' hex.Encode(result[2:], v) } else { // .. or resort to "escape" for _, b := range v { if b == '\\' { result = append(result, '\\', '\\') } else if b < 0x20 || b > 0x7e { result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) } else { result = append(result, b) } } } return result } var infinityTsEnabled = false var infinityTsNegative time.Time var infinityTsPositive time.Time // formatTs formats t into a format postgres understands. func formatTs(t time.Time) []byte { if infinityTsEnabled { // t <= -infinity : ! (t > -infinity) if !t.After(infinityTsNegative) { return []byte("-infinity") } // t >= infinity : ! (!t < infinity) if !t.Before(infinityTsPositive) { return []byte("infinity") } } return FormatTimestamp(t) } // FormatTimestamp formats t into Postgres' text format for timestamps. func FormatTimestamp(t time.Time) []byte { // Need to send dates before 0001 A.D. with " BC" suffix, instead of the // minus sign preferred by Go. // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on bc := false if t.Year() <= 0 { // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" t = t.AddDate((-t.Year())*2+1, 0, 0) bc = true } b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) _, offset := t.Zone() offset %= 60 if offset != 0 { // RFC3339Nano already printed the minus sign if offset < 0 { offset = -offset } b = append(b, ':') if offset < 10 { b = append(b, '0') } b = strconv.AppendInt(b, int64(offset), 10) } if bc { b = append(b, " BC"...) } return b } // Oid is a Postgres Object ID. type Oid uint32 const ( T_bool Oid = 16 T_bytea Oid = 17 T_char Oid = 18 T_name Oid = 19 T_int8 Oid = 20 T_int2 Oid = 21 T_int2vector Oid = 22 T_int4 Oid = 23 T_regproc Oid = 24 T_text Oid = 25 T_oid Oid = 26 T_tid Oid = 27 T_xid Oid = 28 T_cid Oid = 29 T_oidvector Oid = 30 T_pg_ddl_command Oid = 32 T_pg_type Oid = 71 T_pg_attribute Oid = 75 T_pg_proc Oid = 81 T_pg_class Oid = 83 T_json Oid = 114 T_xml Oid = 142 T__xml Oid = 143 T_pg_node_tree Oid = 194 T__json Oid = 199 T_smgr Oid = 210 T_index_am_handler Oid = 325 T_point Oid = 600 T_lseg Oid = 601 T_path Oid = 602 T_box Oid = 603 T_polygon Oid = 604 T_line Oid = 628 T__line Oid = 629 T_cidr Oid = 650 T__cidr Oid = 651 T_float4 Oid = 700 T_float8 Oid = 701 T_abstime Oid = 702 T_reltime Oid = 703 T_tinterval Oid = 704 T_unknown Oid = 705 T_circle Oid = 718 T__circle Oid = 719 T_money Oid = 790 T__money Oid = 791 T_macaddr Oid = 829 T_inet Oid = 869 T__bool Oid = 1000 T__bytea Oid = 1001 T__char Oid = 1002 T__name Oid = 1003 T__int2 Oid = 1005 T__int2vector Oid = 1006 T__int4 Oid = 1007 T__regproc Oid = 1008 T__text Oid = 1009 T__tid Oid = 1010 T__xid Oid = 1011 T__cid Oid = 1012 T__oidvector Oid = 1013 T__bpchar Oid = 1014 T__varchar Oid = 1015 T__int8 Oid = 1016 T__point Oid = 1017 T__lseg Oid = 1018 T__path Oid = 1019 T__box Oid = 1020 T__float4 Oid = 1021 T__float8 Oid = 1022 T__abstime Oid = 1023 T__reltime Oid = 1024 T__tinterval Oid = 1025 T__polygon Oid = 1027 T__oid Oid = 1028 T_aclitem Oid = 1033 T__aclitem Oid = 1034 T__macaddr Oid = 1040 T__inet Oid = 1041 T_bpchar Oid = 1042 T_varchar Oid = 1043 T_date Oid = 1082 T_time Oid = 1083 T_timestamp Oid = 1114 T__timestamp Oid = 1115 T__date Oid = 1182 T__time Oid = 1183 T_timestamptz Oid = 1184 T__timestamptz Oid = 1185 T_interval Oid = 1186 T__interval Oid = 1187 T__numeric Oid = 1231 T_pg_database Oid = 1248 T__cstring Oid = 1263 T_timetz Oid = 1266 T__timetz Oid = 1270 T_bit Oid = 1560 T__bit Oid = 1561 T_varbit Oid = 1562 T__varbit Oid = 1563 T_numeric Oid = 1700 T_refcursor Oid = 1790 T__refcursor Oid = 2201 T_regprocedure Oid = 2202 T_regoper Oid = 2203 T_regoperator Oid = 2204 T_regclass Oid = 2205 T_regtype Oid = 2206 T__regprocedure Oid = 2207 T__regoper Oid = 2208 T__regoperator Oid = 2209 T__regclass Oid = 2210 T__regtype Oid = 2211 T_record Oid = 2249 T_cstring Oid = 2275 T_any Oid = 2276 T_anyarray Oid = 2277 T_void Oid = 2278 T_trigger Oid = 2279 T_language_handler Oid = 2280 T_internal Oid = 2281 T_opaque Oid = 2282 T_anyelement Oid = 2283 T__record Oid = 2287 T_anynonarray Oid = 2776 T_pg_authid Oid = 2842 T_pg_auth_members Oid = 2843 T__txid_snapshot Oid = 2949 T_uuid Oid = 2950 T__uuid Oid = 2951 T_txid_snapshot Oid = 2970 T_fdw_handler Oid = 3115 T_pg_lsn Oid = 3220 T__pg_lsn Oid = 3221 T_tsm_handler Oid = 3310 T_anyenum Oid = 3500 T_tsvector Oid = 3614 T_tsquery Oid = 3615 T_gtsvector Oid = 3642 T__tsvector Oid = 3643 T__gtsvector Oid = 3644 T__tsquery Oid = 3645 T_regconfig Oid = 3734 T__regconfig Oid = 3735 T_regdictionary Oid = 3769 T__regdictionary Oid = 3770 T_jsonb Oid = 3802 T__jsonb Oid = 3807 T_anyrange Oid = 3831 T_event_trigger Oid = 3838 T_int4range Oid = 3904 T__int4range Oid = 3905 T_numrange Oid = 3906 T__numrange Oid = 3907 T_tsrange Oid = 3908 T__tsrange Oid = 3909 T_tstzrange Oid = 3910 T__tstzrange Oid = 3911 T_daterange Oid = 3912 T__daterange Oid = 3913 T_int8range Oid = 3926 T__int8range Oid = 3927 T_pg_shseclabel Oid = 4066 T_regnamespace Oid = 4089 T__regnamespace Oid = 4090 T_regrole Oid = 4096 T__regrole Oid = 4097 ) var TypeName = map[Oid]string{ T_bool: "BOOL", T_bytea: "BYTEA", T_char: "CHAR", T_name: "NAME", T_int8: "INT8", T_int2: "INT2", T_int2vector: "INT2VECTOR", T_int4: "INT4", T_regproc: "REGPROC", T_text: "TEXT", T_oid: "OID", T_tid: "TID", T_xid: "XID", T_cid: "CID", T_oidvector: "OIDVECTOR", T_pg_ddl_command: "PG_DDL_COMMAND", T_pg_type: "PG_TYPE", T_pg_attribute: "PG_ATTRIBUTE", T_pg_proc: "PG_PROC", T_pg_class: "PG_CLASS", T_json: "JSON", T_xml: "XML", T__xml: "_XML", T_pg_node_tree: "PG_NODE_TREE", T__json: "_JSON", T_smgr: "SMGR", T_index_am_handler: "INDEX_AM_HANDLER", T_point: "POINT", T_lseg: "LSEG", T_path: "PATH", T_box: "BOX", T_polygon: "POLYGON", T_line: "LINE", T__line: "_LINE", T_cidr: "CIDR", T__cidr: "_CIDR", T_float4: "FLOAT4", T_float8: "FLOAT8", T_abstime: "ABSTIME", T_reltime: "RELTIME", T_tinterval: "TINTERVAL", T_unknown: "UNKNOWN", T_circle: "CIRCLE", T__circle: "_CIRCLE", T_money: "MONEY", T__money: "_MONEY", T_macaddr: "MACADDR", T_inet: "INET", T__bool: "_BOOL", T__bytea: "_BYTEA", T__char: "_CHAR", T__name: "_NAME", T__int2: "_INT2", T__int2vector: "_INT2VECTOR", T__int4: "_INT4", T__regproc: "_REGPROC", T__text: "_TEXT", T__tid: "_TID", T__xid: "_XID", T__cid: "_CID", T__oidvector: "_OIDVECTOR", T__bpchar: "_BPCHAR", T__varchar: "_VARCHAR", T__int8: "_INT8", T__point: "_POINT", T__lseg: "_LSEG", T__path: "_PATH", T__box: "_BOX", T__float4: "_FLOAT4", T__float8: "_FLOAT8", T__abstime: "_ABSTIME", T__reltime: "_RELTIME", T__tinterval: "_TINTERVAL", T__polygon: "_POLYGON", T__oid: "_OID", T_aclitem: "ACLITEM", T__aclitem: "_ACLITEM", T__macaddr: "_MACADDR", T__inet: "_INET", T_bpchar: "BPCHAR", T_varchar: "VARCHAR", T_date: "DATE", T_time: "TIME", T_timestamp: "TIMESTAMP", T__timestamp: "_TIMESTAMP", T__date: "_DATE", T__time: "_TIME", T_timestamptz: "TIMESTAMPTZ", T__timestamptz: "_TIMESTAMPTZ", T_interval: "INTERVAL", T__interval: "_INTERVAL", T__numeric: "_NUMERIC", T_pg_database: "PG_DATABASE", T__cstring: "_CSTRING", T_timetz: "TIMETZ", T__timetz: "_TIMETZ", T_bit: "BIT", T__bit: "_BIT", T_varbit: "VARBIT", T__varbit: "_VARBIT", T_numeric: "NUMERIC", T_refcursor: "REFCURSOR", T__refcursor: "_REFCURSOR", T_regprocedure: "REGPROCEDURE", T_regoper: "REGOPER", T_regoperator: "REGOPERATOR", T_regclass: "REGCLASS", T_regtype: "REGTYPE", T__regprocedure: "_REGPROCEDURE", T__regoper: "_REGOPER", T__regoperator: "_REGOPERATOR", T__regclass: "_REGCLASS", T__regtype: "_REGTYPE", T_record: "RECORD", T_cstring: "CSTRING", T_any: "ANY", T_anyarray: "ANYARRAY", T_void: "VOID", T_trigger: "TRIGGER", T_language_handler: "LANGUAGE_HANDLER", T_internal: "INTERNAL", T_opaque: "OPAQUE", T_anyelement: "ANYELEMENT", T__record: "_RECORD", T_anynonarray: "ANYNONARRAY", T_pg_authid: "PG_AUTHID", T_pg_auth_members: "PG_AUTH_MEMBERS", T__txid_snapshot: "_TXID_SNAPSHOT", T_uuid: "UUID", T__uuid: "_UUID", T_txid_snapshot: "TXID_SNAPSHOT", T_fdw_handler: "FDW_HANDLER", T_pg_lsn: "PG_LSN", T__pg_lsn: "_PG_LSN", T_tsm_handler: "TSM_HANDLER", T_anyenum: "ANYENUM", T_tsvector: "TSVECTOR", T_tsquery: "TSQUERY", T_gtsvector: "GTSVECTOR", T__tsvector: "_TSVECTOR", T__gtsvector: "_GTSVECTOR", T__tsquery: "_TSQUERY", T_regconfig: "REGCONFIG", T__regconfig: "_REGCONFIG", T_regdictionary: "REGDICTIONARY", T__regdictionary: "_REGDICTIONARY", T_jsonb: "JSONB", T__jsonb: "_JSONB", T_anyrange: "ANYRANGE", T_event_trigger: "EVENT_TRIGGER", T_int4range: "INT4RANGE", T__int4range: "_INT4RANGE", T_numrange: "NUMRANGE", T__numrange: "_NUMRANGE", T_tsrange: "TSRANGE", T__tsrange: "_TSRANGE", T_tstzrange: "TSTZRANGE", T__tstzrange: "_TSTZRANGE", T_daterange: "DATERANGE", T__daterange: "_DATERANGE", T_int8range: "INT8RANGE", T__int8range: "_INT8RANGE", T_pg_shseclabel: "PG_SHSECLABEL", T_regnamespace: "REGNAMESPACE", T__regnamespace: "_REGNAMESPACE", T_regrole: "REGROLE", T__regrole: "_REGROLE", } ================================================ FILE: internal/testutil/testutil.go ================================================ package testutil import ( "path/filepath" "reflect" "runtime" "strconv" "strings" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) func Diff[T any](got, want T) string { diff := cmp.Diff( got, want, cmp.Exporter(func(typ reflect.Type) bool { return true }), cmpopts.EquateEmpty(), ) if diff != "" { return "\n-got +want\n" + diff } return "" } func Callers() string { var pc [50]uintptr n := runtime.Callers(2, pc[:]) // skip runtime.Callers + Callers callsites := make([]string, 0, n) frames := runtime.CallersFrames(pc[:n]) for frame, more := frames.Next(); more; frame, more = frames.Next() { callsites = append(callsites, frame.File+":"+strconv.Itoa(frame.Line)) } callsites = callsites[:len(callsites)-1] // skip testing.tRunner if len(callsites) == 1 { return "" } var b strings.Builder b.WriteString("\n[") for i := len(callsites) - 1; i >= 0; i-- { if i < len(callsites)-1 { b.WriteString(" -> ") } b.WriteString(filepath.Base(callsites[i])) } b.WriteString("]") return b.String() } ================================================ FILE: joins.go ================================================ package sq import ( "bytes" "context" "fmt" ) // Join operators. const ( JoinInner = "JOIN" JoinLeft = "LEFT JOIN" JoinRight = "RIGHT JOIN" JoinFull = "FULL JOIN" JoinCross = "CROSS JOIN" ) // JoinTable represents a join on a table. type JoinTable struct { JoinOperator string Table Table OnPredicate Predicate UsingFields []Field } // JoinUsing creates a new JoinTable with the USING operator. func JoinUsing(table Table, fields ...Field) JoinTable { return JoinTable{JoinOperator: JoinInner, Table: table, UsingFields: fields} } // Join creates a new JoinTable with the JOIN operator. func Join(table Table, predicates ...Predicate) JoinTable { return CustomJoin(JoinInner, table, predicates...) } // LeftJoin creates a new JoinTable with the LEFT JOIN operator. func LeftJoin(table Table, predicates ...Predicate) JoinTable { return CustomJoin(JoinLeft, table, predicates...) } // FullJoin creates a new JoinTable with the FULL JOIN operator. func FullJoin(table Table, predicates ...Predicate) JoinTable { return CustomJoin(JoinFull, table, predicates...) } // CrossJoin creates a new JoinTable with the CROSS JOIN operator. func CrossJoin(table Table) JoinTable { return CustomJoin(JoinCross, table) } // CustomJoin creates a new JoinTable with the a custom join operator. func CustomJoin(joinOperator string, table Table, predicates ...Predicate) JoinTable { switch len(predicates) { case 0: return JoinTable{JoinOperator: joinOperator, Table: table} case 1: return JoinTable{JoinOperator: joinOperator, Table: table, OnPredicate: predicates[0]} default: return JoinTable{JoinOperator: joinOperator, Table: table, OnPredicate: And(predicates...)} } } // WriteSQL implements the SQLWriter interface. func (join JoinTable) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { if join.JoinOperator == "" { join.JoinOperator = JoinInner } variadicPredicate, isVariadic := join.OnPredicate.(VariadicPredicate) hasNoPredicate := join.OnPredicate == nil && len(variadicPredicate.Predicates) == 0 && len(join.UsingFields) == 0 if hasNoPredicate && (join.JoinOperator == JoinInner || join.JoinOperator == JoinLeft || join.JoinOperator == JoinRight || join.JoinOperator == JoinFull) && // exclude sqlite from this check because they allow join without predicate dialect != DialectSQLite { return fmt.Errorf("%s requires at least one predicate specified", join.JoinOperator) } if dialect == DialectSQLite && (join.JoinOperator == JoinRight || join.JoinOperator == JoinFull) { return fmt.Errorf("sqlite does not support %s", join.JoinOperator) } // JOIN buf.WriteString(string(join.JoinOperator) + " ") if join.Table == nil { return fmt.Errorf("joining on a nil table") } // _, isQuery := join.Table.(Query) if isQuery { buf.WriteString("(") } err := join.Table.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } if isQuery { buf.WriteString(")") } // AS if tableAlias := getAlias(join.Table); tableAlias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, tableAlias) + quoteTableColumns(dialect, join.Table)) } else if isQuery && dialect != DialectSQLite { return fmt.Errorf("%s %s subquery must have alias", dialect, join.JoinOperator) } if isVariadic { // ON VariadicPredicate buf.WriteString(" ON ") variadicPredicate.Toplevel = true err = variadicPredicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } } else if join.OnPredicate != nil { // ON Predicate buf.WriteString(" ON ") err = join.OnPredicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return err } } else if len(join.UsingFields) > 0 { // USING Fields buf.WriteString(" USING (") err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, join.UsingFields, "", false) if err != nil { return err } buf.WriteString(")") } return nil } func writeJoinTables(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, joinTables []JoinTable) error { var err error for i, joinTable := range joinTables { if i > 0 { buf.WriteString(" ") } err = joinTable.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("join #%d: %w", i+1, err) } } return nil } ================================================ FILE: joins_test.go ================================================ package sq import "testing" func TestJoinTables(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") tests := []TestTable{{ description: "JoinUsing", item: JoinUsing(a, a.FIRST_NAME, a.LAST_NAME), wantQuery: "JOIN actor AS a USING (first_name, last_name)", }, { description: "Join without operator", item: CustomJoin("", a, a.ACTOR_ID.Eq(a.ACTOR_ID), a.FIRST_NAME.Ne(a.LAST_NAME)), wantQuery: "JOIN actor AS a ON a.actor_id = a.actor_id AND a.first_name <> a.last_name", }, { description: "Join", item: Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)), wantQuery: "JOIN actor AS a ON a.actor_id = a.actor_id", }, { description: "LeftJoin", item: LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)), wantQuery: "LEFT JOIN actor AS a ON a.actor_id = a.actor_id", }, { description: "Right Join", item: JoinTable{JoinOperator: JoinRight, Table: a, OnPredicate: a.ACTOR_ID.Eq(a.ACTOR_ID)}, wantQuery: "RIGHT JOIN actor AS a ON a.actor_id = a.actor_id", }, { description: "FullJoin", item: FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)), wantQuery: "FULL JOIN actor AS a ON a.actor_id = a.actor_id", }, { description: "CrossJoin", item: CrossJoin(a), wantQuery: "CROSS JOIN actor AS a", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } notOKTests := []TestTable{{ description: "full join has no predicate", item: FullJoin(a), }, { description: "sqlite does not support full join", dialect: DialectSQLite, item: FullJoin(a, Expr("TRUE")), }, { description: "table is nil", item: Join(nil, Expr("TRUE")), }, { description: "UsingField returns err", item: JoinUsing(a, nil), }} for _, tt := range notOKTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertNotOK(t) }) } errTests := []TestTable{{ description: "table err", item: Join(FaultySQL{}, a.ACTOR_ID.Eq(a.ACTOR_ID)), }, { description: "VariadicPredicate err", item: Join(a, And(FaultySQL{})), }, { description: "predicate err", item: Join(a, FaultySQL{}), }} for _, tt := range errTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertErr(t, ErrFaultySQL) }) } } ================================================ FILE: logger.go ================================================ package sq import ( "bytes" "context" "database/sql" "fmt" "io" "log" "os" "path/filepath" "strconv" "strings" "sync/atomic" "time" ) // QueryStats represents the statistics from running a query. type QueryStats struct { // Dialect of the query. Dialect string // Query string. Query string // Args slice provided with the query string. Args []any // Params maps param names back to arguments in the args slice (by index). Params map[string][]int // Err is the error from running the query. Err error // RowCount from running the query. Not valid for Exec(). RowCount sql.NullInt64 // RowsAffected by running the query. Not valid for // FetchOne/FetchAll/FetchCursor. RowsAffected sql.NullInt64 // LastInsertId of the query. LastInsertId sql.NullInt64 // Exists is the result of FetchExists(). Exists sql.NullBool // When the query started at. StartedAt time.Time // Time taken by the query. TimeTaken time.Duration // The caller file where the query was invoked. CallerFile string // The line in the caller file that invoked the query. CallerLine int // The name of the function where the query was invoked. CallerFunction string // The results from running the query (if it was provided). Results string } // LogSettings are the various log settings taken into account when producing // the QueryStats. type LogSettings struct { // Dispatch logging asynchronously (logs may arrive out of order which can be confusing, but it won't block function calls). LogAsynchronously bool // Include time taken by the query. IncludeTime bool // Include caller (filename and line number). IncludeCaller bool // Include fetched results. IncludeResults int } // SqLogger represents a logger for the sq package. type SqLogger interface { // SqLogSettings should populate a LogSettings struct, which influences // what is added into the QueryStats. SqLogSettings(context.Context, *LogSettings) // SqLogQuery logs a query when for the given QueryStats. SqLogQuery(context.Context, QueryStats) } type sqLogger struct { logger *log.Logger config LoggerConfig } // LoggerConfig is the config used for the sq logger. type LoggerConfig struct { // Dispatch logging asynchronously (logs may arrive out of order which can be confusing, but it won't block function calls). LogAsynchronously bool // Show time taken by the query. ShowTimeTaken bool // Show caller (filename and line number). ShowCaller bool // Show fetched results. ShowResults int // If true, logs are shown as plaintext (no color). NoColor bool // Verbose query interpolation, which shows the query before and after // interpolating query arguments. The logged query is interpolated by // default, InterpolateVerbose only controls whether the query before // interpolation is shown. To disable query interpolation entirely, look at // HideArgs. InterpolateVerbose bool // Explicitly hides arguments when logging the query (only the query // placeholders will be shown). HideArgs bool } var _ SqLogger = (*sqLogger)(nil) var defaultLogger = NewLogger(os.Stdout, "", log.LstdFlags, LoggerConfig{ ShowTimeTaken: true, ShowCaller: true, }) var verboseLogger = NewLogger(os.Stdout, "", log.LstdFlags, LoggerConfig{ ShowTimeTaken: true, ShowCaller: true, ShowResults: 5, InterpolateVerbose: true, }) // NewLogger returns a new SqLogger. func NewLogger(w io.Writer, prefix string, flag int, config LoggerConfig) SqLogger { return &sqLogger{ logger: log.New(w, prefix, flag), config: config, } } // SqLogSettings implements the SqLogger interface. func (l *sqLogger) SqLogSettings(ctx context.Context, settings *LogSettings) { settings.LogAsynchronously = l.config.LogAsynchronously settings.IncludeTime = l.config.ShowTimeTaken settings.IncludeCaller = l.config.ShowCaller settings.IncludeResults = l.config.ShowResults } // SqLogQuery implements the SqLogger interface. func (l *sqLogger) SqLogQuery(ctx context.Context, queryStats QueryStats) { var reset, red, green, blue, purple string envNoColor, _ := strconv.ParseBool(os.Getenv("NO_COLOR")) if !l.config.NoColor && !envNoColor { reset = colorReset red = colorRed green = colorGreen blue = colorBlue purple = colorPurple } buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) if queryStats.Err == nil { buf.WriteString(green + "[OK]" + reset) } else { buf.WriteString(red + "[FAIL]" + reset) } if l.config.HideArgs { buf.WriteString(" " + queryStats.Query + ";") } else if !l.config.InterpolateVerbose { if queryStats.Err != nil { buf.WriteString(" " + queryStats.Query + ";") if len(queryStats.Args) > 0 { buf.WriteString(" [") } for i := 0; i < len(queryStats.Args); i++ { if i > 0 { buf.WriteString(", ") } buf.WriteString(fmt.Sprintf("%#v", queryStats.Args[i])) } if len(queryStats.Args) > 0 { buf.WriteString("]") } } else { query, err := Sprintf(queryStats.Dialect, queryStats.Query, queryStats.Args) if err != nil { query += " " + err.Error() } buf.WriteString(" " + query + ";") } } if queryStats.Err != nil { errStr := queryStats.Err.Error() if i := strings.IndexByte(errStr, '\n'); i < 0 { buf.WriteString(blue + " err" + reset + "={" + queryStats.Err.Error() + "}") } } if l.config.ShowTimeTaken { buf.WriteString(blue + " timeTaken" + reset + "=" + queryStats.TimeTaken.String()) } if queryStats.RowCount.Valid { buf.WriteString(blue + " rowCount" + reset + "=" + strconv.FormatInt(queryStats.RowCount.Int64, 10)) } if queryStats.RowsAffected.Valid { buf.WriteString(blue + " rowsAffected" + reset + "=" + strconv.FormatInt(queryStats.RowsAffected.Int64, 10)) } if queryStats.LastInsertId.Valid { buf.WriteString(blue + " lastInsertId" + reset + "=" + strconv.FormatInt(queryStats.LastInsertId.Int64, 10)) } if queryStats.Exists.Valid { buf.WriteString(blue + " exists" + reset + "=" + strconv.FormatBool(queryStats.Exists.Bool)) } if l.config.ShowCaller { buf.WriteString(blue + " caller" + reset + "=" + queryStats.CallerFile + ":" + strconv.Itoa(queryStats.CallerLine) + ":" + filepath.Base(queryStats.CallerFunction)) } if !l.config.HideArgs && l.config.InterpolateVerbose { buf.WriteString("\n" + purple + "----[ Executing query ]----" + reset) buf.WriteString("\n" + queryStats.Query + "; " + fmt.Sprintf("%#v", queryStats.Args)) buf.WriteString("\n" + purple + "----[ with bind values ]----" + reset) query, err := Sprintf(queryStats.Dialect, queryStats.Query, queryStats.Args) query += ";" if err != nil { query += " " + err.Error() } buf.WriteString("\n" + query) } if l.config.ShowResults > 0 && queryStats.Err == nil { buf.WriteString("\n" + purple + "----[ Fetched result ]----" + reset) buf.WriteString(queryStats.Results) if queryStats.RowCount.Int64 > int64(l.config.ShowResults) { buf.WriteString("\n...\n(Fetched " + strconv.FormatInt(queryStats.RowCount.Int64, 10) + " rows)") } } if buf.Len() > 0 { l.logger.Println(buf.String()) } } // Log wraps a DB and adds logging to it. func Log(db DB) interface { DB SqLogger } { return struct { DB SqLogger }{DB: db, SqLogger: defaultLogger} } // VerboseLog wraps a DB and adds verbose logging to it. func VerboseLog(db DB) interface { DB SqLogger } { return struct { DB SqLogger }{DB: db, SqLogger: verboseLogger} } var defaultLogSettings atomic.Value // SetDefaultLogSettings sets the function to configure the default // LogSettings. This value is not used unless SetDefaultLogQuery is also // configured. func SetDefaultLogSettings(logSettings func(context.Context, *LogSettings)) { defaultLogSettings.Store(logSettings) } var defaultLogQuery atomic.Value // SetDefaultLogQuery sets the default logging function to call for all // queries (if a logger is not explicitly passed in). func SetDefaultLogQuery(logQuery func(context.Context, QueryStats)) { defaultLogQuery.Store(logQuery) } type sqLogStruct struct { logSettings func(context.Context, *LogSettings) logQuery func(context.Context, QueryStats) } var _ SqLogger = (*sqLogStruct)(nil) func (l *sqLogStruct) SqLogSettings(ctx context.Context, logSettings *LogSettings) { if l.logSettings == nil { return } l.logSettings(ctx, logSettings) } func (l *sqLogStruct) SqLogQuery(ctx context.Context, queryStats QueryStats) { if l.logQuery == nil { return } l.logQuery(ctx, queryStats) } const ( colorReset = "\x1b[0m" colorRed = "\x1b[91m" colorGreen = "\x1b[92m" colorYellow = "\x1b[93m" colorBlue = "\x1b[94m" colorPurple = "\x1b[95m" colorCyan = "\x1b[96m" colorGray = "\x1b[97m" colorWhite = "\x1b[97m" ) ================================================ FILE: logger_test.go ================================================ package sq import ( "bytes" "context" "database/sql" "fmt" "log" "testing" "time" "github.com/bokwoon95/sq/internal/testutil" ) func TestLogger(t *testing.T) { type TT struct { description string ctx context.Context stats QueryStats config LoggerConfig wantOutput string } assert := func(t *testing.T, tt TT) { if tt.ctx == nil { tt.ctx = context.Background() } buf := &bytes.Buffer{} logger := sqLogger{ logger: log.New(buf, "", 0), config: tt.config, } logger.SqLogQuery(tt.ctx, tt.stats) if diff := testutil.Diff(buf.String(), tt.wantOutput); diff != "" { t.Error(testutil.Callers(), diff) } } t.Run("Log VerboseLog", func(t *testing.T) { t.Parallel() var logSettings LogSettings Log(nil).SqLogSettings(context.Background(), &logSettings) diff := testutil.Diff(logSettings, LogSettings{ LogAsynchronously: false, IncludeTime: true, IncludeCaller: true, IncludeResults: 0, }) if diff != "" { t.Error(testutil.Callers(), diff) } VerboseLog(nil).SqLogSettings(context.Background(), &logSettings) diff = testutil.Diff(logSettings, LogSettings{ LogAsynchronously: false, IncludeTime: true, IncludeCaller: true, IncludeResults: 5, }) if diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("no color", func(t *testing.T) { var tt TT tt.config.NoColor = true tt.stats.Query = "SELECT 1" tt.wantOutput = "[OK] SELECT 1;\n" assert(t, tt) }) tests := []TT{{ description: "err", stats: QueryStats{ Query: "SELECT 1", Err: fmt.Errorf("lorem ipsum"), }, wantOutput: "\x1b[91m[FAIL]\x1b[0m SELECT 1;\x1b[94m err\x1b[0m={lorem ipsum}\n", }, { description: "HideArgs", config: LoggerConfig{HideArgs: true}, stats: QueryStats{ Query: "SELECT ?", Args: []any{1}, }, wantOutput: "\x1b[92m[OK]\x1b[0m SELECT ?;\n", }, { description: "RowCount", stats: QueryStats{ Query: "SELECT 1", RowCount: sql.NullInt64{Valid: true, Int64: 3}, }, wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m rowCount\x1b[0m=3\n", }, { description: "RowsAffected", stats: QueryStats{ Query: "SELECT 1", RowsAffected: sql.NullInt64{Valid: true, Int64: 5}, }, wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m rowsAffected\x1b[0m=5\n", }, { description: "LastInsertId", stats: QueryStats{ Query: "SELECT 1", LastInsertId: sql.NullInt64{Valid: true, Int64: 7}, }, wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m lastInsertId\x1b[0m=7\n", }, { description: "Exists", stats: QueryStats{ Query: "SELECT EXISTS (SELECT 1)", Exists: sql.NullBool{Valid: true, Bool: true}, }, wantOutput: "\x1b[92m[OK]\x1b[0m SELECT EXISTS (SELECT 1);\x1b[94m exists\x1b[0m=true\n", }, { description: "ShowCaller", config: LoggerConfig{ShowCaller: true}, stats: QueryStats{ Query: "SELECT 1", CallerFile: "file.go", CallerLine: 22, CallerFunction: "someFunc", }, wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m caller\x1b[0m=file.go:22:someFunc\n", }, { description: "Verbose", config: LoggerConfig{InterpolateVerbose: true, ShowTimeTaken: true}, stats: QueryStats{ Query: "SELECT ?, ?", Args: []any{1, "bob"}, TimeTaken: 300 * time.Millisecond, }, wantOutput: "\x1b[92m[OK]\x1b[0m\x1b[94m timeTaken\x1b[0m=300ms" + "\n\x1b[95m----[ Executing query ]----\x1b[0m" + "\nSELECT ?, ?; []interface {}{1, \"bob\"}" + "\n\x1b[95m----[ with bind values ]----\x1b[0m" + "\nSELECT 1, 'bob';\n", }, { description: "ShowResults", config: LoggerConfig{ShowResults: 1}, stats: QueryStats{ Query: "SELECT 1", Results: "\nlorem ipsum dolor sit amet", }, wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;" + "\n\x1b[95m----[ Fetched result ]----\x1b[0m" + "\nlorem ipsum dolor sit amet\n", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { assert(t, tt) }) } } ================================================ FILE: misc.go ================================================ package sq import ( "bytes" "context" "fmt" "strings" ) // ValueExpression represents an SQL value that is passed in as an argument to // a prepared query. type ValueExpression struct { value any alias string } var _ interface { Field Predicate Any } = (*ValueExpression)(nil) // Value returns a new ValueExpression. func Value(value any) ValueExpression { return ValueExpression{value: value} } // WriteSQL implements the SQLWriter interface. func (e ValueExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return WriteValue(ctx, dialect, buf, args, params, e.value) } // As returns a new ValueExpression with the given alias. func (e ValueExpression) As(alias string) ValueExpression { e.alias = alias return e } // In returns a 'expr IN (val)' Predicate. func (e ValueExpression) In(val any) Predicate { return In(e.value, val) } // Eq returns a 'expr = val' Predicate. func (e ValueExpression) Eq(val any) Predicate { return Eq(e.value, val) } // Ne returns a 'expr <> val' Predicate. func (e ValueExpression) Ne(val any) Predicate { return Ne(e.value, val) } // Lt returns a 'expr < val' Predicate. func (e ValueExpression) Lt(val any) Predicate { return Lt(e.value, val) } // Le returns a 'expr <= val' Predicate. func (e ValueExpression) Le(val any) Predicate { return Le(e.value, val) } // Gt returns a 'expr > val' Predicate. func (e ValueExpression) Gt(val any) Predicate { return Gt(e.value, val) } // Ge returns a 'expr >= val' Predicate. func (e ValueExpression) Ge(val any) Predicate { return Ge(e.value, val) } // GetAlias returns the alias of the ValueExpression. func (e ValueExpression) GetAlias() string { return e.alias } // IsField implements the Field interface. func (e ValueExpression) IsField() {} // IsArray implements the Array interface. func (e ValueExpression) IsArray() {} // IsBinary implements the Binary interface. func (e ValueExpression) IsBinary() {} // IsBoolean implements the Boolean interface. func (e ValueExpression) IsBoolean() {} // IsEnum implements the Enum interface. func (e ValueExpression) IsEnum() {} // IsJSON implements the JSON interface. func (e ValueExpression) IsJSON() {} // IsNumber implements the Number interface. func (e ValueExpression) IsNumber() {} // IsString implements the String interface. func (e ValueExpression) IsString() {} // IsTime implements the Time interfaces. func (e ValueExpression) IsTime() {} // IsUUID implements the UUID interface. func (e ValueExpression) IsUUID() {} // LiteralValue represents an SQL value literally interpolated into the query. // Doing so potentially exposes the query to SQL injection so only do this for // values that you trust e.g. literals and constants. type LiteralValue struct { value any alias string } var _ interface { Field Predicate Binary Boolean Number String Time } = (*LiteralValue)(nil) // Literal returns a new LiteralValue. func Literal(value any) LiteralValue { return LiteralValue{value: value} } // WriteSQL implements the SQLWriter interface. func (v LiteralValue) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { s, err := Sprint(dialect, v.value) if err != nil { return err } buf.WriteString(s) return nil } // As returns a new LiteralValue with the given alias. func (v LiteralValue) As(alias string) LiteralValue { v.alias = alias return v } // In returns a 'literal IN (val)' Predicate. func (v LiteralValue) In(val any) Predicate { return In(v, val) } // Eq returns a 'literal = val' Predicate. func (v LiteralValue) Eq(val any) Predicate { return Eq(v, val) } // Ne returns a 'literal <> val' Predicate. func (v LiteralValue) Ne(val any) Predicate { return Ne(v, val) } // Lt returns a 'literal < val' Predicate. func (v LiteralValue) Lt(val any) Predicate { return Lt(v, val) } // Le returns a 'literal <= val' Predicate. func (v LiteralValue) Le(val any) Predicate { return Le(v, val) } // Gt returns a 'literal > val' Predicate. func (v LiteralValue) Gt(val any) Predicate { return Gt(v, val) } // Ge returns a 'literal >= val' Predicate. func (v LiteralValue) Ge(val any) Predicate { return Ge(v, val) } // GetAlias returns the alias of the LiteralValue. func (v LiteralValue) GetAlias() string { return v.alias } // IsField implements the Field interface. func (v LiteralValue) IsField() {} // IsBinary implements the Binary interface. func (v LiteralValue) IsBinary() {} // IsBoolean implements the Boolean interface. func (v LiteralValue) IsBoolean() {} // IsNumber implements the Number interface. func (v LiteralValue) IsNumber() {} // IsString implements the String interface. func (v LiteralValue) IsString() {} // IsTime implements the Time interfaces. func (v LiteralValue) IsTime() {} // DialectExpression represents an SQL expression that renders differently // depending on the dialect. type DialectExpression struct { Default any Cases DialectCases } // DialectCases is a slice of DialectCases. type DialectCases = []DialectCase // DialectCase holds the result to be used for a given dialect in a // DialectExpression. type DialectCase struct { Dialect string Result any } var _ interface { Table Field Predicate Any } = (*DialectExpression)(nil) // DialectValue returns a new DialectExpression. The value passed in is used as // the default. func DialectValue(value any) DialectExpression { return DialectExpression{Default: value} } // DialectExpr returns a new DialectExpression. The expression passed in is // used as the default. func DialectExpr(format string, values ...any) DialectExpression { return DialectExpression{Default: Expr(format, values...)} } // WriteSQL implements the SQLWriter interface. func (e DialectExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { for _, Case := range e.Cases { if dialect == Case.Dialect { return WriteValue(ctx, dialect, buf, args, params, Case.Result) } } return WriteValue(ctx, dialect, buf, args, params, e.Default) } // DialectValue adds a new dialect-value pair to the DialectExpression. func (e DialectExpression) DialectValue(dialect string, value any) DialectExpression { e.Cases = append(e.Cases, DialectCase{Dialect: dialect, Result: value}) return e } // DialectExpr adds a new dialect-expression pair to the DialectExpression. func (e DialectExpression) DialectExpr(dialect string, format string, values ...any) DialectExpression { e.Cases = append(e.Cases, DialectCase{Dialect: dialect, Result: Expr(format, values...)}) return e } // IsTable implements the Table interface. func (e DialectExpression) IsTable() {} // IsField implements the Field interface. func (e DialectExpression) IsField() {} // IsArray implements the Array interface. func (e DialectExpression) IsArray() {} // IsBinary implements the Binary interface. func (e DialectExpression) IsBinary() {} // IsBoolean implements the Boolean interface. func (e DialectExpression) IsBoolean() {} // IsEnum implements the Enum interface. func (e DialectExpression) IsEnum() {} // IsJSON implements the JSON interface. func (e DialectExpression) IsJSON() {} // IsNumber implements the Number interface. func (e DialectExpression) IsNumber() {} // IsString implements the String interface. func (e DialectExpression) IsString() {} // IsTime implements the Time interface. func (e DialectExpression) IsTime() {} // IsUUID implements the UUID interface. func (e DialectExpression) IsUUID() {} // CaseExpression represents an SQL CASE expression. type CaseExpression struct { alias string Cases PredicateCases Default any } // PredicateCases is a slice of PredicateCases. type PredicateCases = []PredicateCase // PredicateCase holds the result to be used for a given predicate in a // CaseExpression. type PredicateCase struct { Predicate Predicate Result any } var _ interface { Field Any } = (*CaseExpression)(nil) // CaseWhen returns a new CaseExpression. func CaseWhen(predicate Predicate, result any) CaseExpression { return CaseExpression{ Cases: PredicateCases{{Predicate: predicate, Result: result}}, } } // WriteSQL implements the SQLWriter interface. func (e CaseExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { buf.WriteString("CASE") if len(e.Cases) == 0 { return fmt.Errorf("CaseExpression empty") } var err error for i, Case := range e.Cases { buf.WriteString(" WHEN ") err = WriteValue(ctx, dialect, buf, args, params, Case.Predicate) if err != nil { return fmt.Errorf("CASE #%d WHEN: %w", i+1, err) } buf.WriteString(" THEN ") err = WriteValue(ctx, dialect, buf, args, params, Case.Result) if err != nil { return fmt.Errorf("CASE #%d THEN: %w", i+1, err) } } if e.Default != nil { buf.WriteString(" ELSE ") err = WriteValue(ctx, dialect, buf, args, params, e.Default) if err != nil { return fmt.Errorf("CASE ELSE: %w", err) } } buf.WriteString(" END") return nil } // When adds a new predicate-result pair to the CaseExpression. func (e CaseExpression) When(predicate Predicate, result any) CaseExpression { e.Cases = append(e.Cases, PredicateCase{Predicate: predicate, Result: result}) return e } // Else sets the fallback result of the CaseExpression. func (e CaseExpression) Else(fallback any) CaseExpression { e.Default = fallback return e } // As returns a new CaseExpression with the given alias. func (e CaseExpression) As(alias string) CaseExpression { e.alias = alias return e } // GetAlias returns the alias of the CaseExpression. func (e CaseExpression) GetAlias() string { return e.alias } // IsField implements the Field interface. func (e CaseExpression) IsField() {} // IsArray implements the Array interface. func (e CaseExpression) IsArray() {} // IsBinary implements the Binary interface. func (e CaseExpression) IsBinary() {} // IsBoolean implements the Boolean interface. func (e CaseExpression) IsBoolean() {} // IsEnum implements the Enum interface. func (e CaseExpression) IsEnum() {} // IsJSON implements the JSON interface. func (e CaseExpression) IsJSON() {} // IsNumber implements the Number interface. func (e CaseExpression) IsNumber() {} // IsString implements the String interface. func (e CaseExpression) IsString() {} // IsTime implements the Time interface. func (e CaseExpression) IsTime() {} // IsUUID implements the UUID interface. func (e CaseExpression) IsUUID() {} // SimpleCaseExpression represents an SQL simple CASE expression. type SimpleCaseExpression struct { alias string Expression any Cases SimpleCases Default any } // SimpleCases is a slice of SimpleCases. type SimpleCases = []SimpleCase // SimpleCase holds the result to be used for a given value in a // SimpleCaseExpression. type SimpleCase struct { Value any Result any } var _ interface { Field Any } = (*SimpleCaseExpression)(nil) // Case returns a new SimpleCaseExpression. func Case(expression any) SimpleCaseExpression { return SimpleCaseExpression{Expression: expression} } // WriteSQL implements the SQLWriter interface. func (e SimpleCaseExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { buf.WriteString("CASE ") if len(e.Cases) == 0 { return fmt.Errorf("SimpleCaseExpression empty") } var err error err = WriteValue(ctx, dialect, buf, args, params, e.Expression) if err != nil { return fmt.Errorf("CASE: %w", err) } for i, Case := range e.Cases { buf.WriteString(" WHEN ") err = WriteValue(ctx, dialect, buf, args, params, Case.Value) if err != nil { return fmt.Errorf("CASE #%d WHEN: %w", i+1, err) } buf.WriteString(" THEN ") err = WriteValue(ctx, dialect, buf, args, params, Case.Result) if err != nil { return fmt.Errorf("CASE #%d THEN: %w", i+1, err) } } if e.Default != nil { buf.WriteString(" ELSE ") err = WriteValue(ctx, dialect, buf, args, params, e.Default) if err != nil { return fmt.Errorf("CASE ELSE: %w", err) } } buf.WriteString(" END") return nil } // When adds a new value-result pair to the SimpleCaseExpression. func (e SimpleCaseExpression) When(value any, result any) SimpleCaseExpression { e.Cases = append(e.Cases, SimpleCase{Value: value, Result: result}) return e } // Else sets the fallback result of the SimpleCaseExpression. func (e SimpleCaseExpression) Else(fallback any) SimpleCaseExpression { e.Default = fallback return e } // As returns a new SimpleCaseExpression with the given alias. func (e SimpleCaseExpression) As(alias string) SimpleCaseExpression { e.alias = alias return e } // GetAlias returns the alias of the SimpleCaseExpression. func (e SimpleCaseExpression) GetAlias() string { return e.alias } // IsField implements the Field interface. func (e SimpleCaseExpression) IsField() {} // IsArray implements the Array interface. func (e SimpleCaseExpression) IsArray() {} // IsBinary implements the Binary interface. func (e SimpleCaseExpression) IsBinary() {} // IsBoolean implements the Boolean interface. func (e SimpleCaseExpression) IsBoolean() {} // IsEnum implements the Enum interface. func (e SimpleCaseExpression) IsEnum() {} // IsJSON implements the JSON interface. func (e SimpleCaseExpression) IsJSON() {} // IsNumber implements the Number interface. func (e SimpleCaseExpression) IsNumber() {} // IsString implements the String interface. func (e SimpleCaseExpression) IsString() {} // IsTime implements the Time interface. func (e SimpleCaseExpression) IsTime() {} // IsUUID implements the UUID interface. func (e SimpleCaseExpression) IsUUID() {} // Count represents an SQL COUNT() expression. func Count(field Field) Expression { return Expr("COUNT({})", field) } // CountStar represents an SQL COUNT(*) expression. func CountStar() Expression { return Expr("COUNT(*)") } // Sum represents an SQL SUM() expression. func Sum(num Number) Expression { return Expr("SUM({})", num) } // Avg represents an SQL AVG() expression. func Avg(num Number) Expression { return Expr("AVG({})", num) } // Min represent an SQL MIN() expression. func Min(field Field) Expression { return Expr("MIN({})", field) } // Max represents an SQL MAX() expression. func Max(field Field) Expression { return Expr("MAX({})", field) } // SelectValues represents a table literal comprised of SELECT statements // UNION-ed together e.g. // // (SELECT 1 AS a, 2 AS b, 3 AS c // UNION ALL // SELECT 4, 5, 6 // UNION ALL // SELECT 7, 8, 9) AS tbl type SelectValues struct { Alias string Columns []string RowValues [][]any } var _ interface { Query Table } = (*SelectValues)(nil) // WriteSQL implements the SQLWriter interface. func (vs SelectValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error for i, rowvalue := range vs.RowValues { if i > 0 { buf.WriteString(" UNION ALL ") } if len(vs.Columns) > 0 && len(rowvalue) != len(vs.Columns) { return fmt.Errorf("rowvalue #%d: got %d values, want %d values (%s)", i+1, len(rowvalue), len(vs.Columns), strings.Join(vs.Columns, ", ")) } buf.WriteString("SELECT ") for j, value := range rowvalue { if j > 0 { buf.WriteString(", ") } err = WriteValue(ctx, dialect, buf, args, params, value) if err != nil { return fmt.Errorf("rowvalue #%d value #%d: %w", i+1, j+1, err) } if i == 0 && j < len(vs.Columns) { buf.WriteString(" AS " + QuoteIdentifier(dialect, vs.Columns[j])) } } } return nil } // Field returns a new field qualified by the SelectValues' alias. func (vs SelectValues) Field(name string) AnyField { return NewAnyField(name, TableStruct{alias: vs.Alias}) } // SetFetchableFields implements the Query interface. It always returns false // as the second result. func (vs SelectValues) SetFetchableFields([]Field) (query Query, ok bool) { return vs, false } // GetDialect implements the Query interface. It always returns an empty // string. func (vs SelectValues) GetDialect() string { return "" } // GetAlias returns the alias of the SelectValues. func (vs SelectValues) GetAlias() string { return vs.Alias } // IsTable implements the Table interface. func (vs SelectValues) IsTable() {} // TableValues represents a table literal created by the VALUES clause e.g. // // (VALUES // // (1, 2, 3), // (4, 5, 6), // (7, 8, 9)) AS tbl (a, b, c) type TableValues struct { Alias string Columns []string RowValues [][]any } var _ interface { Query Table } = (*TableValues)(nil) // WriteSQL implements the SQLWriter interface. func (vs TableValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { if len(vs.RowValues) == 0 { return nil } var err error buf.WriteString("VALUES ") for i, rowvalue := range vs.RowValues { if len(vs.Columns) > 0 && len(vs.Columns) != len(rowvalue) { return fmt.Errorf("rowvalue #%d: got %d values, want %d values (%s)", i+1, len(rowvalue), len(vs.Columns), strings.Join(vs.Columns, ", ")) } if i > 0 { buf.WriteString(", ") } if dialect == DialectMySQL { buf.WriteString("ROW(") } else { buf.WriteString("(") } for j, value := range rowvalue { if j > 0 { buf.WriteString(", ") } err = WriteValue(ctx, dialect, buf, args, params, value) if err != nil { return fmt.Errorf("rowvalue #%d value #%d: %w", i+1, j+1, err) } } buf.WriteString(")") } return nil } // Field returns a new field qualified by the TableValues' alias. func (vs TableValues) Field(name string) AnyField { return NewAnyField(name, TableStruct{alias: vs.Alias}) } // SetFetchableFields implements the Query interface. It always returns false // as the second result. func (vs TableValues) SetFetchableFields([]Field) (query Query, ok bool) { return vs, false } // GetDialect implements the Query interface. It always returns an empty // string. func (vs TableValues) GetDialect() string { return "" } // GetAlias returns the alias of the TableValues. func (vs TableValues) GetAlias() string { return vs.Alias } // GetColumns returns the names of the columns in the TableValues. func (vs TableValues) GetColumns() []string { return vs.Columns } // IsTable implements the Table interface. func (vs TableValues) IsTable() {} ================================================ FILE: misc_test.go ================================================ package sq import ( "bytes" "context" "testing" "time" "github.com/bokwoon95/sq/internal/testutil" ) func TestValueExpression(t *testing.T) { t.Run("alias", func(t *testing.T) { t.Parallel() expr := Value(1).As("num") if diff := testutil.Diff(expr.GetAlias(), "num"); diff != "" { t.Error(testutil.Callers(), diff) } }) tests := []TestTable{{ description: "basic", item: Value(Param("xyz", 42)), wantQuery: "?", wantArgs: []any{42}, wantParams: map[string][]int{"xyz": {0}}, }, { description: "In", item: Value(1).In([]int{18, 21, 32}), wantQuery: "? IN (?, ?, ?)", wantArgs: []any{1, 18, 21, 32}, }, { description: "Eq", item: Value(1).Eq(34), wantQuery: "? = ?", wantArgs: []any{1, 34}, }, { description: "Ne", item: Value(1).Ne(34), wantQuery: "? <> ?", wantArgs: []any{1, 34}, }, { description: "Lt", item: Value(1).Lt(34), wantQuery: "? < ?", wantArgs: []any{1, 34}, }, { description: "Le", item: Value(1).Le(34), wantQuery: "? <= ?", wantArgs: []any{1, 34}, }, { description: "Gt", item: Value(1).Gt(34), wantQuery: "? > ?", wantArgs: []any{1, 34}, }, { description: "Ge", item: Value(1).Ge(34), wantQuery: "? >= ?", wantArgs: []any{1, 34}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestLiteralExpression(t *testing.T) { t.Run("alias", func(t *testing.T) { t.Parallel() expr := Literal(1).As("num") if diff := testutil.Diff(expr.GetAlias(), "num"); diff != "" { t.Error(testutil.Callers(), diff) } }) tests := []TestTable{{ description: "binary", item: Literal([]byte{0xab, 0xcd, 0xef}), wantQuery: "x'abcdef'", }, { description: "time", item: Literal(time.Unix(0, 0).UTC()), wantQuery: "'1970-01-01 00:00:00'", }, { description: "In", item: Literal(1).In([]any{Literal(18), Literal(21), Literal(32)}), wantQuery: "1 IN (18, 21, 32)", }, { description: "Eq", item: Literal(true).Eq(Literal(false)), wantQuery: "TRUE = FALSE", }, { description: "Ne", item: Literal("one").Ne(Literal("thirty four")), wantQuery: "'one' <> 'thirty four'", }, { description: "Lt", item: Literal(1).Lt(Literal(34)), wantQuery: "1 < 34", }, { description: "Le", item: Literal(1).Le(Literal(34)), wantQuery: "1 <= 34", }, { description: "Gt", item: Literal(1).Gt(Literal(34)), wantQuery: "1 > 34", }, { description: "Ge", item: Literal(1).Ge(Literal(34)), wantQuery: "1 >= 34", }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } } func TestDialectExpression(t *testing.T) { t.Parallel() expr := DialectValue(Expr("default")). DialectValue(DialectSQLite, Expr("sqlite")). DialectValue(DialectPostgres, Expr("postgres")). DialectValue(DialectMySQL, Expr("mysql")). DialectExpr(DialectSQLServer, "{}", Expr("sqlserver")) var tt TestTable tt.item = expr // default tt.wantQuery = "default" tt.assert(t) // sqlite tt.dialect = DialectSQLite tt.wantQuery = "sqlite" tt.assert(t) // postgres tt.dialect = DialectPostgres tt.wantQuery = "postgres" tt.assert(t) // mysql tt.dialect = DialectMySQL tt.wantQuery = "mysql" tt.assert(t) // sqlserver tt.dialect = DialectSQLServer tt.wantQuery = "sqlserver" tt.assert(t) } func TestCaseExpressions(t *testing.T) { t.Run("name and alias", func(t *testing.T) { t.Parallel() // CaseExpression caseExpr := CaseWhen(Value(true), 1).As("result_a") if diff := testutil.Diff(caseExpr.GetAlias(), "result_a"); diff != "" { t.Error(testutil.Callers(), diff) } // SimpleCaseExpression simpleCaseExpr := Case(1).When(1, 2).As("result_b") if diff := testutil.Diff(simpleCaseExpr.GetAlias(), "result_b"); diff != "" { t.Error(testutil.Callers(), diff) } }) t.Run("CaseExpression", func(t *testing.T) { t.Parallel() TestTable{ item: CaseWhen(Expr("x = y"), 1).When(Expr("a = b"), 2).Else(3), wantQuery: "CASE WHEN x = y THEN ? WHEN a = b THEN ? ELSE ? END", wantArgs: []any{1, 2, 3}, }.assert(t) }) t.Run("SimpleCaseExpression", func(t *testing.T) { t.Parallel() TestTable{ item: Case(Expr("a")).When(1, 2).When(3, 4).Else(5), wantQuery: "CASE a WHEN ? THEN ? WHEN ? THEN ? ELSE ? END", wantArgs: []any{1, 2, 3, 4, 5}, }.assert(t) }) t.Run("CaseExpression cannot be empty", func(t *testing.T) { t.Parallel() TestTable{item: CaseExpression{}}.assertNotOK(t) }) t.Run("SimpleCaseExpression cannot be empty", func(t *testing.T) { t.Parallel() TestTable{item: SimpleCaseExpression{}}.assertNotOK(t) }) errTests := []TestTable{{ description: "CASE WHEN predicate err", item: CaseWhen(FaultySQL{}, 1), }, { description: "CASE WHEN result err", item: CaseWhen(Value(true), FaultySQL{}), }, { description: "CASE WHEN fallback err", item: CaseWhen(Value(true), 1).Else(FaultySQL{}), }, { description: "CASE expression err", item: Case(FaultySQL{}).When(1, 2), }, { description: "CASE value err", item: Case(1).When(FaultySQL{}, 2), }, { description: "CASE result err", item: Case(1).When(2, FaultySQL{}), }, { description: "CASE fallback err", item: Case(1).When(2, 3).Else(FaultySQL{}), }} for _, tt := range errTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertErr(t, ErrFaultySQL) }) } } func TestSelectValues(t *testing.T) { type TestTable struct { description string dialect string item SelectValues wantQuery string wantArgs []any } t.Run("dialect alias and fields", func(t *testing.T) { selectValues := SelectValues{ Alias: "aaa", } if diff := testutil.Diff(selectValues.GetAlias(), "aaa"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(selectValues.GetDialect(), ""); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := selectValues.SetFetchableFields(nil) if diff := testutil.Diff(ok, false); diff != "" { t.Error(testutil.Callers(), diff) } gotField, _, _ := ToSQL("", selectValues.Field("bbb"), nil) if diff := testutil.Diff(gotField, "aaa.bbb"); diff != "" { t.Error(testutil.Callers(), diff) } }) tests := []TestTable{{ description: "empty", item: SelectValues{}, wantQuery: "", wantArgs: nil, }, { description: "no columns", item: SelectValues{ RowValues: [][]any{ {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, }, }, wantQuery: "SELECT ?, ?, ?" + " UNION ALL SELECT ?, ?, ?" + " UNION ALL SELECT ?, ?, ?", wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, }, { description: "postgres", dialect: DialectPostgres, item: SelectValues{ Columns: []string{"a", "b", "c"}, RowValues: [][]any{ {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, }, }, wantQuery: "SELECT $1 AS a, $2 AS b, $3 AS c" + " UNION ALL SELECT $4, $5, $6" + " UNION ALL SELECT $7, $8, $9", wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() var buf bytes.Buffer var gotArgs []any err := tt.item.WriteSQL(context.Background(), tt.dialect, &buf, &gotArgs, nil) if err != nil { t.Fatal(testutil.Callers(), err) } gotQuery := buf.String() if diff := testutil.Diff(gotQuery, tt.wantQuery); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(gotArgs, tt.wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } }) } } func TestTableValues(t *testing.T) { type TestTable struct { description string dialect string item TableValues wantQuery string wantArgs []any } t.Run("dialect alias columns and fields", func(t *testing.T) { tableValues := TableValues{ Alias: "aaa", Columns: []string{"a", "b", "c"}, } if diff := testutil.Diff(tableValues.GetAlias(), "aaa"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(tableValues.GetDialect(), ""); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := tableValues.SetFetchableFields(nil) if diff := testutil.Diff(ok, false); diff != "" { t.Error(testutil.Callers(), diff) } gotColumns := tableValues.GetColumns() wantColumns := []string{"a", "b", "c"} if diff := testutil.Diff(gotColumns, wantColumns); diff != "" { t.Error(testutil.Callers(), diff) } gotField, _, _ := ToSQL("", tableValues.Field("bbb"), nil) wantField := "aaa.bbb" if diff := testutil.Diff(gotField, wantField); diff != "" { t.Error(testutil.Callers(), diff) } }) tests := []TestTable{{ description: "empty", item: TableValues{}, wantQuery: "", wantArgs: nil, }, { description: "no columns", item: TableValues{ RowValues: [][]any{ {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, }, }, wantQuery: "VALUES (?, ?, ?)" + ", (?, ?, ?)" + ", (?, ?, ?)", wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, }, { description: "postgres", dialect: DialectPostgres, item: TableValues{ Columns: []string{"a", "b", "c"}, RowValues: [][]any{ {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, }, }, wantQuery: "VALUES ($1, $2, $3)" + ", ($4, $5, $6)" + ", ($7, $8, $9)", wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, }, { description: "mysql", dialect: DialectMySQL, item: TableValues{ Columns: []string{"a", "b", "c"}, RowValues: [][]any{ {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, }, }, wantQuery: "VALUES ROW(?, ?, ?)" + ", ROW(?, ?, ?)" + ", ROW(?, ?, ?)", wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() var buf bytes.Buffer var gotArgs []any err := tt.item.WriteSQL(context.Background(), tt.dialect, &buf, &gotArgs, nil) if err != nil { t.Fatal(testutil.Callers(), err) } gotQuery := buf.String() if diff := testutil.Diff(gotQuery, tt.wantQuery); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(gotArgs, tt.wantArgs); diff != "" { t.Error(testutil.Callers(), diff) } }) } } ================================================ FILE: row_column.go ================================================ package sq import ( "database/sql" "database/sql/driver" "encoding/json" "fmt" "path/filepath" "reflect" "runtime" "strconv" "strings" "time" "github.com/bokwoon95/sq/internal/googleuuid" "github.com/bokwoon95/sq/internal/pqarray" ) // Row represents the state of a row after a call to rows.Next(). type Row struct { dialect string sqlRows *sql.Rows runningIndex int fields []Field scanDest []any queryIsStatic bool columns []string columnTypes []*sql.ColumnType values []any columnIndex map[string]int } // Column returns the names of the columns returned by the query. This method // can only be called in a rowmapper if it is paired with a raw SQL query e.g. // Queryf("SELECT * FROM my_table"). Otherwise, an error will be returned. func (row *Row) Columns() []string { if row.queryIsStatic { return row.columns } if row.sqlRows == nil { return nil } columns, err := row.sqlRows.Columns() if err != nil { panic(fmt.Errorf(callsite(1)+"sqlRows.Columns: %w", err)) } return columns } // ColumnTypes returns the column types returned by the query. This method can // only be called in a rowmapper if it is paired with a raw SQL query e.g. // Queryf("SELECT * FROM my_table"). Otherwise, an error will be returned. func (row *Row) ColumnTypes() []*sql.ColumnType { if row.queryIsStatic { return row.columnTypes } if row.sqlRows == nil { return nil } columnTypes, err := row.sqlRows.ColumnTypes() if err != nil { panic(fmt.Errorf(callsite(1)+"sqlRows.ColumnTypes: %w", err)) } return columnTypes } // Values returns the values of the current row. This method can only be called // in a rowmapper if it is paired with a raw SQL query e.g. Queryf("SELECT * // FROM my_table"). Otherwise, an error will be returned. func (row *Row) Values() []any { if row.queryIsStatic { values := make([]any, len(row.values)) copy(values, row.values) return values } if row.sqlRows == nil { return nil } columns, err := row.sqlRows.Columns() if err != nil { panic(fmt.Errorf(callsite(1)+"sqlRows.Columns: %w", err)) } values := make([]any, len(columns)) scanDest := make([]any, len(columns)) for i := range values { scanDest[i] = &values[i] } err = row.sqlRows.Scan(scanDest...) if err != nil { panic(fmt.Errorf(callsite(1)+"sqlRows.Scan: %w", err)) } return values } // Value returns the value of the expression. It is intended for use cases // where you only know the name of the column but not its type to scan into. // The underlying type of the value is determined by the database driver you // are using. func (row *Row) Value(format string, values ...any) any { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s is not present in query (available columns: %s)", format, strings.Join(row.columns, ", "))) } return row.values[index] } if row.sqlRows == nil { var value any row.fields = append(row.fields, Expr(format, values...)) row.scanDest = append(row.scanDest, &value) return nil } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*any) return *scanDest } // Scan scans the expression into destPtr. func (row *Row) Scan(destPtr any, format string, values ...any) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call Scan for static queries")) } row.scan(destPtr, Expr(format, values...), 1) } // ScanField scans the field into destPtr. func (row *Row) ScanField(destPtr any, field Field) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call ScanField for static queries")) } row.scan(destPtr, field, 1) } func (row *Row) scan(destPtr any, field Field, skip int) { if row.sqlRows == nil { row.fields = append(row.fields, field) switch destPtr.(type) { case *bool, *sql.NullBool: row.scanDest = append(row.scanDest, &sql.NullBool{}) case *float64, *sql.NullFloat64: row.scanDest = append(row.scanDest, &sql.NullFloat64{}) case *int32, *sql.NullInt32: row.scanDest = append(row.scanDest, &sql.NullInt32{}) case *int, *int64, *sql.NullInt64: row.scanDest = append(row.scanDest, &sql.NullInt64{}) case *string, *sql.NullString: row.scanDest = append(row.scanDest, &sql.NullString{}) case *time.Time, *sql.NullTime: row.scanDest = append(row.scanDest, &sql.NullTime{}) default: if reflect.TypeOf(destPtr).Kind() != reflect.Ptr { panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr)) } row.scanDest = append(row.scanDest, destPtr) } return } defer func() { row.runningIndex++ }() switch destPtr := destPtr.(type) { case *bool: scanDest := row.scanDest[row.runningIndex].(*sql.NullBool) *destPtr = scanDest.Bool case *sql.NullBool: scanDest := row.scanDest[row.runningIndex].(*sql.NullBool) *destPtr = *scanDest case *float64: scanDest := row.scanDest[row.runningIndex].(*sql.NullFloat64) *destPtr = scanDest.Float64 case *sql.NullFloat64: scanDest := row.scanDest[row.runningIndex].(*sql.NullFloat64) *destPtr = *scanDest case *int: scanDest := row.scanDest[row.runningIndex].(*sql.NullInt64) *destPtr = int(scanDest.Int64) case *int32: scanDest := row.scanDest[row.runningIndex].(*sql.NullInt32) *destPtr = scanDest.Int32 case *sql.NullInt32: scanDest := row.scanDest[row.runningIndex].(*sql.NullInt32) *destPtr = *scanDest case *int64: scanDest := row.scanDest[row.runningIndex].(*sql.NullInt64) *destPtr = scanDest.Int64 case *sql.NullInt64: scanDest := row.scanDest[row.runningIndex].(*sql.NullInt64) *destPtr = *scanDest case *string: scanDest := row.scanDest[row.runningIndex].(*sql.NullString) *destPtr = scanDest.String case *sql.NullString: scanDest := row.scanDest[row.runningIndex].(*sql.NullString) *destPtr = *scanDest case *time.Time: scanDest := row.scanDest[row.runningIndex].(*sql.NullTime) *destPtr = scanDest.Time case *sql.NullTime: scanDest := row.scanDest[row.runningIndex].(*sql.NullTime) *destPtr = *scanDest default: destValue := reflect.ValueOf(destPtr).Elem() srcValue := reflect.ValueOf(row.scanDest[row.runningIndex]).Elem() destValue.Set(srcValue) } } // Array scans the array expression into destPtr. The destPtr must be a pointer // to a []string, []int, []int64, []int32, []float64, []float32 or []bool. func (row *Row) Array(destPtr any, format string, values ...any) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call Array for static queries")) } row.array(destPtr, Expr(format, values...), 1) } // ArrayField scans the array field into destPtr. The destPtr must be a pointer // to a []string, []int, []int64, []int32, []float64, []float32 or []bool. func (row *Row) ArrayField(destPtr any, field Array) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call ArrayField for static queries")) } row.array(destPtr, field, 1) } func (row *Row) array(destPtr any, field Array, skip int) { if row.sqlRows == nil { if reflect.TypeOf(destPtr).Kind() != reflect.Ptr { panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr)) } if row.dialect == DialectPostgres { switch destPtr.(type) { case *[]string, *[]int, *[]int64, *[]int32, *[]float64, *[]float32, *[]bool: break default: panic(fmt.Errorf(callsite(skip+1)+"destptr (%T) must be either a pointer to a []string, []int, []int64, []int32, []float64, []float32 or []bool", destPtr)) } } row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &nullBytes{ dialect: row.dialect, displayType: displayTypeString, }) return } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*nullBytes) if !scanDest.valid { return } if row.dialect != DialectPostgres { err := json.Unmarshal(scanDest.bytes, destPtr) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"unmarshaling json %q into %T: %w", string(scanDest.bytes), destPtr, err)) } return } switch destPtr := destPtr.(type) { case *[]string: var array pqarray.StringArray err := array.Scan(scanDest.bytes) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to string array: %w", string(scanDest.bytes), err)) } *destPtr = array case *[]int: var array pqarray.Int64Array err := array.Scan(scanDest.bytes) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to int64 array: %w", string(scanDest.bytes), err)) } *destPtr = (*destPtr)[:cap(*destPtr)] if len(*destPtr) < len(array) { *destPtr = make([]int, len(array)) } *destPtr = (*destPtr)[:len(array)] for i, num := range array { (*destPtr)[i] = int(num) } case *[]int64: var array pqarray.Int64Array err := array.Scan(scanDest.bytes) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to int64 array: %w", string(scanDest.bytes), err)) } *destPtr = array case *[]int32: var array pqarray.Int32Array err := array.Scan(scanDest.bytes) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to int32 array: %w", string(scanDest.bytes), err)) } *destPtr = array case *[]float64: var array pqarray.Float64Array err := array.Scan(scanDest.bytes) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to float64 array: %w", string(scanDest.bytes), err)) } *destPtr = array case *[]float32: var array pqarray.Float32Array err := array.Scan(scanDest.bytes) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to float32 array: %w", string(scanDest.bytes), err)) } *destPtr = array case *[]bool: var array pqarray.BoolArray err := array.Scan(scanDest.bytes) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"unable to convert %q to bool array: %w", string(scanDest.bytes), err)) } *destPtr = array default: panic(fmt.Errorf(callsite(skip+1)+"destptr (%T) must be either a pointer to a []string, []int, []int64, []int32, []float64, []float32 or []bool", destPtr)) } } // Bytes returns the []byte value of the expression. func (row *Row) Bytes(format string, values ...any) []byte { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: panic(fmt.Errorf(callsite(1)+"%d is int64, not []byte", value)) case float64: panic(fmt.Errorf(callsite(1)+"%d is float64, not []byte", value)) case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not []byte", value)) case []byte: return value case string: return []byte(value) case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not []byte", value)) case nil: return nil default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not []byte", value)) } } if row.sqlRows == nil { row.fields = append(row.fields, Expr(format, values...)) row.scanDest = append(row.scanDest, &nullBytes{ dialect: row.dialect, }) return nil } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*nullBytes) var b []byte if scanDest.valid { b = make([]byte, len(scanDest.bytes)) copy(b, scanDest.bytes) } return b } // BytesField returns the []byte value of the field. func (row *Row) BytesField(field Binary) []byte { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call BytesField for static queries")) } if row.sqlRows == nil { row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &nullBytes{ dialect: row.dialect, }) return nil } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*nullBytes) var b []byte if scanDest.valid { b = make([]byte, len(scanDest.bytes)) copy(b, scanDest.bytes) } return b } // == Bool == // // Bool returns the bool value of the expression. func (row *Row) Bool(format string, values ...any) bool { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: if value == 1 { return true } if value == 0 { return false } panic(fmt.Errorf(callsite(1)+"%d is int64, not bool", value)) case float64: panic(fmt.Errorf(callsite(1)+"%d is float64, not bool", value)) case bool: return value case []byte: // Special case: go-mysql-driver returns everything as []byte. if string(value) == "1" { return true } if string(value) == "0" { return false } panic(fmt.Errorf(callsite(1)+"%#v is []byte, not bool", value)) case string: panic(fmt.Errorf(callsite(1)+"%q is string, not bool", value)) case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not bool", value)) case nil: return false default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not bool", value)) } } return row.NullBoolField(Expr(format, values...)).Bool } // BoolField returns the bool value of the field. func (row *Row) BoolField(field Boolean) bool { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call BoolField for static queries")) } return row.NullBoolField(field).Bool } // NullBool returns the sql.NullBool value of the expression. func (row *Row) NullBool(format string, values ...any) sql.NullBool { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: if value == 1 { return sql.NullBool{Bool: true, Valid: true} } if value == 0 { return sql.NullBool{Bool: false, Valid: true} } panic(fmt.Errorf(callsite(1)+"%d is int64, not bool", value)) case float64: panic(fmt.Errorf(callsite(1)+"%d is float64, not bool", value)) case bool: return sql.NullBool{Bool: value, Valid: true} case []byte: // Special case: go-mysql-driver returns everything as []byte. if string(value) == "1" { return sql.NullBool{Bool: true, Valid: true} } if string(value) == "0" { return sql.NullBool{Bool: false, Valid: true} } panic(fmt.Errorf(callsite(1)+"%d is []byte, not bool", value)) case string: panic(fmt.Errorf(callsite(1)+"%q is string, not bool", value)) case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not bool", value)) case nil: return sql.NullBool{} default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not bool", value)) } } return row.NullBoolField(Expr(format, values...)) } // NullBoolField returns the sql.NullBool value of the field. func (row *Row) NullBoolField(field Boolean) sql.NullBool { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call NullBoolField for static queries")) } if row.sqlRows == nil { row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &sql.NullBool{}) return sql.NullBool{} } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*sql.NullBool) return *scanDest } // Enum scans the enum expression into destPtr. func (row *Row) Enum(destPtr Enumeration, format string, values ...any) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call Enum for static queries")) } row.enum(destPtr, Expr(format, values...), 1) } // EnumField scans the enum field into destPtr. func (row *Row) EnumField(destPtr Enumeration, field Enum) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call EnumField for static queries")) } row.enum(destPtr, field, 1) } func (row *Row) enum(destPtr Enumeration, field Enum, skip int) { if row.sqlRows == nil { destType := reflect.TypeOf(destPtr) if destType.Kind() != reflect.Ptr { panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr)) } row.fields = append(row.fields, field) switch destType.Elem().Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.String: row.scanDest = append(row.scanDest, &sql.NullString{}) default: panic(fmt.Errorf(callsite(skip+1)+"underlying type of %[1]v is neither an integer or string (%[1]T)", destPtr)) } return } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*sql.NullString) names := destPtr.Enumerate() enumIndex := 0 destValue := reflect.ValueOf(destPtr).Elem() if scanDest.Valid { enumIndex = getEnumIndex(scanDest.String, names, destValue.Type()) } if enumIndex < 0 { panic(fmt.Errorf(callsite(skip+1)+"%q is not a valid %T", scanDest.String, destPtr)) } switch destValue.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: destValue.SetInt(int64(enumIndex)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: destValue.SetUint(uint64(enumIndex)) case reflect.String: destValue.SetString(scanDest.String) } } // Float64 returns the float64 value of the expression. func (row *Row) Float64(format string, values ...any) float64 { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: return float64(value) case float64: return value case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not float64", value)) case []byte: // Special case: go-mysql-driver returns everything as []byte. n, err := strconv.ParseFloat(string(value), 64) if err != nil { panic(fmt.Errorf(callsite(1)+"%d is []byte, not float64", value)) } return n case string: panic(fmt.Errorf(callsite(1)+"%q is string, not float64", value)) case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not float64", value)) case nil: return 0 default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not float64", value)) } } return row.NullFloat64Field(Expr(format, values...)).Float64 } // Float64Field returns the float64 value of the field. func (row *Row) Float64Field(field Number) float64 { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call Float64Field for static queries")) } return row.NullFloat64Field(field).Float64 } // NullFloat64 returns the sql.NullFloat64 valye of the expression. func (row *Row) NullFloat64(format string, values ...any) sql.NullFloat64 { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: return sql.NullFloat64{Float64: float64(value), Valid: true} case float64: return sql.NullFloat64{Float64: value, Valid: true} case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not float64", value)) case []byte: // Special case: go-mysql-driver returns everything as []byte. n, err := strconv.ParseFloat(string(value), 64) if err != nil { panic(fmt.Errorf(callsite(1)+"%d is []byte, not float64", value)) } return sql.NullFloat64{Float64: n, Valid: true} case string: panic(fmt.Errorf(callsite(1)+"%q is string, not float64", value)) case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not float64", value)) case nil: return sql.NullFloat64{} default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not float64", value)) } } return row.NullFloat64Field(Expr(format, values...)) } // NullFloat64Field returns the sql.NullFloat64 value of the field. func (row *Row) NullFloat64Field(field Number) sql.NullFloat64 { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call NullFloat64Field for static queries")) } if row.sqlRows == nil { row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &sql.NullFloat64{}) return sql.NullFloat64{} } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*sql.NullFloat64) return *scanDest } // Int returns the int value of the expression. func (row *Row) Int(format string, values ...any) int { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: return int(value) case float64: return int(value) case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not int", value)) case []byte: // Special case: go-mysql-driver returns everything as []byte. n, err := strconv.Atoi(string(value)) if err != nil { panic(fmt.Errorf(callsite(1)+"%d is []byte, not int", value)) } return n case string: panic(fmt.Errorf(callsite(1)+"%q is string, not int", value)) case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not int", value)) case nil: return 0 default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not int", value)) } } return int(row.NullInt64Field(Expr(format, values...)).Int64) } // IntField returns the int value of the field. func (row *Row) IntField(field Number) int { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call IntField for static queries")) } return int(row.NullInt64Field(field).Int64) } // Int64 returns the int64 value of the expression. func (row *Row) Int64(format string, values ...any) int64 { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: return int64(value) case float64: return int64(value) case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not int64", value)) case []byte: // Special case: go-mysql-driver returns everything as []byte. n, err := strconv.ParseInt(string(value), 10, 64) if err != nil { panic(fmt.Errorf(callsite(1)+"%d is []byte, not int64", value)) } return n case string: panic(fmt.Errorf(callsite(1)+"%q is string, not int64", value)) case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not int64", value)) case nil: return 0 default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not int64", value)) } } return row.NullInt64Field(Expr(format, values...)).Int64 } // Int64Field returns the int64 value of the field. func (row *Row) Int64Field(field Number) int64 { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call Int64Field for static queries")) } return row.NullInt64Field(field).Int64 } // NullInt64 returns the sql.NullInt64 value of the expression. func (row *Row) NullInt64(format string, values ...any) sql.NullInt64 { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: return sql.NullInt64{Int64: value, Valid: true} case float64: return sql.NullInt64{Int64: int64(value), Valid: true} case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not int64", value)) case []byte: // Special case: go-mysql-driver returns everything as []byte. n, err := strconv.ParseInt(string(value), 10, 64) if err != nil { panic(fmt.Errorf(callsite(1)+"%d is []byte, not int64", value)) } return sql.NullInt64{Int64: n, Valid: true} case string: panic(fmt.Errorf(callsite(1)+"%q is string, not int64", value)) case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not int64", value)) case nil: return sql.NullInt64{} default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not int64", value)) } } return row.NullInt64Field(Expr(format, values...)) } // NullInt64Field returns the sql.NullInt64 value of the field. func (row *Row) NullInt64Field(field Number) sql.NullInt64 { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call NullInt64Field for static queries")) } if row.sqlRows == nil { row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &sql.NullInt64{}) return sql.NullInt64{} } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*sql.NullInt64) return *scanDest } // JSON scans the JSON expression into destPtr. func (row *Row) JSON(destPtr any, format string, values ...any) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call JSON for static queries")) } row.json(destPtr, Expr(format, values...), 1) } // JSONField scans the JSON field into destPtr. func (row *Row) JSONField(destPtr any, field JSON) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call JSONField for static queries")) } row.json(destPtr, field, 1) } func (row *Row) json(destPtr any, field JSON, skip int) { if row.sqlRows == nil { if reflect.TypeOf(destPtr).Kind() != reflect.Ptr { panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr)) } row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &nullBytes{ dialect: row.dialect, displayType: displayTypeString, }) return } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*nullBytes) if scanDest.valid { err := json.Unmarshal(scanDest.bytes, destPtr) if err != nil { _, file, line, _ := runtime.Caller(skip + 1) panic(fmt.Errorf(callsite(skip+1)+"unmarshaling json %q into %T: %w", file, line, string(scanDest.bytes), destPtr, err)) } } } // String returns the string value of the expression. func (row *Row) String(format string, values ...any) string { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: panic(fmt.Errorf(callsite(1)+"%d is int64, not string", value)) case float64: panic(fmt.Errorf(callsite(1)+"%d is float64, not string", value)) case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not string", value)) case []byte: return string(value) case string: return value case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not string", value)) case nil: return "" default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not string", value)) } } return row.NullStringField(Expr(format, values...)).String } // String returns the string value of the field. func (row *Row) StringField(field String) string { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call StringField for static queries")) } return row.NullStringField(field).String } // NullString returns the sql.NullString value of the expression. func (row *Row) NullString(format string, values ...any) sql.NullString { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: panic(fmt.Errorf(callsite(1)+"%d is int64, not string", value)) case float64: panic(fmt.Errorf(callsite(1)+"%d is float64, not string", value)) case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not string", value)) case []byte: return sql.NullString{String: string(value), Valid: true} case string: return sql.NullString{String: value, Valid: true} case time.Time: panic(fmt.Errorf(callsite(1)+"%v is time.Time, not string", value)) case nil: return sql.NullString{} default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not string", value)) } } return row.NullStringField(Expr(format, values...)) } // NullStringField returns the sql.NullString value of the field. func (row *Row) NullStringField(field String) sql.NullString { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call NullStringField for static queries")) } if row.sqlRows == nil { row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &sql.NullString{}) return sql.NullString{} } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*sql.NullString) return *scanDest } // https://github.com/mattn/go-sqlite3/blob/4396a38886da660e403409e35ef4a37906bf0975/sqlite3.go#L209 var sqliteTimestampFormats = []string{ "2006-01-02 15:04:05.999999999-07:00", "2006-01-02T15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999", "2006-01-02 15:04:05", "2006-01-02T15:04:05", "2006-01-02 15:04", "2006-01-02T15:04", "2006-01-02", } // Time returns the time.Time value of the expression. func (row *Row) Time(format string, values ...any) time.Time { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: panic(fmt.Errorf(callsite(1)+"%d is int64, not time.Time", value)) case float64: panic(fmt.Errorf(callsite(1)+"%d is float64, not time.Time", value)) case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not time.Time", value)) case []byte: // Special case: go-mysql-driver returns everything as []byte. s := strings.TrimSuffix(string(value), "Z") for _, format := range sqliteTimestampFormats { if t, err := time.ParseInLocation(format, s, time.UTC); err == nil { return t } } panic(fmt.Errorf(callsite(1)+"%d is []byte, not time.Time", value)) case string: panic(fmt.Errorf(callsite(1)+"%q is string, not time.Time", value)) case time.Time: return value case nil: return time.Time{} default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not time.Time", value)) } } return row.NullTimeField(Expr(format, values...)).Time } // Time returns the time.Time value of the field. func (row *Row) TimeField(field Time) time.Time { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call TimeField for static queries")) } return row.NullTimeField(field).Time } // NullTime returns the sql.NullTime value of the expression. func (row *Row) NullTime(format string, values ...any) sql.NullTime { if row.queryIsStatic { index, ok := row.columnIndex[format] if !ok { panic(fmt.Errorf(callsite(1)+"column %s does not exist (available columns: %s)", format, strings.Join(row.columns, ", "))) } value := row.values[index] switch value := value.(type) { case int64: panic(fmt.Errorf(callsite(1)+"%d is int64, not time.Time", value)) case float64: panic(fmt.Errorf(callsite(1)+"%d is float64, not time.Time", value)) case bool: panic(fmt.Errorf(callsite(1)+"%v is bool, not time.Time", value)) case []byte: // Special case: go-mysql-driver returns everything as []byte. s := strings.TrimSuffix(string(value), "Z") for _, format := range sqliteTimestampFormats { if t, err := time.ParseInLocation(format, s, time.UTC); err == nil { return sql.NullTime{Time: t, Valid: true} } } panic(fmt.Errorf(callsite(1)+"%d is []byte, not time.Time", value)) case string: panic(fmt.Errorf(callsite(1)+"%q is string, not time.Time", value)) case time.Time: return sql.NullTime{Time: value, Valid: true} case nil: return sql.NullTime{} default: panic(fmt.Errorf(callsite(1)+"%[1]v is %[1]T, not time.Time", value)) } } return row.NullTimeField(Expr(format, values...)) } // NullTimeField returns the sql.NullTime value of the field. func (row *Row) NullTimeField(field Time) sql.NullTime { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call NullTimeField for static queries")) } if row.sqlRows == nil { row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &sql.NullTime{}) return sql.NullTime{} } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*sql.NullTime) return *scanDest } // UUID scans the UUID expression into destPtr. func (row *Row) UUID(destPtr any, format string, values ...any) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call UUID for static queries")) } row.uuid(destPtr, Expr(format, values...), 1) } // UUIDField scans the UUID field into destPtr. func (row *Row) UUIDField(destPtr any, field UUID) { if row.queryIsStatic { panic(fmt.Errorf(callsite(1) + "cannot call UUIDField for static queries")) } row.uuid(destPtr, field, 1) } func (row *Row) uuid(destPtr any, field UUID, skip int) { if row.sqlRows == nil { if _, ok := destPtr.(*[16]byte); !ok { if reflect.TypeOf(destPtr).Kind() != reflect.Ptr { panic(fmt.Errorf(callsite(skip+1)+"cannot pass in non pointer value (%#v) as destPtr", destPtr)) } destValue := reflect.ValueOf(destPtr).Elem() if destValue.Kind() != reflect.Array || destValue.Len() != 16 || destValue.Type().Elem().Kind() != reflect.Uint8 { panic(fmt.Errorf(callsite(skip+1)+"%T is not a pointer to a [16]byte", destPtr)) } } row.fields = append(row.fields, field) row.scanDest = append(row.scanDest, &nullBytes{ dialect: row.dialect, displayType: displayTypeUUID, }) return } defer func() { row.runningIndex++ }() scanDest := row.scanDest[row.runningIndex].(*nullBytes) var err error var uuid [16]byte if len(scanDest.bytes) == 16 { copy(uuid[:], scanDest.bytes) } else if len(scanDest.bytes) > 0 { uuid, err = googleuuid.ParseBytes(scanDest.bytes) if err != nil { panic(fmt.Errorf(callsite(skip+1)+"parsing %q as UUID string: %w", string(scanDest.bytes), err)) } } if destArrayPtr, ok := destPtr.(*[16]byte); ok { copy((*destArrayPtr)[:], uuid[:]) return } destValue := reflect.ValueOf(destPtr).Elem() for i := 0; i < 16; i++ { destValue.Index(i).Set(reflect.ValueOf(uuid[i])) } } // Column keeps track of what the values mapped to what Field in an // InsertQuery or SelectQuery. type Column struct { dialect string // determines if UPDATE or INSERT isUpdate bool // UPDATE assignments Assignments // INSERT rowStarted bool rowEnded bool firstField string insertColumns Fields rowValues RowValues } // Set maps the value to the Field. func (col *Column) Set(field Field, value any) { if field == nil { panic(fmt.Errorf(callsite(1) + "setting a nil field")) } // UPDATE mode if col.isUpdate { col.assignments = append(col.assignments, Set(field, value)) return } // INSERT mode name := toString(col.dialect, field) if name == "" { panic(fmt.Errorf(callsite(1) + "field name is empty")) } if !col.rowStarted { col.rowStarted = true col.firstField = name col.insertColumns = append(col.insertColumns, field) col.rowValues = append(col.rowValues, RowValue{value}) return } if col.rowStarted && name == col.firstField { if !col.rowEnded { col.rowEnded = true } // Start a new RowValue col.rowValues = append(col.rowValues, RowValue{value}) return } if !col.rowEnded { col.insertColumns = append(col.insertColumns, field) } // Append to last RowValue last := len(col.rowValues) - 1 col.rowValues[last] = append(col.rowValues[last], value) } // SetBytes maps the []byte value to the field. func (col *Column) SetBytes(field Binary, value []byte) { col.Set(field, value) } // SetBool maps the bool value to the field. func (col *Column) SetBool(field Boolean, value bool) { col.Set(field, value) } // SetFloat64 maps the float64 value to the field. func (col *Column) SetFloat64(field Number, value float64) { col.Set(field, value) } // SetInt maps the int value to the field. func (col *Column) SetInt(field Number, value int) { col.Set(field, value) } // SetInt64 maps the int64 value to the field. func (col *Column) SetInt64(field Number, value int64) { col.Set(field, value) } // SetString maps the string value to the field. func (col *Column) SetString(field String, value string) { col.Set(field, value) } // SetTime maps the time.Time value to the field. func (col *Column) SetTime(field Time, value time.Time) { col.Set(field, value) } // SetArray maps the array value to the field. The value should be []string, // []int, []int64, []int32, []float64, []float32 or []bool. func (col *Column) SetArray(field Array, value any) { col.Set(field, ArrayValue(value)) } // SetEnum maps the enum value to the field. func (col *Column) SetEnum(field Enum, value Enumeration) { col.Set(field, EnumValue(value)) } // SetJSON maps the JSON value to the field. The value should be able to be // convertible to JSON using json.Marshal. func (col *Column) SetJSON(field JSON, value any) { col.Set(field, JSONValue(value)) } // SetUUID maps the UUID value to the field. The value's type or underlying // type should be [16]byte. func (col *Column) SetUUID(field UUID, value any) { col.Set(field, UUIDValue(value)) } func callsite(skip int) string { _, file, line, ok := runtime.Caller(skip + 1) if !ok { return "" } return filepath.Base(file) + ":" + strconv.Itoa(line) + ": " } type displayType int8 const ( displayTypeBinary displayType = iota displayTypeString displayTypeUUID ) // nullBytes is used in place of scanning into *[]byte. We use *nullBytes // instead of *[]byte because of the displayType field, which determines how to // render the value to the user. This is important for logging the query // results, because UUIDs/JSON/Arrays are all scanned into bytes but we don't // want to display them as bytes (we need to convert them to UUID/JSON/Array // strings instead). type nullBytes struct { bytes []byte dialect string displayType displayType valid bool } func (n *nullBytes) Scan(value any) error { if value == nil { n.bytes, n.valid = nil, false return nil } n.valid = true switch value := value.(type) { case string: n.bytes = []byte(value) case []byte: n.bytes = value default: return fmt.Errorf("unable to convert %#v to bytes", value) } return nil } func (n *nullBytes) Value() (driver.Value, error) { if !n.valid { return nil, nil } switch n.displayType { case displayTypeString: return string(n.bytes), nil case displayTypeUUID: if n.dialect != "postgres" { return n.bytes, nil } var uuid [16]byte var buf [36]byte copy(uuid[:], n.bytes) googleuuid.EncodeHex(buf[:], uuid) return string(buf[:]), nil default: return n.bytes, nil } } ================================================ FILE: select_query.go ================================================ package sq import ( "bytes" "context" "fmt" ) // SelectQuery represents an SQL SELECT query. type SelectQuery struct { Dialect string // WITH CTEs []CTE // SELECT Distinct bool SelectFields []Field DistinctOnFields []Field // TOP LimitTop any LimitTopPercent any // FROM FromTable Table // JOIN JoinTables []JoinTable // WHERE WherePredicate Predicate // GROUP BY GroupByFields []Field // HAVING HavingPredicate Predicate // WINDOW NamedWindows []NamedWindow // ORDER BY OrderByFields []Field // LIMIT LimitRows any // OFFSET OffsetRows any // FETCH NEXT FetchNextRows any FetchWithTies bool // FOR UPDATE | FOR SHARE LockClause string LockValues []any // AS Alias string Columns []string } var _ interface { Query Table Field Any } = (*SelectQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SelectQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error if len(q.SelectFields) == 0 { return fmt.Errorf("SELECT: no fields provided") } // Table Policies var policies []Predicate policies, err = appendPolicy(ctx, dialect, policies, q.FromTable) if err != nil { return fmt.Errorf("FROM %s Policy: %w", toString(q.Dialect, q.FromTable), err) } for _, joinTable := range q.JoinTables { policies, err = appendPolicy(ctx, dialect, policies, joinTable.Table) if err != nil { return fmt.Errorf("%s %s Policy: %w", joinTable.JoinOperator, joinTable.Table, err) } } if len(policies) > 0 { if q.WherePredicate != nil { policies = append(policies, q.WherePredicate) } q.WherePredicate = And(policies...) } // WITH if len(q.CTEs) > 0 { err = writeCTEs(ctx, dialect, buf, args, params, q.CTEs) if err != nil { return fmt.Errorf("WITH: %w", err) } } // SELECT buf.WriteString("SELECT ") if q.LimitTop != nil || q.LimitTopPercent != nil { // TOP if dialect != DialectSQLServer { return fmt.Errorf("%s does not support SELECT TOP n", dialect) } if len(q.OrderByFields) == 0 { return fmt.Errorf("sqlserver does not support TOP without ORDER BY") } err = writeTop(ctx, dialect, buf, args, params, q.LimitTop, q.LimitTopPercent, q.FetchWithTies) if err != nil { return err } } if len(q.DistinctOnFields) > 0 { if dialect != DialectPostgres { return fmt.Errorf("%s does not support SELECT DISTINCT ON", dialect) } if q.Distinct { return fmt.Errorf("postgres SELECT cannot be DISTINCT and DISTINCT ON at the same time") } buf.WriteString("DISTINCT ON (") err = writeFields(ctx, dialect, buf, args, params, q.DistinctOnFields, false) if err != nil { return fmt.Errorf("DISTINCT ON: %w", err) } buf.WriteString(") ") } else if q.Distinct { buf.WriteString("DISTINCT ") } err = writeFields(ctx, dialect, buf, args, params, q.SelectFields, true) if err != nil { return fmt.Errorf("SELECT: %w", err) } // FROM if q.FromTable != nil { buf.WriteString(" FROM ") _, isQuery := q.FromTable.(Query) if isQuery { buf.WriteString("(") } err = q.FromTable.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("FROM: %w", err) } if isQuery { buf.WriteString(")") } if alias := getAlias(q.FromTable); alias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, alias) + quoteTableColumns(dialect, q.FromTable)) } else if isQuery && dialect != DialectSQLite { return fmt.Errorf("%s FROM subquery must have alias", dialect) } } // JOIN if len(q.JoinTables) > 0 { if q.FromTable == nil { return fmt.Errorf("can't JOIN without a FROM table") } buf.WriteString(" ") err = writeJoinTables(ctx, dialect, buf, args, params, q.JoinTables) if err != nil { return fmt.Errorf("JOIN: %w", err) } } // WHERE if q.WherePredicate != nil { buf.WriteString(" WHERE ") switch predicate := q.WherePredicate.(type) { case VariadicPredicate: predicate.Toplevel = true err = predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("WHERE: %w", err) } default: err = q.WherePredicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("WHERE: %w", err) } } } // GROUP BY if len(q.GroupByFields) > 0 { buf.WriteString(" GROUP BY ") err = writeFields(ctx, dialect, buf, args, params, q.GroupByFields, false) if err != nil { return fmt.Errorf("GROUP BY: %w", err) } } // HAVING if q.HavingPredicate != nil { buf.WriteString(" HAVING ") switch predicate := q.HavingPredicate.(type) { case VariadicPredicate: predicate.Toplevel = true err = predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("HAVING: %w", err) } default: err = q.HavingPredicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("HAVING: %w", err) } } } // WINDOW if len(q.NamedWindows) > 0 { buf.WriteString(" WINDOW ") err = NamedWindows(q.NamedWindows).WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("WINDOW: %w", err) } } // ORDER BY if len(q.OrderByFields) > 0 { buf.WriteString(" ORDER BY ") err = writeFields(ctx, dialect, buf, args, params, q.OrderByFields, false) if err != nil { return fmt.Errorf("ORDER BY: %w", err) } } // LIMIT if q.LimitRows != nil { if dialect == DialectSQLServer { return fmt.Errorf("sqlserver does not support LIMIT") } buf.WriteString(" LIMIT ") err = WriteValue(ctx, dialect, buf, args, params, q.LimitRows) if err != nil { return fmt.Errorf("LIMIT: %w", err) } } // OFFSET if q.OffsetRows != nil { if dialect == DialectSQLServer { if len(q.OrderByFields) == 0 { return fmt.Errorf("sqlserver does not support OFFSET without ORDER BY") } if q.LimitTop != nil || q.LimitTopPercent != nil { return fmt.Errorf("sqlserver does not support OFFSET with TOP") } } buf.WriteString(" OFFSET ") err = WriteValue(ctx, dialect, buf, args, params, q.OffsetRows) if err != nil { return fmt.Errorf("OFFSET: %w", err) } if dialect == DialectSQLServer { buf.WriteString(" ROWS") } } // FETCH NEXT if q.FetchNextRows != nil { switch dialect { case DialectPostgres: if q.LimitRows != nil { return fmt.Errorf("postgres does not allow FETCH NEXT with LIMIT") } case DialectSQLServer: if q.LimitTop != nil || q.LimitTopPercent != nil { return fmt.Errorf("sqlserver does not allow FETCH NEXT with TOP") } default: return fmt.Errorf("%s does not support FETCH NEXT", dialect) } buf.WriteString(" FETCH NEXT ") err = WriteValue(ctx, dialect, buf, args, params, q.FetchNextRows) if err != nil { return fmt.Errorf("FETCH NEXT: %w", err) } buf.WriteString(" ROWS ") if q.FetchWithTies { if dialect == DialectSQLServer { return fmt.Errorf("sqlserver WITH TIES only works with TOP") } if len(q.OrderByFields) == 0 { return fmt.Errorf("%s WITH TIES cannot be used without ORDER BY", dialect) } buf.WriteString("WITH TIES") } else { buf.WriteString("ONLY") } } // FOR UPDATE | FOR SHARE if q.LockClause != "" { buf.WriteString(" ") err = Writef(ctx, dialect, buf, args, params, q.LockClause, q.LockValues) if err != nil { return err } } return nil } // Select creates a new SelectQuery. func Select(fields ...Field) SelectQuery { return SelectQuery{SelectFields: fields} } // SelectDistinct creates a new SelectQuery. func SelectDistinct(fields ...Field) SelectQuery { return SelectQuery{ SelectFields: fields, Distinct: true, } } // SelectOne creates a new SelectQuery. func SelectOne() SelectQuery { return SelectQuery{SelectFields: Fields{Expr("1")}} } // From creates a new SelectQuery. func From(table Table) SelectQuery { return SelectQuery{FromTable: table} } // Select appends to the SelectFields in the SelectQuery. func (q SelectQuery) Select(fields ...Field) SelectQuery { q.SelectFields = append(q.SelectFields, fields...) return q } // SelectDistinct sets the SelectFields in the SelectQuery. func (q SelectQuery) SelectDistinct(fields ...Field) SelectQuery { q.SelectFields = fields q.Distinct = true return q } // SelectOne sets the SelectQuery to SELECT 1. func (q SelectQuery) SelectOne(fields ...Field) SelectQuery { q.SelectFields = Fields{Expr("1")} return q } // From sets the FromTable field in the SelectQuery. func (q SelectQuery) From(table Table) SelectQuery { q.FromTable = table return q } // Join joins a new Table to the SelectQuery. func (q SelectQuery) Join(table Table, predicates ...Predicate) SelectQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the SelectQuery. func (q SelectQuery) LeftJoin(table Table, predicates ...Predicate) SelectQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the SelectQuery. func (q SelectQuery) CrossJoin(table Table) SelectQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the SelectQuery with a custom join operator. func (q SelectQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) SelectQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the SelectQuery with the USING operator. func (q SelectQuery) JoinUsing(table Table, fields ...Field) SelectQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Where appends to the WherePredicate field in the SelectQuery. func (q SelectQuery) Where(predicates ...Predicate) SelectQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // GroupBy appends to the GroupByFields field in the SelectQuery. func (q SelectQuery) GroupBy(fields ...Field) SelectQuery { q.GroupByFields = append(q.GroupByFields, fields...) return q } // Having appends to the HavingPredicate field in the SelectQuery. func (q SelectQuery) Having(predicates ...Predicate) SelectQuery { q.HavingPredicate = appendPredicates(q.HavingPredicate, predicates) return q } // OrderBy appends to the OrderByFields field in the SelectQuery. func (q SelectQuery) OrderBy(fields ...Field) SelectQuery { q.OrderByFields = append(q.OrderByFields, fields...) return q } // Limit sets the LimitRows field in the SelectQuery. func (q SelectQuery) Limit(limit any) SelectQuery { q.LimitRows = limit return q } // Offset sets the OffsetRows field in the SelectQuery. func (q SelectQuery) Offset(offset any) SelectQuery { q.OffsetRows = offset return q } // As returns a new SelectQuery with the table alias (and optionally column // aliases). func (q SelectQuery) As(alias string, columns ...string) SelectQuery { q.Alias = alias q.Columns = columns return q } // Field returns a new field qualified by the SelectQuery's alias. func (q SelectQuery) Field(name string) AnyField { return NewAnyField(name, TableStruct{alias: q.Alias}) } // SetFetchableFields implements the Query interface. func (q SelectQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { if len(q.SelectFields) == 0 { q.SelectFields = fields return q, true } return q, false } // GetFetchableFields returns the fetchable fields of the query. func (q SelectQuery) GetFetchableFields() []Field { return q.SelectFields } // GetDialect implements the Query interface. func (q SelectQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q SelectQuery) SetDialect(dialect string) SelectQuery { q.Dialect = dialect return q } // GetAlias returns the alias of the SelectQuery. func (q SelectQuery) GetAlias() string { return q.Alias } // GetColumns returns the column aliases of the SelectQuery. func (q SelectQuery) GetColumns() []string { return q.Columns } // IsTable implements the Table interface. func (q SelectQuery) IsTable() {} // IsField implements the Field interface. func (q SelectQuery) IsField() {} // IsArray implements the Array interface. func (q SelectQuery) IsArray() {} // IsBinary implements the Binary interface. func (q SelectQuery) IsBinary() {} // IsBoolean implements the Boolean interface. func (q SelectQuery) IsBoolean() {} // IsEnum implements the Enum interface. func (q SelectQuery) IsEnum() {} // IsJSON implements the JSON interface. func (q SelectQuery) IsJSON() {} // IsNumber implements the Number interface. func (q SelectQuery) IsNumber() {} // IsString implements the String interface. func (q SelectQuery) IsString() {} // IsTime implements the Time interface. func (q SelectQuery) IsTime() {} // IsUUID implements the UUID interface. func (q SelectQuery) IsUUID() {} // SQLiteSelectQuery represents an SQLite SELECT query. type SQLiteSelectQuery SelectQuery var _ interface { Query Table Field Any } = (*SQLiteSelectQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SQLiteSelectQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return SelectQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // Select creates a new SQLiteSelectQuery. func (b sqliteQueryBuilder) Select(fields ...Field) SQLiteSelectQuery { return SQLiteSelectQuery{ Dialect: DialectSQLite, CTEs: b.ctes, SelectFields: fields, } } // SelectDistinct creates a new SQLiteSelectQuery. func (b sqliteQueryBuilder) SelectDistinct(fields ...Field) SQLiteSelectQuery { return SQLiteSelectQuery{ Dialect: DialectSQLite, CTEs: b.ctes, SelectFields: fields, Distinct: true, } } // SelectOne creates a new SQLiteSelectQuery. func (b sqliteQueryBuilder) SelectOne() SQLiteSelectQuery { return SQLiteSelectQuery{ Dialect: DialectSQLite, CTEs: b.ctes, SelectFields: Fields{Expr("1")}, } } // From creates a new SQLiteSelectQuery. func (b sqliteQueryBuilder) From(table Table) SQLiteSelectQuery { return SQLiteSelectQuery{ Dialect: DialectSQLite, CTEs: b.ctes, FromTable: table, } } // Select appends to the SelectFields in the SQLiteSelectQuery. func (q SQLiteSelectQuery) Select(fields ...Field) SQLiteSelectQuery { q.SelectFields = append(q.SelectFields, fields...) return q } // SelectDistinct sets the SelectFields in the SQLiteSelectQuery. func (q SQLiteSelectQuery) SelectDistinct(fields ...Field) SQLiteSelectQuery { q.SelectFields = fields q.Distinct = true return q } // SelectOne sets the SQLiteSelectQuery to SELECT 1. func (q SQLiteSelectQuery) SelectOne(fields ...Field) SQLiteSelectQuery { q.SelectFields = Fields{Expr("1")} return q } // From sets the FromTable field in the SQLiteSelectQuery. func (q SQLiteSelectQuery) From(table Table) SQLiteSelectQuery { q.FromTable = table return q } // Join joins a new Table to the SQLiteSelectQuery. func (q SQLiteSelectQuery) Join(table Table, predicates ...Predicate) SQLiteSelectQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the SQLiteSelectQuery. func (q SQLiteSelectQuery) LeftJoin(table Table, predicates ...Predicate) SQLiteSelectQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the SQLiteSelectQuery. func (q SQLiteSelectQuery) CrossJoin(table Table) SQLiteSelectQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the SQLiteSelectQuery with a custom join // operator. func (q SQLiteSelectQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) SQLiteSelectQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the SQLiteSelectQuery with the USING operator. func (q SQLiteSelectQuery) JoinUsing(table Table, fields ...Field) SQLiteSelectQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Where appends to the WherePredicate field in the SQLiteSelectQuery. func (q SQLiteSelectQuery) Where(predicates ...Predicate) SQLiteSelectQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // GroupBy appends to the GroupByFields field in the SQLiteSelectQuery. func (q SQLiteSelectQuery) GroupBy(fields ...Field) SQLiteSelectQuery { q.GroupByFields = append(q.GroupByFields, fields...) return q } // Having appends to the HavingPredicate field in the SQLiteSelectQuery. func (q SQLiteSelectQuery) Having(predicates ...Predicate) SQLiteSelectQuery { q.HavingPredicate = appendPredicates(q.HavingPredicate, predicates) return q } // OrderBy appends to the OrderByFields field in the SQLiteSelectQuery. func (q SQLiteSelectQuery) OrderBy(fields ...Field) SQLiteSelectQuery { q.OrderByFields = append(q.OrderByFields, fields...) return q } // Limit sets the LimitRows field in the SQLiteSelectQuery. func (q SQLiteSelectQuery) Limit(limit any) SQLiteSelectQuery { q.LimitRows = limit return q } // Offset sets the OffsetRows field in the SQLiteSelectQuery. func (q SQLiteSelectQuery) Offset(offset any) SQLiteSelectQuery { q.OffsetRows = offset return q } // As returns a new SQLiteSelectQuery with the table alias (and optionally // column aliases). func (q SQLiteSelectQuery) As(alias string, columns ...string) SQLiteSelectQuery { q.Alias = alias q.Columns = columns return q } // Field returns a new field qualified by the SQLiteSelectQuery's alias. func (q SQLiteSelectQuery) Field(name string) AnyField { return NewAnyField(name, TableStruct{alias: q.Alias}) } // SetFetchableFields implements the Query interface. func (q SQLiteSelectQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { if len(q.SelectFields) == 0 { q.SelectFields = fields return q, true } return q, false } // GetFetchableFields returns the fetchable fields of the query. func (q SQLiteSelectQuery) GetFetchableFields() []Field { return q.SelectFields } // GetDialect implements the Query interface. func (q SQLiteSelectQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q SQLiteSelectQuery) SetDialect(dialect string) SQLiteSelectQuery { q.Dialect = dialect return q } // GetAlias returns the alias of the SQLiteSelectQuery. func (q SQLiteSelectQuery) GetAlias() string { return q.Alias } // IsTable implements the Table interface. func (q SQLiteSelectQuery) IsTable() {} // IsField implements the Field interface. func (q SQLiteSelectQuery) IsField() {} // IsArray implements the Array interface. func (q SQLiteSelectQuery) IsArray() {} // IsBinary implements the Binary interface. func (q SQLiteSelectQuery) IsBinary() {} // IsBoolean implements the Boolean interface. func (q SQLiteSelectQuery) IsBoolean() {} // IsEnum implements the Enum interface. func (q SQLiteSelectQuery) IsEnum() {} // IsJSON implements the JSON interface. func (q SQLiteSelectQuery) IsJSON() {} // IsNumber implements the Number interface. func (q SQLiteSelectQuery) IsNumber() {} // IsString implements the String interface. func (q SQLiteSelectQuery) IsString() {} // IsTime implements the Time interface. func (q SQLiteSelectQuery) IsTime() {} // IsUUID implements the UUID interface. func (q SQLiteSelectQuery) IsUUID() {} // PostgresSelectQuery represents a Postgres SELECT query. type PostgresSelectQuery SelectQuery var _ interface { Query Table Field Any } = (*PostgresSelectQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q PostgresSelectQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return SelectQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // Select creates a new PostgresSelectQuery. func (b postgresQueryBuilder) Select(fields ...Field) PostgresSelectQuery { q := PostgresSelectQuery{ CTEs: b.ctes, SelectFields: fields, } if q.Dialect == "" { q.Dialect = DialectPostgres } return q } // SelectDistinct creates a new PostgresSelectQuery. func (b postgresQueryBuilder) SelectDistinct(fields ...Field) PostgresSelectQuery { q := PostgresSelectQuery{ CTEs: b.ctes, SelectFields: fields, Distinct: true, } if q.Dialect == "" { q.Dialect = DialectPostgres } return q } // SelectOne creates a new PostgresSelectQuery. func (b postgresQueryBuilder) SelectOne() PostgresSelectQuery { q := PostgresSelectQuery{ CTEs: b.ctes, SelectFields: Fields{Expr("1")}, } if q.Dialect == "" { q.Dialect = DialectPostgres } return q } // From creates a new PostgresSelectQuery. func (b postgresQueryBuilder) From(table Table) PostgresSelectQuery { q := PostgresSelectQuery{ CTEs: b.ctes, FromTable: table, } if q.Dialect == "" { q.Dialect = DialectPostgres } return q } // Select appends to the SelectFields in the PostgresSelectQuery. func (q PostgresSelectQuery) Select(fields ...Field) PostgresSelectQuery { q.SelectFields = append(q.SelectFields, fields...) return q } // SelectDistinct sets the SelectFields in the PostgresSelectQuery. func (q PostgresSelectQuery) SelectDistinct(fields ...Field) PostgresSelectQuery { q.SelectFields = fields q.Distinct = true return q } // DistinctOn sets the DistinctOnFields in the PostgresSelectQuery. func (q PostgresSelectQuery) DistinctOn(fields ...Field) PostgresSelectQuery { q.DistinctOnFields = fields return q } // SelectOne sets the PostgresSelectQuery to SELECT 1. func (q PostgresSelectQuery) SelectOne(fields ...Field) PostgresSelectQuery { q.SelectFields = Fields{Expr("1")} return q } // From sets the FromTable field in the PostgresSelectQuery. func (q PostgresSelectQuery) From(table Table) PostgresSelectQuery { q.FromTable = table return q } // Join joins a new Table to the PostgresSelectQuery. func (q PostgresSelectQuery) Join(table Table, predicates ...Predicate) PostgresSelectQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the PostgresSelectQuery. func (q PostgresSelectQuery) LeftJoin(table Table, predicates ...Predicate) PostgresSelectQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the PostgresSelectQuery. func (q PostgresSelectQuery) FullJoin(table Table, predicates ...Predicate) PostgresSelectQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the PostgresSelectQuery. func (q PostgresSelectQuery) CrossJoin(table Table) PostgresSelectQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the PostgresSelectQuery with a custom join // operator. func (q PostgresSelectQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) PostgresSelectQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the PostgresSelectQuery with the USING operator. func (q PostgresSelectQuery) JoinUsing(table Table, fields ...Field) PostgresSelectQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Where appends to the WherePredicate field in the PostgresSelectQuery. func (q PostgresSelectQuery) Where(predicates ...Predicate) PostgresSelectQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // GroupBy appends to the GroupByFields field in the PostgresSelectQuery. func (q PostgresSelectQuery) GroupBy(fields ...Field) PostgresSelectQuery { q.GroupByFields = append(q.GroupByFields, fields...) return q } // Having appends to the HavingPredicate field in the PostgresSelectQuery. func (q PostgresSelectQuery) Having(predicates ...Predicate) PostgresSelectQuery { q.HavingPredicate = appendPredicates(q.HavingPredicate, predicates) return q } // OrderBy appends to the OrderByFields field in the PostgresSelectQuery. func (q PostgresSelectQuery) OrderBy(fields ...Field) PostgresSelectQuery { q.OrderByFields = append(q.OrderByFields, fields...) return q } // Limit sets the LimitRows field in the PostgresSelectQuery. func (q PostgresSelectQuery) Limit(limit any) PostgresSelectQuery { q.LimitRows = limit return q } // Offset sets the OffsetRows field in the PostgresSelectQuery. func (q PostgresSelectQuery) Offset(offset any) PostgresSelectQuery { q.OffsetRows = offset return q } // FetchNext sets the FetchNextRows field in the PostgresSelectQuery. func (q PostgresSelectQuery) FetchNext(n any) PostgresSelectQuery { q.FetchNextRows = n return q } // WithTies enables the FetchWithTies field in the PostgresSelectQuery. func (q PostgresSelectQuery) WithTies() PostgresSelectQuery { q.FetchWithTies = true return q } // LockRows sets the lock clause of the PostgresSelectQuery. func (q PostgresSelectQuery) LockRows(lockClause string, lockValues ...any) PostgresSelectQuery { q.LockClause = lockClause q.LockValues = lockValues return q } // As returns a new PostgresSelectQuery with the table alias (and optionally // column aliases). func (q PostgresSelectQuery) As(alias string, columns ...string) PostgresSelectQuery { q.Alias = alias q.Columns = columns return q } // Field returns a new field qualified by the PostgresSelectQuery's alias. func (q PostgresSelectQuery) Field(name string) AnyField { return NewAnyField(name, TableStruct{alias: q.Alias}) } // SetFetchableFields implements the Query interface. func (q PostgresSelectQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { if len(q.SelectFields) == 0 { q.SelectFields = fields return q, true } return q, false } // GetFetchableFields returns the fetchable fields of the query. func (q PostgresSelectQuery) GetFetchableFields() []Field { return q.SelectFields } // GetDialect implements the Query interface. func (q PostgresSelectQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q PostgresSelectQuery) SetDialect(dialect string) PostgresSelectQuery { q.Dialect = dialect return q } // GetAlias returns the alias of the PostgresSelectQuery. func (q PostgresSelectQuery) GetAlias() string { return q.Alias } // IsTable implements the Table interface. func (q PostgresSelectQuery) IsTable() {} // IsField implements the Field interface. func (q PostgresSelectQuery) IsField() {} // IsArray implements the Array interface. func (q PostgresSelectQuery) IsArray() {} // IsBinary implements the Binary interface. func (q PostgresSelectQuery) IsBinary() {} // IsBoolean implements the Boolean interface. func (q PostgresSelectQuery) IsBoolean() {} // IsEnum implements the Enum interface. func (q PostgresSelectQuery) IsEnum() {} // IsJSON implements the JSON interface. func (q PostgresSelectQuery) IsJSON() {} // IsNumber implements the Number interface. func (q PostgresSelectQuery) IsNumber() {} // IsString implements the String interface. func (q PostgresSelectQuery) IsString() {} // IsTime implements the Time interface. func (q PostgresSelectQuery) IsTime() {} // IsUUID implements the UUID interface. func (q PostgresSelectQuery) IsUUID() {} // MySQLSelectQuery represents a MySQL SELECT query. type MySQLSelectQuery SelectQuery var _ interface { Query Table Field Any } = (*MySQLSelectQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q MySQLSelectQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return SelectQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // Select creates a new MySQLSelectQuery. func (b mysqlQueryBuilder) Select(fields ...Field) MySQLSelectQuery { q := MySQLSelectQuery{ CTEs: b.ctes, SelectFields: fields, } if q.Dialect == "" { q.Dialect = DialectMySQL } return q } // SelectDistinct creates a new MySQLSelectQuery. func (b mysqlQueryBuilder) SelectDistinct(fields ...Field) MySQLSelectQuery { q := MySQLSelectQuery{ CTEs: b.ctes, SelectFields: fields, Distinct: true, } if q.Dialect == "" { q.Dialect = DialectMySQL } return q } // SelectOne creates a new MySQLSelectQuery. func (b mysqlQueryBuilder) SelectOne() MySQLSelectQuery { q := MySQLSelectQuery{ CTEs: b.ctes, SelectFields: Fields{Expr("1")}, } if q.Dialect == "" { q.Dialect = DialectMySQL } return q } // From creates a new MySQLSelectQuery. func (b mysqlQueryBuilder) From(table Table) MySQLSelectQuery { q := MySQLSelectQuery{ CTEs: b.ctes, FromTable: table, } if q.Dialect == "" { q.Dialect = DialectMySQL } return q } // Select appends to the SelectFields in the MySQLSelectQuery. func (q MySQLSelectQuery) Select(fields ...Field) MySQLSelectQuery { q.SelectFields = append(q.SelectFields, fields...) return q } // SelectDistinct sets the SelectFields in the MySQLSelectQuery. func (q MySQLSelectQuery) SelectDistinct(fields ...Field) MySQLSelectQuery { q.SelectFields = fields q.Distinct = true return q } // SelectOne sets the MySQLSelectQuery to SELECT 1. func (q MySQLSelectQuery) SelectOne(fields ...Field) MySQLSelectQuery { q.SelectFields = Fields{Expr("1")} return q } // From sets the FromTable field in the MySQLSelectQuery. func (q MySQLSelectQuery) From(table Table) MySQLSelectQuery { q.FromTable = table return q } // Join joins a new Table to the MySQLSelectQuery. func (q MySQLSelectQuery) Join(table Table, predicates ...Predicate) MySQLSelectQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the MySQLSelectQuery. func (q MySQLSelectQuery) LeftJoin(table Table, predicates ...Predicate) MySQLSelectQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the MySQLSelectQuery. func (q MySQLSelectQuery) FullJoin(table Table, predicates ...Predicate) MySQLSelectQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the MySQLSelectQuery. func (q MySQLSelectQuery) CrossJoin(table Table) MySQLSelectQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the MySQLSelectQuery with a custom join // operator. func (q MySQLSelectQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) MySQLSelectQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the MySQLSelectQuery with the USING operator. func (q MySQLSelectQuery) JoinUsing(table Table, fields ...Field) MySQLSelectQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Where appends to the WherePredicate field in the MySQLSelectQuery. func (q MySQLSelectQuery) Where(predicates ...Predicate) MySQLSelectQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // GroupBy appends to the GroupByFields field in the MySQLSelectQuery. func (q MySQLSelectQuery) GroupBy(fields ...Field) MySQLSelectQuery { q.GroupByFields = append(q.GroupByFields, fields...) return q } // Having appends to the HavingPredicate field in the MySQLSelectQuery. func (q MySQLSelectQuery) Having(predicates ...Predicate) MySQLSelectQuery { q.HavingPredicate = appendPredicates(q.HavingPredicate, predicates) return q } // OrderBy appends to the OrderByFields field in the MySQLSelectQuery. func (q MySQLSelectQuery) OrderBy(fields ...Field) MySQLSelectQuery { q.OrderByFields = append(q.OrderByFields, fields...) return q } // Limit sets the LimitRows field in the MySQLSelectQuery. func (q MySQLSelectQuery) Limit(limit any) MySQLSelectQuery { q.LimitRows = limit return q } // Offset sets the OffsetRows field in the MySQLSelectQuery. func (q MySQLSelectQuery) Offset(offset any) MySQLSelectQuery { q.OffsetRows = offset return q } // LockRows sets the lock clause of the MySQLSelectQuery. func (q MySQLSelectQuery) LockRows(lockClause string, lockValues ...any) MySQLSelectQuery { q.LockClause = lockClause q.LockValues = lockValues return q } // As returns a new MySQLSelectQuery with the table alias (and optionally // column aliases). func (q MySQLSelectQuery) As(alias string, columns ...string) MySQLSelectQuery { q.Alias = alias q.Columns = columns return q } // Field returns a new field qualified by the MySQLSelectQuery's alias. func (q MySQLSelectQuery) Field(name string) AnyField { return NewAnyField(name, TableStruct{alias: q.Alias}) } // SetFetchableFields implements the Query interface. func (q MySQLSelectQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { if len(q.SelectFields) == 0 { q.SelectFields = fields return q, true } return q, false } // GetFetchableFields returns the fetchable fields of the query. func (q MySQLSelectQuery) GetFetchableFields() []Field { return q.SelectFields } // GetDialect implements the Query interface. func (q MySQLSelectQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q MySQLSelectQuery) SetDialect(dialect string) MySQLSelectQuery { q.Dialect = dialect return q } // GetAlias returns the alias of the MySQLSelectQuery. func (q MySQLSelectQuery) GetAlias() string { return q.Alias } // IsTable implements the Table interface. func (q MySQLSelectQuery) IsTable() {} // IsField implements the Field interface. func (q MySQLSelectQuery) IsField() {} // IsArray implements the Array interface. func (q MySQLSelectQuery) IsArray() {} // IsBinary implements the Binary interface. func (q MySQLSelectQuery) IsBinary() {} // IsBoolean implements the Boolean interface. func (q MySQLSelectQuery) IsBoolean() {} // IsEnum implements the Enum interface. func (q MySQLSelectQuery) IsEnum() {} // IsJSON implements the JSON interface. func (q MySQLSelectQuery) IsJSON() {} // IsNumber implements the Number interface. func (q MySQLSelectQuery) IsNumber() {} // IsString implements the String interface. func (q MySQLSelectQuery) IsString() {} // IsTime implements the Time interface. func (q MySQLSelectQuery) IsTime() {} // IsUUID implements the UUID interface. func (q MySQLSelectQuery) IsUUID() {} // SQLServerSelectQuery represents an SQL Server SELECT query. type SQLServerSelectQuery SelectQuery var _ interface { Query Table Field Any } = (*SQLServerSelectQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SQLServerSelectQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return SelectQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // Select creates a new SQLServerSelectQuery. func (b sqlserverQueryBuilder) Select(fields ...Field) SQLServerSelectQuery { q := SQLServerSelectQuery{ CTEs: b.ctes, SelectFields: fields, } if q.Dialect == "" { q.Dialect = DialectSQLServer } return q } // SelectDistinct creates a new SQLServerSelectQuery. func (b sqlserverQueryBuilder) SelectDistinct(fields ...Field) SQLServerSelectQuery { q := SQLServerSelectQuery{ CTEs: b.ctes, SelectFields: fields, Distinct: true, } if q.Dialect == "" { q.Dialect = DialectSQLServer } return q } // SelectOne creates a new SQLServerSelectQuery. func (b sqlserverQueryBuilder) SelectOne() SQLServerSelectQuery { q := SQLServerSelectQuery{ CTEs: b.ctes, SelectFields: Fields{Expr("1")}, } if q.Dialect == "" { q.Dialect = DialectSQLServer } return q } // From creates a new SQLServerSelectQuery. func (b sqlserverQueryBuilder) From(table Table) SQLServerSelectQuery { q := SQLServerSelectQuery{ CTEs: b.ctes, FromTable: table, } if q.Dialect == "" { q.Dialect = DialectSQLServer } return q } // Select appends to the SelectFields in the SQLServerSelectQuery. func (q SQLServerSelectQuery) Select(fields ...Field) SQLServerSelectQuery { q.SelectFields = append(q.SelectFields, fields...) return q } // SelectDistinct sets the SelectFields in the SQLServerSelectQuery. func (q SQLServerSelectQuery) SelectDistinct(fields ...Field) SQLServerSelectQuery { q.SelectFields = fields q.Distinct = true return q } // SelectOne sets the SQLServerSelectQuery to SELECT 1. func (q SQLServerSelectQuery) SelectOne(fields ...Field) SQLServerSelectQuery { q.SelectFields = Fields{Expr("1")} return q } // Top sets the LimitTop field of the SQLServerSelectQuery. func (q SQLServerSelectQuery) Top(limit any) SQLServerSelectQuery { q.LimitTop = limit return q } // Top sets the LimitTopPercent field of the SQLServerSelectQuery. func (q SQLServerSelectQuery) TopPercent(percentLimit any) SQLServerSelectQuery { q.LimitTopPercent = percentLimit return q } // From sets the FromTable field in the SQLServerSelectQuery. func (q SQLServerSelectQuery) From(table Table) SQLServerSelectQuery { q.FromTable = table return q } // Join joins a new Table to the SQLServerSelectQuery. func (q SQLServerSelectQuery) Join(table Table, predicates ...Predicate) SQLServerSelectQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the SQLServerSelectQuery. func (q SQLServerSelectQuery) LeftJoin(table Table, predicates ...Predicate) SQLServerSelectQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the SQLServerSelectQuery. func (q SQLServerSelectQuery) FullJoin(table Table, predicates ...Predicate) SQLServerSelectQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the SQLServerSelectQuery. func (q SQLServerSelectQuery) CrossJoin(table Table) SQLServerSelectQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the SQLServerSelectQuery with a custom join // operator. func (q SQLServerSelectQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) SQLServerSelectQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // Where appends to the WherePredicate field in the SQLServerSelectQuery. func (q SQLServerSelectQuery) Where(predicates ...Predicate) SQLServerSelectQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // GroupBy appends to the GroupByFields field in the SQLServerSelectQuery. func (q SQLServerSelectQuery) GroupBy(fields ...Field) SQLServerSelectQuery { q.GroupByFields = append(q.GroupByFields, fields...) return q } // Having appends to the HavingPredicate field in the SQLServerSelectQuery. func (q SQLServerSelectQuery) Having(predicates ...Predicate) SQLServerSelectQuery { q.HavingPredicate = appendPredicates(q.HavingPredicate, predicates) return q } // OrderBy appends to the OrderByFields field in the SQLServerSelectQuery. func (q SQLServerSelectQuery) OrderBy(fields ...Field) SQLServerSelectQuery { q.OrderByFields = append(q.OrderByFields, fields...) return q } // Offset sets the OffsetRows field in the SQLServerSelectQuery. func (q SQLServerSelectQuery) Offset(offset any) SQLServerSelectQuery { q.OffsetRows = offset return q } // FetchNext sets the FetchNextRows field in the SQLServerSelectQuery. func (q SQLServerSelectQuery) FetchNext(n any) SQLServerSelectQuery { q.FetchNextRows = n return q } // WithTies enables the FetchWithTies field in the SQLServerSelectQuery. func (q SQLServerSelectQuery) WithTies() SQLServerSelectQuery { q.FetchWithTies = true return q } // As returns a new SQLServerSelectQuery with the table alias (and optionally // column aliases). func (q SQLServerSelectQuery) As(alias string, columns ...string) SQLServerSelectQuery { q.Alias = alias q.Columns = columns return q } // Field returns a new field qualified by the SQLServerSelectQuery's alias. func (q SQLServerSelectQuery) Field(name string) AnyField { return NewAnyField(name, TableStruct{alias: q.Alias}) } // SetFetchableFields implements the Query interface. func (q SQLServerSelectQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { if len(q.SelectFields) == 0 { q.SelectFields = fields return q, true } return q, false } // GetFetchableFields returns the fetchable fields of the query. func (q SQLServerSelectQuery) GetFetchableFields() []Field { return q.SelectFields } // GetDialect implements the Query interface. func (q SQLServerSelectQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q SQLServerSelectQuery) SetDialect(dialect string) SQLServerSelectQuery { q.Dialect = dialect return q } // GetAlias returns the alias of the SQLServerSelectQuery. func (q SQLServerSelectQuery) GetAlias() string { return q.Alias } // IsTable implements the Table interface. func (q SQLServerSelectQuery) IsTable() {} // IsField implements the Field interface. func (q SQLServerSelectQuery) IsField() {} // IsArray implements the Array interface. func (q SQLServerSelectQuery) IsArray() {} // IsBinary implements the Binary interface. func (q SQLServerSelectQuery) IsBinary() {} // IsBoolean implements the Boolean interface. func (q SQLServerSelectQuery) IsBoolean() {} // IsEnum implements the Enum interface. func (q SQLServerSelectQuery) IsEnum() {} // IsJSON implements the JSON interface. func (q SQLServerSelectQuery) IsJSON() {} // IsNumber implements the Number interface. func (q SQLServerSelectQuery) IsNumber() {} // IsString implements the String interface. func (q SQLServerSelectQuery) IsString() {} // IsTime implements the Time interface. func (q SQLServerSelectQuery) IsTime() {} // IsUUID implements the UUID interface. func (q SQLServerSelectQuery) IsUUID() {} ================================================ FILE: select_query_test.go ================================================ package sq import ( "testing" "github.com/bokwoon95/sq/internal/testutil" ) func TestSQLiteSelectQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SQLite.SelectOne().From(a).SetDialect("lorem ipsum").As("q1") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(q1.GetAlias(), "q1"); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.ACTOR_ID}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.SelectFields = q1.SelectFields[:0] _, ok = q1.SetFetchableFields([]Field{a.ACTOR_ID}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("subquery, cte and joins", func(t *testing.T) { t.Parallel() var tt TestTable subquery := SQLite.Select(Expr("*")).From(a).As("subquery") cte := NewCTE("cte", nil, SQLite.From(a).Select(Expr("*"))) tt.item = SQLite. With(cte). From(a).From(a). Join(subquery, Eq(subquery.Field("actor_id"), a.ACTOR_ID)). LeftJoin(cte, Eq(cte.Field("actor_id"), a.ACTOR_ID)). CrossJoin(a). CustomJoin(",", a). JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). SelectOne() tt.wantQuery = "WITH cte AS (SELECT * FROM actor AS a)" + " SELECT 1" + " FROM actor AS a" + " JOIN (SELECT * FROM actor AS a) AS subquery ON subquery.actor_id = a.actor_id" + " LEFT JOIN cte ON cte.actor_id = a.actor_id" + " CROSS JOIN actor AS a" + " , actor AS a" + " JOIN actor AS a USING (first_name, last_name)" tt.assert(t) }) t.Run("all", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. SelectDistinct(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). SelectDistinct(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). From(a). Where(a.ACTOR_ID.GtInt(5)). GroupBy(a.FIRST_NAME). Having(a.FIRST_NAME.IsNotNull()). OrderBy(a.LAST_NAME). Limit(10). Offset(20) tt.wantQuery = "SELECT DISTINCT a.actor_id, a.first_name, a.last_name" + " FROM actor AS a" + " WHERE a.actor_id > $1" + " GROUP BY a.first_name" + " HAVING a.first_name IS NOT NULL" + " ORDER BY a.last_name" + " LIMIT $2" + " OFFSET $3" tt.wantArgs = []any{5, 10, 20} tt.assert(t) }) } func TestPostgresSelectQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := Postgres.SelectOne().From(a).SetDialect("lorem ipsum").As("q1") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(q1.GetAlias(), "q1"); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.ACTOR_ID}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.SelectFields = q1.SelectFields[:0] _, ok = q1.SetFetchableFields([]Field{a.ACTOR_ID}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("subquery, cte and joins", func(t *testing.T) { t.Parallel() var tt TestTable subquery := Postgres.Select(Expr("*")).From(a).As("subquery") cte := NewCTE("cte", nil, Postgres.From(a).Select(Expr("*"))) tt.item = Postgres. With(cte). From(a).From(a). Join(subquery, Eq(subquery.Field("actor_id"), a.ACTOR_ID)). LeftJoin(cte, Eq(cte.Field("actor_id"), a.ACTOR_ID)). FullJoin(a, Expr("1 = 1")). CrossJoin(a). CustomJoin(",", a). JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). SelectOne() tt.wantQuery = "WITH cte AS (SELECT * FROM actor AS a)" + " SELECT 1" + " FROM actor AS a" + " JOIN (SELECT * FROM actor AS a) AS subquery ON subquery.actor_id = a.actor_id" + " LEFT JOIN cte ON cte.actor_id = a.actor_id" + " FULL JOIN actor AS a ON 1 = 1" + " CROSS JOIN actor AS a" + " , actor AS a" + " JOIN actor AS a USING (first_name, last_name)" tt.assert(t) }) t.Run("all", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. SelectDistinct(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). SelectDistinct(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). From(a). Where(a.ACTOR_ID.GtInt(5)). GroupBy(a.FIRST_NAME). Having(a.FIRST_NAME.IsNotNull()). OrderBy(a.LAST_NAME). Limit(10). Offset(20) tt.wantQuery = "SELECT DISTINCT a.actor_id, a.first_name, a.last_name" + " FROM actor AS a" + " WHERE a.actor_id > $1" + " GROUP BY a.first_name" + " HAVING a.first_name IS NOT NULL" + " ORDER BY a.last_name" + " LIMIT $2" + " OFFSET $3" tt.wantArgs = []any{5, 10, 20} tt.assert(t) }) t.Run("DistinctOn", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. Select(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). DistinctOn(a.FIRST_NAME, a.LAST_NAME). From(a) tt.wantQuery = "SELECT DISTINCT ON (a.first_name, a.last_name)" + " a.actor_id, a.first_name, a.last_name" + " FROM actor AS a" tt.assert(t) }) t.Run("FetchNext, WithTies, LockRows", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. Select(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). From(a). OrderBy(a.ACTOR_ID). Offset(10). FetchNext(20).WithTies(). LockRows("FOR UPDATE") tt.wantQuery = "SELECT a.actor_id, a.first_name, a.last_name" + " FROM actor AS a" + " ORDER BY a.actor_id" + " OFFSET $1" + " FETCH NEXT $2 ROWS WITH TIES" + " FOR UPDATE" tt.wantArgs = []any{10, 20} tt.assert(t) }) } func TestMySQLSelectQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := MySQL.SelectOne().From(a).SetDialect("lorem ipsum").As("q1") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(q1.GetAlias(), "q1"); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.ACTOR_ID}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.SelectFields = q1.SelectFields[:0] _, ok = q1.SetFetchableFields([]Field{a.ACTOR_ID}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("subquery, cte and joins", func(t *testing.T) { t.Parallel() var tt TestTable subquery := MySQL.Select(Expr("*")).From(a).As("subquery") cte := NewCTE("cte", nil, MySQL.From(a).Select(Expr("*"))) tt.item = MySQL. With(cte). From(a).From(a). Join(subquery, Eq(subquery.Field("actor_id"), a.ACTOR_ID)). LeftJoin(cte, Eq(cte.Field("actor_id"), a.ACTOR_ID)). FullJoin(a, Expr("1 = 1")). CrossJoin(a). CustomJoin(",", a). JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). SelectOne() tt.wantQuery = "WITH cte AS (SELECT * FROM actor AS a)" + " SELECT 1" + " FROM actor AS a" + " JOIN (SELECT * FROM actor AS a) AS subquery ON subquery.actor_id = a.actor_id" + " LEFT JOIN cte ON cte.actor_id = a.actor_id" + " FULL JOIN actor AS a ON 1 = 1" + " CROSS JOIN actor AS a" + " , actor AS a" + " JOIN actor AS a USING (first_name, last_name)" tt.assert(t) }) t.Run("all", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. SelectDistinct(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). SelectDistinct(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). From(a). Where(a.ACTOR_ID.GtInt(5)). GroupBy(a.FIRST_NAME). Having(a.FIRST_NAME.IsNotNull()). OrderBy(a.LAST_NAME). Limit(10). Offset(20) tt.wantQuery = "SELECT DISTINCT a.actor_id, a.first_name, a.last_name" + " FROM actor AS a" + " WHERE a.actor_id > ?" + " GROUP BY a.first_name" + " HAVING a.first_name IS NOT NULL" + " ORDER BY a.last_name" + " LIMIT ?" + " OFFSET ?" tt.wantArgs = []any{5, 10, 20} tt.assert(t) }) t.Run("LockRows", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. Select(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). From(a). OrderBy(a.ACTOR_ID). Offset(10). LockRows("FOR UPDATE") tt.wantQuery = "SELECT a.actor_id, a.first_name, a.last_name" + " FROM actor AS a" + " ORDER BY a.actor_id" + " OFFSET ?" + " FOR UPDATE" tt.wantArgs = []any{10} tt.assert(t) }) } func TestSQLServerSelectQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SQLServer.SelectOne().From(a).SetDialect("lorem ipsum").As("q1") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(q1.GetAlias(), "q1"); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.ACTOR_ID}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.SelectFields = q1.SelectFields[:0] _, ok = q1.SetFetchableFields([]Field{a.ACTOR_ID}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("subquery, cte and joins", func(t *testing.T) { t.Parallel() var tt TestTable subquery := SQLServer.Select(Expr("*")).From(a).As("subquery") cte := NewCTE("cte", nil, SQLServer.From(a).Select(Expr("*"))) tt.item = SQLServer. With(cte). From(a).From(a). Join(subquery, Eq(subquery.Field("actor_id"), a.ACTOR_ID)). LeftJoin(cte, Eq(cte.Field("actor_id"), a.ACTOR_ID)). FullJoin(a, Expr("1 = 1")). CrossJoin(a). CustomJoin(",", a). SelectOne() tt.wantQuery = "WITH cte AS (SELECT * FROM actor AS a)" + " SELECT 1" + " FROM actor AS a" + " JOIN (SELECT * FROM actor AS a) AS subquery ON subquery.actor_id = a.actor_id" + " LEFT JOIN cte ON cte.actor_id = a.actor_id" + " FULL JOIN actor AS a ON 1 = 1" + " CROSS JOIN actor AS a" + " , actor AS a" tt.assert(t) }) t.Run("all", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. SelectDistinct(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). SelectDistinct(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). From(a). Where(a.ACTOR_ID.GtInt(5)). GroupBy(a.FIRST_NAME). Having(a.FIRST_NAME.IsNotNull()). OrderBy(a.LAST_NAME). Offset(10). FetchNext(20) tt.wantQuery = "SELECT DISTINCT a.actor_id, a.first_name, a.last_name" + " FROM actor AS a" + " WHERE a.actor_id > @p1" + " GROUP BY a.first_name" + " HAVING a.first_name IS NOT NULL" + " ORDER BY a.last_name" + " OFFSET @p2 ROWS" + " FETCH NEXT @p3 ROWS ONLY" tt.wantArgs = []any{5, 10, 20} tt.assert(t) }) t.Run("TopPercent", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. Select(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). TopPercent(5). From(a). OrderBy(a.ACTOR_ID) tt.wantQuery = "SELECT TOP (@p1) PERCENT a.actor_id, a.first_name, a.last_name" + " FROM actor AS a" + " ORDER BY a.actor_id" tt.wantArgs = []any{5} tt.assert(t) }) t.Run("Top, WithTies", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. Select(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Top(10).WithTies(). From(Expr("{} AS a WITH (UPDLOCK, ROWLOCK)", a)). OrderBy(a.ACTOR_ID) tt.wantQuery = "SELECT TOP (@p1) WITH TIES a.actor_id, a.first_name, a.last_name" + " FROM actor AS a WITH (UPDLOCK, ROWLOCK)" + " ORDER BY a.actor_id" tt.wantArgs = []any{10} tt.assert(t) }) } func TestSelectQuery(t *testing.T) { t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SelectQuery{FromTable: Expr("tbl"), Dialect: "lorem ipsum", Alias: "q1"} if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } if diff := testutil.Diff(q1.GetAlias(), "q1"); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{Expr("f1")}) if !ok { t.Fatal(testutil.Callers(), "not ok") } }) t.Run("PolicyTable", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SelectQuery{ SelectFields: Fields{Expr("1")}, FromTable: policyTableStub{policy: And(Expr("1 = 1"), Expr("2 = 2"))}, WherePredicate: Expr("3 = 3"), } tt.wantQuery = "SELECT 1 FROM policy_table_stub WHERE (1 = 1 AND 2 = 2) AND 3 = 3" tt.assert(t) }) t.Run("Where Having Window", func(t *testing.T) { t.Parallel() var tt TestTable w1 := NamedWindow{Name: "w1", Definition: PartitionBy(Expr("f1"))} w2 := NamedWindow{Name: "w2", Definition: OrderBy(Expr("f2"))} w3 := NamedWindow{Name: "w3", Definition: OrderBy(Expr("f3")).Frame("ROWS UNBOUNDED PRECEDING")} tt.item = SelectQuery{ SelectFields: Fields{CountStarOver(w1), CountStarOver(w2), CountStarOver(w3)}, FromTable: Expr("tbl"), WherePredicate: Expr("1 = 1"), GroupByFields: Fields{Expr("f2")}, HavingPredicate: Expr("2 = 2"), NamedWindows: NamedWindows{w1, w2, w3}, } tt.wantQuery = "SELECT COUNT(*) OVER w1, COUNT(*) OVER w2, COUNT(*) OVER w3" + " FROM tbl" + " WHERE 1 = 1" + " GROUP BY f2" + " HAVING 2 = 2" + " WINDOW w1 AS (PARTITION BY f1)" + ", w2 AS (ORDER BY f2)" + ", w3 AS (ORDER BY f3 ROWS UNBOUNDED PRECEDING)" tt.assert(t) }) notOKTests := []TestTable{{ description: "no fields provided not allowed", item: SelectQuery{}, }, { description: "dialect does not support TOP", item: SelectQuery{ Dialect: DialectSQLite, SelectFields: Fields{Expr("f1")}, LimitTop: 5, }, }, { description: "sqlserver does not allow TOP without ORDER BY", item: SelectQuery{ Dialect: DialectSQLServer, SelectFields: Fields{Expr("f1")}, LimitTop: 5, }, }, { description: "dialect does not support DISTINCT ON", item: SelectQuery{ Dialect: DialectSQLite, SelectFields: Fields{Expr("f1")}, DistinctOnFields: Fields{Expr("f2")}, }, }, { description: "postgres does not allow both DISTINCT and DISTINCT ON", item: SelectQuery{ Dialect: DialectPostgres, SelectFields: Fields{Expr("f1")}, Distinct: true, DistinctOnFields: Fields{Expr("f2")}, }, }, { description: "postgres does not allow subquery no alias", item: SelectQuery{ Dialect: DialectPostgres, SelectFields: Fields{Expr("f1")}, FromTable: SelectQuery{SelectFields: Fields{Expr("f1")}}, }, }, { description: "JOIN without FROM not allowed", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, JoinTables: []JoinTable{ Join(Expr("tbl"), Expr("1 = 1")), }, }, }, { description: "sqlserver does not support LIMIT", item: SelectQuery{ Dialect: DialectSQLServer, SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), LimitRows: 5, }, }, { description: "sqlserver does not allow OFFSET without ORDER BY", item: SelectQuery{ Dialect: DialectSQLServer, SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), OffsetRows: 5, }, }, { description: "sqlserver does not support OFFSET with TOP", item: SelectQuery{ Dialect: DialectSQLServer, SelectFields: Fields{Expr("f1")}, LimitTop: 5, FromTable: Expr("tbl"), OrderByFields: Fields{Expr("f1")}, OffsetRows: 10, }, }, { description: "postgres does not allow FETCH NEXT with LIMIT", item: SelectQuery{ Dialect: DialectPostgres, SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), OrderByFields: Fields{Expr("f1")}, LimitRows: 5, OffsetRows: 10, FetchNextRows: 20, }, }, { description: "sqlserver does not allow FETCH NEXT with TOP", item: SelectQuery{ Dialect: DialectSQLServer, SelectFields: Fields{Expr("f1")}, LimitTop: 5, FromTable: Expr("tbl"), OrderByFields: Fields{Expr("f1")}, FetchNextRows: 10, }, }, { description: "dialect does not support FETCH NEXT", item: SelectQuery{ Dialect: DialectSQLite, SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), OrderByFields: Fields{Expr("f1")}, FetchNextRows: 20, }, }, { description: "sqlserver does not support FETCH NEXT with WITH TIES", item: SelectQuery{ Dialect: DialectSQLServer, SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), OrderByFields: Fields{Expr("f1")}, FetchNextRows: 20, FetchWithTies: true, }, }, { description: "postgres does not allow WITH TIES without ORDER BY", item: SelectQuery{ Dialect: DialectPostgres, SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), FetchNextRows: 20, FetchWithTies: true, }, }} for _, tt := range notOKTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertNotOK(t) }) } errTests := []TestTable{{ description: "FromTable Policy err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: policyTableStub{err: ErrFaultySQL}, }, }, { description: "JoinTables Policy err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), JoinTables: []JoinTable{ Join(policyTableStub{err: ErrFaultySQL}, Expr("1 = 1")), }, }, }, { description: "CTEs err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, CTEs: []CTE{ NewCTE("cte", nil, Queryf("{}", FaultySQL{})), }, FromTable: Expr("tbl"), }, }, { description: "sqlserver LimitTop err", item: SelectQuery{ Dialect: DialectSQLServer, SelectFields: Fields{Expr("f1")}, LimitTop: FaultySQL{}, FromTable: Expr("tbl"), OrderByFields: Fields{Expr("f1")}, }, }, { description: "postgres DistinctOnFields err", item: SelectQuery{ Dialect: DialectPostgres, SelectFields: Fields{Expr("f1")}, DistinctOnFields: Fields{FaultySQL{}}, FromTable: Expr("tbl"), }, }, { description: "SelectFields err", item: SelectQuery{ SelectFields: Fields{FaultySQL{}}, FromTable: Expr("tbl"), }, }, { description: "FromTable err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: FaultySQL{}, }, }, { description: "JoinTables err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), JoinTables: []JoinTable{ Join(Expr("tbl"), FaultySQL{}), }, }, }, { description: "WherePredicate VariadicPredicate err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), WherePredicate: And(FaultySQL{}), }, }, { description: "WherePredicate err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), WherePredicate: FaultySQL{}, }, }, { description: "GroupBy err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), GroupByFields: Fields{FaultySQL{}}, }, }, { description: "HavingPredicate VariadicPredicate err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), GroupByFields: Fields{Expr("f1")}, HavingPredicate: And(FaultySQL{}), }, }, { description: "HavingPredicate err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), GroupByFields: Fields{Expr("f1")}, HavingPredicate: FaultySQL{}, }, }, { description: "NamedWindows err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), NamedWindows: NamedWindows{{ Name: "w", Definition: OrderBy(FaultySQL{}), }}, }, }, { description: "OrderByFields err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), OrderByFields: Fields{FaultySQL{}}, }, }, { description: "LimitRows err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), LimitRows: FaultySQL{}, }, }, { description: "OffsetRows err", item: SelectQuery{ SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), OffsetRows: FaultySQL{}, }, }, { description: "FetchNext err", item: SelectQuery{ Dialect: DialectPostgres, SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), FetchNextRows: FaultySQL{}, }, }, { description: "LockClause err", item: SelectQuery{ Dialect: DialectPostgres, SelectFields: Fields{Expr("f1")}, FromTable: Expr("tbl"), LockClause: "FOR UPDATE OF {}", LockValues: []any{FaultySQL{}}, }, }} for _, tt := range errTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertErr(t, ErrFaultySQL) }) } } ================================================ FILE: sq.go ================================================ package sq import ( "bytes" "context" "database/sql" "database/sql/driver" "encoding/json" "fmt" "reflect" "strings" "sync" "github.com/bokwoon95/sq/internal/googleuuid" "github.com/bokwoon95/sq/internal/pqarray" ) var bufpool = &sync.Pool{ New: func() any { return &bytes.Buffer{} }, } // Dialects supported. const ( DialectSQLite = "sqlite" DialectPostgres = "postgres" DialectMySQL = "mysql" DialectSQLServer = "sqlserver" ) // SQLWriter is anything that can be converted to SQL. type SQLWriter interface { // WriteSQL writes the SQL representation of the SQLWriter into the query // string (*bytes.Buffer) and args slice (*[]any). // // The params map is used to hold the mappings between named parameters in // the query to the corresponding index in the args slice and is used for // rebinding args by their parameter name. The params map may be nil, check // first before writing to it. WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error } // DB is a database/sql abstraction that can query the database. *sql.Conn, // *sql.DB and *sql.Tx all implement DB. type DB interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) } // Result is the result of an Exec command. type Result struct { LastInsertId int64 RowsAffected int64 } // Query is either SELECT, INSERT, UPDATE or DELETE. type Query interface { SQLWriter // SetFetchableFields should return a query with its fetchable fields set // to the given fields. If not applicable, it should return false as the // second return value. SetFetchableFields([]Field) (query Query, ok bool) GetDialect() string } // Table is anything you can Select from or Join. type Table interface { SQLWriter IsTable() } // PolicyTable is a table that produces a policy (i.e. a predicate) to be // enforced whenever it is invoked in a query. This is equivalent to Postgres' // Row Level Security (RLS) feature but works application-side. Only SELECT, // UPDATE and DELETE queries are affected. type PolicyTable interface { Table Policy(ctx context.Context, dialect string) (Predicate, error) } // Window is a window used in SQL window functions. type Window interface { SQLWriter IsWindow() } // Field is either a table column or some SQL expression. type Field interface { SQLWriter IsField() } // Predicate is an SQL expression that evaluates to true or false. type Predicate interface { Boolean } // Assignment is an SQL assignment 'field = value'. type Assignment interface { SQLWriter IsAssignment() } // Any is a catch-all interface that covers every field type. type Any interface { Array Binary Boolean Enum JSON Number String Time UUID } // Enumeration represents a Go enum. type Enumeration interface { // Enumerate returns the names of all valid enum values. // // If the enum is backed by a string, each string in the slice is the // corresponding enum's string value. // // If the enum is backed by an int, each int index in the slice is the // corresponding enum's int value and the string is the enum's name. Enums // with empty string names are considered invalid, unless it is the very // first enum (at index 0). Enumerate() []string } // Array is a Field of array type. type Array interface { Field IsArray() } // Binary is a Field of binary type. type Binary interface { Field IsBinary() } // Boolean is a Field of boolean type. type Boolean interface { Field IsBoolean() } // Enum is a Field of enum type. type Enum interface { Field IsEnum() } // JSON is a Field of json type. type JSON interface { Field IsJSON() } // Number is a Field of numeric type. type Number interface { Field IsNumber() } // String is a Field of string type. type String interface { Field IsString() } // Time is a Field of time type. type Time interface { Field IsTime() } // UUID is a Field of uuid type. type UUID interface { Field IsUUID() } // DialectValuer is any type that will yield a different driver.Valuer // depending on the SQL dialect. type DialectValuer interface { DialectValuer(dialect string) (driver.Valuer, error) } // TableStruct is meant to be embedded in table structs to make them implement // the Table interface. type TableStruct struct { schema string name string alias string } // ViewStruct is just an alias for TableStruct. type ViewStruct = TableStruct var _ Table = (*TableStruct)(nil) // NewTableStruct creates a new TableStruct. func NewTableStruct(schema, name, alias string) TableStruct { return TableStruct{schema: schema, name: name, alias: alias} } // WriteSQL implements the SQLWriter interface. func (ts TableStruct) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { if ts.schema != "" { buf.WriteString(QuoteIdentifier(dialect, ts.schema) + ".") } buf.WriteString(QuoteIdentifier(dialect, ts.name)) return nil } // GetAlias returns the alias of the TableStruct. func (ts TableStruct) GetAlias() string { return ts.alias } // IsTable implements the Table interface. func (ts TableStruct) IsTable() {} func withPrefix(w SQLWriter, prefix string) SQLWriter { if field, ok := w.(interface { SQLWriter WithPrefix(string) Field }); ok { return field.WithPrefix(prefix) } return w } func getAlias(w SQLWriter) string { if w, ok := w.(interface{ GetAlias() string }); ok { return w.GetAlias() } return "" } func toString(dialect string, w SQLWriter) string { buf := bufpool.Get().(*bytes.Buffer) buf.Reset() defer bufpool.Put(buf) var args []any _ = w.WriteSQL(context.Background(), dialect, buf, &args, nil) return buf.String() } func writeFieldsWithPrefix(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, fields []Field, prefix string, includeAlias bool) error { var err error var alias string for i, field := range fields { if field == nil { return fmt.Errorf("field #%d is nil", i+1) } if i > 0 { buf.WriteString(", ") } err = withPrefix(field, prefix).WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("field #%d: %w", i+1, err) } if includeAlias { if alias = getAlias(field); alias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) } } } return nil } func writeFields(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, fields []Field, includeAlias bool) error { var err error var alias string for i, field := range fields { if field == nil { return fmt.Errorf("field #%d is nil", i+1) } if i > 0 { buf.WriteString(", ") } _, isQuery := field.(Query) if isQuery { buf.WriteString("(") } err = field.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("field #%d: %w", i+1, err) } if isQuery { buf.WriteString(")") } if includeAlias { if alias = getAlias(field); alias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) } } } return nil } // mapperFunctionPanicked recovers from any panics. // // The function is called as such so that it shows up as // "sq.mapperFunctionPanicked" in panic stack trace, giving the user a // descriptive clue of what went wrong (i.e. their mapper function panicked). func mapperFunctionPanicked(err *error) { if r := recover(); r != nil { switch r := r.(type) { case error: *err = r default: *err = fmt.Errorf(fmt.Sprint(r)) } } } // ArrayValue takes in a []string, []int, []int64, []int32, []float64, // []float32 or []bool and returns a driver.Valuer for that type. For Postgres, // it serializes into a Postgres array. Otherwise, it serializes into a JSON // array. func ArrayValue(value any) driver.Valuer { return &arrayValue{value: value} } type arrayValue struct { dialect string value any } // Value implements the driver.Valuer interface. func (v *arrayValue) Value() (driver.Value, error) { switch v.value.(type) { case []string, []int, []int64, []int32, []float64, []float32, []bool: break default: return nil, fmt.Errorf("value %#v is not a []string, []int, []int32, []float64, []float32 or []bool", v.value) } if v.dialect != DialectPostgres { var b strings.Builder err := json.NewEncoder(&b).Encode(v.value) if err != nil { return nil, err } return strings.TrimSpace(b.String()), nil } if ints, ok := v.value.([]int); ok { bigints := make([]int64, len(ints)) for i, num := range ints { bigints[i] = int64(num) } v.value = bigints } return pqarray.Array(v.value).Value() } // DialectValuer implements the DialectValuer interface. func (v *arrayValue) DialectValuer(dialect string) (driver.Valuer, error) { v.dialect = dialect return v, nil } // EnumValue takes in an Enumeration and returns a driver.Valuer which // serializes the enum into a string and additionally checks if the enum is // valid. func EnumValue(value Enumeration) driver.Valuer { return &enumValue{value: value} } type enumValue struct { value Enumeration } // Value implements the driver.Valuer interface. func (v *enumValue) Value() (driver.Value, error) { value := reflect.ValueOf(v.value) names := v.value.Enumerate() switch value.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i := int(value.Int()) if i < 0 || i >= len(names) { return nil, fmt.Errorf("%d is not a valid %T", i, v.value) } name := names[i] if name == "" && i != 0 { return nil, fmt.Errorf("%d is not a valid %T", i, v.value) } return name, nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: i := int(value.Uint()) if i < 0 || i >= len(names) { return nil, fmt.Errorf("%d is not a valid %T", i, v.value) } name := names[i] if name == "" && i != 0 { return nil, fmt.Errorf("%d is not a valid %T", i, v.value) } return name, nil case reflect.String: typ := value.Type() name := value.String() if getEnumIndex(name, names, typ) < 0 { return nil, fmt.Errorf("%q is not a valid %T", name, v.value) } return name, nil default: return nil, fmt.Errorf("underlying type of %[1]v is neither an integer nor string (%[1]T)", v.value) } } var ( enumIndexMu sync.RWMutex enumIndex = make(map[reflect.Type]map[string]int) ) // getEnumIndex returns the index of the enum within the names slice. func getEnumIndex(name string, names []string, typ reflect.Type) int { if len(names) <= 4 { for idx := range names { if names[idx] == name { return idx } } return -1 } var nameIndex map[string]int enumIndexMu.RLock() nameIndex = enumIndex[typ] enumIndexMu.RUnlock() if nameIndex != nil { idx, ok := nameIndex[name] if !ok { return -1 } return idx } idx := -1 nameIndex = make(map[string]int) for i := range names { if names[i] == name { idx = i } nameIndex[names[i]] = i } enumIndexMu.Lock() enumIndex[typ] = nameIndex enumIndexMu.Unlock() return idx } // JSONValue takes in an interface{} and returns a driver.Valuer which runs the // value through json.Marshal before submitting it to the database. func JSONValue(value any) driver.Valuer { return &jsonValue{value: value} } type jsonValue struct { value any } // Value implements the driver.Valuer interface. func (v *jsonValue) Value() (driver.Value, error) { var b strings.Builder err := json.NewEncoder(&b).Encode(v.value) return strings.TrimSpace(b.String()), err } // UUIDValue takes in a type whose underlying type must be a [16]byte and // returns a driver.Valuer. func UUIDValue(value any) driver.Valuer { return &uuidValue{value: value} } type uuidValue struct { dialect string value any } // Value implements the driver.Valuer interface. func (v *uuidValue) Value() (driver.Value, error) { if v.value == nil { return nil, nil } uuid, ok := v.value.([16]byte) if !ok { value := reflect.ValueOf(v.value) typ := value.Type() if value.Kind() != reflect.Array || value.Len() != 16 || typ.Elem().Kind() != reflect.Uint8 { return nil, fmt.Errorf("%[1]v %[1]T is not [16]byte", v.value) } for i := 0; i < value.Len(); i++ { uuid[i] = value.Index(i).Interface().(byte) } } if v.dialect != DialectPostgres { return uuid[:], nil } var buf [36]byte googleuuid.EncodeHex(buf[:], uuid) return string(buf[:]), nil } // DialectValuer implements the DialectValuer interface. func (v *uuidValue) DialectValuer(dialect string) (driver.Valuer, error) { v.dialect = dialect return v, nil } func preprocessValue(dialect string, value any) (any, error) { if dialectValuer, ok := value.(DialectValuer); ok { driverValuer, err := dialectValuer.DialectValuer(dialect) if err != nil { return nil, fmt.Errorf("calling DialectValuer on %#v: %w", dialectValuer, err) } value = driverValuer } switch value := value.(type) { case nil: return nil, nil case Enumeration: driverValue, err := (&enumValue{value: value}).Value() if err != nil { return nil, fmt.Errorf("converting %#v to string: %w", value, err) } return driverValue, nil case [16]byte: driverValue, err := (&uuidValue{dialect: dialect, value: value}).Value() if err != nil { if dialect == DialectPostgres { return nil, fmt.Errorf("converting %#v to string: %w", value, err) } return nil, fmt.Errorf("converting %#v to bytes: %w", value, err) } return driverValue, nil case driver.Valuer: driverValue, err := value.Value() if err != nil { return nil, fmt.Errorf("calling Value on %#v: %w", value, err) } return driverValue, nil } return value, nil } ================================================ FILE: sq.md ================================================ # sq (Structured Query)
code example of a select query using sq, to give viewers a quick idea of what the library is about
## Introduction to sq #introduction Github link: [github.com/bokwoon95/sq](https://github.com/bokwoon95/sq) sq is a type-safe data mapper and query builder for Go. It is not an ORM, but aims to be as convenient as an ORM while retaining the flexibility of a query builder/raw sql. Notable features: - Works across SQLite, Postgres, MySQL and SQL Server. [[more info](#set-query-dialect)] - Each dialect has its own query builder, allowing the full use of dialect-specific features. [[more info](#dialect-specific-features)] - Declarative schema migrations. [[more info](#declarative-schema)] - Supports arrays, enums, JSON and UUID. [[more info](#arrays-enums-json-uuid)] - Query logging. [[more info](#logging)] ## Installation #installation This package only supports Go 1.18 and above because it uses generics for data mapping. ```shell $ go get github.com/bokwoon95/sq $ go install -tags=fts5 github.com/bokwoon95/sqddl@latest ``` ## Quickstart #quickstart Connect to the database. ```go db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") ``` Define your model structs(s). ```go type Actor struct { ActorID int FirstName string LastName string LastUpdate time.Time } ``` Use one of the below three functions to run your query. - **FetchAll(db, query, rowmapper) ([]T, error)** - Fetch all results from a query. - Equivalent to [sql.Query](https://pkg.go.dev/database/sql#DB.Query). - **FetchOne(db, query, rowmapper) (T, error)**. - Fetch one result from a query. - Returns sql.ErrNoRows if no results. - Equivalent to [sql.QueryRow](https://pkg.go.dev/database/sql#DB.QueryRow). - **Exec(db, query) (sq.Result, error)**. - Executes a query. - Returns the rows affected (and the last insert ID, if it is supported by the dialect). - Equivalent to [sql.Exec](https://pkg.go.dev/database/sql#DB.Exec). ### Select example #rawsql-select #### Fetch all #rawsql-fetch-all ```sql SELECT actor_id, first_name, last_name FROM actor WHERE first_name = 'DAN' ``` ```go actors, err := sq.FetchAll(db, sq. Queryf("SELECT {*} FROM actor WHERE first_name = {}", "DAN"). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.Int("actor_id"), FirstName: row.String("first_name"), LastName: row.String("last_name"), } }, ) ``` #### Fetch one #rawsql-fetch-one ```sql SELECT actor_id, first_name, last_name FROM actor WHERE actor_id = 18 ``` ```go actor, err := sq.FetchOne(db, sq. Queryf("SELECT {*} FROM actor WHERE actor_id = {}", 18). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.Int("actor_id"), FirstName: row.String("first_name"), LastName: row.String("last_name"), } }, ) ``` #### Fetch cursor #rawsql-fetch-cursor ```sql SELECT actor_id, first_name, last_name FROM actor WHERE first_name = 'DAN' ``` ```go cursor, err := sq.FetchCursor(db, sq. Queryf("SELECT {*} FROM actor WHERE first_name = {}", "DAN"). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.Int("actor_id"), FirstName: row.String("first_name"), LastName: row.String("last_name"), } }, ) if err != nil { } defer cursor.Close() var actors []Actor for cursor.Next() { actor, err := cursor.Result() if err != nil { } actors = append(actors, actor) } ``` #### Fetch exists #rawsql-fetch-exists ```sql SELECT EXISTS (SELECT 1 FROM actor WHERE actor_id = 18) ``` ```go exists, err := sq.FetchExists(db, sq. Queryf("SELECT 1 FROM actor WHERE actor_id = {}", 18). SetDialect(sq.DialectPostgres), ) ``` ### Insert example #rawsql-insert #### Insert one #rawsql-insert-one ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (18, 'DAN', 'TORN') ``` ```go _, err := sq.Exec(db, sq. Queryf("INSERT INTO actor (actor_id, first_name, last_name) VALUES {}", sq.RowValue{ 18, "DAN", "TORN", }). SetDialect(sq.DialectPostgres), ) ``` #### Insert many #rawsql-insert-many ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (18, 'DAN', 'TORN'), (56, 'DAN', 'HARRIS'), (116, 'DAN', 'STREEP') ``` ```go _, err := sq.Exec(db, sq. Queryf("INSERT INTO actor (actor_id, first_name, last_name) VALUES {}", sq.RowValues{ {18, "DAN", "TORN"}, {56, "DAN", "HARRIS"}, {166, "DAN", "STREEP"}, }). SetDialect(sq.DialectPostgres), ) ``` ### Update example #rawsql-update ```sql UPDATE actor SET first_name = 'DAN', last_name = 'TORN' WHERE actor_id = 18 ``` ```go _, err := sq.Exec(db, sq. Queryf("UPDATE actor SET first_name = {}, last_name = {} WHERE actor_id = {}", "DAN", "TORN", 18, ). SetDialect(sq.DialectPostgres), ) ``` ### Delete example #rawsql-delete ```sql DELETE FROM actor WHERE actor_id = 56 ``` ```go _, err := sq.Exec(db, sq. Queryf("DELETE FROM actor WHERE actor_id = {}", 56). SetDialect(sq.DialectPostgres), ) ``` ## How the rowmapper works #rowmapper The [FetchAll/FetchOne/FetchCursor examples in the quickstart](#rawsql-select) use a rowmapper function both as a way of indicating what fields should be selected, as well as encoding how each row should be procedurally mapped back to a model struct. ```go // The rowmapper function signature should match func(*sq.Row) T. func(row *sq.Row) Actor { return Actor{ ActorID: row.Int("actor_id"), FirstName: row.String("first_name"), LastName: row.String("last_name"), } } ``` To go into greater detail, the rowmapper is first called in "passive mode" where the `sq.Row` records the fields needed by the SELECT query. Those fields are then injected back into the SELECT query ([via the `{*}` insertion point](#rawsql-select)) and the query is run for real. Then the rowmapper is called in "active mode" where each `sq.Row` method call actually returns a value from the underlying row. The `Actor` result returned by each rowmapper call is then appended into a slice. All this is done generically, so the rowmapper can yield any variable of type `T` and a slice `[]T` will be returned at the end. **The order in which you call the `sq.Row` methods must be deterministic and must not change between rowmapper invocations**. Don't put an `row.Int()` call inside an if-block, for example. ### Static vs dynamic queries #static-vs-dynamic-queries The [query examples in the quickstart](#rawsql-select) showcase dynamic queries, i.e. queries whose SELECT-ed fields are dynamically determined by the rowmapper. You can also write static queries, where the columns you SELECT are hardcoded into the query and the rowmapper references those fields by alias/name. ```go actors, err := sq.FetchAll(db, sq. Queryf("SELECT actor_id, first_name, last_name AS lname FROM actor WHERE first_name = {}", "DAN"). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { fmt.Printf("%#v\n", row.Columns()) // []string{"actor_id", "first_name", "lname"} fmt.Printf("%#v\n", row.Values()) // []any{18, "DAN", "TORN"} return Actor{ ActorID: row.Int("actor_id"), FirstName: row.String("first_name"), LastName: row.String("lname"), } }, ) ``` ### Handling errors #rowmapper-handling-errors If you do any computation in a rowmapper that returns an error, you can panic() with it and the error will be propagated as the error return value of FetchAll/FetchOne/FetchCursor. Try not to do anything that returns an error in the rowmapper. ```go func(row *sq.Row) Film { var film Film film.FilmID = row.Int("film_id") film.Title = row.String("title") film.Description = row.String("description") // Pull raw bytes from the DB and unmarshal as JSON. b := row.Bytes("special_features") err := json.Unmarshal(b, &film.SpecialFeatures) if err != nil { panic(err) } // Alternatively you can use row.JSON(), which doesn't // require you to do error handling. row.JSON(&film.SpecialFeatures, "special_features") return film } ``` ### Available methods #sq-row-methods ```go // These methods are straighforward and return the type associated with their // name. // // NULL values are automatically converted to a zero value: 0 for numbers, the // empty string for strings, an nil slice for []byte, etc. Use the NullXXX // method variants if capturing NULL is meaningful to you. var _ []byte = row.Bytes("field_name") var _ bool = row.Bool("field_name") var _ float64 = row.Float64("field_name") var _ int = row.Int("field_name") var _ int64 = row.Int64("field_name") var _ string = row.String("field_name") var _ time.Time = row.Time("field_name") // The sql.NullXXX variants. var _ sql.NullBool = row.NullBool("field_name") var _ sql.NullFloat64 = row.NullFloat64("field_name") var _ sql.NullInt64 = row.NullInt64("field_name") var _ sql.NullString = row.NullString("field_name") var _ sql.NullTime = row.NullTime("field_name") // row.Scan scans the value of field_name into a destination pointer. If the // pointer type implements sql.Scanner, this is where to use it. row.Scan(dest, "field_name") // row.Array scans the value of field_name into a destination slice pointer. Only // *[]bool, *[]int64, *[]int32, *[]float64, *[]float32 and *[]string are // supported. On Postgres this value must be an array, while for other dialects // this value must be a JSON array. row.Array(sliceDest, "field_name") // row.JSON scans the value of field_name into a destination pointer that // json.Unmarshal can unmarshal JSON into. The value must be JSON. row.JSON(jsonDest, "field_name") // row.UUID scans the value of field_name into a destination pointer whose // underlying type must be [16]byte. The value can be BINARY(16) or a UUID string. row.UUID(uuidDest, "field_name") ``` Additionally there are also the `Field` method variants that accept an `sq.Field` instead of a `string` name. This is relevant if you are [using the query builder](#querybuilder) instead of [raw SQL](#rawsql-select). ```go var _ []byte = row.BytesField(tbl.FIELD_NAME) var _ bool = row.BoolField(tbl.FIELD_NAME) var _ float64 = row.Float64Field(tbl.FIELD_NAME) var _ int = row.IntField(tbl.FIELD_NAME) var _ int64 = row.Int64Field(tbl.FIELD_NAME) var _ string = row.StringField(tbl.FIELD_NAME) var _ time.Time = row.TimeField(tbl.FIELD_NAME) var _ sql.NullBool = row.NullBoolField(tbl.FIELD_NAME) var _ sql.NullFloat64 = row.NullFloat64Field(tbl.FIELD_NAME) var _ sql.NullInt64 = row.NullInt64Field(tbl.FIELD_NAME) var _ sql.NullString = row.NullStringField(tbl.FIELD_NAME) var _ sql.NullTime = row.NullTimeField(tbl.FIELD_NAME) row.ScanField(dest, tbl.FIELD_NAME) row.ArrayField(sliceDest, tbl.FIELD_NAME) row.JSONField(jsonDest, tbl.FIELD_NAME) row.UUIDField(uuidDest, tbl.FIELD_NAME) ``` ## Setting the dialect of a query #set-query-dialect Each [sample query in the quickstart](#rawsql-select) has its dialect set to Postgres. ```go sq.Queryf("SELECT {*} FROM actor WHERE first_name = {}", "DAN").SetDialect(sq.DialectPostgres) ``` This is to generate a Postgres-compatible query, where each curly brace `{}` placeholder is replaced with a Postgres dollar placeholder (e.g. $1, $2, $3). This is the same case for the [query builder](#querybuilder-select). You can choose one of four possible dialects: ```go const ( DialectSQLite = "sqlite" // placeholders are $1, $2, $3 DialectPostgres = "postgres" // placeholders are $1, $2, $3 DialectMySQL = "mysql" // placeholders are ?, ?, ? DialectSQLServer = "sqlserver" // placeholders are @p1, @p2, @p3 ) ``` Each dialect that you pick will use the corresponding placeholder type when generating the query. [Ordinal placeholders (`{1}`, `{2}`, `{3}`) and named placeholders (`{foo}`, `{bar}`, `{baz}`)](#ordinal-named-placeholders) are also supported. You can use the **sq.SQLite**, **sq.Postgres**, **sq.MySQL** and **sq.SQLServer** package-level variables as shorthand for setting the dialect (in order to type less). ```go sq.SQLite.Queryf(query) // sq.Queryf(query).SetDialect(sq.DialectSQLite) sq.Postgres.Queryf(query) // sq.Queryf(query).SetDialect(sq.DialectPostgres) sq.MySQL.Queryf(query) // sq.Queryf(query).SetDialect(sq.DialectMySQL) sq.SQLServer.Queryf(query) // sq.Queryf(query).SetDialect(sq.DialectSQLServer) ``` ### Setting the query dialect globally #set-query-dialect-globally To set the default dialect globally, set the value of sq.DefaultDialect. This value is used when no dialect is provided (i.e. an empty string). ```go func init() { // Sets the default dialect of all queries to Postgres (unless a dialect is // explicitly provided). // // NOTE: You can't use a pointer to sq.DialectPostgres directly because it is // a constant which cannot be addressed. dialect := sq.DialectPostgres sq.DefaultDialect.Store(&dialect) } ``` ## sq's query templating syntax #templating-syntax sq.Queryf (and sq.Expr) use a Printf-style templating syntax where the format string uses curly brace `{}` placeholders. Here is a basic example for Queryf: > **Note**: All examples below interpolate their arguments into the SQL query for illustrative purposes, but in actuality the [proper prepared statement placeholders](#set-query-dialect) will be generated. ```go sq.Queryf("SELECT first_name FROM actor WHERE actor_id = {}", 18) ``` ```sql SELECT first_name FROM actor WHERE actor_id = 18 ``` sq.Queryf has an Append() method which allows for a basic level of query building: ```go var ( name = "bob" email = "bob@email.com" age = 27 ) q := sq.Queryf("SELECT name, email FROM tbl WHERE 1 = 1") // https://stackoverflow.com/questions/1264681/what-is-the-purpose-of-using-where-1-1-in-sql-statements if name != "" { q = q.Append("AND name = {}", name) } if email != "" { q = q.Append("AND email = {}", email) } if age != 0 { q = q.Append("AND age = {}", age) } ``` ```sql SELECT name, email FROM tbl WHERE 1 = 1 AND name = 'bob' AND email = 'bob@email.com' AND age = 27 ``` Unlike with SQL prepared statements, the curly brace `{}` placeholders are allowed to change the structure of a query (i.e. it can appear anywhere inside a query): ```go sq.Queryf( "SELECT {} FROM {} WHERE first_name = {}", sq.Fields{ sq.Expr("actor_id"), sq.Expr("last_name"), }, sq.Expr("actor"), "DAN", ) ``` ```sql SELECT actor_id, last_name FROM actor WHERE first_name = 'DAN' ``` ### Escaping the curly brace #escaping-curly-brace If you wish to actually use curly braces `{}` inside the format string (which is very rare), you must escape the opening curly brace by doubling it up like this: `{{}`. ```go sq.Queryf("SELECT '{{}', '{{abcd}'") ``` ```sql SELECT '{}', '{abcd}' ``` ### Value expansion #value-expansion Each value passed to the query preprocessor is evaluated based on the following cases in the order shown: 1. If the value [implements the `SQLWriter` interface](#sqlwriter), its `WriteSQL` method is called. 2. Else if the value is a slice, the slice is expanded into a comma separated list. - Each item in this list is further evaluated recursively following the same logic. - byte slices (`[]byte`) are the exception, they are treated as a unit and do not undergo slice expansion. 3. Otherwise, a dialect-appropriate placeholder is appended to the query string and the value itself is appended to the args. Here is an example of the three different cases in action. ```go sq.Queryf( "SELECT {} FROM actor WHERE actor_id IN ({}) AND first_name = {}", // case 1 sq.Expr("jsonb_build_object({})", []any{ // case 2 sq.Literal("first_name"), // case 1 sq.Expr("first_name"), // case 1 sq.Literal("last_name"), // case 1 sq.Expr("last_name"), // case 1 }), // case 2 []int{18, 56, 116}, // case 3 "DAN", ).SetDialect(sq.DialectPostgres) ``` ```sql SELECT jsonb_build_object('first_name', first_name, 'last_name', last_name) FROM actor WHERE actor_id IN ($1, $2, $3) AND first_name = $4 -- args: 18, 56, 11, 'DAN' ``` ### Ordinal and Named placeholders #ordinal-named-placeholders The templating syntax supports 3 types of placeholders: 1. Anonymous placeholders `{}`. 2. Ordinal placeholders `{1}`, `{2}`, `{3}`. - Ordinal placeholders used 1-based indexing. 3. Named placeholders `{foo}`, `{bar}`, `{baz}`. - Named placeholders in the format string must have a corresponding `sql.Named` value. - Placeholder names must consist only of unicode letters, numbers `0-9` or underscore `_`. It is possible for an anonymous placeholder, an ordinal placeholder and a named placeholder to refer to the same value. ```go sq.Queryf("SELECT {}, {2}, {}, {name}", "Marco", sql.Named("name", "Polo")) // └─────────────┘ // All refer to 'Polo' ``` ```sql SELECT 'Marco', 'Polo', 'Polo', 'Polo' ``` #### Anonymous parameter example #anonymous-params ```go sq.SQLite.Queryf( "SELECT {}, {}, {}", "foo", "bar", "foo") // SQLite sq.Postgres.Queryf( "SELECT {}, {}, {}", "foo", "bar", "foo") // Postgres sq.MySQL.Queryf( "SELECT {}, {}, {}", "foo", "bar", "foo") // MySQL sq.SQLServer.Queryf("SELECT {}, {}, {}", "foo", "bar", "foo") // SQLServer ``` ```sql SELECT $1, $2, $3 -- SQLite, Args: 'foo', 'bar', 'foo' SELECT $1, $2, $3 -- Postgres, Args: 'foo', 'bar', 'foo' SELECT ?, ?, ? -- MySQL, Args: 'foo', 'bar', 'foo' SELECT @p1, @p2, @p3 -- SQLServer, Args: 'foo', 'bar', 'foo' ``` #### Ordinal parameter example #ordinal-params ```go sq.SQLite.Queryf( "SELECT {1}, {2}, {1}", "foo", "bar") // SQLite sq.Postgres.Queryf( "SELECT {1}, {2}, {1}", "foo", "bar") // Postgres sq.MySQL.Queryf( "SELECT {1}, {2}, {1}", "foo", "bar") // MySQL sq.SQLServer.Queryf("SELECT {1}, {2}, {1}", "foo", "bar") // SQLServer ``` ```sql SELECT $1, $2, $1 -- SQLite, Args: 'foo', 'bar' SELECT $1, $2, $1 -- Postgres, Args: 'foo', 'bar' SELECT ?, ?, ? -- MySQL, Args: 'foo', 'bar', 'foo' SELECT @p1, @p2, @p1 -- SQLServer, Args: 'foo', 'bar' ``` #### Named parameter example #named-params ```go // SQLite sq.SQLite.Queryf("SELECT {one}, {two}, {one}", sql.Named("one", "foo"), sql.Named("two", "bar"), ) // Postgres sq.Postgres.Queryf("SELECT {one}, {two}, {one}", sql.Named("one", "foo"), sql.Named("two", "bar"), ) // MySQL sq.MySQL.Queryf("SELECT {one}, {two}, {one}", sql.Named("one", "foo"), sql.Named("two", "bar"), ) // SQLServer sq.SQLServer.Queryf("SELECT {one}, {two}, {one}", sql.Named("one", "foo"), sql.Named("two", "bar"), ) ``` ```sql SELECT $one, $two, $one -- SQLite, Args: one: 'foo', two: 'bar' SELECT $1, $2, $1 -- Postgres, Args: 'foo', 'bar' SELECT ?, ?, ? -- MySQL, Args: 'foo', 'bar', 'foo' SELECT @one, @two, @one -- SQLServer, Args: one: 'foo', two: 'bar' ``` ### SQLWriter example #sqlwriter An SQLWriter represents anything that can render itself as SQL. It is the first thing taken into consideration during [value expansion](#value-expansion). Here is the definition of the SQLWriter interface: ```go type SQLWriter interface { WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error } ``` As an example, we will create a custom SQLWriter component that renders itself as string `str` for `num` times, where `str` and `num` are parameters: ```go sq.Queryf("SELECT {}", multiplier{str: "lorem ipsum", num: 5, delim: " "}) ``` ```sql SELECT lorem ipsum lorem ipsum lorem ipsum lorem ipsum lorem ipsum ``` This is the implementation of `multiplier`: ```go type multiplier struct { str string num int delim string } func (m multiplier) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { for i := 0; i < m.num; i++ { if i > 0 { buf.WriteString(m.delim) } buf.WriteString(m.str) } return nil } ``` ```go sq.Queryf("SELECT {}", multiplier{str: "foo", num: 3, delim: "AND"}) sq.Queryf("SELECT {}", multiplier{str: "lorem ipsum", num: 4, delim: ", "}) sq.Queryf("SELECT {}", multiplier{str: "🎉", num: 6, delim: ""}) ``` ```sql SELECT foo AND foo AND foo SELECT lorem ipsum, lorem ipsum, lorem ipsum, lorem ipsum SELECT 🎉🎉🎉🎉🎉🎉 ``` ## Using the query builder #querybuilder ### Table structs #table-structs To use a query builder, you need to first define your table struct(s). ```go type ACTOR struct { sq.TableStruct // A table struct is marked by embedding sq.TableStruct as the first field. ACTOR_ID sq.NumberField FIRST_NAME sq.StringField LAST_NAME sq.StringField LAST_UPDATE sq.TimeField } ``` You can then instantiate the table using [sq.New()](https://pkg.go.dev/github.com/bokwoon95/sq#New) and use it to create predicates and participate in a query. ```go a := sq.New[ACTOR]("a") // actor AS a a.ACTOR_ID.EqInt(18) // a.actor_id = 18 a.LAST_UPDATE.IsNotNull() // a.last_update IS NOT NULL sq.Select(a.FIRST_NAME, a.LAST_NAME).From(a).Where(a.ACTOR_ID.In([]int{18, 56, 116})) // SELECT a.first_name, a.last_name FROM actor AS a WHERE a.actor_id IN (18, 56, 116) ``` #### Model structs #model-structs In general, there should be two types of structs that you use with the query builder. One is the table struct, which represents an instance of an SQL table. The other is a model struct, which represents an instance of a domain model (in this example, an actor). ```go // Table struct (represents your SQL table). type ACTOR struct { sq.TableStruct `sq:"Actor"` ACTOR_ID sq.NumberField `sq:"ActorID"` FIRST_NAME sq.StringField `sq:"FirstName"` LAST_NAME sq.StringField `sq:"LastName"` LAST_UPDATE sq.TimeField `sq:"LastUpdate"` } // Model struct (represents an instance of an actor). type Actor struct { ActorID int FirstName string LastName string LastUpdate time.Time } // Note the different casing of ACTOR vs Actor. a := sq.New[ACTOR]("a") actors, err := sq.FetchAll(db, sq. From(a). Where(a.FIRST_NAME.EqString("DAN")). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` ### Available Field types #field-types There are 10 available field types that you can use in your [table structs](#table-structs). - **NumberField** (`int`, `int64`, INT, BIGINT, NUMERIC, etc) - **StringField** (`string`, TEXT, VARCHAR, etc) - **TimeField** (`time.Time`, DATE, DATETIME, TIMESTAMP, etc) - **BooleanField** (`bool`, BOOLEAN, TINYINT, BIT, etc) - **BinaryField** (`[]byte`, BYTEA, BINARY, etc) - **ArrayField** - Represents a primitive slice type in Go (`[]string`, `[]int64`, `[]int32`, `[]float64`, `[]float32`, `[]bool`) - In Postgres, this is a native array (TEXT[], INT[], BIGINT[], NUMERIC[], BOOLEAN[]) - In other databases, this is a JSON array. - **EnumField** - Represents an "enum" type in Go (`iota`, `string`, take your pick) - In Postgres, this is a native enum type (CREATE TYPE AS ENUM) - In other databases, this is a plain string. - Your Go enum type must [implement the `Enumeration` interface](#enums). - **JSONField** - Represents a Go type that works with `json.Marshal` and `json.Unmarshal`. - In Postgres, this is the JSONB or JSON type. - In MySQL, this is the JSON type. - In other databases, this is a plain string. - **UUIDField** - Represents any type whose underlying type is [16]byte in Go. - In Postgres, this is a UUID. - In other databases, this is a BINARY(16). - **AnyField** - A catch-all field type that can substitute as any of the 9 other field types. - Use this to represent types like `TSVECTOR` that don't have a corresponding representation. ### Field name to column name translation #field-name-translation The table name and column names are derived by lowercasing the struct name and struct field names. So a struct `ACTOR` will be translated to a table called `actor`, and a field `ACTOR_ID` will be translated to a column called `actor_id`. If that is not what you want, you can specify the desired name inside an `sq` struct tag. ```go type ACTOR struct { sq.TableStruct `sq:"Actor"` ACTOR_ID sq.NumberField `sq:"ActorID"` FIRST_NAME sq.StringField `sq:"FirstName"` LAST_NAME sq.StringField `sq:"LastName"` LAST_UPDATE sq.TimeField `sq:"LastUpdate"` } a := sq.New[ACTOR]("") // "Actor" a.ACTOR_ID // "Actor"."ActorID" a.FIRST_NAME // "Actor"."FirstName" ``` ### Aliasing a table struct #alias-table-struct sq.New() takes in an alias string as an argument and returns a table with that alias. Leave the alias string blank if you don't want the table to have an alias. ```go a1 := sq.New[ACTOR]("a") // actor AS a a1.ACTOR_ID // a.actor_id a2 := sq.New[ACTOR]("") // actor a2.ACTOR_ID // actor.actor_id ``` ### Table structs as a declarative schema #declarative-schema #### Generating migrations #generating-migrations > Example: here is [the table struct representation](https://github.com/bokwoon95/sqddl/blob/main/ddl/testdata/tables.go.txt) of the [sakila database schema](https://www.jooq.org/sakila). Your [table structs](#table-structs) serve as a declarative schema for your tables. The [sqddl tool](https://bokwoon.neocities.org/sqddl.html) is able to parse Go files containing table structs and [generate the necessary migrations](https://bokwoon.neocities.org/sqddl.html#generate) needed to reach that desired schema. The generated migrations can then be [applied using the same sqddl tool](https://bokwoon.neocities.org/sqddl.html#migrate). ```shell # Generate migrations needed to go from $DATABASE_URL to tables/tables.go and write into ./migrations dir $ sqddl generate -src "$DATABASE_URL" -dest tables/tables.go -output-dir ./migrations # Apply the pending migrations in ./migrations dir against the database $DATABASE_URL $ sqddl migrate -db "$DATABASE_URL" -dir ./migrations ``` For more information on how to express "CREATE TABLE" DDL using tables structs, please check out the [sqddl documentation](https://bokwoon.neocities.org/sqddl.html#table-structs). #### Generating table structs #generating-table-structs The reverse is also possible, you can [generate table structs from an existing database](https://bokwoon.neocities.org/sqddl.html#tables). If you have an existing database this is the recommended way to get started, rather than creating the table structs manually to match the database. ```shell # Generate table structs from $DATABASE_URL and write into tables/tables.go $ sqddl tables -db "$DATABASE_URL" -file tables/tables.go ``` Once you have your table structs, you can edit your table structs and [generate migrations](#generating-migrations) from them. Note that migration generation only covers [a subset of possible DDL operations](#) so it's possible that you will have to write some migrations by hand. ### Select example #querybuilder-select #### Fetch all #querybuilder-fetch-all ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.first_name = 'DAN' ``` ```go a := sq.New[ACTOR]("a") actors, err := sq.FetchAll(db, sq. From(a). Where(a.FIRST_NAME.EqString("DAN")). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` #### Fetch one #querybuilder-fetch-one ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.actor_id = 18 ``` ```go a := sq.New[ACTOR]("a") actor, err := sq.FetchOne(db, sq. From(a). Where(a.ACTOR_ID.EqInt(18)). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` #### Fetch cursor #querybuilder-fetch-cursor ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.first_name = 'DAN' ``` ```go a := sq.New[ACTOR]("a") cursor, err := sq.FetchCursor(db, sq. From(a). Where(a.FIRST_NAME.EqString("DAN")). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) if err != nil { } defer cursor.Close() var actor []Actor for cursor.Next() { actor, err := cursor.Result() if err != nil { } actors = append(actors, actor) } ``` #### Fetch exists #querybuilder-fetch-exists ```sql SELECT EXISTS (SELECT 1 FROM actor AS a WHERE a.actor_id = 18) ``` ```go a := sq.New[ACTOR]("a") exists, err := sq.FetchExists(db, sq. SelectOne(). From(a). Where(a.ACTOR_ID.EqInt(18)). SetDialect(sq.DialectPostgres), ) ``` #### Fetch distinct #querybuilder-fetch-distinct ```sql SELECT DISTINCT a.first_name FROM actor AS a ``` ```go a := sq.New[ACTOR]("a") firstNames, err := sq.FetchAll(db, sq. SelectDistinct(). From(a). SetDialect(sq.DialectPostgres), func(row *sq.Row) string { return row.String(a.FIRST_NAME) }, ) ``` ### Insert example #querybuilder-insert #### Insert one #querybuilder-insert-one ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (18, 'DAN', 'TORN') ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(18, "DAN", "TORN"). SetDialect(sq.DialectPostgres), ) ``` #### Insert many #querybuilder-insert-many ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (18, 'DAN', 'TORN'), (56, 'DAN', 'HARRIS'), (116, 'DAN', 'STREEP') ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(18, "DAN", "TORN"). Values(56, "DAN", "HARRIS"). Values(166, "DAN", "STREEP"). SetDialect(sq.DialectPostgres), ) ``` #### Insert from Select #querybuilder-insert-from-select ```sql INSERT INTO actor (actor_id, first_name, last_name) SELECT actor.actor_id, actor.first_name, actor.last_name FROM actor WHERE actor.last_update IS NOT NULL ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Select(sq. Select(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). From(a). Where(a.LAST_UPDATE.IsNotNull()), ). SetDialect(sq.DialectPostgres), ) ``` #### Insert one (column mapper) #querybuilder-insert-one-columnmapper ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (18, 'DAN', 'TORN') ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq. InsertInto(a). ColumnValues(func(col *sq.Column) { col.SetInt(a.ACTOR_ID, 18) col.SetString(a.FIRST_NAME, "DAN") col.SetString(a.LAST_NAME, "TORN") return nil }). SetDialect(sq.DialectPostgres), ) ``` #### Insert many (column mapper) #querybuilder-insert-many-columnmapper ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (18, 'DAN', 'TORN'), (56, 'DAN', 'HARRIS'), (116, 'DAN', 'STREEP') ``` ```go actors := []Actor{ {ActorID: 18, FirstName: "DAN", LastName: "TORN"}, {ActorID: 56, FirstName: "DAN", LastName: "HARRIS"}, {ActorID: 166, FirstName: "DAN", LastName: "STREEP"}, } a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq. InsertInto(a). ColumnValues(func(col *sq.Column) { for _, actor := range actors { col.SetInt(a.ACTOR_ID, actor.ActorID) col.SetString(a.FIRST_NAME, actor.FirstName) col.SetString(a.LAST_NAME, actor.LastName) } return nil }). SetDialect(sq.DialectPostgres), ) ``` #### How does the Insert column mapper work? #insert-columnmapper The Insert column mapper works by having the `sq.Column` note down the very first field passed to it. Everytime `sq.Column` sees that field again, it will treat it as starting a new row value. ```go a := sq.New[ACTOR]("") q := sq. InsertInto(a). ColumnValues(func(col *sq.Column) { col.SetInt(a.ACTOR_ID, 1) // every a.ACTOR_ID will mark the start of a new row value col.SetString(a.FIRST_NAME, "PENELOPE") col.SetString(a.LAST_NAME, "GUINESS") col.SetInt(a.ACTOR_ID, 2) col.SetString(a.FIRST_NAME, "NICK") col.SetString(a.LAST_NAME, "WAHLBERG") col.SetInt(a.ACTOR_ID, 3) col.SetString(a.FIRST_NAME, "ED") col.SetString(a.LAST_NAME, "CHASE") return nil }). SetDialect(sq.DialectPostgres) ``` ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ``` ### Update example #querybuilder-update ```sql UPDATE actor SET first_name = 'DAN', last_name = 'TORN' WHERE actor.actor_id = 18 ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq. Update(a). Set( a.FIRST_NAME.SetString("DAN"), a.LAST_NAME.SetString("TORN"), ). Where(a.ACTOR_ID.EqInt(18)). SetDialect(sq.DialectPostgres), ) ``` #### Update (column mapper) #update-columnmapper ```sql UPDATE actor SET first_name = 'DAN', last_name = 'TORN' WHERE actor.actor_id = 18 ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq. Update(a). SetFunc(func(col *sq.Column) { col.SetString(a.FIRST_NAME, "DAN") col.SetString(a.LAST_NAME, "TORN") return nil }). Where(a.ACTOR_ID.EqInt(18)). SetDialect(sq.DialectPostgres), ) ``` ### Delete example #querybuilder-delete-example ```sql DELETE FROM actor WHERE actor.actor_id = 56 ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq. DeleteFrom(a). Where(a.ACTOR_ID.EqInt(56)). SetDialect(sq.DialectPostgres), ) ``` ### Combining predicates (AND and OR) #combining-predicates `Where()` accepts more than one predicate. By default, those predicates are `AND`-ed together. ```go a := sq.New[ACTOR]("a") query := sq. Select(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). From(a). Where( a.FIRST_NAME.EqString("BOB"), a.LAST_NAME.EqString("THE BUILDER"), a.LAST_UPDATE.IsNotNull(), ) ``` ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.first_name = 'BOB' AND a.last_name = 'THE BUILDER' AND a.last_update IS NOT NULL ``` If you need to `OR` those predicates together, wrap them in `sq.Or()`. ```go a := sq.New[ACTOR]("a") query := sq. Select(a.actor_id, a.FIRST_NAME, a.LAST_NAME). From(a). Where(sq.Or( // <-- sq.Or a.FIRST_NAME.EqString("BOB"), a.LAST_NAME.EqString("THE BUILDER"), a.LAST_UPDATE.IsNotNull(), )) ``` ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.first_name = 'BOB' OR a.last_name = 'THE BUILDER' OR a.last_update IS NOT NULL ``` ### Using expressions in the query builder #expr If you need to do SQL math or call an SQL function, you need to use sq.Expr() to create an expression. [The same query templating syntax](#templating-syntax) in sq.Queryf() can be used here. ```sql SELECT a.first_name || ' ' || a.last_name AS fullname FROM actor AS a WHERE a.actor_id IN (18, 56, 116) ``` ```go a := sq.New[ACTOR]("a") q := sq. Select(sq.Expr("{} || ' ' || {}", a.FIRST_NAME, a.LAST_NAME).As("fullname")). From(a). Where(a.ACTOR_ID.In([]int{18, 56, 116})). SetDialect(sq.DialectPostgres) ``` sq.Expr() satisfies the `Any` interface and can be used wherever a `Number`, `String`, `Time`, `Boolean`, `Binary`, `Array`, `Enum`, `JSON` or `UUID` interface is expected. #### Dialect expressions #dialect-expr Sometimes a query may be the same across different dialects save for some dialect-specific function call or expression, which changes for each dialect. In those cases you can use sq.DialectExpr() to use different expressions depending on the dialect. ```sql -- The 3 queries below are nearly identical except for the name of their JSON -- aggregation function. -- SQLite SELECT json_group_array(a.last_name) FROM actor AS a WHERE a.first_name = 'DAN' -- Postgres SELECT json_agg(a.last_name) FROM actor AS a WHERE a.first_name = 'DAN' -- MySQL SELECT json_arrayagg(a.last_name) FROM actor AS a WHERE a.first_name = 'DAN' ``` ```go a := sq.New[ACTOR]("a") q := sq. Select( sq.DialectExpr("json_group_array({})", a.LAST_NAME). // default case DialectExpr(sq.DialectPostgres, "json_agg({})", a.LAST_NAME). // if dialect == sq.DialectPostgres DialectExpr(sq.DialectMySQL, "json_arrayagg({})", a.LAST_NAME), // if dialect == sq.DialectMySQL ), From(a). Where(a.FIRST_NAME.EqString("DAN")). SetDialect(dialect) ``` Similar to sq.Expr(), sq.DialectExpr() can be used wherever a `Number`, `String`, `Time`, `Boolean`, `Binary`, `Array`, `Enum`, `JSON` or `UUID` interface is expected. ## How do I use dialect-specific features? #dialect-specific-features There are dialect-specific query builders for each dialect that are accessible through the four package-level variables: - **sq.SQLite** - **sq.Postgres** - **sq.MySQL** - **sq.SQLServer** Do note that you can also use the dialect-agnostic query builder ([as shown in the query builder examples)](#querybuilder-select) if you're not using any dialect-specific features. Doing so will make your queries more portable, as you can just [toggle the dialect on the query](#set-query-dialect) and have it work across multiple databases without effort. ### SQLite-specific features #sqlite-specific-features #### RETURNING #sqlite-returning ```sql INSERT INTO actor (first_name, last_name) VALUES ('PENELOPE', 'GUINESS'), ('NICK', 'WAHLBERG'), ('ED', 'CHASE') RETURNING actor.actor_id, actor.first_name, actor.last_name ``` ```go a := sq.New[ACTOR]("") actors, err := sq.FetchAll(db, sq.SQLite. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("PENELOPE", "GUINESS"). Values("NICK", "WAHLBERG"). Values("ED", "CHASE"), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` #### LastInsertId #sqlite-last-insert-id ```sql INSERT INTO actor (first_name, last_name) VALUES ('PENELOPE', 'GUINESS'); SELECT last_insert_rowid(); ``` ```go a := sq.New[ACTOR]("") result, err := sq.Exec(db, sq.SQLite. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("PENELOPE", "GUINESS"), ) if err != nil { } fmt.Println(result.LastInsertId) // int64 ``` #### Insert ignore duplicates #sqlite-insert-ignore-duplicates ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ON CONFLICT DO NOTHING ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq.SQLite. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(1, "PENELOPE", "GUINESS"). Values(2, "NICK", "WAHLBERG"). Values(3, "ED", "CHASE"). OnConflict().DoNothing(), ) ``` #### Upsert #sqlite-upsert ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ON CONFLICT (actor_id) DO UPDATE SET first_name = EXCLUDED.first_name, last_name = EXCLUDED.last_name ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq.SQLite. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(1, "PENELOPE", "GUINESS"). Values(2, "NICK", "WAHLBERG"). Values(3, "ED", "CHASE"). OnConflict(a.ACTOR_ID).DoUpdateSet( a.FIRST_NAME.Set(a.FIRST_NAME.WithPrefix("EXCLUDED")), a.LAST_NAME.Set(a.LAST_NAME.WithPrefix("EXLCUDED")), ), ) ``` #### Update with Join #sqlite-update-with-join ```sql UPDATE actor SET last_name = 'DINO' FROM film_actor JOIN film ON film.film_id = film_actor.film_id WHERE film_actor.actor_id = actor.actor_id AND film.title = 'ACADEMY DINOSAUR' ``` ```go a, fa, f := sq.New[ACTOR](""), sq.New[FILM_ACTOR](""), sq.New[FILM]("") _, err := sq.Exec(db, sq.SQLite. Update(a). Set(a.LAST_NAME.SetString("DINO")). From(fa). Join(f, f.FILM_ID.Eq(fa.FILM_ID)). Where( fa.ACTOR_ID.Eq(a.ACTOR_ID), f.TITLE.EqString("ACADEMY DINOSAUR"), ), ) ``` #### Delete with Join #sqlite-delete-with-join This is not technically an SQLite-specific feature as it uses a plain subquery to achieve a Delete with Join. Other databases have their own dialect-specific way of doing this, but this method works across every database and as such I prefer it over the others. ```sql DELETE FROM actor WHERE EXISTS ( SELECT 1 FROM film_actor JOIN film ON film.film_id = film_actor.film_id WHERE film_actor.actor_id = actor.actor_id AND film.title = 'ACADEMY DINOSAUR' ) ``` ```go a, fa, f := sq.New[ACTOR](""), sq.New[FILM_ACTOR](""), sq.New[FILM]("") _, err := sq.Exec(db, sq.SQLite. DeleteFrom(a). Where(sq.Exists(sq. SelectOne(). From(fa). Join(f, f.FILM_ID.Eq(f.FILM_ID)). Where( fa.ACTOR_ID.Eq(a.ACTOR_ID), f.TITLE.EqString("ACADEMY DINOSAUR"), ), )), ) ``` #### Bulk Update #sqlite-bulk-update ```sql UPDATE actor SET first_name = tmp.first_name, last_name = tmp.last_name FROM ( SELECT 1 AS actor_id, 'PENELOPE' AS first_name, 'GUINESS' AS last_name UNION ALL SELECT 2, 'NICK', 'WAHLBERG' UNION ALL SELECT 3, 'ED', 'CHASE' ) AS tmp WHERE tmp.actor_id = actor.actor_id ``` ```go a := sq.New[ACTOR]("") tmp := sq.SelectValues{ Alias: "tmp", Columns: []string{"actor_id", "first_name", "last_name"}, RowValues: [][]any{ {1, "PENELOPE", "GUINESS"}, {2, "NICK", "WAHLBERG"}, {3, "ED", "CHASE"}, }, } _, err := sq.Exec(db, sq.SQLite. Update(a). Set( a.FIRST_NAME.Set(tmp.Field("first_name")), a.LAST_NAME.Set(tmp.Field("last_name")), ). From(tmp). Where(tmp.Field("actor_id").Eq(a.ACTOR_ID)), ) ``` ### Postgres-specific features #postgres-specific-features #### DISTINCT ON #postgres-distinct-on ```sql SELECT DISTINCT ON (a.first_name) a.first_name, a.last_name FROM actor AS a ORDER BY a.first_name ``` ```go a := sq.New[ACTOR]("a") actors, err := sq.FetchAll(db, sq.Postgres. From(a). DistinctOn(a.FIRST_NAME). OrderBy(a.FIRST_NAME), func(row *sq.Row) Actor { return Actor{ FirstName: row.String(a.FIRST_NAME), LastName: row.String(a.LAST_NAME), } }, ) ``` #### FETCH NEXT, WITH TIES #postgres-fetch-next-with-ties ```sql SELECT a.first_name FROM actor AS a OFFSET 5 FETCH NEXT 10 ROWS WITH TIES ``` ```go a := sq.New[ACTOR]("a") firstNames, err := sq.FetchAll(db, sq.Postgres. From(a). Offset(5). FetchNext(10).WithTies(), func(row *sq.Row) string { return row.String(a.FIRST_NAME) }, ) ``` #### FOR UPDATE, FOR SHARE #postgres-for-update-for-share **For Update** ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.first_name = 'DAN' FOR UPDATE SKIP LOCKED ``` ```go actors, err := sq.FetchAll(db, sq.Postgres. From(a). Where(a.FIRST_NAME.EqString("DAN")). LockRows("FOR UPDATE SKIP LOCKED"), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` **For Share** ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.first_name = 'DAN' FOR SHARE ``` ```go actors, err := sq.FetchAll(db, sq.Postgres. From(a). Where(a.FIRST_NAME.EqString("DAN")). LockRows("FOR SHARE"), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` #### RETURNING #postgres-returning ```sql INSERT INTO actor (first_name, last_name) VALUES ('PENELOPE', 'GUINESS'), ('NICK', 'WAHLBERG'), ('ED', 'CHASE') RETURNING actor.actor_id, actor.first_name, actor.last_name ``` ```go a := sq.New[ACTOR]("") actors, err := sq.FetchAll(db, sq.Postgres. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("PENELOPE", "GUINESS"). Values("NICK", "WAHLBERG"). Values("ED", "CHASE"), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` #### Insert ignore duplicates #postgres-insert-ignore-duplicates ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ON CONFLICT DO NOTHING ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq.Postgres. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(1, "PENELOPE", "GUINESS"). Values(2, "NICK", "WAHLBERG"). Values(3, "ED", "CHASE"). OnConflict().DoNothing(), ) ``` #### Upsert #postgres-upsert ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ON CONFLICT (actor_id) DO UPDATE SET first_name = EXCLUDED.first_name, last_name = EXCLUDED.last_name ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq.Postgres. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(1, "PENELOPE", "GUINESS"). Values(2, "NICK", "WAHLBERG"). Values(3, "ED", "CHASE"). OnConflict(a.ACTOR_ID).DoUpdateSet( a.FIRST_NAME.Set(a.FIRST_NAME.WithPrefix("EXCLUDED")), a.LAST_NAME.Set(a.LAST_NAME.WithPrefix("EXLCUDED")), ), ) ``` #### Update with Join #postgres-update-with-join ```sql UPDATE actor SET last_name = 'DINO' FROM film_actor JOIN film ON film.film_id = film_actor.film_id WHERE film_actor.actor_id = actor.actor_id AND film.title = 'ACADEMY DINOSAUR' ``` ```go a, fa, f := sq.New[ACTOR](""), sq.New[FILM_ACTOR](""), sq.New[FILM]("") _, err := sq.Exec(db, sq.Postgres. Update(a). Set(a.LAST_NAME.SetString("DINO")). From(fa). Join(f, f.FILM_ID.Eq(fa.FILM_ID)). Where( fa.ACTOR_ID.Eq(a.ACTOR_ID), f.TITLE.EqString("ACADEMY DINOSAUR"), ), ) ``` #### Delete with Join #postgres-delete-with-join ```sql DELETE FROM actor USING film_actor JOIN film ON film.film_id = film_actor.film_id WHERE film_actor.actor_id = actor.actor_id AND film.title = 'ACADEMY DINOSAUR' ``` ```go a, fa, f := sq.New[ACTOR](""), sq.New[FILM_ACTOR](""), sq.New[FILM]("") _, err := sq.Exec(db, sq.Postgres. DeleteFrom(a). Using(fa). Join(f, f.FILM_ID.Eq(fa.FILM_ID)). Where( fa.ACTOR_ID.Eq(a.ACTOR_ID), f.TITLE.EqString("ACADEMY DINOSAUR"), ), ) ``` #### Bulk Update #postgres-bulk-update ```sql UPDATE actor SET first_name = tmp.first_name, last_name = tmp.last_name FROM (VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ) AS tmp (actor_id, first_name, last_name) WHERE tmp.actor_id = actor.actor_id ``` ```go a := sq.New[ACTOR]("") tmp := sq.TableValues{ Alias: "tmp", Columns: []string{"actor_id", "first_name", "last_name"}, RowValues: [][]any{ {1, "PENELOPE", "GUINESS"}, {2, "NICK", "WAHLBERG"}, {3, "ED", "CHASE"}, }, } _, err := sq.Exec(db, sq.Postgres. Update(a). Set( a.FIRST_NAME.Set(tmp.Field("first_name")), a.LAST_NAME.Set(tmp.Field("last_name")), ). From(tmp). Where(tmp.Field("actor_id").Eq(a.ACTOR_ID)), ) ``` ### MySQL-specific features #mysql-specific-features #### FOR UPDATE, FOR SHARE #mysql-for-update-for-share **For Update** ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.first_name = 'DAN' FOR UPDATE SKIP LOCKED ``` ```go actors, err := sq.FetchAll(db, sq.MySQL. From(a). Where(a.FIRST_NAME.EqString("DAN")). LockRows("FOR UPDATE SKIP LOCKED"), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` **For Share** ```sql SELECT a.actor_id, a.first_name, a.last_name FROM actor AS a WHERE a.first_name = 'DAN' FOR SHARE ``` ```go actors, err := sq.FetchAll(db, sq.MySQL. From(a). Where(a.FIRST_NAME.EqString("DAN")). LockRows("FOR SHARE"), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` #### LastInsertId #mysql-last-insert-id ```sql INSERT INTO actor (first_name, last_name) VALUES ('PENELOPE', 'GUINESS'); SELECT last_insert_id(); ``` ```go a := sq.New[ACTOR]("") result, err := sq.Exec(db, sq.MySQL. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("PENELOPE", "GUINESS"), ) if err != nil { } fmt.Println(result.LastInsertId) // int64 ``` #### Insert ignore duplicates #mysql-insert-ignore-duplicates **ON DUPLICATE KEY UPDATE field = field** MySQL lacks ON DUPLICATE KEY DO NOTHING but assigning a field to itself is the closest thing we can get. If a field is assigned to itself, MySQL doesn't actually trigger an update (making it do nothing). ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ON DUPLICATE KEY UPDATE actor.actor_id = actor.actor_id ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq.MySQL. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(1, "PENELOPE", "GUINESS"). Values(2, "NICK", "WAHLBERG"). Values(3, "ED", "CHASE"). OnDuplicateKeyUpdate( a.ACTOR_ID.Set(a.ACTOR_ID), ), ) ``` **INSERT IGNORE** INSERT IGNORE will ignore all kinds of errors (such as foreign key violations) so use only if you really, really don't care if an INSERT fails. ```sql INSERT IGNORE INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq.MySQL. InsertIgnoreInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(1, "PENELOPE", "GUINESS"). Values(2, "NICK", "WAHLBERG"). Values(3, "ED", "CHASE"), ) ``` #### Upsert #mysql-upsert **Row Alias (MySQL 8.0+ onwards)** ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') AS new ON DUPLICATE KEY UPDATE actor.first_name = new.first_name, actor.last_name = new.last_name ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq.MySQL. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(1, "PENELOPE", "GUINESS"). Values(2, "NICK", "WAHLBERG"). Values(3, "ED", "CHASE"). As("new"). OnDuplicateKeyUpdate( a.FIRST_NAME.Set(a.FIRST_NAME.WithPrefix("new")), a.LAST_NAME.Set(a.LAST_NAME.WithPrefix("new")), ), ) ``` **VALUES()** ```sql INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ON DUPLICATE KEY UPDATE actor.first_name = VALUES(first_name), actor.last_name = VALUES(last_name) ``` ```go a := sq.New[ACTOR]("") _, err := sq.Exec(db, sq.MySQL. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Values(1, "PENELOPE", "GUINESS"). Values(2, "NICK", "WAHLBERG"). Values(3, "ED", "CHASE"). OnDuplicateKeyUpdate( a.FIRST_NAME.Setf("VALUES({})", a.FIRST_NAME.WithPrefix("")), a.LAST_NAME.Setf("VALUES({})", a.LAST_NAME.WithPrefix("")), ), ) ``` #### Update with Join #mysql-update-with-join ```sql UPDATE actor JOIN film_actor ON film_actor.actor_id = actor.actor_id JOIN film ON film.film_id = film_actor.film_id SET actor.last_name = 'DINO' WHERE film.title = 'ACADEMY DINOSAUR' ``` ```go a, fa, f := sq.New[ACTOR](""), sq.New[FILM_ACTOR](""), sq.New[FILM]("") _, err := sq.Exec(db, sq.MySQL. Update(a). Join(fa, fa.ACTOR_ID.Eq(a.ACTOR_ID)). Join(f, f.FILM_ID.Eq(fa.FILM_ID)). Set(a.LAST_NAME.SetString("DINO")). Where(f.TITLE.EqString("ACADEMY DINOSAUR")), ) ``` #### Delete with Join #mysql-delete-with-join ```sql DELETE actor FROM actor JOIN film_actor ON film_actor.actor_id = actor.actor_id JOIN film ON film.film_id = film_actor.film_id WHERE film.title = 'ACADEMY DINOSAUR' ``` ```go a, fa, f := sq.New[ACTOR](""), sq.New[FILM_ACTOR](""), sq.New[FILM]("") _, err := sq.Exec(db, sq.MySQL. Delete(a). From(a) Join(fa, fa.ACTOR_ID.Eq(a.ACTOR_ID)). Join(f, f.FILM_ID.Eq(fa.FILM_ID)). Where(f.TITLE.EqString("ACADEMY DINOSAUR")), ) ``` #### Bulk Update #mysql-bulk-update ```sql UPDATE actor JOIN (VALUES ROW(1, 'PENELOPE', 'GUINESS'), ROW(2, 'NICK', 'WAHLBERG'), ROW(3, 'ED', 'CHASE') ) AS tmp (actor_id, first_name, last_name) ON tmp.actor_id = actor.actor_id SET first_name = tmp.first_name, last_name = tmp.last_name ``` ```go a := sq.New[ACTOR]("") tmp := sq.TableValues{ Alias: "tmp", Columns: []string{"actor_id", "first_name", "last_name"}, RowValues: [][]any{ {1, "PENELOPE", "GUINESS"}, {2, "NICK", "WAHLBERG"}, {3, "ED", "CHASE"}, }, } _, err := sq.Exec(db, sq.MySQL. Update(a). Join(tmp, tmp.Field("actor_id").Eq(a.ACTOR_ID)). Set( a.FIRST_NAME.Set(tmp.Field("first_name")), a.LAST_NAME.Set(tmp.Field("last_name")), ), ) ``` ### SQLServer-specific features #sqlserver-specific-features #### TOP, WITH TIES #sqlserver-top-with-ties ```sql SELECT TOP 10 WITH TIES a.first_name FROM actor AS a ``` ```go a := sq.New[ACTOR]("a") firstNames, err := sq.FetchAll(db, sq.SQLServer. From(a). Top(10).WithTies(), func(row *sq.Row) string { return row.String(a.FIRST_NAME) }, ) ``` #### OUTPUT #sqlserver-output ```sql INSERT INTO actor (first_name, last_name) OUTPUT INSERTED.actor_id, INSERTED.first_name, INSERTED.last_name VALUES ('PENELOPE', 'GUINESS'), ('NICK', 'WAHLBERG'), ('ED', 'CHASE') ``` ```go a := sq.New[ACTOR]("") actors, err := sq.FetchAll(db, sq.SQLServer. InsertInto(a). Columns(a.FIRST_NAME, a.LAST_NAME). Values("PENELOPE", "GUINESS"). Values("NICK", "WAHLBERG"). Values("ED", "CHASE"), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) ``` **INSERTED.* vs DELETED.*** - For Insert queries, OUTPUT fields to use the INSERTED.\* prefix. - For Delete queries, OUTPUT fields use the DELETED.\* prefix. - For Update queries, OUTPUT fields use the INSERTED.\* prefix. Technically both INSERTED.\* and DELETED.\* fields are supported for Update queries, but sq only supports INSERTED.\* because that is how RETURNING behaves in SQLite and Postgres. #### Insert ignore duplicates #sqlserver-insert-ignore-duplicates This is technically not an SQL Server-specific feature as SQL Server completely does not support this. You have to employ a workaround using INSERT with SELECT ([https://stackoverflow.com/a/10703792](https://stackoverflow.com/a/10703792)). I'm including the workaround here for completion's sake. ```sql -- Insert rows that don't exist. INSERT INTO actor (actor_id, first_name, last_name) SELECT actor_id, first_name, last_name FROM ( VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ) AS rowvalues (actor_id, first_name, last_name) WHERE NOT EXISTS ( SELECT 1 FROM actor WHERE actor.actor_id = rowvalues.actor_id ) ``` ```go a := sq.New[ACTOR]("") // Insert rows that don't exist. _, err := sq.Exec(db, sq.SQLServer. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Select(sq.Queryf("SELECT actor_id, first_name, last_name"+ "FROM (VALUES {}) AS rowvalues (actor_id, first_name, last_name)"+ "WHERE NOT EXISTS (SELECT 1 FROM actor WHERE actor.actor_id = rowvalues.actor_id)", sq.RowValues{ {1, "PENELOPE", "GUINESS"}, {2, "NICK", "WAHLBERG"}, {3, "ED", "CHASE"}, }, )), ) ``` #### Upsert #sqlserver-upsert This is technically not an SQL Server-specific feature as SQL Server does not support this. You have to employ a 2-step workaround using an UPDATE with JOIN + an INSERT with SELECT ([https://sqlperformance.com/2020/09/locking/upsert-anti-pattern](https://sqlperformance.com/2020/09/locking/upsert-anti-pattern)). I'm including the workaround here for completion's sake. Avoid using MERGE for upserting. - [https://www.mssqltips.com/sqlservertip/3074/use-caution-with-sql-servers-merge-statement/](https://www.mssqltips.com/sqlservertip/3074/use-caution-with-sql-servers-merge-statement/) - [https://michaeljswart.com/2021/08/what-to-avoid-if-you-want-to-use-merge/](https://michaeljswart.com/2021/08/what-to-avoid-if-you-want-to-use-merge/) ```sql -- Update rows that exist. UPDATE actor SET first_name = rowvalues.first_name, last_name = rowvalues.last_name FROM actor JOIN (VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ) AS rowvalues (actor_id, first_name, last_name) ON rowvalues.actor_id = actor.actor_id; -- Insert rows that don't exist. INSERT INTO actor (actor_id, first_name, last_name) SELECT actor_id, first_name, last_name FROM (VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ) AS rowvalues (actor_id, first_name, last_name) WHERE NOT EXISTS ( SELECT 1 FROM actor WHERE actor.actor_id = rowvalues.actor_id ); ``` ```go a := sq.New[ACTOR]("") // Update rows that exist. _, err := sq.Exec(db, sq.SQLServer. Update(a). Set( a.FIRST_NAME.Setf("rowvalues.first_name"), a.LAST_NAME.Setf("rowvalues.last_name"), ). From(a). Join(sq. Queryf("VALUES {}", sq.RowValues{ {1, "PENELOPE", "GUINESS"}, {2, "NICK", "WAHLBERG"}, {3, "ED", "CHASE"}, }). As("rowvalues (actor_id, first_name, last_name)"), sq.Expr("rowvalues.actor_id").Eq(a.ACTOR_ID), ) ) // Insert rows that don't exist. _, err := sq.Exec(db, sq.SQLServer. InsertInto(a). Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). Select(sq.Queryf("SELECT actor_id, first_name, last_name"+ "FROM (VALUES {}) AS rowvalues (actor_id, first_name, last_name)"+ "WHERE NOT EXISTS (SELECT 1 FROM actor WHERE actor.actor_id = rowvalues.actor_id)", sq.RowValues{ {1, "PENELOPE", "GUINESS"}, {2, "NICK", "WAHLBERG"}, {3, "ED", "CHASE"}, }, )), ) ``` #### Update with Join #sqlserver-update-with-join ```sql UPDATE actor SET last_name = 'DINO' FROM actor JOIN film_actor ON film_actor.actor_id = actor.actor_id JOIN film ON film.film_id = film_actor.film_id WHERE film.title = 'ACADEMY DINOSAUR' ``` ```go a, fa, f := sq.New[ACTOR](""), sq.New[FILM_ACTOR](""), sq.New[FILM]("") _, err := sq.Exec(db, sq.SQLServer. Update(a). Set(a.LAST_NAME.SetString("DINO")). From(a). Join(fa, fa.ACTOR_ID.Eq(a.ACTOR_ID)). Join(f, f.FILM_ID.Eq(fa.FILM_ID)). Where(f.TITLE.EqString("ACADEMY DINOSAUR")), ) ``` #### Delete with Join #sqlserver-delete-with-join ```sql DELETE actor FROM actor JOIN film_actor ON film_actor.actor_id = actor.actor_id JOIN film ON film.film_id = film_actor.film_id WHERE film.title = 'ACADEMY DINOSAUR' ``` ```go a, fa, f := sq.New[ACTOR](""), sq.New[FILM_ACTOR](""), sq.New[FILM]("") _, err := sq.Exec(db, sq.SQLServer. Delete(a). From(a) Join(fa, fa.ACTOR_ID.Eq(a.ACTOR_ID)). Join(f, f.FILM_ID.Eq(fa.FILM_ID)). Where(f.TITLE.EqString("ACADEMY DINOSAUR")), ) ``` #### Bulk Update #sqlserver-bulk-update ```sql UPDATE actor SET first_name = tmp.first_name, last_name = tmp.last_name FROM actor JOIN (VALUES (1, 'PENELOPE', 'GUINESS'), (2, 'NICK', 'WAHLBERG'), (3, 'ED', 'CHASE') ) AS tmp (actor_id, first_name, last_name) ON tmp.actor_id = actor.actor_id ``` ```go a := sq.New[ACTOR]("") tmp := sq.TableValues{ Alias: "tmp", Columns: []string{"actor_id", "first_name", "last_name"}, RowValues: [][]any{ {1, "PENELOPE", "GUINESS"}, {2, "NICK", "WAHLBERG"}, {3, "ED", "CHASE"}, }, } _, err := sq.Exec(db, sq.SQLServer. Update(a). Set( a.FIRST_NAME.Set(tmp.Field("first_name")), a.LAST_NAME.Set(tmp.Field("last_name")), ). From(a) Join(tmp, tmp.Field("actor_id").Eq(a.ACTOR_ID)), ) ``` ## Working with arrays, enums, JSON and UUID #arrays-enums-json-uuid ### Arrays #arrays Slices of primitive types (`[]string`, `[]int64`, `[]int32`, `[]float64`, `[]float32`, `[]bool`) can be saved into the database. For Postgres, it will be saved as an ARRAY (TEXT[], INT[], BIGINT[], NUMERIC[] or BOOLEAN[]). For other databases, it will be saved as a JSON array. **Writing arrays** ```go // Raw SQL _, err := sq.Exec(db, sq. Queryf("INSERT INTO posts (title, body, tags) VALUES {}", sq.RowValue{ "Hello World!", "This is my first blog post.", sq.ArrayValue([]string{"introduction", "hello-world", "meta"}), }). SetDialect(sq.DialectPostgres), ) // Query Builder p := sq.New[POSTS]("") _, err := sq.Exec(db, sq. InsertInto(p). ColumnValues(func(col *sq.Column) { col.SetString(p.TITLE, "Hello World!") col.SetString(p.BODY, "This is my first blog post.") col.SetArray(p.TAGS, []string{"introduction", "hello-world", "meta"}) }). SetDialect(sq.DialectPostgres), ) ``` **Reading arrays** ```go // Raw SQL posts, err := sq.FetchAll(db, sq. Queryf("SELECT {*} FROM posts WHERE post_id IN ({})", []int{1, 2, 3}). SetDialect(sq.DialectPostgres), func(row *sq.Row) Post { var post Post post.Title = row.String("title") post.Body = row.String("body") row.Array(&post.Tags, "tags") return post }, ) // Query Builder p := sq.New[POSTS]("") posts, err := sq.FetchAll(db, sq. From(p). Where(p.POST_ID.In([]int{1, 2, 3})). SetDialect(sq.DialectPostgres), func(row *sq.Row) Post { var post Post post.Title = row.StringField(p.TITLE) post.Body = row.StringField(p.BODY) row.ArrayField(&post.Tags, p.TAGS) return post }, ) ``` ### Enums #enums A Go type is considered an enum if it implements the `Enumeration` interface: ```go type Enumeration interface{ Enumerate() []string } ``` As an example, this is how an int-based enum and a string-based enum would be implemented: ```go type Color int const ( ColorInvalid Color = iota ColorRed ColorGreen ColorBlue ) var colorNames = [...]string{ ColorInvalid: "", ColorRed: "red", ColorGreen: "green", ColorBlue: "blue", } func (c Color) Enumerate() []string { return colorNames[:] } ``` ```go type Direction string const ( DirectionInvalid = Direction("") DirectionNorth = Direction("north") DirectionSouth = Direction("south") DirectionEast = Direction("east") DirectionWest = Direction("west") ) func (d Direction) Enumerate() []string { return []string{ string(DirectionInvalid), string(DirectionNorth), string(DirectionSouth), string(DirectionEast), string(DirectionWest), } } ``` By implementing the `Enumeration` interface, you automatically get enum type validation when writing enums to and reading enums from the database. - If you try to write an enum value to the database that isn't present in the `Enumerate()` slice, it will be flagged as an error. - If the database returns an enum value that isn't present in the `Enumerate()` slice, it will be flagged as an error. **Writing enums** ```go // Raw SQL _, err := sq.Exec(db, sq. Queryf("INSERT INTO fruits (name, color) VALUES {}", sq.RowValue{ "apple", sq.EnumValue(ColorRed), }). SetDialect(sq.DialectPostgres), ) // Query Builder f := sq.New[FRUITS]("") _, err := sq.Exec(db, sq. InsertInto(f). ColumnValues(func(col *sq.Column) { col.SetString(f.NAME, "apple") col.SetEnum(f.COLOR, ColorRed) }). SetDialect(sq.DialectPostgres), ) ``` **Reading enums** ```go // Raw SQL fruits, err := sq.FetchAll(db, sq. Queryf("SELECT {*} FROM fruits WHERE fruit_id IN ({})", []int{1, 2, 3}). SetDialect(sq.DialectPostgres), func(row *sq.Row) Fruit { var fruit Fruit fruit.Name = row.String("name") row.Enum(&fruit.Color, "color") return fruit }, ) // Query Builder f := sq.New[FRUITS]("") posts, err := sq.FetchAll(db, sq. From(f). Where(f.FRUIT_ID.In([]int{1, 2, 3})). SetDialect(sq.DialectPostgres), func(row *sq.Row) Fruit { var fruit Fruit fruit.Name = row.StringField(f.NAME) row.EnumField(&fruit.Color, f.COLOR) return fruit }, ) ``` ### JSON #json Any Go type that works with `json.Marshal` and `json.Unmarshal` can be saved into the database. For Postgres, it will be saved as JSONB. For MySQL, it will be saved as JSON. For other databases, it will be saved as a JSON string. **Writing JSON** ```go // Raw SQL _, err := sq.Exec(db, sq. Queryf("INSERT INTO products (name, price, attributes) VALUES {}", sq.RowValue{ "Sleeping Bag", 89.99, sq.JSONValue(map[string]any{ "Length (cm)": 220, "Width (cm)": 150, "Weight (kg)": 2.96, "Color": "Lake Blue", "Fill Material": "190T Pongee", "Outer Material": "Polyester", }), }). SetDialect(sq.DialectPostgres), ) // Query Builder p := sq.New[PRODUCTS]("") _, err := sq.Exec(db, sq. InsertInto(p). ColumnValues(func(col *sq.Column) { col.SetString(p.NAME, "Sleeping Bag") col.SetFloat64(p.PRICE, 89.99) col.SetJSON(p.ATTRIBUTES, map[string]any{ "Length (cm)": 220, "Width (cm)": 150, "Weight (kg)": 2.96, "Color": "Lake Blue", "Fill Material": "190T Pongee", "Outer Material": "Polyester", }) }). SetDialect(sq.DialectPostgres), ) ``` **Reading JSON** ```go // Raw SQL products, err := sq.FetchAll(db, sq. Queryf("SELECT {*} FROM products WHERE product_id IN ({})", []int{1, 2, 3}). SetDialect(sq.DialectPostgres), func(row *sq.Row) Product { var product Product product.Name = row.String("name") product.Price = row.Float64("price") row.JSON(&product.Attributes, "attributes") return product }, ) // Query Builder p := sq.New[PRODUCTS]("") posts, err := sq.FetchAll(db, sq. From(p). Where(p.PRODUCT_ID.In([]int{1, 2, 3})). SetDialect(sq.DialectPostgres), func(row *sq.Row) Product { var product Product product.Name = row.StringField(p.NAME) product.Price = row.Float64Field(p.PRICE) row.JSONField(&product.Attributes, p.ATTRIBUTES) return product }, ) ``` ### UUID #uuid Any Go type whose underlying type is `[16]byte` can be saved as a UUID into the database. For Postgres, it will be saved as UUID. For other databases, it will be saved as a BINARY(16). It is likely that the Go UUID library you are using already implements sql.Scanner and driver.Valuer (e.g. [github.com/google/uuid](https://github.com/google/uuid)). You can choose to rely on their built-in SQL behaviour: - Instead of wrapping the uuid in sq.UUIDValue(), just use the uuid directly. - Instead of calling col.SetUUID(), just call col.Set(). - Instead of calling row.UUID()/row.UUIDField(), just call row.Scan()/row.ScanField(). The main benefit of using this library's built-in UUID helpers is to have UUID reading/writing work identically across database dialects: for Postgres, if you want to save a UUID you must give it a UUID string. For other databases, if you want to save a UUID as a BINARY(16) you must give it raw UUID bytes. Using this library's UUID helpers means you don't have to manually account for this UUID string/bytes disparity between Postgres and the other DBs. **Writing UUID** ```go userID, err := uuid.Parse("d619cde3-7661-4b6e-928e-4d5b239a18a9") if err != nil { } // Raw SQL _, err = sq.Exec(db, sq. Queryf("INSERT INTO users (user_id, name, email) VALUES {}", sq.RowValue{ sq.UUIDValue(userID), "John Doe", "john_doe@email.com", }). SetDialect(sq.DialectPostgres), ) // Query Builder u := sq.New[USERS]("") _, err := sq.Exec(db, sq. InsertInto(u). ColumnValues(func(col *sq.Column) { col.SetUUID(u.USER_ID, userID) col.SetString(u.NAME, "John Doe") col.SetString(u.EMAIL, "john_doe@email.com") }). SetDialect(sq.DialectPostgres), ) ``` **Reading UUID** ```go // Raw SQL users, err := sq.FetchAll(db, sq. Queryf("SELECT {*} FROM users WHERE email IS NOT NULL"). SetDialect(sq.DialectPostgres), func(row *sq.Row) User { var user User row.UUID(&user.UserID, "user_id") user.Name = row.String("name") user.Email = row.String("email") return user }, ) // Query Builder u := sq.New[USERS]("") posts, err := sq.FetchAll(db, sq. From(u). Where(u.EMAIL.IsNotNull()). SetDialect(sq.DialectPostgres), func(row *sq.Row) User { var user User row.UUIDField(&user.UserID, u.USER_ID) user.Name = row.StringField(u.NAME) user.Email = row.StringField(u.EMAIL) return user }, ) ``` ## Logging #logging Queries can be logged wrapping the database with `sq.Log()` or `sq.VerboseLog()`. **sq.Log()** ```go // With logging ↓ wrap the db firstName, err := sq.FetchOne(sq.Log(db), sq. Queryf("SELECT {*} FROM actor WHERE last_name IN ({})", []string{"AKROYD", "ALLEN", "WILLIAMS"}), func(row *sq.Row) string { return row.String("first_name") }, ) ``` ```shell 2022/02/06 15:34:36 [OK] SELECT first_name FROM actor WHERE last_name IN (?, ?, ?) | timeTaken=9.834µs rowCount=9 caller=/Users/bokwoon/Documents/sq/fetch_exec_test.go:74:sq.TestFetchExec ``` **sq.VerboseLog()** ```go // With verbose logging ↓ wrap the db firstName, err := sq.FetchOne(sq.VerboseLog(db), sq. Queryf("SELECT {*} FROM actor WHERE last_name IN ({})", []string{"AKROYD", "ALLEN", "WILLIAMS"}), func(row *sq.Row) string { return row.String("first_name") }, ) ``` ```shell 2022/02/06 15:34:36 [OK] timeTaken=9.834µs rowCount=9 caller=/Users/bokwoon/Documents/sq/fetch_exec_test.go:74:sq.TestFetchExec ----[ Executing query ]---- SELECT first_name FROM actor WHERE last_name IN (?, ?, ?) []interface {}{"AKROYD", "ALLEN", "WILLIAMS"} ----[ with bind values ]---- SELECT first_name FROM actor WHERE last_name IN ('AKROYD', 'ALLEN', 'WILLIAMS') ----[ Fetched result ]---- ----[ Row 1 ]---- first_name: 'CHRISTIAN' ----[ Row 2 ]---- first_name: 'SEAN' ----[ Row 3 ]---- first_name: 'KIRSTEN' ----[ Row 4 ]---- first_name: 'CUBA' ----[ Row 5 ]---- first_name: 'MORGAN' ... (Fetched 9 rows) ``` ### Logging without manual sq.Log() wrapping #logging-without-manual-wrapping To log every query without manually wrapping it in sq.Log(), set the global logger using SetDefaultLogQuery(). It takes in a callback function which is called everytime a query is called (if no logger was explicitly provided to FetchOne, FetchAll, Exec, etc). ```go func init() { logger := sq.NewLogger(os.Stdout, "", log.LstdFlags, sq.LoggerConfig{ ShowTimeTaken: true, HideArgs: true, }) sq.SetDefaultLogQuery(func(ctx context.Context, queryStats sq.QueryStats) { // You can choose to only log queries if they encountered an error. // if queryStats.Err == nil { // return // } logger.SqLogQuery(ctx, queryStats) }) } ``` ### Custom logger #custom-logger A custom logger can also be used by creating [custom DB type that implements the `SqLogger` interface](#logging-without-manual-wrapping). The logging information is passed in as a `QueryStats` struct, which you can feed into the structured logger of your choice. ```go // QueryStats represents the statistics from running a query. type QueryStats struct { // Dialect of the query. Dialect string // Query string. Query string // Args slice provided with the query string. Args []any // Params maps param names back to arguments in the args slice (by index). Params map[string][]int // Err is the error from running the query. Err error // RowCount from running the query. Not valid for Exec(). RowCount sql.NullInt64 // RowsAffected by running the query. Not valid for // FetchOne/FetchAll/FetchCursor. RowsAffected sql.NullInt64 // LastInsertId of the query. LastInsertId sql.NullInt64 // Exists is the result of FetchExists(). Exists sql.NullBool // When the query started at. StartedAt time.Time // Time taken by the query. TimeTaken time.Duration // The caller file where the query was invoked. CallerFile string // The line in the caller file that invoked the query. CallerLine int // The name of the function where the query was invoked. CallerFunction string // The results from running the query (if it was provided). Results string } ``` As an example, we will create a custom database logger that outputs JSON and only logs if the query took longer than 1 second. ```go type MyDB struct { *sql.DB } func (myDB MyDB) SqLogSettings(ctx context.Context, settings *sq.LogSettings) { settings.LogAsynchronously = false // Should the logging be dispatched in a separate goroutine? settings.IncludeTime = true // Should timeTaken be included in the QueryStats? settings.IncludeCaller = true // Should caller info be included in the QueryStats? settings.IncludeResults = 0 // The first how many rows of results should be included? Leave 0 to not include any results. } func (myDB MyDB) SqLogQuery(ctx context.Context, stats sq.QueryStats) { if stats.TimeTaken < time.Second { return } output := map[string]any{ "query": stats.Query, "args": stats.Args, "caller": stats.CallerFile + ":" + strconv.Itoa(stats.CallerLine) "timeTaken": stats.TimeTaken.String(), } b, err := json.MarshalIndent(output, "", " ") if err != nil { log.Println(err.Error()) return } log.Println("TOO SLOW! " + string(b)) } ``` ```shell 2022/02/06 15:34:36 TOO SLOW! { "args": [ 1 ], "caller": "/Users/bokwoon/Documents/sq/fetch_exec_test.go:74", "query": "SELECT actor_id, first_name, last_name FROM actor WHERE actor_id = ?", "timeTaken": "1.534s" } ``` ## Working with transactions #transactions Fetch() and Exec() both accept an sq.DB interface, which represents something that can query the database. ```go // *sql.Conn, *sql.DB and *sql.Tx all implement DB. type DB interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) } ``` To use an \*sql.Tx (or an \*sql.Conn), you can pass it in like a normal \*sql.DB. ```go tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) if err != nil { return err } // good practice defer tx.Rollback first, if tx.Commit is called then this becomes a no-op. defer tx.Rollback() // do operation 1 _, err = sq.Exec(tx, q1) if err != nil { return err } // do operation 2 _, err = sq.Exec(tx, q2) if err != nil { return err } // do operation 3 _, err = sq.Exec(tx, q3) if err != nil { return err } // If all goes well, commit. If anything wrong happened before reaching here we // just bail and let defer tx.Rollback() kick in err = tx.Commit() if err != nil { return err } // if we reach here, success ``` ## Compiling queries #compiling-queries The cost of query building can be amortized by compiling queries down into a query string and args slice. Compiled queries are reused by supplying a different set of parameters each time you execute them. They can be executed safely in parallel. ```go // Compile the query. compiledQuery, err := sq.CompileFetch(sq. Queryf("SELECT {*} FROM actor WHERE first_name = {first_name}, last_name = {last_name}", sql.Named("first_name", nil), // first_name is a rebindable param, with default value nil sql.Named("last_name", nil), // last_name is a rebindable param, with default value nil ). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.Int("actor_id"), FirstName: row.String("first_name"), LastName: row.String("last_name"), } }, ) if err != nil { } // Obtain the query string and args slice back from the CompiledFetch. // The params map and rowmapper function are also available. query, args, params, rowmapper := compiledQuery.GetSQL() // Execute the compiled query with the default values. actor, err := compiledQuery.FetchOne(db, nil) if err != nil { } // Execute the compiled query with values first_name = "DAN", last_name = "TORN". actor, err := compiledQuery.FetchOne(db, sq.Params{ "first_name": "DAN", "last_name": "TORN", }) if err != nil { } ``` ### Rebindable params #rebindable-params Only [named parameters](#ordinal-named-placeholders) can be rebinded in a compiled query, which means they must be provided during the query building phase. ```go // WRONG: actor_id cannot be rebinded. compiledQuery, err := sq.CompileFetch( sq.Queryf("SELECT {*} FROM actor WHERE actor_id = {}", 1), func(row *sq.Row) Actor { return Actor{ FirstName: row.String("first_name"), LastName: row.String("last_name"), } }, ) if err != nil { } // ERROR: named parameter {actorID} not provided actor, err := compiledQuery.FetchOne(db, sq.Params{"actorID": 2}) if err != nil { } // CORRECT: actor_id can be rebinded (using "actorID"). compiledQuery, err := sq.CompileFetch( sq.Queryf("SELECT {*} FROM actor WHERE actor_id = {actorID}", sql.Named("actorID", 1)), func(row *sq.Row) Actor { return Actor{ FirstName: row.String("first_name"), LastName: row.String("last_name"), } }, ) if err != nil { } actor, err := compiledQuery.FetchOne(db, sq.Params{"actorID": 2}) if err != nil { } ``` Most of the time you should use [sql.Named()](https://pkg.go.dev/database/sql#Named), but if you need to conform to various interfaces like [String](https://pkg.go.dev/github.com/bokwoon95/sq#String) or [Number](https://pkg.go.dev/github.com/bokwoon95/sq#Number) you can use the typed versions [sq.StringParam()](https://pkg.go.dev/github.com/bokwoon95/sq#StringParam) or [sq.IntParam()](https://pkg.go.dev/github.com/bokwoon95/sq#IntParam).
Parameter Description
sql.Named(name string, value any) database/sql's named parameter type
sq.Param(name string, value any) same as sql.Named, but satisfies the Field interface
sq.BinaryParam(name string, b []byte) same as sql.Named, but satisfies the Binary interface
sq.BooleanParam(name string, b bool) same as sql.Named, but satisfies the Boolean interface
sq.IntParam(name string, num int) same as sql.Named, but satisfies the Number interface
sq.Int64Param(name string, num int64) same as sql.Named, but satisfies the Number interface
sq.Float64Param(name string, num float64) same as sql.Named, but satisfies the Number interface
sq.StringParam(name string, s string) same as sql.Named, but satisfies the String interface
sq.TimeParam(name string, t time.Time) same as sql.Named, but satisfies the Time interface
### CompiledFetch example #compiled-fetch ```go type ACTOR struct { sq.TableStruct ACTOR_ID sq.NumberField FIRST_NAME sq.StringField LAST_NAME sq.StringField LAST_UPDATE sq.TimeField } a := sq.New[ACTOR]("") compiledQuery, err := sq.CompileFetch(sq. From(a). Where(a.ACTOR_ID.Eq(sq.IntParam("actor_id", 0))). // actor_id is a rebindable param, with default value 0 SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) if err != nil { } actor, err := compiledQuery.FetchOne(db, sq.Params{"actor_id": 1}) fmt.Println(actor) // {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS"} actor, err = compiledQuery.FetchOne(db, sq.Params{"actor_id": 2}) fmt.Println(actor) // {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG"} actor, err = compiledQuery.FetchOne(db, sq.Params{"actor_id": 3}) fmt.Println(actor) // {ActorID: 3, FirstName: "ED", LastName: "CHASE"} ``` ### CompiledExec example #compiled-exec ```go type ACTOR struct { sq.TableStruct ACTOR_ID sq.NumberField FIRST_NAME sq.StringField LAST_NAME sq.StringField LAST_UPDATE sq.TimeField } a := sq.New[ACTOR]("") compiledQuery, err = sq.CompileExec(sq. InsertInto(a). ColumnValues(func(col *sq.Column) { col.Set(a.ACTOR_ID, sql.Named("actor_id", nil)) // actor_id is a rebindable param, with default value nil col.Set(a.FIRST_NAME, sql.Named("first_name", nil)) // first_name is a rebindable param, with default value nil col.Set(a.LAST_NAME, sql.Named("last_name", nil)) // last_name is a rebindable param, with default value nil return nil }). SetDialect(sq.DialectPostgres), ) if err != nil { } _, err := compiledQuery.Exec(db, sq.Params{ "actor_id": 1, "first_name": "PENELOPE", "last_name": "GUINESS", }) // INSERT INTO actor (actor_id, first_name, last_name) VALUES (1, 'PENELOPE', 'GUINESS') _, err = compiledQuery.Exec(db, sq.Params{ "actor_id": 2, "first_name": "NICK", "last_name": "WAHLBERG", }) // INSERT INTO actor (actor_id, first_name, last_name) VALUES (2, 'NICK', 'WAHLBERG') _, err = compiledQuery.Exec(db, sq.Params{ "actor_id": 3, "first_name": "ED", "last_name": "CHASE", }) // INSERT INTO actor (actor_id, first_name, last_name) VALUES (3, 'ED', 'CHASE') ``` ### Preparing queries #preparing-queries [Compiled queries](#compiling-queries) can be further prepared by binding it to a database connection (creating a prepared statement). ```go type ACTOR struct { sq.TableStruct ACTOR_ID sq.NumberField FIRST_NAME sq.StringField LAST_NAME sq.StringField LAST_UPDATE sq.TimeField } // Compile the query. a := sq.New[ACTOR]("") compiledQuery, err := sq.CompileFetch(sq. From(a). Where(a.ACTOR_ID.Eq(sq.IntParam("actor_id", 0))). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.IntField(a.ACTOR_ID), FirstName: row.StringField(a.FIRST_NAME), LastName: row.StringField(a.LAST_NAME), } }, ) if err != nil { } // Prepare the compiled query. preparedQuery, err := compiledQuery.Prepare(db) if err != nil { } // Use the prepared query with default values. actor, err := preparedQuery.FetchOne(nil) if err != nil { } // Use the prepared query with values actor_id = 1. actor, err = preparedQuery.FetchOne(sq.Params{"actor_id": 1}) if err != nil { } ``` Alternatively, you may directly prepare PreparedQuery directly with PrepareFetch. ```go // Prepare the query. preparedQuery, err := sq.PrepareFetch(db, sq. Queryf("SELECT {*} FROM actor WHERE first_name = {first_name}, last_name = {last_name}", sql.Named("first_name", nil), // first_name is a rebindable param, with default value nil sql.Named("last_name", nil), // last_name is a rebindable param, with default value nil ). SetDialect(sq.DialectPostgres), func(row *sq.Row) Actor { return Actor{ ActorID: row.Int("actor_id"), FirstName: row.String("first_name"), LastName: row.String("last_name"), } }, ) if err != nil { } // Obtain a CompiledFetch from the PreparedFetch. This is useful if you need to // re-prepare the query on another DB connection. compiledQuery := preparedQuery.GetCompiled() // Example: preparedQuery, err = compiledQuery.Prepare(db2) // Execute the prepared query with the default values. actor, err := preparedQuery.FetchOne(nil) if err != nil { } // Execute the prepared query with values first_name = "DAN", last_name = "TORN". actor, err := preparedQuery.FetchOne(sq.Params{ "first_name": "DAN", "last_name": "TORN", }) if err != nil { } ``` ## Application-side Row Level Security #appliction-side-row-level-security You can define policies on your table structs such that whenever it is used in a query, it will produce an additional predicate to be added to the query. This roughly emulates Postgres' Row Level Security, except it works completely application-side and supports every database (not just Postgres). Since table policies are baked directly into the query string, it plays well with database/sql's connection pooling because you don't have to set session-level variables (which force you to use an \*sql.Tx or \*sql.Conn). That means it also plays well with an external connection pooler like PgBouncer, because again no session-level variables are required. The main downside is that this can be easily bypassed if you reference the table directly with raw SQL instead of using the query builder. ### A PolicyTable example #policytable To define a table policy, a table struct must implement the `PolicyTable` interface. ```go type PolicyTable interface { Table Policy(ctx context.Context, dialect string) (Predicate, error) } ``` The context is the same context that was passed in to **sq.FetchAllContext**, **sq.FetchOneContext** or **sq.ExecContext**. As an example, we will define a table `employees` that stores employees for multiple tenants (indicated by the `tenant_id`). Any SELECT, UPDATE or DELETE query that hits the `employees` table must have a `tenant_id` predicate added to it. **Before** ```sql SELECT name FROM employees; UPDATE employees SET name = $1 WHERE employee_id = $2; DELETE FROM employees WHERE employee_id = $1; ``` **After** ```sql SELECT name FROM employees WHERE tenant_id = $1; UPDATE employees SET name = $1 WHERE tenant_id = $2 AND employee_id = $3; DELETE FROM employees WHERE tenant_id = $1 AND employee_id = $2; ``` Here is how to define the policy on the employees table. ```go type EMPLOYEES struct { sq.TableStruct TENANT_ID sq.NumberField EMPLOYEE_ID sq.NumberField NAME sq.StringField } func (tbl EMPLOYEES) Policy(ctx context.Context, dialect string) (sq.Predicate, error) { tenantID, ok := ctx.Value("tenantID").(int) if !ok { return nil, errors.New("tenantID not provided") } return tbl.TENANT_ID.EqInt(tenantID), nil } ``` Note that if the `tenantID` cannot be retrieved from the context, `(EMPLOYEES).Policy()` returns an error. This means that any invocation of the `EMPLOYEES` table struct will always require the `tenantID` to be in the context or else query building will fail. You may choose to omit this check by simply returning a `nil` Predicate. `nil` Predicates do not get added to the query. Here is how to use employees table. ```go // get tenantID from somewhere and put it into the context ctx := context.WithValue(context.Background(), "tenantID", 1) e := sq.New[EMPLOYEES]("") // Query 1 names, err := sq.FetchAllContext(ctx, db, sq.From(e), func(row *sq.Row) string { return row.String(e.NAME) }, ) // SELECT employees.name FROM employees WHERE employees.tenant_id = 1 // Query 2 _, err := sq.ExecContext(ctx, db, sq. Update(e). Set(e.NAME.SetString("BOB")). Where(e.EMPLOYEE_ID.EqInt(18)), ) // UPDATE employees SET name = 'BOB' WHERE employees.tenant_id = 1 AND employees.employee_id = 18 // Query 3 _, err := sq.ExecContext(ctx, db, sq. DeleteFrom(e). Where(e.EMPLOYEE_ID.EqInt(18)), ) // DELETE FROM employees WHERE employees.tenant_id = 1 AND employees.employee_id = 18 ``` ## SQL examples #sql-examples ### IN #in #### In slice #in-slice ```sql a.actor_id IN (1, 2, 3) ``` ```go a := sq.New[ACTOR]("a") a.ACTOR_ID.In([]int{1, 2, 3}) ``` #### In RowValues #in-rowvalues ```sql a.first_name IN ('PENELOPE', 'NICK', 'ED') (a.first_name, a.last_name) IN (('PENELOPE', 'GUINESS'), ('NICK', 'WAHLBERG'), ('ED', 'CHASE')) ``` ```go a := sq.New[ACTOR]("a") a.FIRST_NAME.In(sq.RowValue{"PENELOPE", "NICK", "ED"}) sq.RowValue{a.FIRST_NAME, a.LAST_NAME}.In(sq.RowValues{ {"PENELOPE", "GUINESS"}, {"NICK", "WAHLBERG"}, {"ED", "CHASE"}, }) ``` #### In Subquery #in-subquery ```sql (actor.first_name, actor.last_name) IN ( SELECT a.first_name, a.last_name FROM actor AS a WHERE a.actor_id <= 3 ) ``` ```go actor, a := sq.New[ACTOR](""), sq.New[ACTOR]("a") sq.RowValue{actor.FIRST_NAME, actor.LAST_NAME}.In(sq. Select(a.FIRST_NAME, a.LAST_NAME). From(a). Where(a.ACTOR_ID.Le(3)), ) ``` ### CASE #case #### Predicate Case #predicate-case ```sql CASE WHEN f.length <= 60 THEN 'short' WHEN f.length > 60 AND f.length <= 120 THEN 'medium' ELSE 'long' END AS length_type ``` ```go f := sq.New[FILM]("f") sq.CaseWhen(f.LENGTH.LeInt(60), "short"). CaseWhen(sq.And(f.LENGTH.GtInt(60), f.LENGTH.LeInt(120)), "medium"). Else("long"). As("length_type") ``` #### Simple Case #simple-case ```sql CASE f.rating WHEN 'G' THEN 'family' WHEN 'PG' THEN 'teens' WHEN 'PG-13' THEN 'teens' WHEN 'R' THEN 'adults' WHEN 'NC-17' THEN 'adults' ELSE 'unknown' END AS audience ``` ```go f := sq.New[FILM]("f") sq.Case(f.RATING). When("G", "family"). When("PG", "teens"). When("PG-13", "teens"). When("R", "adults"). When("NC-17", "adults"). Else("unknown"). As("Audience") ``` ### EXISTS #exists #### Where Exists #where-exists ```sql SELECT c.customer_id, c.first_name, c.last_name FROM customers AS c WHERE EXISTS ( SELECT 1 FROM orders AS o WHERE o.customer_id = c.customer_id GROUP BY o.customer_id HAVING COUNT(*) > 2 ) ORDER BY c.first_name, c.last_name ``` ```go c, o := sq.New[CUSTOMERS]("c"), sq.New[ORDERS]("o") customers, err := sq.FetchAll(db, sq. From(c). Where(sq.Exists(sq. SelectOne(). From(o). Where(o.CUSTOMER_ID.Eq(c.CUSTOMER_ID)). GroupBy(o.CUSTOMER_ID). Having(sq.Expr("COUNT(*) > 2")), )). OrderBy(c.FIRST_NAME, c.LAST_NAME), func(row *sq.Row) Customer { return Customer{ CustomerID: row.Int(c.CUSTOMER_ID), FirstName: row.String(c.FIRST_NAME), LastName: row.String(c.LAST_NAME), } }, ) ``` #### Where Not Exists #where-not-exists ```sql SELECT p.product_id, p.product_name FROM products AS p WHERE NOT EXISTS ( SELECT 1 FROM order_details AS od WHERE p.product_id = od.product_id ) ``` ```go p, od := sq.New[PRODUCTS]("p"), sq.New[ORDER_DETAILS]("od") products, err := sq.FetchAll(db, sq. From(p). Where(sq.NotExists(sq. SelectOne(). From(od). Where(p.PRODUCT_ID.Eq(od.PRODUCT_ID)), )), func(row *sq.Row) Product { return Product{ ProductID: row.Int(p.PRODUCT_ID), ProductName: row.String(p.PRODUCT_NAME), } }, ) ``` ### Subqueries #subqueries A Subquery is a SelectQuery nested inside another SelectQuery. **Using SelectQuery as Field** ```sql SELECT city.city, (SELECT country.country FROM country WHERE country.country_id = city.country_id) AS country FROM city WHERE city.city = 'Vancouver' ``` ```go city, country := sq.New[CITY](""), sq.New[COUNTRY]("") results, err := sq.FetchAll(db, sq. From(city). Where(city.CITY.EqString("Vancouver")). SetDialect(sq.DialectPostgres), func(row *sq.Row) Result { return Result{ City: row.StringField(city.CITY), Country: row.StringField(sq. Select(country.COUNTRY). From(country). Where(country.COUNTRY_ID.Eq(city.COUNTRY_ID)). As("country"), ), } }, ) ``` **Using SelectQuery as Table** ```sql SELECT film.title, film_stats.actor_count FROM film JOIN ( SELECT film_actor.film_id, COUNT(*) AS actor_count FROM film_actor GROUP BY film_actor.film_id ) AS film_stats ON film_stats.film_id = film.film_id ``` ```go film, film_actor := sq.New[FILM](""), sq.New[FILM_ACTOR]("") // create the subquery film_stats := sq.Postgres. Select( film_actor.FILM_ID, sq.CountStar().As("actor_count"), ). From(film_actor). GroupBy(film_actor.FILM_ID). As("film_stats") // use the subquery results, err := sq.FetchAll(db, sq. From(film). Join(film_stats, film_stats.Field("field_id").Eq(film.FILM_ID)), func(row *sq.Row) Result { return Result{ Title: row.String(film.TITLE), ActorCount: row.Int(film_stats.Field("actor_count")), } }, ) ``` ### WITH (Common Table Expressions) #common-table-expressions Common Table Expressions (CTEs) are an alternative to [subqueries](#subqueries). ```sql WITH film_stats AS ( SELECT film_id, COUNT(*) AS actor_count FROM film_actor GROUP BY film_id ) SELECT film.title, film_stats.actor_count FROM film JOIN film_stats ON film_stats.film_id = film.film_id ``` ```go film, film_actor := sq.New[FILM](""), sq.New[FILM_ACTOR]("") // create the CTE film_stats := sq.NewCTE("film_stats", nil, sq.Postgres. Select( film_actor.FILM_ID, sq.CountStar().As("actor_count"), ). From(film_actor). GroupBy(film_actor.FILM_ID), ) // use the CTE results, err := sq.FetchAll(db, sq.Postgres. With(film_stats). From(film). Join(film_stats, film_stats.Field("field_id").Eq(film.FILM_ID)), func(row *sq.Row) Result { return Result{ Title: row.String(film.TITLE), ActorCount: row.Int(film_stats.Field("actor_count")), } }, ) ``` **Recursive Common Table Expressions** ```sql WITH RECURSIVE counter (n) AS ( SELECT 1 UNION ALL SELECT counter.n + 1 FROM counter WHERE counter.n + 1 <= 100 ) SELECT counter.n FROM counter; ``` ```go counter := sq.NewRecursiveCTE("counter", []string{"n"}, sq.UnionAll( sq.Queryf("SELECT 1"), sq.Queryf("SELECT counter.n + 1 FROM counter WHERE counter.n + 1 <= {}", 100) )) sq.Postgres.With(counter).Select(counter.Field("n")).From(counter) ``` ### Aggregate functions #aggregate-functions sq provides some built-in aggregate functions. They return an `sq.Expression` and so can [pretty much be used everywhere](#expr). ```go func Count(field Field) Expression func CountStar() Expression func Sum(num Number) Expression func Avg(num Number) Expression func Min(field Field) Expression func Max(field Field) Expression ``` ### Window functions #window-functions sq provides some built-in window functions. They return an `sq.Expression` and so can [pretty much be used everywhere](#expr). ```go func CountOver(field Field, window Window) Expression func CountStarOver(window Window) Expression func SumOver(num Number, window Window) Expression func AvgOver(num Number, window Window) Expression func MinOver(field Field, window Window) Expression func MaxOver(field Field, window Window) Expression func RowNumberOver(window Window) Expression func RankOver(window Window) Expression func DenseRankOver(window Window) Expression func CumeDistOver(window Window) Expression func FirstValueOver(window Window) Expression func LastValueOver(window Window) Expression ``` **Missing window functions** The `LeadOver`, `LagOver` and `NtileOver` window functions do not have a representative Go function because they can be overloaded (they have multiple signatures) while Go functions cannot. If you need them, use an `sq.Expr()` as a stand-in. ```sql LEAD(a.actor_id) OVER (PARTITION BY a.first_name) LEAD(a.actor_id, 2) OVER (PARTITION BY a.first_name) LEAD(a.actor_id, 2, 5) OVER (PARTITION BY a.first_name) ``` ```go a := sq.New[ACTOR]("a") sq.Expr("LEAD({}) OVER (PARTITION BY {})", a.ACTOR_ID, a.FIRST_NAME) sq.Expr("LEAD({}, {}) OVER (PARTITION BY {})", a.ACTOR_ID, 2, a.FIRST_NAME) sq.Expr("LEAD({}, {}, {}) OVER (PARTITION BY {})", a.ACTOR_ID, 2, 5, a.FIRST_NAME) ``` **Using window functions** To use a window function, you must create a window using `sq.PartitionBy()`, `sq.OrderBy()` or `sq.BaseWindow()`. You can also pass in `nil` to represent the empty window. ```sql -- Example 1 SELECT COUNT(*) OVER () -- Example 2 SELECT SUM(a.actor_id) OVER (PARTITION BY a.first_name) -- Example 3 SELECT AVG(a.actor_id) OVER ( PARTITION BY a.first_name, a.last_name ORDER BY a.LAST_UPDATE DESC RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING ) ``` ```go a := sq.New[ACTOR]("a") // Example 1 sq.Postgres.Select(sq.CountStarOver(nil)) // Example 2 sq.Postgres.Select(sq.SumOver(a.ACTOR_ID, sq.PartitionBy(a.FIRST_NAME))) // Example 3 sq.Postgres.Select(sq.AvgOver(a.ACTOR_ID, sq. PartitionBy(a.FIRST_NAME, a.LAST_NAME). OrderBy(a.LAST_UPDATE.Desc()). Frame("RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING"), )) ``` SQLite, Postgres and MySQL support the named windows as part of the SELECT query. This allows you to reuse a window definition without having to specify it over and over. ```sql SELECT SUM(a.actor_id) OVER w1, MIN(a.actor_id) OVER w2, AVG(a.actor_id) OVER (w1 ORDER BY a.last_update) FROM actor AS a WINDOW w1 AS (PARTITION BY a.first_name), w2 AS (PARTITION BY a.last_name) ``` ```go a := sq.New[ACTOR]("a") w1 := sq.NamedWindow{Name: "w1", Definition: sq.PartitionBy(a.FIRST_NAME)} w2 := sq.NamedWindow{Name: "w2", Definition: sq.PartitionBy(a.LAST_NAME)} sq.Postgres. Select( sq.SumOver(a.ACTOR_ID, w1), sq.MinOver(a.ACTOR_ID, w2), sq.AvgOver(a.ACTOR_ID, sq.BaseWindow(w1).OrderBy(a.LAST_UPDATE)), ). From(a). Window(w1, w2) ``` ### UNION, INTERSECT, EXCEPT #union-intersect-except **Union** ```sql SELECT t1.field FROM t1 UNION SELECT t2.field FROM t2 UNION SELECT t3.field FROM t3 ``` ```go sq.Union( sq.Select(t1.FIELD).From(t1), sq.Select(t2.FIELD).From(t2), sq.Select(t3.FIELD).From(t3), ) ``` **Intersect** ```sql SELECT t1.field FROM t1 INTERSECT SELECT t2.field FROM t2 INTERSECT SELECT t3.field FROM t3 ``` ```go sq.Intersect( sq.Select(t1.FIELD).From(t1), sq.Select(t2.FIELD).From(t2), sq.Select(t3.FIELD).From(t3), ) ``` **Intersect** ```sql SELECT t1.field FROM t1 EXCEPT SELECT t2.field FROM t2 EXCEPT SELECT t3.field FROM t3 ``` ```go sq.Except( sq.Select(t1.FIELD).From(t1), sq.Select(t2.FIELD).From(t2), sq.Select(t3.FIELD).From(t3), ) ``` ### ORDER BY #orderby ```sql SELECT a.first_name FROM actor AS a ORDER BY a.actor_id DESC SELECT a.last_name FROM actor AS a ORDER BY a.actor_id ASC NULLS FIRST ``` ```go a := sq.New[ACTOR]("a") sq.Select(a.FIRST_NAME).From(a).OrderBy(a.ACTOR_ID.Desc()) sq.Select(a.LAST_NAME).From(a).OrderBy(a.ACTOR_ID.Asc().NullsFirst()) ``` ## Migrating from go-structured-query #migrating-from-go-structured-query If you are migrating to this library from [go-structured-query](https://github.com/bokwoon95/go-structured-query), here are the main changes ([index](#toc-migrating-from-go-structured-query)): ### Tables structs embed sq.TableStruct instead of sq.TableInfo. Struct constructors are no longer needed. #table-struct-changes **go-structured-query (old)** ```go // Code generated by 'sqgen-postgres tables'; DO NOT EDIT. import sq "github.com/bokwoon95/go-structured-query/postgres" type TABLE_ACTOR struct { *sq.TableInfo ACTOR_ID sq.NumberField FIRST_NAME sq.StringField LAST_NAME sq.StringField LAST_UPDATE sq.TimeField } func ACTOR() TABLE_ACTOR { tbl := TABLE_USER_ROLES_STUDENTS{TableInfo: &sq.TableInfo{ Schema: "", Name: "actor", }} tbl.ACTOR_ID = sq.NewNumberField("actor_id", tbl.TableInfo) tbl.FIRST_NAME = sq.NewStringField("first_name", tbl.TableInfo) tbl.LAST_NAME = sq.NewStringField("last_name", tbl.TableInfo) tbl.LAST_UPDATE = sq.NewTimeField("last_update", tbl.TableInfo) return tbl } // Instantiate table. a := ACTOR() // Instantiate table with alias. a := ACTOR().As("a") ``` **sq (new)** ```go import "github.com/bokwoon95/sq" type ACTOR struct { sq.TableStruct ACTOR_ID sq.NumberField FIRST_NAME sq.StringField LAST_NAME sq.StringField LAST_UPDATE sq.TimeField } // Instantiate table. a := sq.New[ACTOR]("") // Instantiate table with alias. a := sq.New[ACTOR]("a") ``` ### The CLI command is now called `sqddl`, not `sqgen-xxx`. Code generation is now optional. #code-generation-changes The new `sqddl` command replaces the old `sqgen-postgres` and `sqgen-mysql` commands. **go-structured-query (old)** ```shell # Install the sqgen-postgres command. $ go install github.com/bokwoon95/go-structured-query/cmd/sqgen-postgres@latest # Introspect the database and generate a file called 'tables.go' with package # name 'tables' inside the 'tables' directory. $ sqgen-postgres tables \ --database 'postgres://user:pass@localhost:5432/sakila?sslmode=disable' \ --pkg tables \ --directory ./tables \ --file tables.go ``` **sq (new)** ```shell # Install the sqddl command. $ go install -tags=fts5 github.com/bokwoon95/sqddl@latest # Introspect the database and generate a file called 'tables.go' with package # name 'tables' inside the 'tables' directory. # # The dialect (sqlite, postgres, mysql or sqlserver) is inferred from the # database URL. Refer to https://bokwoon.neocities.org/sqddl#tables for more # flags. $ sqddl tables \ -db 'postgres://user:pass@localhost:5432/sakila?sslmode=disable' \ -pkg tables -file ./tables/tables.go ``` **Code generation is now optional** You no longer have to generate the table struct code everytime your database schema changes. It is possible to define your table structs as the source of truth and generate migrations from it, using the `sqddl` tool. Read the documentation at [https://bokwoon.neocities.org/sqddl#table-structs](https://bokwoon.neocities.org/sqddl#table-structs) for more information. As an example here is an ACTOR table struct, which corresponds to the following CREATE TABLE statement: ```go // tables/tables.go import "github.com/bokwoon95/sq" type ACTOR struct { sq.TableStruct ACTOR_ID sq.NumberField `ddl:"notnull primarykey identity"` FIRST_NAME sq.StringField `ddl:"type=VARCHAR(45) notnull"` LAST_NAME sq.StringField `ddl:"type=VARCHAR(45) notnull index"` LAST_UPDATE sq.TimeField `ddl:"notnull default=CURRENT_TIMESTAMP"` } ``` ```sql CREATE TABLE actor ( actor_id INT NOT NULL PRIMARY KEY ALWAYS GENERATED AS IDENTITY, first_name VARCHAR(45) NOT NULL, last_name VARCHAR(45) NOT NULL, last_update TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX actor_last_name_idx ON actor (last_name); ``` Generate migrations for the ACTOR struct using `sqddl generate`: ```shell $ sqddl generate \ -src 'postgres://user:pass@localhost:5432/sakila' \ -dest tables/tables.go \ -output-dir ./migrations ``` More information: [https://bokwoon.neocities.org/sqddl#generate](https://bokwoon.neocities.org/sqddl#generate). ### Fetching #fetch-changes Fetching results has been lifted from a method (SelectRowx, Selectx) into a function (FetchOne, FetchAll). The accumulator function is no longer needed as a slice is generically returned based on the mapper function. **go-structured-query (old)** ```go import sq "github.com/bokwoon95/go-structured-query/postgres" // Fetch one. err := sq.From(tbl).Where(condition).SelectRowx(mapper).Fetch(db) // Fetch all. err := sq.From(tbl).Where(condition).Selectx(mapper, accumulator).Fetch(db) ``` **sq (new)** ```go import "github.com/bokwoon95/sq" // Fetch one. result, err := sq.FetchOne(db, sq.Postgres.From(tbl).Where(condition), mapper) // Fetch all. results, err := sq.FetchAll(db, sq.Postgres.From(tbl).Where(condition), mapper) ``` ### Exec-ing #exec-changes Similar to Fetch, Exec has been lifted from a method into a function. It is no longer necessary to pass in an ExecFlag indicating if you want the lastInsertId or rowsAffected, it will automatically be populated depending on the dialect. **go-structured-query (old)** ```go import sq "github.com/bokwoon95/go-structured-query/mysql" lastInsertId, rowsAffected, err := sq.InsertInto(tbl).Values(values...).Exec(db, sq.ElastInsertID|sq.ErowsAffected) ``` **sq (new)** ```go import "github.com/bokwoon95/sq" res, err := sq.Exec(db, sq.MySQL.InsertInto(tbl).Values(values...)) res.LastInsertId // int64 (valid because MySQL supports LastInsertId) res.RowAffected // int64 ``` ### Expr() replaces Fieldf() and Predicatef() #expr-changes [Fieldf() and Predicatef()](https://bokwoon95.github.io/sq/basics/sql-escape-hatch.html#fieldf-and-predicatef) were used to define Fields and Predicates containing an arbitrary SQL expression. They have been replaced by [Expr()](#expr), which does double duty. The placeholder symbol has been changed from `?` to `{}`. Assuming we want to replicate the query below: ```sql SELECT a.first_name || ' ' || a.last_name FROM actor AS a WHERE a.last_update + INTERVAL '1 hour' < CURRENT_TIMESTAMP ``` **go-structured-query (old)** ```go import sq "github.com/bokwoon95/go-structured-query/postgres" a := ACTOR().As("a") field := sq.Fieldf("? || ' ' || ?", a.FIRST_NAME, a.LAST_NAME) predicate := sq.Predicatef("? + INTERVAL '1 hour' < CURRENT_TIMESTAMP", a.LAST_UPDATE) sq.Select(field).From(a).Where(predicate) ``` **sq (new)** ```go import "github.com/bokwoon95/sq" a := sq.New[ACTOR]("a") field := sq.Expr("{} || ' ' || {}", a.FIRST_NAME, a.LAST_NAME) predicate := sq.Expr("{} + INTERVAL '1 hour' < CURRENT_TIMESTAMP", a.LAST_UPDATE) sq.Postgres.Select(field).From(a).Where(predicate) ``` ### Logging #logging-changes Logging has been lifted out of a method [.WithDefaultLog()](https://bokwoon95.github.io/sq/basics/logging.html) and into a function [sq.Log()](#logging) (which should wrap the database object). **go-structured-query (old)** ```go import sq "github.com/bokwoon95/go-structured-query/postgres" err := sq.WithDefaultLog().From(tbl).Where(predicate).Selectx(mapper, accumulator).Fetch(db) ``` **sq (new)** ```go import "github.com/bokwoon95/sq" results, err := sq.FetchAll(sq.Log(db), sq.Postgres.From(tbl).Where(predicate)) ``` ### Generating type-safe wrappers of PL/pgSQL functions has been removed #generated-plpgsql-function-changes [go-structured-query supported code-generating wrappers for plpgsql functions](https://bokwoon95.github.io/sq/sql-expressions/plpgsql.html), that feature is not present in sq because I was unsatisfied with the design. As a workaround you should use sq.Expr() to invoke functions instead. **go-structured-query (old)** ```go import sq "github.com/bokwoon95/go-structured-query/postgres" sq.Select(ADD_NUMS(1, 2)) // ADD_NUMS is code-generated. ``` **sq (new)** ```go import "github.com/bokwoon95/sq" sq.Postgres.Select(sq.Expr("add_nums({}, {})", 1, 2)) ``` ================================================ FILE: sq_test.go ================================================ package sq import ( "database/sql" "testing" "github.com/bokwoon95/sq/internal/testutil" "github.com/google/uuid" ) type Weekday uint const ( WeekdayInvalid Weekday = iota Sunday Monday Tuesday Wednesday Thursday Friday Saturday ) func (d Weekday) Enumerate() []string { return []string{ WeekdayInvalid: "", Sunday: "Sunday", Monday: "Monday", Tuesday: "Tuesday", Wednesday: "Wednesday", Thursday: "Thursday", Friday: "Friday", Saturday: "Saturday", } } func Test_preprocessValue(t *testing.T) { type TestTable struct { description string dialect string input any wantOutput any } tests := []TestTable{{ description: "empty", input: nil, wantOutput: nil, }, { description: "driver.Valuer", input: uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20"), wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20", }, { description: "Postgres DialectValuer", dialect: DialectPostgres, input: UUIDValue(uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20")), wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20", }, { description: "MySQL DialectValuer", dialect: DialectMySQL, input: UUIDValue(uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20")), wantOutput: []byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, }, { description: "Postgres [16]byte", dialect: DialectPostgres, input: [16]byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20", }, { description: "MySQL [16]byte", dialect: DialectMySQL, input: [16]byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, wantOutput: []byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, }, { description: "Enumeration", input: Monday, wantOutput: "Monday", }, { description: "int", input: 42, wantOutput: 42, }, { description: "sql.NullString", input: sql.NullString{ Valid: false, String: "lorem ipsum dolor sit amet", }, wantOutput: nil, }} for _, tt := range tests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() gotOutput, err := preprocessValue(tt.dialect, tt.input) if err != nil { t.Fatal(testutil.Callers(), err) } if diff := testutil.Diff(gotOutput, tt.wantOutput); diff != "" { t.Error(testutil.Callers(), diff) } }) } } ================================================ FILE: update_query.go ================================================ package sq import ( "bytes" "context" "fmt" ) // UpdateQuery represents an SQL UPDATE query. type UpdateQuery struct { Dialect string ColumnMapper func(*Column) // WITH CTEs []CTE // UPDATE UpdateTable Table // FROM FromTable Table JoinTables []JoinTable // SET Assignments []Assignment // WHERE WherePredicate Predicate // ORDER BY OrderByFields []Field // LIMIT LimitRows any // RETURNING ReturningFields []Field } var _ Query = (*UpdateQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q UpdateQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) (err error) { if q.ColumnMapper != nil { col := &Column{ dialect: q.Dialect, isUpdate: true, } defer mapperFunctionPanicked(&err) q.ColumnMapper(col) if err != nil { return err } q.Assignments = col.assignments } // Table Policies var policies []Predicate policies, err = appendPolicy(ctx, dialect, policies, q.UpdateTable) if err != nil { return fmt.Errorf("UPDATE %s Policy: %w", toString(q.Dialect, q.UpdateTable), err) } policies, err = appendPolicy(ctx, dialect, policies, q.FromTable) if err != nil { return fmt.Errorf("FROM %s Policy: %w", toString(q.Dialect, q.FromTable), err) } for _, joinTable := range q.JoinTables { policies, err = appendPolicy(ctx, dialect, policies, joinTable.Table) if err != nil { return fmt.Errorf("%s %s Policy: %w", joinTable.JoinOperator, joinTable.Table, err) } } if len(policies) > 0 { if q.WherePredicate != nil { policies = append(policies, q.WherePredicate) } q.WherePredicate = And(policies...) } // WITH if len(q.CTEs) > 0 { err = writeCTEs(ctx, dialect, buf, args, params, q.CTEs) if err != nil { return fmt.Errorf("WITH: %w", err) } } // UPDATE buf.WriteString("UPDATE ") if q.UpdateTable == nil { return fmt.Errorf("no table provided to UPDATE") } err = q.UpdateTable.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("UPDATE: %w", err) } if dialect != DialectSQLServer { if alias := getAlias(q.UpdateTable); alias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) } } if len(q.Assignments) == 0 { return fmt.Errorf("no fields to update") } // SET (not mysql) if dialect != DialectMySQL { buf.WriteString(" SET ") err = Assignments(q.Assignments).WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("SET: %w", err) } } // OUTPUT if len(q.ReturningFields) > 0 && dialect == DialectSQLServer { buf.WriteString(" OUTPUT ") err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, q.ReturningFields, "INSERTED", true) if err != nil { return err } } // FROM if q.FromTable != nil { if dialect == DialectMySQL { return fmt.Errorf("mysql UPDATE does not support FROM") } buf.WriteString(" FROM ") err = q.FromTable.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("FROM: %w", err) } if alias := getAlias(q.FromTable); alias != "" { buf.WriteString(" AS " + QuoteIdentifier(dialect, alias) + quoteTableColumns(dialect, q.FromTable)) } } // JOIN if len(q.JoinTables) > 0 { if q.FromTable == nil && dialect != DialectMySQL { return fmt.Errorf("%s can't JOIN without a FROM table", dialect) } buf.WriteString(" ") err = writeJoinTables(ctx, dialect, buf, args, params, q.JoinTables) if err != nil { return fmt.Errorf("JOIN: %w", err) } } // SET (mysql) if dialect == DialectMySQL { buf.WriteString(" SET ") err = Assignments(q.Assignments).WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("SET: %w", err) } } // WHERE if q.WherePredicate != nil { buf.WriteString(" WHERE ") switch predicate := q.WherePredicate.(type) { case VariadicPredicate: predicate.Toplevel = true err = predicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("WHERE: %w", err) } default: err = q.WherePredicate.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("WHERE: %w", err) } } } // ORDER BY if len(q.OrderByFields) > 0 { if dialect != DialectMySQL { return fmt.Errorf("%s UPDATE does not support ORDER BY", dialect) } buf.WriteString(" ORDER BY ") err = writeFields(ctx, dialect, buf, args, params, q.OrderByFields, false) if err != nil { return fmt.Errorf("ORDER BY: %w", err) } } // LIMIT if q.LimitRows != nil { if dialect != DialectMySQL { return fmt.Errorf("%s UPDATE does not support LIMIT", dialect) } buf.WriteString(" LIMIT ") err = WriteValue(ctx, dialect, buf, args, params, q.LimitRows) if err != nil { return fmt.Errorf("LIMIT: %w", err) } } // RETURNING if len(q.ReturningFields) > 0 && dialect != DialectSQLServer { if dialect != DialectPostgres && dialect != DialectSQLite { return fmt.Errorf("%s UPDATE does not support RETURNING", dialect) } buf.WriteString(" RETURNING ") err = writeFields(ctx, dialect, buf, args, params, q.ReturningFields, true) if err != nil { return fmt.Errorf("RETURNING: %w", err) } } return nil } // Update returns a new UpdateQuery. func Update(table Table) UpdateQuery { return UpdateQuery{UpdateTable: table} } // Set sets the Assignments field of the UpdateQuery. func (q UpdateQuery) Set(assignments ...Assignment) UpdateQuery { q.Assignments = append(q.Assignments, assignments...) return q } // SetFunc sets the ColumnMapper field of the UpdateQuery. func (q UpdateQuery) SetFunc(colmapper func(*Column)) UpdateQuery { q.ColumnMapper = colmapper return q } // Where appends to the WherePredicate field of the UpdateQuery. func (q UpdateQuery) Where(predicates ...Predicate) UpdateQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // SetFetchableFields implements the Query interface. func (q UpdateQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { switch q.Dialect { case DialectPostgres, DialectSQLite: if len(q.ReturningFields) == 0 { q.ReturningFields = fields return q, true } return q, false default: return q, false } } // GetFetchableFields returns the fetchable fields of the query. func (q UpdateQuery) GetFetchableFields() []Field { switch q.Dialect { case DialectPostgres, DialectSQLite: return q.ReturningFields default: return nil } } // GetDialect implements the Query interface. func (q UpdateQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the query. func (q UpdateQuery) SetDialect(dialect string) UpdateQuery { q.Dialect = dialect return q } // SQLiteUpdateQuery represents an SQLite UPDATE query. type SQLiteUpdateQuery UpdateQuery var _ Query = (*SQLiteUpdateQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SQLiteUpdateQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return UpdateQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // Update returns a new SQLiteUpdateQuery. func (b sqliteQueryBuilder) Update(table Table) SQLiteUpdateQuery { return SQLiteUpdateQuery{ Dialect: DialectSQLite, CTEs: b.ctes, UpdateTable: table, } } // Set sets the Assignments field of the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) Set(assignments ...Assignment) SQLiteUpdateQuery { q.Assignments = append(q.Assignments, assignments...) return q } // SetFunc sets the ColumnMapper of the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) SetFunc(colmapper func(*Column)) SQLiteUpdateQuery { q.ColumnMapper = colmapper return q } // From sets the FromTable field of the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) From(table Table) SQLiteUpdateQuery { q.FromTable = table return q } // Join joins a new Table to the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) Join(table Table, predicates ...Predicate) SQLiteUpdateQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) LeftJoin(table Table, predicates ...Predicate) SQLiteUpdateQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) CrossJoin(table Table) SQLiteUpdateQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the SQLiteUpdateQuery with a custom join // operator. func (q SQLiteUpdateQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) SQLiteUpdateQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the SQLiteUpdateQuery with the USING operator. func (q SQLiteUpdateQuery) JoinUsing(table Table, fields ...Field) SQLiteUpdateQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Where appends to the WherePredicate field of the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) Where(predicates ...Predicate) SQLiteUpdateQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // Returning sets the ReturningFields field of the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) Returning(fields ...Field) SQLiteUpdateQuery { q.ReturningFields = append(q.ReturningFields, fields...) return q } // SetFetchableFields implements the Query interface. func (q SQLiteUpdateQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return UpdateQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) GetFetchableFields() []Field { return UpdateQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q SQLiteUpdateQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the SQLiteUpdateQuery. func (q SQLiteUpdateQuery) SetDialect(dialect string) SQLiteUpdateQuery { q.Dialect = dialect return q } // PostgresUpdateQuery represents a Postgres UPDATE query. type PostgresUpdateQuery UpdateQuery var _ Query = (*PostgresUpdateQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q PostgresUpdateQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return UpdateQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // Update returns a new PostgresUpdateQuery. func (b postgresQueryBuilder) Update(table Table) PostgresUpdateQuery { return PostgresUpdateQuery{ Dialect: DialectPostgres, CTEs: b.ctes, UpdateTable: table, } } // Set sets the Assignments field of the PostgresUpdateQuery. func (q PostgresUpdateQuery) Set(assignments ...Assignment) PostgresUpdateQuery { q.Assignments = append(q.Assignments, assignments...) return q } // SetFunc sets the ColumnMapper of the PostgresUpdateQuery. func (q PostgresUpdateQuery) SetFunc(colmapper func(*Column)) PostgresUpdateQuery { q.ColumnMapper = colmapper return q } // From sets the FromTable field of the PostgresUpdateQuery. func (q PostgresUpdateQuery) From(table Table) PostgresUpdateQuery { q.FromTable = table return q } // Join joins a new Table to the PostgresUpdateQuery. func (q PostgresUpdateQuery) Join(table Table, predicates ...Predicate) PostgresUpdateQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the PostgresUpdateQuery. func (q PostgresUpdateQuery) LeftJoin(table Table, predicates ...Predicate) PostgresUpdateQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the PostgresUpdateQuery. func (q PostgresUpdateQuery) FullJoin(table Table, predicates ...Predicate) PostgresUpdateQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the PostgresUpdateQuery. func (q PostgresUpdateQuery) CrossJoin(table Table) PostgresUpdateQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the PostgresUpdateQuery with a custom join // operator. func (q PostgresUpdateQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) PostgresUpdateQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the PostgresUpdateQuery with the USING operator. func (q PostgresUpdateQuery) JoinUsing(table Table, fields ...Field) PostgresUpdateQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Where appends to the WherePredicate field of the PostgresUpdateQuery. func (q PostgresUpdateQuery) Where(predicates ...Predicate) PostgresUpdateQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // Returning sets the ReturningFields field of the PostgresUpdateQuery. func (q PostgresUpdateQuery) Returning(fields ...Field) PostgresUpdateQuery { q.ReturningFields = append(q.ReturningFields, fields...) return q } // SetFetchableFields implements the Query interface. func (q PostgresUpdateQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return UpdateQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the PostgresUpdateQuery. func (q PostgresUpdateQuery) GetFetchableFields() []Field { return UpdateQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q PostgresUpdateQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the PostgresUpdateQuery. func (q PostgresUpdateQuery) SetDialect(dialect string) PostgresUpdateQuery { q.Dialect = dialect return q } // MySQLUpdateQuery represents a MySQL UPDATE query. type MySQLUpdateQuery UpdateQuery var _ Query = (*MySQLUpdateQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q MySQLUpdateQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return UpdateQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // Update returns a new MySQLUpdateQuery. func (b mysqlQueryBuilder) Update(table Table) MySQLUpdateQuery { q := MySQLUpdateQuery{ Dialect: DialectMySQL, CTEs: b.ctes, UpdateTable: table, } return q } // Join joins a new Table to the MySQLUpdateQuery. func (q MySQLUpdateQuery) Join(table Table, predicates ...Predicate) MySQLUpdateQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the MySQLUpdateQuery. func (q MySQLUpdateQuery) LeftJoin(table Table, predicates ...Predicate) MySQLUpdateQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the MySQLUpdateQuery. func (q MySQLUpdateQuery) FullJoin(table Table, predicates ...Predicate) MySQLUpdateQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the MySQLUpdateQuery. func (q MySQLUpdateQuery) CrossJoin(table Table) MySQLUpdateQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the MySQLUpdateQuery with a custom join // operator. func (q MySQLUpdateQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) MySQLUpdateQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // JoinUsing joins a new Table to the MySQLUpdateQuery with the USING operator. func (q MySQLUpdateQuery) JoinUsing(table Table, fields ...Field) MySQLUpdateQuery { q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) return q } // Set sets the Assignments field of the MySQLUpdateQuery. func (q MySQLUpdateQuery) Set(assignments ...Assignment) MySQLUpdateQuery { q.Assignments = append(q.Assignments, assignments...) return q } // SetFunc sets the ColumnMapper of the MySQLUpdateQuery. func (q MySQLUpdateQuery) SetFunc(colmapper func(*Column)) MySQLUpdateQuery { q.ColumnMapper = colmapper return q } // Where appends to the WherePredicate field of the MySQLUpdateQuery. func (q MySQLUpdateQuery) Where(predicates ...Predicate) MySQLUpdateQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // OrderBy sets the OrderByFields of the MySQLUpdateQuery. func (q MySQLUpdateQuery) OrderBy(fields ...Field) MySQLUpdateQuery { q.OrderByFields = append(q.OrderByFields, fields...) return q } // Limit sets the LimitRows field of the MySQLUpdateQuery. func (q MySQLUpdateQuery) Limit(limit any) MySQLUpdateQuery { q.LimitRows = limit return q } // SetFetchableFields implements the Query interface. func (q MySQLUpdateQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return UpdateQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the MySQLUpdateQuery. func (q MySQLUpdateQuery) GetFetchableFields() []Field { return UpdateQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q MySQLUpdateQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the MySQLUpdateQuery. func (q MySQLUpdateQuery) SetDialect(dialect string) MySQLUpdateQuery { q.Dialect = dialect return q } // SQLServerUpdateQuery represents an SQL Server UPDATE query. type SQLServerUpdateQuery UpdateQuery var _ Query = (*SQLServerUpdateQuery)(nil) // WriteSQL implements the SQLWriter interface. func (q SQLServerUpdateQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { return UpdateQuery(q).WriteSQL(ctx, dialect, buf, args, params) } // Update returns a new SQLServerUpdateQuery. func (b sqlserverQueryBuilder) Update(table Table) SQLServerUpdateQuery { return SQLServerUpdateQuery{ Dialect: DialectSQLServer, CTEs: b.ctes, UpdateTable: table, } } // Set sets the Assignments field of the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) Set(assignments ...Assignment) SQLServerUpdateQuery { q.Assignments = append(q.Assignments, assignments...) return q } // SetFunc sets the ColumnMapper of the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) SetFunc(colmapper func(*Column)) SQLServerUpdateQuery { q.ColumnMapper = colmapper return q } // From sets the FromTable field of the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) From(table Table) SQLServerUpdateQuery { q.FromTable = table return q } // Join joins a new Table to the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) Join(table Table, predicates ...Predicate) SQLServerUpdateQuery { q.JoinTables = append(q.JoinTables, Join(table, predicates...)) return q } // LeftJoin left joins a new Table to the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) LeftJoin(table Table, predicates ...Predicate) SQLServerUpdateQuery { q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) return q } // FullJoin full joins a new Table to the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) FullJoin(table Table, predicates ...Predicate) SQLServerUpdateQuery { q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) return q } // CrossJoin cross joins a new Table to the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) CrossJoin(table Table) SQLServerUpdateQuery { q.JoinTables = append(q.JoinTables, CrossJoin(table)) return q } // CustomJoin joins a new Table to the SQLServerUpdateQuery with a custom join // operator. func (q SQLServerUpdateQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) SQLServerUpdateQuery { q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) return q } // Where appends to the WherePredicate field of the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) Where(predicates ...Predicate) SQLServerUpdateQuery { q.WherePredicate = appendPredicates(q.WherePredicate, predicates) return q } // SetFetchableFields implements the Query interface. func (q SQLServerUpdateQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { return UpdateQuery(q).SetFetchableFields(fields) } // GetFetchableFields returns the fetchable fields of the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) GetFetchableFields() []Field { return UpdateQuery(q).GetFetchableFields() } // GetDialect implements the Query interface. func (q SQLServerUpdateQuery) GetDialect() string { return q.Dialect } // SetDialect sets the dialect of the SQLServerUpdateQuery. func (q SQLServerUpdateQuery) SetDialect(dialect string) SQLServerUpdateQuery { q.Dialect = dialect return q } ================================================ FILE: update_query_test.go ================================================ package sq import ( "testing" "github.com/bokwoon95/sq/internal/testutil" ) func TestSQLiteUpdateQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SQLite.Update(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectSQLite) fields := q1.GetFetchableFields() if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("Set", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. With(NewCTE("cte", nil, Queryf("SELECT 1"))). Update(a). Set( a.FIRST_NAME.SetString("bob"), a.LAST_NAME.SetString("the builder"), ). Where(a.ACTOR_ID.EqInt(1), a.LAST_UPDATE.IsNotNull()). Returning(a.ACTOR_ID) tt.wantQuery = "WITH cte AS (SELECT 1)" + " UPDATE actor AS a" + " SET first_name = $1, last_name = $2" + " WHERE a.actor_id = $3 AND a.last_update IS NOT NULL" + " RETURNING a.actor_id" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) t.Run("SetFunc", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. With(NewCTE("cte", nil, Queryf("SELECT 1"))). Update(a). SetFunc(func(col *Column) { col.SetString(a.FIRST_NAME, "bob") col.SetString(a.LAST_NAME, "the builder") }). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " UPDATE actor AS a" + " SET first_name = $1, last_name = $2" + " WHERE a.actor_id = $3" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) t.Run("UPDATE with JOIN", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLite. Update(a). Set( a.FIRST_NAME.SetString("bob"), a.LAST_NAME.SetString("the builder"), ). From(a). Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). CrossJoin(a). CustomJoin(",", a). JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "UPDATE actor AS a" + " SET first_name = $1, last_name = $2" + " FROM actor AS a" + " JOIN actor AS a ON a.actor_id = a.actor_id" + " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" + " CROSS JOIN actor AS a" + " , actor AS a" + " JOIN actor AS a USING (first_name, last_name)" + " WHERE a.actor_id = $3" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) } func TestPostgresUpdateQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := Postgres.Update(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectPostgres) fields := q1.GetFetchableFields() if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { t.Error(testutil.Callers(), diff) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if !ok { t.Fatal(testutil.Callers(), "field should have been set") } }) t.Run("Set", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. With(NewCTE("cte", nil, Queryf("SELECT 1"))). Update(a). Set( a.FIRST_NAME.SetString("bob"), a.LAST_NAME.SetString("the builder"), ). Where(a.ACTOR_ID.EqInt(1), a.LAST_UPDATE.IsNotNull()). Returning(a.ACTOR_ID) tt.wantQuery = "WITH cte AS (SELECT 1)" + " UPDATE actor AS a" + " SET first_name = $1, last_name = $2" + " WHERE a.actor_id = $3 AND a.last_update IS NOT NULL" + " RETURNING a.actor_id" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) t.Run("SetFunc", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. With(NewCTE("cte", nil, Queryf("SELECT 1"))). Update(a). SetFunc(func(col *Column) { col.SetString(a.FIRST_NAME, "bob") col.SetString(a.LAST_NAME, "the builder") }). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " UPDATE actor AS a" + " SET first_name = $1, last_name = $2" + " WHERE a.actor_id = $3" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) t.Run("UPDATE with JOIN", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = Postgres. Update(a). Set( a.FIRST_NAME.SetString("bob"), a.LAST_NAME.SetString("the builder"), ). From(a). Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). CrossJoin(a). CustomJoin(",", a). JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "UPDATE actor AS a" + " SET first_name = $1, last_name = $2" + " FROM actor AS a" + " JOIN actor AS a ON a.actor_id = a.actor_id" + " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" + " FULL JOIN actor AS a ON a.actor_id = a.actor_id" + " CROSS JOIN actor AS a" + " , actor AS a" + " JOIN actor AS a USING (first_name, last_name)" + " WHERE a.actor_id = $3" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) } func TestMySQLUpdateQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("a") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := MySQL.Update(a).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectMySQL) fields := q1.GetFetchableFields() if len(fields) != 0 { t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } }) t.Run("Set", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. With(NewCTE("cte", nil, Queryf("SELECT 1"))). Update(a). Set( a.FIRST_NAME.SetString("bob"), a.LAST_NAME.SetString("the builder"), ). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " UPDATE actor AS a" + " SET a.first_name = ?, a.last_name = ?" + " WHERE a.actor_id = ?" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) t.Run("SetFunc", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. With(NewCTE("cte", nil, Queryf("SELECT 1"))). Update(a). SetFunc(func(col *Column) { col.SetString(a.FIRST_NAME, "bob") col.SetString(a.LAST_NAME, "the builder") }). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " UPDATE actor AS a" + " SET a.first_name = ?, a.last_name = ?" + " WHERE a.actor_id = ?" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) t.Run("UPDATE with JOIN, ORDER BY, LIMIT", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = MySQL. Update(a). Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). CrossJoin(a). CustomJoin(",", a). JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). Set( a.FIRST_NAME.SetString("bob"), a.LAST_NAME.SetString("the builder"), ). Where(a.ACTOR_ID.EqInt(1)). OrderBy(a.ACTOR_ID). Limit(5) tt.wantQuery = "UPDATE actor AS a" + " JOIN actor AS a ON a.actor_id = a.actor_id" + " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" + " FULL JOIN actor AS a ON a.actor_id = a.actor_id" + " CROSS JOIN actor AS a" + " , actor AS a" + " JOIN actor AS a USING (first_name, last_name)" + " SET a.first_name = ?, a.last_name = ?" + " WHERE a.actor_id = ?" + " ORDER BY a.actor_id" + " LIMIT ?" tt.wantArgs = []any{"bob", "the builder", 1, 5} tt.assert(t) }) } func TestSQLServerUpdateQuery(t *testing.T) { type ACTOR struct { TableStruct ACTOR_ID NumberField FIRST_NAME StringField LAST_NAME StringField LAST_UPDATE TimeField } a := New[ACTOR]("") t.Run("basic", func(t *testing.T) { t.Parallel() q1 := SQLServer.Update(a).SetDialect("lorem ipsum") if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } q1 = q1.SetDialect(DialectSQLServer) fields := q1.GetFetchableFields() if len(fields) != 0 { t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) } _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } q1.ReturningFields = q1.ReturningFields[:0] _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) if ok { t.Fatal(testutil.Callers(), "field should not have been set") } }) t.Run("Set", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. With(NewCTE("cte", nil, Queryf("SELECT 1"))). Update(a). Set( a.FIRST_NAME.SetString("bob"), a.LAST_NAME.SetString("the builder"), ). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " UPDATE actor" + " SET first_name = @p1, last_name = @p2" + " WHERE actor.actor_id = @p3" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) t.Run("SetFunc", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. With(NewCTE("cte", nil, Queryf("SELECT 1"))). Update(a). SetFunc(func(col *Column) { col.SetString(a.FIRST_NAME, "bob") col.SetString(a.LAST_NAME, "the builder") }). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "WITH cte AS (SELECT 1)" + " UPDATE actor" + " SET first_name = @p1, last_name = @p2" + " WHERE actor.actor_id = @p3" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) t.Run("UPDATE with JOIN", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = SQLServer. Update(a). Set( a.FIRST_NAME.SetString("bob"), a.LAST_NAME.SetString("the builder"), ). From(a). Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). CrossJoin(a). CustomJoin(",", a). Where(a.ACTOR_ID.EqInt(1)) tt.wantQuery = "UPDATE actor" + " SET first_name = @p1, last_name = @p2" + " FROM actor" + " JOIN actor ON actor.actor_id = actor.actor_id" + " LEFT JOIN actor ON actor.actor_id = actor.actor_id" + " FULL JOIN actor ON actor.actor_id = actor.actor_id" + " CROSS JOIN actor" + " , actor" + " WHERE actor.actor_id = @p3" tt.wantArgs = []any{"bob", "the builder", 1} tt.assert(t) }) } func TestUpdateQuery(t *testing.T) { t.Run("basic", func(t *testing.T) { t.Parallel() q1 := UpdateQuery{UpdateTable: Expr("tbl"), Dialect: "lorem ipsum"} if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { t.Error(testutil.Callers(), diff) } }) f1, f2, f3 := Expr("f1"), Expr("f2"), Expr("f3") colmapper := func(col *Column) { col.Set(f1, 1) col.Set(f2, 2) col.Set(f3, 3) } t.Run("PolicyTable", func(t *testing.T) { t.Parallel() var tt TestTable tt.item = UpdateQuery{ UpdateTable: policyTableStub{policy: And(Expr("1 = 1"), Expr("2 = 2"))}, ColumnMapper: colmapper, WherePredicate: Expr("3 = 3"), } tt.wantQuery = "UPDATE policy_table_stub SET f1 = ?, f2 = ?, f3 = ? WHERE (1 = 1 AND 2 = 2) AND 3 = 3" tt.wantArgs = []any{1, 2, 3} tt.assert(t) }) notOKTests := []TestTable{{ description: "nil UpdateTable not allowed", item: UpdateQuery{ UpdateTable: nil, ColumnMapper: colmapper, }, }, { description: "empty Assignments not allowed", item: UpdateQuery{ UpdateTable: Expr("tbl"), Assignments: nil, }, }, { description: "mysql does not support FROM", item: UpdateQuery{ Dialect: DialectMySQL, UpdateTable: Expr("tbl"), FromTable: Expr("tbl"), ColumnMapper: colmapper, }, }, { description: "dialect does not allow JOIN without FROM", item: UpdateQuery{ Dialect: DialectPostgres, UpdateTable: Expr("tbl"), FromTable: nil, JoinTables: []JoinTable{ Join(Expr("tbl"), Expr("1 = 1")), }, ColumnMapper: colmapper, }, }, { description: "dialect does not support ORDER BY", item: UpdateQuery{ Dialect: DialectPostgres, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, OrderByFields: Fields{f1}, }, }, { description: "dialect does not support LIMIT", item: UpdateQuery{ Dialect: DialectPostgres, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, LimitRows: 5, }, }, { description: "dialect does not support RETURNING", item: UpdateQuery{ Dialect: DialectMySQL, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, ReturningFields: Fields{f1, f2, f3}, }, }} for _, tt := range notOKTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertNotOK(t) }) } errTests := []TestTable{{ description: "ColumnMapper err", item: UpdateQuery{ UpdateTable: Expr("tbl"), ColumnMapper: func(*Column) { panic(ErrFaultySQL) }, }, }, { description: "UpdateTable Policy err", item: UpdateQuery{ UpdateTable: policyTableStub{err: ErrFaultySQL}, ColumnMapper: colmapper, }, }, { description: "FromTable Policy err", item: UpdateQuery{ UpdateTable: Expr("tbl"), FromTable: policyTableStub{err: ErrFaultySQL}, ColumnMapper: colmapper, }, }, { description: "JoinTables Policy err", item: UpdateQuery{ UpdateTable: Expr("tbl"), ColumnMapper: colmapper, FromTable: Expr("tbl"), JoinTables: []JoinTable{ Join(policyTableStub{err: ErrFaultySQL}, Expr("1 = 1")), }, }, }, { description: "CTEs err", item: UpdateQuery{ CTEs: []CTE{NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{}))}, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, }, }, { description: "UpdateTable err", item: UpdateQuery{ UpdateTable: FaultySQL{}, ColumnMapper: colmapper, }, }, { description: "not mysql Assignments err", item: UpdateQuery{ Dialect: DialectPostgres, UpdateTable: Expr("tbl"), Assignments: []Assignment{FaultySQL{}}, }, }, { description: "FromTable err", item: UpdateQuery{ Dialect: DialectPostgres, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, FromTable: FaultySQL{}, }, }, { description: "JoinTables err", item: UpdateQuery{ Dialect: DialectPostgres, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, FromTable: Expr("tbl"), JoinTables: []JoinTable{ Join(FaultySQL{}, Expr("1 = 1")), }, }, }, { description: "mysql Assignments err", item: UpdateQuery{ Dialect: DialectMySQL, UpdateTable: Expr("tbl"), Assignments: []Assignment{FaultySQL{}}, }, }, { description: "WherePredicate Variadic err", item: UpdateQuery{ UpdateTable: Expr("tbl"), ColumnMapper: colmapper, WherePredicate: And(FaultySQL{}), }, }, { description: "WherePredicate err", item: UpdateQuery{ UpdateTable: Expr("tbl"), ColumnMapper: colmapper, WherePredicate: FaultySQL{}, }, }, { description: "OrderByFields err", item: UpdateQuery{ Dialect: DialectMySQL, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, OrderByFields: Fields{FaultySQL{}}, }, }, { description: "LimitRows err", item: UpdateQuery{ Dialect: DialectMySQL, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, OrderByFields: Fields{f1}, LimitRows: FaultySQL{}, }, }, { description: "ReturningFields err", item: UpdateQuery{ Dialect: DialectPostgres, UpdateTable: Expr("tbl"), ColumnMapper: colmapper, ReturningFields: Fields{FaultySQL{}}, }, }} for _, tt := range errTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertErr(t, ErrFaultySQL) }) } } ================================================ FILE: window.go ================================================ package sq import ( "bytes" "context" "fmt" ) // NamedWindow represents an SQL named window. type NamedWindow struct { Name string Definition Window } var _ Window = (*NamedWindow)(nil) // WriteSQL implements the SQLWriter interface. func (w NamedWindow) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { buf.WriteString(w.Name) return nil } // IsWindow implements the Window interface. func (w NamedWindow) IsWindow() {} // WindowDefinition represents an SQL window definition. type WindowDefinition struct { BaseWindowName string PartitionByFields []Field OrderByFields []Field FrameSpec string FrameValues []any } var _ Window = (*WindowDefinition)(nil) // BaseWindow creates a new WindowDefinition based off an existing NamedWindow. func BaseWindow(w NamedWindow) WindowDefinition { return WindowDefinition{BaseWindowName: w.Name} } // PartitionBy returns a new WindowDefinition with the PARTITION BY clause. func PartitionBy(fields ...Field) WindowDefinition { return WindowDefinition{PartitionByFields: fields} } // PartitionBy returns a new WindowDefinition with the ORDER BY clause. func OrderBy(fields ...Field) WindowDefinition { return WindowDefinition{OrderByFields: fields} } // WriteSQL implements the SQLWriter interface. func (w WindowDefinition) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error var written bool buf.WriteString("(") if w.BaseWindowName != "" { buf.WriteString(w.BaseWindowName + " ") } if len(w.PartitionByFields) > 0 { written = true buf.WriteString("PARTITION BY ") err = writeFields(ctx, dialect, buf, args, params, w.PartitionByFields, false) if err != nil { return fmt.Errorf("Window PARTITION BY: %w", err) } } if len(w.OrderByFields) > 0 { if written { buf.WriteString(" ") } written = true buf.WriteString("ORDER BY ") err = writeFields(ctx, dialect, buf, args, params, w.OrderByFields, false) if err != nil { return fmt.Errorf("Window ORDER BY: %w", err) } } if w.FrameSpec != "" { if written { buf.WriteString(" ") } written = true err = Writef(ctx, dialect, buf, args, params, w.FrameSpec, w.FrameValues) if err != nil { return fmt.Errorf("Window FRAME: %w", err) } } buf.WriteString(")") return nil } // PartitionBy returns a new WindowDefinition with the PARTITION BY clause. func (w WindowDefinition) PartitionBy(fields ...Field) WindowDefinition { w.PartitionByFields = fields return w } // OrderBy returns a new WindowDefinition with the ORDER BY clause. func (w WindowDefinition) OrderBy(fields ...Field) WindowDefinition { w.OrderByFields = fields return w } // Frame returns a new WindowDefinition with the frame specification set. func (w WindowDefinition) Frame(frameSpec string, frameValues ...any) WindowDefinition { w.FrameSpec = frameSpec w.FrameValues = frameValues return w } // IsWindow implements the Window interface. func (w WindowDefinition) IsWindow() {} // NamedWindows represents a slice of NamedWindows. type NamedWindows []NamedWindow // WriteSQL imeplements the SQLWriter interface. func (ws NamedWindows) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { var err error for i, window := range ws { if i > 0 { buf.WriteString(", ") } buf.WriteString(window.Name + " AS ") err = window.Definition.WriteSQL(ctx, dialect, buf, args, params) if err != nil { return fmt.Errorf("window #%d: %w", i+1, err) } } return nil } // CountOver represents the COUNT() OVER () window function. func CountOver(field Field, window Window) Expression { if window == nil { return Expr("COUNT({}) OVER ()", field) } return Expr("COUNT({}) OVER {}", field, window) } // CountStarOver represents the COUNT(*) OVER () window function. func CountStarOver(window Window) Expression { if window == nil { return Expr("COUNT(*) OVER ()") } return Expr("COUNT(*) OVER {}", window) } // SumOver represents the SUM() OVER () window function. func SumOver(num Number, window Window) Expression { if window == nil { return Expr("SUM({}) OVER ()", num) } return Expr("SUM({}) OVER {}", num, window) } // AvgOver represents the AVG() OVER () window function. func AvgOver(num Number, window Window) Expression { if window == nil { return Expr("AVG({}) OVER ()", num) } return Expr("AVG({}) OVER {}", num, window) } // MinOver represents the MIN() OVER () window function. func MinOver(field Field, window Window) Expression { if window == nil { return Expr("MIN({}) OVER ()", field) } return Expr("MIN({}) OVER {}", field, window) } // MaxOver represents the MAX() OVER () window function. func MaxOver(field Field, window Window) Expression { if window == nil { return Expr("MAX({}) OVER ()", field) } return Expr("MAX({}) OVER {}", field, window) } // RowNumberOver represents the ROW_NUMBER() OVER () window function. func RowNumberOver(window Window) Expression { if window == nil { return Expr("ROW_NUMBER() OVER ()") } return Expr("ROW_NUMBER() OVER {}", window) } // RankOver represents the RANK() OVER () window function. func RankOver(window Window) Expression { if window == nil { return Expr("RANK() OVER ()") } return Expr("RANK() OVER {}", window) } // DenseRankOver represents the DENSE_RANK() OVER () window function. func DenseRankOver(window Window) Expression { if window == nil { return Expr("DENSE_RANK() OVER ()") } return Expr("DENSE_RANK() OVER {}", window) } // CumeDistOver represents the CUME_DIST() OVER () window function. func CumeDistOver(window Window) Expression { if window == nil { return Expr("CUME_DIST() OVER ()") } return Expr("CUME_DIST() OVER {}", window) } // FirstValueOver represents the FIRST_VALUE() OVER () window function. func FirstValueOver(field Field, window Window) Expression { if window == nil { return Expr("FIRST_VALUE({}) OVER ()", field) } return Expr("FIRST_VALUE({}) OVER {}", field, window) } // LastValueOver represents the LAST_VALUE() OVER () window // function. func LastValueOver(field Field, window Window) Expression { if window == nil { return Expr("LAST_VALUE({}) OVER ()", field) } return Expr("LAST_VALUE({}) OVER {}", field, window) } ================================================ FILE: window_test.go ================================================ package sq import "testing" func TestWindow(t *testing.T) { t.Run("basic", func(t *testing.T) { t.Parallel() f1, f2, f3 := Expr("f1"), Expr("f2"), Expr("f3") TestTable{ item: PartitionBy(f1).OrderBy(f2, f3).Frame("RANGE UNBOUNDED PRECEDING"), wantQuery: "(PARTITION BY f1 ORDER BY f2, f3 RANGE UNBOUNDED PRECEDING)", }.assert(t) TestTable{ item: OrderBy(f1).PartitionBy(f2, f3).Frame("ROWS {} PRECEDING", 5), wantQuery: "(PARTITION BY f2, f3 ORDER BY f1 ROWS ? PRECEDING)", wantArgs: []any{5}, }.assert(t) }) errTests := []TestTable{{ description: "PartitionBy err", item: PartitionBy(FaultySQL{}), }, { description: "OrderBy err", item: OrderBy(FaultySQL{}), }, { description: "Frame err", item: OrderBy(Expr("f")).Frame("ROWS {} PRECEDING", FaultySQL{}), }, { description: "NamedWindows err", item: NamedWindows{{ Name: "w", Definition: OrderBy(Expr("f")).Frame("ROWS {} PRECEDING", FaultySQL{}), }}, }} for _, tt := range errTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assertErr(t, ErrFaultySQL) }) } funcTests := []TestTable{{ description: "CountOver", item: CountOver(Expr("f1"), WindowDefinition{}), wantQuery: "COUNT(f1) OVER ()", }, { description: "CountOver nil", item: CountOver(Expr("f1"), nil), wantQuery: "COUNT(f1) OVER ()", }, { description: "CountStarOver", item: CountStarOver(WindowDefinition{}), wantQuery: "COUNT(*) OVER ()", }, { description: "SumOver", item: SumOver(Expr("f1"), PartitionBy(Expr("f2"))), wantQuery: "SUM(f1) OVER (PARTITION BY f2)", }, { description: "AvgOver", item: AvgOver(Expr("f1"), PartitionBy(Expr("f2"))), wantQuery: "AVG(f1) OVER (PARTITION BY f2)", }, { description: "MinOver", item: MinOver(Expr("f1"), PartitionBy(Expr("f2"))), wantQuery: "MIN(f1) OVER (PARTITION BY f2)", }, { description: "MaxOver", item: MaxOver(Expr("f1"), PartitionBy(Expr("f2"))), wantQuery: "MAX(f1) OVER (PARTITION BY f2)", }, { description: "RowNumberOver", item: RowNumberOver(PartitionBy(Expr("f1"))), wantQuery: "ROW_NUMBER() OVER (PARTITION BY f1)", }, { description: "RankOver", item: RankOver(PartitionBy(Expr("f1"))), wantQuery: "RANK() OVER (PARTITION BY f1)", }, { description: "DenseRankOver", item: DenseRankOver(PartitionBy(Expr("f1"))), wantQuery: "DENSE_RANK() OVER (PARTITION BY f1)", }, { description: "CumeDistOver", item: CumeDistOver(PartitionBy(Expr("f1"))), wantQuery: "CUME_DIST() OVER (PARTITION BY f1)", }, { description: "FirstValueOver", item: FirstValueOver(Expr("f1"), PartitionBy(Expr("f2"))), wantQuery: "FIRST_VALUE(f1) OVER (PARTITION BY f2)", }, { description: "LastValueOver", item: LastValueOver(Expr("f1"), PartitionBy(Expr("f2"))), wantQuery: "LAST_VALUE(f1) OVER (PARTITION BY f2)", }, { description: "NamedWindow", item: CountStarOver(NamedWindow{Name: "w"}), wantQuery: "COUNT(*) OVER w", }, func() TestTable { var tt TestTable tt.description = "BaseWindow" w := NamedWindow{Name: "w", Definition: PartitionBy(Expr("f1"))} tt.item = CountStarOver(BaseWindow(w).Frame("ROWS UNBOUNDED PRECEDING")) tt.wantQuery = "COUNT(*) OVER (w ROWS UNBOUNDED PRECEDING)" return tt }(), func() TestTable { var tt TestTable tt.description = "NamedWindows" w1 := NamedWindow{Name: "w1", Definition: PartitionBy(Expr("f1"))} w2 := NamedWindow{Name: "w2", Definition: OrderBy(Expr("f2"))} w3 := NamedWindow{Name: "w3", Definition: OrderBy(Expr("f3")).Frame("ROWS UNBOUNDED PRECEDING")} tt.item = NamedWindows{w1, w2, w3} tt.wantQuery = "w1 AS (PARTITION BY f1)" + ", w2 AS (ORDER BY f2)" + ", w3 AS (ORDER BY f3 ROWS UNBOUNDED PRECEDING)" return tt }()} for _, tt := range funcTests { tt := tt t.Run(tt.description, func(t *testing.T) { t.Parallel() tt.assert(t) }) } }