Repository: drone/sqlgen Branch: master Commit: 39ee509bb576 Files: 30 Total size: 58.2 KB Directory structure: gitextract_ds8o5qgz/ ├── .gitignore ├── LICENSE ├── README.md ├── bench/ │ ├── type.go │ ├── type_sql.go │ └── type_test.go ├── demo/ │ ├── hook.go │ ├── hook_sql.go │ ├── issue.go │ ├── issue_sql.go │ ├── user.go │ └── user_sql.go ├── fmt.go ├── gen.go ├── gen_funcs.go ├── gen_schema.go ├── parse/ │ ├── const.go │ ├── node.go │ ├── parse.go │ ├── tag.go │ └── tag_test.go ├── schema/ │ ├── base.go │ ├── dialect.go │ ├── dialect_mysql.go │ ├── dialect_postgres.go │ ├── dialect_sqlite.go │ ├── helper.go │ └── schema.go ├── tmpl.go └── util.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ sqlgen *.sqlite *.txt _docs ================================================ FILE: LICENSE ================================================ Copyright (c) 2015, drone.io All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ **sqlgen** generates SQL statements and database helper functions from your Go structs. It can be used in place of a simple ORM or hand-written SQL. See the [demo](https://github.com/drone/sqlgen/tree/master/demo) directory for examples. ### Install Install or upgrade with this command: ``` go get -u github.com/drone/sqlgen ``` ### Usage ``` Usage of sqlgen: -type string type to generate; required -file string input file name; required -o string output file name -pkg string output package name -db string sql dialect; sqlite, postgres, mysql -schema generate sql schema and queries; default true -funcs generate sql helper functions; default true ``` ### Tutorial First, let's start with a simple `User` struct in `user.go`: ```Go type User struct { ID int64 Login string Email string } ``` We can run the following command: ``` sqlgen -file user.go -type User -pkg demo ``` The tool outputs the following generated code: ```Go func ScanUser(row *sql.Row) (*User, error) { var v0 int64 var v1 string var v2 string err := row.Scan( &v0, &v1, &v2, ) if err != nil { return nil, err } v := &User{} v.ID = v0 v.Login = v1 v.Email = v2 return v, nil } const CreateUserStmt = ` CREATE TABLE IF NOT EXISTS users ( user_id INTEGER ,user_login TEXT ,user_email TEXT ); ` const SelectUserStmt = ` SELECT user_id ,user_login ,user_email FROM users ` const SelectUserRangeStmt = ` SELECT user_id ,user_login ,user_email FROM users LIMIT ? OFFSET ? ` // more functions and sql statements not displayed ``` This is a great start, but what if we want to specify primary keys, column sizes and more? This may be acheived by annotating your code using Go tags. For example, we can tag the `ID` field to indicate it is a primary key and will auto increment: ```diff type User struct { - ID int64 + ID int64 `sql:"pk: true, auto: true"` Login string Email string } ``` This information allows the tool to generate smarter SQL statements: ```diff CREATE TABLE IF NOT EXISTS users ( -user_id INTEGER +user_id INTEGER PRIMARY KEY AUTOINCREMENT ,user_login TEXT ,user_email TEXT ); ``` Including SQL statements to select, insert, update and delete data using the primary key: ```Go const SelectUserPkeyStmt = ` SELECT user_id ,user_login ,user_email WHERE user_id=? ` const UpdateUserPkeyStmt = ` UPDATE users SET user_id=? ,user_login=? ,user_email=? WHERE user_id=? ` const DeleteUserPkeyStmt = ` DELETE FROM users WHERE user_id=? ` ``` We can take this one step further and annotate indexes. In our example, we probably want to make sure the `user_login` field has a unique index: ```diff type User struct { ID int64 `sql:"pk: true, auto: true"` - Login string + Login string `sql:"unique: user_login"` Email string } ``` This information instructs the tool to generate the following: ```Go const CreateUserLogin = ` CREATE UNIQUE INDEX IF NOT EXISTS user_login ON users (user_login) ``` The tool also assumes that we probably intend to fetch data from the database using this index. The tool will therefore automatically generate the following queries: ```Go const SelectUserLoginStmt = ` SELECT user_id ,user_login ,user_email WHERE user_login=? ` const UpdateUserLoginStmt = ` UPDATE users SET user_id=? ,user_login=? ,user_email=? WHERE user_login=? ` const DeleteUserLoginStmt = ` DELETE FROM users WHERE user_login=? ` ``` ### Nesting Nested Go structures can be flattened into a single database table. As an example, we have a `User` and `Address` with a one-to-one relationship. In some cases, we may prefer to de-normalize our data and store in a single table, avoiding un-necessary joins. ```diff type User struct { ID int64 `sql:"pk: true"` Login string Email string + Addr *Address } type Address struct { City string State string Zip string `sql:"index: user_zip"` } ``` The above relationship is flattened into a single table (see below). When the data is retrieved from the database the nested structure is restored. ```sql CREATE TALBE IF NOT EXISTS users ( user_id INTEGER PRIMARY KEY AUTO_INCREMENT ,user_login TEXT ,user_email TEXT ,user_addr_city TEXT ,user_addr_state TEXT ,user_addr_zip TEXT ); ``` ### JSON Encoding Some types in your struct may not have native equivalents in your database such as `[]string`. These values can be marshaled and stored as JSON in the database. ```diff type User struct { ID int64 `sql:"pk: true"` Login string Email string + Label []string `sql:"encode: json" } ``` ### Dialects You may specify one of the following SQL dialects when generating your code: `postgres`, `mysql` and `sqlite`. The default value is `sqlite`. ``` sqlgen -file user.go -type User -pkg demo -db postgres ``` ### Go Generate Example use with `go:generate`: ```Go package demo //go:generate sqlgen -file user.go -type User -pkg demo -o user_sql.go type User struct { ID int64 `sql:"pk: true, auto: true"` Login string `sql:"unique: user_login"` Email string `sql:"size: 1024"` Avatar string } ``` ### Benchmarks This tool demonstrates performance gains, albeit small, over light-weight ORM packages such as `sqlx` and `meddler`. Over time I plan to expand the benchmarks to include additional ORM packages. To run the project benchmarks: ``` go get ./... go generate ./... go build cd bench go test -bench=Bench ``` Example selecing a single row: ``` BenchmarkMeddlerRow-4 30000 42773 ns/op BenchmarkSqlxRow-4 30000 41554 ns/op BenchmarkSqlgenRow-4 50000 39664 ns/op ``` Selecting multiple rows: ``` BenchmarkMeddlerRows-4 2000 1025218 ns/op BenchmarkSqlxRows-4 2000 807213 ns/op BenchmarkSqlgenRows-4 2000 700673 ns/op ``` #### Credits This tool was inspired by [scaneo](https://github.com/variadico/scaneo). ================================================ FILE: bench/type.go ================================================ package bench //go:generate ../sqlgen -file type.go -type User -pkg bench -o type_sql.go type User struct { ID int64 `sql:"pk: true, auto: true" meddler:"user_id,pk" db:"user_id"` Name string `sql:"unique: user_name" meddler:"user_name" db:"user_name"` Pass string `sql:"" meddler:"user_pass" db:"user_pass"` Email string `sql:"unique: user_email" meddler:"user_email" db:"user_email"` Active bool `sql:"index: user_active" meddler:"user_active" db:"user_active"` Created int64 `sql:"" meddler:"user_created" db:"user_created"` Updated int64 `sql:"" meddler:"user_updated" db:"user_updated"` } ================================================ FILE: bench/type_sql.go ================================================ package bench // THIS FILE WAS AUTO-GENERATED. DO NOT MODIFY. import ( "database/sql" ) func ScanUser(row *sql.Row) (*User, error) { var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 int64 var v6 int64 err := row.Scan( &v0, &v1, &v2, &v3, &v4, &v5, &v6, ) if err != nil { return nil, err } v := &User{} v.ID = v0 v.Name = v1 v.Pass = v2 v.Email = v3 v.Active = v4 v.Created = v5 v.Updated = v6 return v, nil } func ScanUsers(rows *sql.Rows) ([]*User, error) { var err error var vv []*User var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 int64 var v6 int64 for rows.Next() { err = rows.Scan( &v0, &v1, &v2, &v3, &v4, &v5, &v6, ) if err != nil { return vv, err } v := &User{} v.ID = v0 v.Name = v1 v.Pass = v2 v.Email = v3 v.Active = v4 v.Created = v5 v.Updated = v6 vv = append(vv, v) } return vv, rows.Err() } func SliceUser(v *User) []interface{} { var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 int64 var v6 int64 v0 = v.ID v1 = v.Name v2 = v.Pass v3 = v.Email v4 = v.Active v5 = v.Created v6 = v.Updated return []interface{}{ v0, v1, v2, v3, v4, v5, v6, } } func SelectUser(db *sql.DB, query string, args ...interface{}) (*User, error) { row := db.QueryRow(query, args...) return ScanUser(row) } func SelectUsers(db *sql.DB, query string, args ...interface{}) ([]*User, error) { rows, err := db.Query(query, args...) if err != nil { return nil, err } defer rows.Close() return ScanUsers(rows) } func InsertUser(db *sql.DB, query string, v *User) error { res, err := db.Exec(query, SliceUser(v)[1:]...) if err != nil { return err } v.ID, err = res.LastInsertId() return err } func UpdateUser(db *sql.DB, query string, v *User) error { args := SliceUser(v)[1:] args = append(args, v.ID) _, err := db.Exec(query, args...) return err } const CreateUserStmt = ` CREATE TABLE IF NOT EXISTS users ( user_id INTEGER PRIMARY KEY AUTOINCREMENT ,user_name TEXT ,user_pass TEXT ,user_email TEXT ,user_active BOOLEAN ,user_created INTEGER ,user_updated INTEGER ); ` const InsertUserStmt = ` INSERT INTO users ( user_name ,user_pass ,user_email ,user_active ,user_created ,user_updated ) VALUES (?,?,?,?,?,?) ` const SelectUserStmt = ` SELECT user_id ,user_name ,user_pass ,user_email ,user_active ,user_created ,user_updated FROM users ` const SelectUserRangeStmt = ` SELECT user_id ,user_name ,user_pass ,user_email ,user_active ,user_created ,user_updated FROM users LIMIT ? OFFSET ? ` const SelectUserCountStmt = ` SELECT count(1) FROM users ` const SelectUserPkeyStmt = ` SELECT user_id ,user_name ,user_pass ,user_email ,user_active ,user_created ,user_updated FROM users WHERE user_id=? ` const UpdateUserPkeyStmt = ` UPDATE users SET user_id=? ,user_name=? ,user_pass=? ,user_email=? ,user_active=? ,user_created=? ,user_updated=? WHERE user_id=? ` const DeleteUserPkeyStmt = ` DELETE FROM users WHERE user_id=? ` const CreateUserNameStmt = ` CREATE UNIQUE INDEX IF NOT EXISTS user_name ON users (user_name) ` const SelectUserNameStmt = ` SELECT user_id ,user_name ,user_pass ,user_email ,user_active ,user_created ,user_updated FROM users WHERE user_name=? ` const UpdateUserNameStmt = ` UPDATE users SET user_id=? ,user_name=? ,user_pass=? ,user_email=? ,user_active=? ,user_created=? ,user_updated=? WHERE user_name=? ` const DeleteUserNameStmt = ` DELETE FROM users WHERE user_name=? ` const CreateUserEmailStmt = ` CREATE UNIQUE INDEX IF NOT EXISTS user_email ON users (user_email) ` const SelectUserEmailStmt = ` SELECT user_id ,user_name ,user_pass ,user_email ,user_active ,user_created ,user_updated FROM users WHERE user_email=? ` const UpdateUserEmailStmt = ` UPDATE users SET user_id=? ,user_name=? ,user_pass=? ,user_email=? ,user_active=? ,user_created=? ,user_updated=? WHERE user_email=? ` const DeleteUserEmailStmt = ` DELETE FROM users WHERE user_email=? ` ================================================ FILE: bench/type_test.go ================================================ package bench import ( "database/sql" "testing" "time" "github.com/Pallinder/go-randomdata" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" "github.com/russross/meddler" ) var db *sql.DB var dbx *sqlx.DB func init() { var err error db, err = sql.Open("sqlite3", ":memory:") if err != nil { panic(err) } db.Exec("DROP TABLE users;") dbx = sqlx.NewDb(db, "sqlite3") ddl := []string{CreateUserStmt} for _, stmt := range ddl { _, err = db.Exec(stmt) if err != nil { panic(err) } } for i := 0; i < 100; i++ { user := &User{} user.Name = randomdata.FullName(randomdata.RandomGender) user.Email = randomdata.Email() user.Pass = "pa55word" user.Created = time.Now().Unix() user.Updated = time.Now().Unix() err := InsertUser(db, InsertUserStmt, user) if err != nil { panic(err) } } } var result *User var results []*User func BenchmarkMeddlerRow(b *testing.B) { var user *User var err error for n := 0; n < b.N; n++ { user = &User{} err = meddler.QueryRow(db, user, SelectUserPkeyStmt, 1) if err != nil { panic(err) } } result = user } func BenchmarkMeddlerRows(b *testing.B) { var users []*User var err error for n := 0; n < b.N; n++ { err = meddler.QueryAll(db, &users, SelectUserStmt) if err != nil { panic(err) } } results = users } func BenchmarkSqlxRow(b *testing.B) { var user *User var err error for n := 0; n < b.N; n++ { user = &User{} err = dbx.Get(user, SelectUserPkeyStmt, 1) if err != nil { panic(err) } } result = user } func BenchmarkSqlxRows(b *testing.B) { var users []*User var err error for n := 0; n < b.N; n++ { err = dbx.Select(&users, SelectUserStmt) if err != nil { panic(err) } } results = users } func BenchmarkSqlgenRow(b *testing.B) { var user *User var err error for n := 0; n < b.N; n++ { user, err = SelectUser(db, SelectUserPkeyStmt, 1) if err != nil { panic(err) } } result = user } func BenchmarkSqlgenRows(b *testing.B) { var users []*User var err error for n := 0; n < b.N; n++ { users, err = SelectUsers(db, SelectUserStmt) if err != nil { panic(err) } } results = users } ================================================ FILE: demo/hook.go ================================================ package demo //go:generate ../sqlgen -file hook.go -type Hook -pkg demo -o hook_sql.go -db mysql type Hook struct { ID int64 `sql:"pk: true, auto: true"` Sha string After string Before string Created bool Deleted bool Forced bool HeadCommit *Commit `sql:"name: head"` } type Commit struct { ID string Message string Timestamp string Author *Author Committer *Author } type Author struct { Name string Email string Username string } ================================================ FILE: demo/hook_sql.go ================================================ package demo // THIS FILE WAS AUTO-GENERATED. DO NOT MODIFY. import ( "database/sql" ) func ScanHook(row *sql.Row) (*Hook, error) { var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 bool var v6 bool var v7 string var v8 string var v9 string var v10 string var v11 string var v12 string var v13 string var v14 string var v15 string err := row.Scan( &v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8, &v9, &v10, &v11, &v12, &v13, &v14, &v15, ) if err != nil { return nil, err } v := &Hook{} v.ID = v0 v.Sha = v1 v.After = v2 v.Before = v3 v.Created = v4 v.Deleted = v5 v.Forced = v6 v.HeadCommit = &Commit{} v.HeadCommit.ID = v7 v.HeadCommit.Message = v8 v.HeadCommit.Timestamp = v9 v.HeadCommit.Author = &Author{} v.HeadCommit.Author.Name = v10 v.HeadCommit.Author.Email = v11 v.HeadCommit.Author.Username = v12 v.HeadCommit.Committer = &Author{} v.HeadCommit.Committer.Name = v13 v.HeadCommit.Committer.Email = v14 v.HeadCommit.Committer.Username = v15 return v, nil } func ScanHooks(rows *sql.Rows) ([]*Hook, error) { var err error var vv []*Hook var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 bool var v6 bool var v7 string var v8 string var v9 string var v10 string var v11 string var v12 string var v13 string var v14 string var v15 string for rows.Next() { err = rows.Scan( &v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8, &v9, &v10, &v11, &v12, &v13, &v14, &v15, ) if err != nil { return vv, err } v := &Hook{} v.ID = v0 v.Sha = v1 v.After = v2 v.Before = v3 v.Created = v4 v.Deleted = v5 v.Forced = v6 v.HeadCommit = &Commit{} v.HeadCommit.ID = v7 v.HeadCommit.Message = v8 v.HeadCommit.Timestamp = v9 v.HeadCommit.Author = &Author{} v.HeadCommit.Author.Name = v10 v.HeadCommit.Author.Email = v11 v.HeadCommit.Author.Username = v12 v.HeadCommit.Committer = &Author{} v.HeadCommit.Committer.Name = v13 v.HeadCommit.Committer.Email = v14 v.HeadCommit.Committer.Username = v15 vv = append(vv, v) } return vv, rows.Err() } func SliceHook(v *Hook) []interface{} { var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 bool var v6 bool var v7 string var v8 string var v9 string var v10 string var v11 string var v12 string var v13 string var v14 string var v15 string v0 = v.ID v1 = v.Sha v2 = v.After v3 = v.Before v4 = v.Created v5 = v.Deleted v6 = v.Forced if v.HeadCommit != nil { v7 = v.HeadCommit.ID v8 = v.HeadCommit.Message v9 = v.HeadCommit.Timestamp if v.HeadCommit.Author != nil { v10 = v.HeadCommit.Author.Name v11 = v.HeadCommit.Author.Email v12 = v.HeadCommit.Author.Username } } if v.HeadCommit.Committer != nil { v13 = v.HeadCommit.Committer.Name v14 = v.HeadCommit.Committer.Email v15 = v.HeadCommit.Committer.Username } return []interface{}{ v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, } } func SelectHook(db *sql.DB, query string, args ...interface{}) (*Hook, error) { row := db.QueryRow(query, args...) return ScanHook(row) } func SelectHooks(db *sql.DB, query string, args ...interface{}) ([]*Hook, error) { rows, err := db.Query(query, args...) if err != nil { return nil, err } defer rows.Close() return ScanHooks(rows) } func InsertHook(db *sql.DB, query string, v *Hook) error { res, err := db.Exec(query, SliceHook(v)[1:]...) if err != nil { return err } v.ID, err = res.LastInsertId() return err } func UpdateHook(db *sql.DB, query string, v *Hook) error { args := SliceHook(v)[1:] args = append(args, v.ID) _, err := db.Exec(query, args...) return err } const CreateHookStmt = ` CREATE TABLE IF NOT EXISTS hooks ( hook_id INTEGER PRIMARY KEY AUTO_INCREMENT ,hook_sha VARCHAR(512) ,hook_after VARCHAR(512) ,hook_before VARCHAR(512) ,hook_created BOOLEAN ,hook_deleted BOOLEAN ,hook_forced BOOLEAN ,hook_head_id VARCHAR(512) ,hook_head_message VARCHAR(512) ,hook_head_timestamp VARCHAR(512) ,hook_head_author_name VARCHAR(512) ,hook_head_author_email VARCHAR(512) ,hook_head_author_username VARCHAR(512) ,hook_head_committer_name VARCHAR(512) ,hook_head_committer_email VARCHAR(512) ,hook_head_committer_username VARCHAR(512) ); ` const InsertHookStmt = ` INSERT INTO hooks ( hook_sha ,hook_after ,hook_before ,hook_created ,hook_deleted ,hook_forced ,hook_head_id ,hook_head_message ,hook_head_timestamp ,hook_head_author_name ,hook_head_author_email ,hook_head_author_username ,hook_head_committer_name ,hook_head_committer_email ,hook_head_committer_username ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ` const SelectHookStmt = ` SELECT hook_id ,hook_sha ,hook_after ,hook_before ,hook_created ,hook_deleted ,hook_forced ,hook_head_id ,hook_head_message ,hook_head_timestamp ,hook_head_author_name ,hook_head_author_email ,hook_head_author_username ,hook_head_committer_name ,hook_head_committer_email ,hook_head_committer_username FROM hooks ` const SelectHookRangeStmt = ` SELECT hook_id ,hook_sha ,hook_after ,hook_before ,hook_created ,hook_deleted ,hook_forced ,hook_head_id ,hook_head_message ,hook_head_timestamp ,hook_head_author_name ,hook_head_author_email ,hook_head_author_username ,hook_head_committer_name ,hook_head_committer_email ,hook_head_committer_username FROM hooks LIMIT ? OFFSET ? ` const SelectHookCountStmt = ` SELECT count(1) FROM hooks ` const SelectHookPkeyStmt = ` SELECT hook_id ,hook_sha ,hook_after ,hook_before ,hook_created ,hook_deleted ,hook_forced ,hook_head_id ,hook_head_message ,hook_head_timestamp ,hook_head_author_name ,hook_head_author_email ,hook_head_author_username ,hook_head_committer_name ,hook_head_committer_email ,hook_head_committer_username FROM hooks WHERE hook_id=? ` const UpdateHookPkeyStmt = ` UPDATE hooks SET hook_id=? ,hook_sha=? ,hook_after=? ,hook_before=? ,hook_created=? ,hook_deleted=? ,hook_forced=? ,hook_head_id=? ,hook_head_message=? ,hook_head_timestamp=? ,hook_head_author_name=? ,hook_head_author_email=? ,hook_head_author_username=? ,hook_head_committer_name=? ,hook_head_committer_email=? ,hook_head_committer_username=? WHERE hook_id=? ` const DeleteHookPkeyStmt = ` DELETE FROM hooks WHERE hook_id=? ` ================================================ FILE: demo/issue.go ================================================ package demo //go:generate ../sqlgen -file issue.go -type Issue -pkg demo -o issue_sql.go -db postgres type Issue struct { ID int64 `sql:"pk: true, auto: true"` Number int Title string `sql:"size: 512"` Body string `sql:"size: 2048"` Assignee string `sql:"index: issue_assignee"` State string `sql:"size: 50"` Labels []string `sql:"encode: json"` locked bool `sql:"-"` } ================================================ FILE: demo/issue_sql.go ================================================ package demo // THIS FILE WAS AUTO-GENERATED. DO NOT MODIFY. import ( "database/sql" "encoding/json" ) func ScanIssue(row *sql.Row) (*Issue, error) { var v0 int64 var v1 int var v2 string var v3 string var v4 string var v5 string var v6 []byte err := row.Scan( &v0, &v1, &v2, &v3, &v4, &v5, &v6, ) if err != nil { return nil, err } v := &Issue{} v.ID = v0 v.Number = v1 v.Title = v2 v.Body = v3 v.Assignee = v4 v.State = v5 json.Unmarshal(v6, &v.Labels) return v, nil } func ScanIssues(rows *sql.Rows) ([]*Issue, error) { var err error var vv []*Issue var v0 int64 var v1 int var v2 string var v3 string var v4 string var v5 string var v6 []byte for rows.Next() { err = rows.Scan( &v0, &v1, &v2, &v3, &v4, &v5, &v6, ) if err != nil { return vv, err } v := &Issue{} v.ID = v0 v.Number = v1 v.Title = v2 v.Body = v3 v.Assignee = v4 v.State = v5 json.Unmarshal(v6, &v.Labels) vv = append(vv, v) } return vv, rows.Err() } func SliceIssue(v *Issue) []interface{} { var v0 int64 var v1 int var v2 string var v3 string var v4 string var v5 string var v6 []byte v0 = v.ID v1 = v.Number v2 = v.Title v3 = v.Body v4 = v.Assignee v5 = v.State v6, _ = json.Marshal(&v.Labels) return []interface{}{ v0, v1, v2, v3, v4, v5, v6, } } func SelectIssue(db *sql.DB, query string, args ...interface{}) (*Issue, error) { row := db.QueryRow(query, args...) return ScanIssue(row) } func SelectIssues(db *sql.DB, query string, args ...interface{}) ([]*Issue, error) { rows, err := db.Query(query, args...) if err != nil { return nil, err } defer rows.Close() return ScanIssues(rows) } func InsertIssue(db *sql.DB, query string, v *Issue) error { res, err := db.Exec(query, SliceIssue(v)[1:]...) if err != nil { return err } v.ID, err = res.LastInsertId() return err } func UpdateIssue(db *sql.DB, query string, v *Issue) error { args := SliceIssue(v)[1:] args = append(args, v.ID) _, err := db.Exec(query, args...) return err } const CreateIssueStmt = ` CREATE TABLE IF NOT EXISTS issues ( issue_id SERIAL PRIMARY KEY ,issue_number INTEGER ,issue_title VARCHAR(512) ,issue_body VARCHAR(2048) ,issue_assignee VARCHAR(512) ,issue_state VARCHAR(50) ,issue_labels BYTEA ); ` const InsertIssueStmt = ` INSERT INTO issues ( issue_number ,issue_title ,issue_body ,issue_assignee ,issue_state ,issue_labels ) VALUES ($1,$2,$3,$4,$5,$6) ` const SelectIssueStmt = ` SELECT issue_id ,issue_number ,issue_title ,issue_body ,issue_assignee ,issue_state ,issue_labels FROM issues ` const SelectIssueRangeStmt = ` SELECT issue_id ,issue_number ,issue_title ,issue_body ,issue_assignee ,issue_state ,issue_labels FROM issues LIMIT $1 OFFSET $2 ` const SelectIssueCountStmt = ` SELECT count(1) FROM issues ` const SelectIssuePkeyStmt = ` SELECT issue_id ,issue_number ,issue_title ,issue_body ,issue_assignee ,issue_state ,issue_labels FROM issues WHERE issue_id=$1 ` const UpdateIssuePkeyStmt = ` UPDATE issues SET issue_id=$1 ,issue_number=$2 ,issue_title=$3 ,issue_body=$4 ,issue_assignee=$5 ,issue_state=$6 ,issue_labels=$7 WHERE issue_id=$8 ` const DeleteIssuePkeyStmt = ` DELETE FROM issues WHERE issue_id=$1 ` ================================================ FILE: demo/user.go ================================================ package demo //go:generate ../sqlgen -file user.go -type User -pkg demo -o user_sql.go type User struct { ID int64 `sql:"pk: true, auto: true"` Login string `sql:"unique: user_login"` Email string `sql:"unique: user_email"` Avatar string Active bool Admin bool // oauth token and secret token string secret string // randomly generated hash used to sign user // session and application tokens. hash string } ================================================ FILE: demo/user_sql.go ================================================ package demo // THIS FILE WAS AUTO-GENERATED. DO NOT MODIFY. import ( "database/sql" ) func ScanUser(row *sql.Row) (*User, error) { var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 bool var v6 string var v7 string var v8 string err := row.Scan( &v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8, ) if err != nil { return nil, err } v := &User{} v.ID = v0 v.Login = v1 v.Email = v2 v.Avatar = v3 v.Active = v4 v.Admin = v5 v.token = v6 v.secret = v7 v.hash = v8 return v, nil } func ScanUsers(rows *sql.Rows) ([]*User, error) { var err error var vv []*User var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 bool var v6 string var v7 string var v8 string for rows.Next() { err = rows.Scan( &v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8, ) if err != nil { return vv, err } v := &User{} v.ID = v0 v.Login = v1 v.Email = v2 v.Avatar = v3 v.Active = v4 v.Admin = v5 v.token = v6 v.secret = v7 v.hash = v8 vv = append(vv, v) } return vv, rows.Err() } func SliceUser(v *User) []interface{} { var v0 int64 var v1 string var v2 string var v3 string var v4 bool var v5 bool var v6 string var v7 string var v8 string v0 = v.ID v1 = v.Login v2 = v.Email v3 = v.Avatar v4 = v.Active v5 = v.Admin v6 = v.token v7 = v.secret v8 = v.hash return []interface{}{ v0, v1, v2, v3, v4, v5, v6, v7, v8, } } func SelectUser(db *sql.DB, query string, args ...interface{}) (*User, error) { row := db.QueryRow(query, args...) return ScanUser(row) } func SelectUsers(db *sql.DB, query string, args ...interface{}) ([]*User, error) { rows, err := db.Query(query, args...) if err != nil { return nil, err } defer rows.Close() return ScanUsers(rows) } func InsertUser(db *sql.DB, query string, v *User) error { res, err := db.Exec(query, SliceUser(v)[1:]...) if err != nil { return err } v.ID, err = res.LastInsertId() return err } func UpdateUser(db *sql.DB, query string, v *User) error { args := SliceUser(v)[1:] args = append(args, v.ID) _, err := db.Exec(query, args...) return err } const CreateUserStmt = ` CREATE TABLE IF NOT EXISTS users ( user_id INTEGER PRIMARY KEY AUTOINCREMENT ,user_login TEXT ,user_email TEXT ,user_avatar TEXT ,user_active BOOLEAN ,user_admin BOOLEAN ,user_token TEXT ,user_secret TEXT ,user_hash TEXT ); ` const InsertUserStmt = ` INSERT INTO users ( user_login ,user_email ,user_avatar ,user_active ,user_admin ,user_token ,user_secret ,user_hash ) VALUES (?,?,?,?,?,?,?,?) ` const SelectUserStmt = ` SELECT user_id ,user_login ,user_email ,user_avatar ,user_active ,user_admin ,user_token ,user_secret ,user_hash FROM users ` const SelectUserRangeStmt = ` SELECT user_id ,user_login ,user_email ,user_avatar ,user_active ,user_admin ,user_token ,user_secret ,user_hash FROM users LIMIT ? OFFSET ? ` const SelectUserCountStmt = ` SELECT count(1) FROM users ` const SelectUserPkeyStmt = ` SELECT user_id ,user_login ,user_email ,user_avatar ,user_active ,user_admin ,user_token ,user_secret ,user_hash FROM users WHERE user_id=? ` const UpdateUserPkeyStmt = ` UPDATE users SET user_id=? ,user_login=? ,user_email=? ,user_avatar=? ,user_active=? ,user_admin=? ,user_token=? ,user_secret=? ,user_hash=? WHERE user_id=? ` const DeleteUserPkeyStmt = ` DELETE FROM users WHERE user_id=? ` const CreateUserLoginStmt = ` CREATE UNIQUE INDEX IF NOT EXISTS user_login ON users (user_login) ` const SelectUserLoginStmt = ` SELECT user_id ,user_login ,user_email ,user_avatar ,user_active ,user_admin ,user_token ,user_secret ,user_hash FROM users WHERE user_login=? ` const UpdateUserLoginStmt = ` UPDATE users SET user_id=? ,user_login=? ,user_email=? ,user_avatar=? ,user_active=? ,user_admin=? ,user_token=? ,user_secret=? ,user_hash=? WHERE user_login=? ` const DeleteUserLoginStmt = ` DELETE FROM users WHERE user_login=? ` const CreateUserEmailStmt = ` CREATE UNIQUE INDEX IF NOT EXISTS user_email ON users (user_email) ` const SelectUserEmailStmt = ` SELECT user_id ,user_login ,user_email ,user_avatar ,user_active ,user_admin ,user_token ,user_secret ,user_hash FROM users WHERE user_email=? ` const UpdateUserEmailStmt = ` UPDATE users SET user_id=? ,user_login=? ,user_email=? ,user_avatar=? ,user_active=? ,user_admin=? ,user_token=? ,user_secret=? ,user_hash=? WHERE user_email=? ` const DeleteUserEmailStmt = ` DELETE FROM users WHERE user_email=? ` ================================================ FILE: fmt.go ================================================ package main import ( "bytes" "io" "os" "os/exec" ) // format formats a template using gofmt. func format(in io.Reader) (io.Reader, error) { var out bytes.Buffer gofmt := exec.Command("gofmt", "-s") gofmt.Stdin = in gofmt.Stdout = &out gofmt.Stderr = os.Stderr err := gofmt.Run() return &out, err } ================================================ FILE: gen.go ================================================ package main import ( "bytes" "flag" "fmt" "io" "os" "github.com/drone/sqlgen/parse" "github.com/drone/sqlgen/schema" ) var ( input = flag.String("file", "", "input file name; required") output = flag.String("o", "", "output file name; required") pkgName = flag.String("pkg", "main", "output package name; required") typeName = flag.String("type", "", "type to generate; required") database = flag.String("db", "sqlite", "sql dialect; required") genSchema = flag.Bool("schema", true, "generate sql schema and queries") genFuncs = flag.Bool("funcs", true, "generate sql helper functions") extraFuncs = flag.Bool("extras", true, "generate extra sql helper functions") ) func main() { flag.Parse() // parses the syntax tree into something a bit // easier to work with. tree, err := parse.Parse(*input, *typeName) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) os.Exit(1) } // if the code is generated in a different folder // that the struct we need to import the struct if tree.Pkg != *pkgName && *pkgName != "main" { // TODO } // load the Tree into a schema Object table := schema.Load(tree) dialect := schema.New(schema.Dialects[*database]) var buf bytes.Buffer if *genFuncs { writePackage(&buf, *pkgName) writeImports(&buf, tree, "database/sql") writeRowFunc(&buf, tree) writeRowsFunc(&buf, tree) writeSliceFunc(&buf, tree) if *extraFuncs { writeSelectRow(&buf, tree) writeSelectRows(&buf, tree) writeInsertFunc(&buf, tree) writeUpdateFunc(&buf, tree) } } else { writePackage(&buf, *pkgName) } // write the sql functions if *genSchema { writeSchema(&buf, dialect, table) } // formats the generated file using gofmt pretty, err := format(&buf) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) return } // create output source for file. defaults to // stdout but may be file. var out io.WriteCloser = os.Stdout if *output != "" { out, err = os.Create(*output) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) return } defer out.Close() } io.Copy(out, pretty) } ================================================ FILE: gen_funcs.go ================================================ package main import ( "bytes" "fmt" "io" "strings" "github.com/acsellers/inflections" "github.com/drone/sqlgen/parse" ) func writeImports(w io.Writer, tree *parse.Node, pkgs ...string) { var pmap = map[string]struct{}{} // add default packages for _, pkg := range pkgs { pmap[pkg] = struct{}{} } // check each edge node to see if it is // encoded, which might require us to import // other packages for _, node := range tree.Edges() { if node.Tags == nil || len(node.Tags.Encode) == 0 { continue } switch node.Tags.Encode { case "json": pmap["encoding/json"] = struct{}{} // case "gzip": // pmap["compress/gzip"] = struct{}{} // case "snappy": // pmap["github.com/golang/snappy"] = struct{}{} } } if len(pmap) == 0 { return } // write the import block, including each // encoder package that was specified. fmt.Fprintln(w, "\nimport (") for pkg, _ := range pmap { fmt.Fprintf(w, "\t%q\n", pkg) } fmt.Fprintln(w, ")") } func writeSliceFunc(w io.Writer, tree *parse.Node) { var buf1, buf2, buf3 bytes.Buffer var i, depth int var parent = tree for _, node := range tree.Edges() { if node.Tags.Skip { continue } // temporary variable declaration switch node.Kind { case parse.Map, parse.Slice: fmt.Fprintf(&buf1, "var v%d %s\n", i, "[]byte") default: fmt.Fprintf(&buf1, "var v%d %s\n", i, node.Type) } // variable scanning fmt.Fprintf(&buf3, "v%d,\n", i) // variable setting path := node.Path()[1:] // if the parent is a ptr struct we // need to create a new if parent != node.Parent && node.Parent.Kind == parse.Ptr { // if node.Parent != nil && node.Parent.Parent != parent { // fmt.Fprintln(&buf2, "}\n") // depth-- // } // seriously ... this works? if node.Parent != nil && node.Parent.Parent != parent { for _, p := range path { if p == parent || depth == 0 { break } fmt.Fprintln(&buf2, "}\n") depth-- } } depth++ fmt.Fprintf(&buf2, "if v.%s != nil {\n", join(path[:len(path)-1], ".")) } switch node.Kind { case parse.Map, parse.Slice, parse.Struct, parse.Ptr: fmt.Fprintf(&buf2, "v%d, _ = json.Marshal(&v.%s)\n", i, join(path, ".")) default: fmt.Fprintf(&buf2, "v%d=v.%s\n", i, join(path, ".")) } parent = node.Parent i++ } for depth != 0 { depth-- fmt.Fprintln(&buf2, "}\n") } fmt.Fprintf(w, sSliceRow, tree.Type, tree.Type, buf1.String(), buf2.String(), buf3.String(), ) } func writeRowFunc(w io.Writer, tree *parse.Node) { var buf1, buf2, buf3 bytes.Buffer var i int var parent = tree for _, node := range tree.Edges() { if node.Tags.Skip { continue } // temporary variable declaration switch node.Kind { case parse.Map, parse.Slice: fmt.Fprintf(&buf1, "var v%d %s\n", i, "[]byte") default: fmt.Fprintf(&buf1, "var v%d %s\n", i, node.Type) } // variable scanning fmt.Fprintf(&buf2, "&v%d,\n", i) // variable setting path := node.Path()[1:] // if the parent is a ptr struct we // need to create a new if parent != node.Parent && node.Parent.Kind == parse.Ptr { fmt.Fprintf(&buf3, "v.%s=&%s{}\n", join(path[:len(path)-1], "."), node.Parent.Type) } switch node.Kind { case parse.Map, parse.Slice, parse.Struct, parse.Ptr: fmt.Fprintf(&buf3, "json.Unmarshal(v%d, &v.%s)\n", i, join(path, ".")) default: fmt.Fprintf(&buf3, "v.%s=v%d\n", join(path, "."), i) } parent = node.Parent i++ } fmt.Fprintf(w, sScanRow, tree.Type, tree.Type, buf1.String(), buf2.String(), tree.Type, buf3.String(), ) } func writeRowsFunc(w io.Writer, tree *parse.Node) { var buf1, buf2, buf3 bytes.Buffer var i int var parent = tree for _, node := range tree.Edges() { if node.Tags.Skip { continue } // temporary variable declaration switch node.Kind { case parse.Map, parse.Slice: fmt.Fprintf(&buf1, "var v%d %s\n", i, "[]byte") default: fmt.Fprintf(&buf1, "var v%d %s\n", i, node.Type) } // variable scanning fmt.Fprintf(&buf2, "&v%d,\n", i) // variable setting path := node.Path()[1:] // if the parent is a ptr struct we // need to create a new if parent != node.Parent && node.Parent.Kind == parse.Ptr { fmt.Fprintf(&buf3, "v.%s=&%s{}\n", join(path[:len(path)-1], "."), node.Parent.Type) } switch node.Kind { case parse.Map, parse.Slice, parse.Struct, parse.Ptr: fmt.Fprintf(&buf3, "json.Unmarshal(v%d, &v.%s)\n", i, join(path, ".")) default: fmt.Fprintf(&buf3, "v.%s=v%d\n", join(path, "."), i) } parent = node.Parent i++ } fmt.Fprintf(w, sScanRows, inflections.Pluralize(tree.Type), tree.Type, tree.Type, buf1.String(), buf2.String(), tree.Type, buf3.String(), ) } func writeSelectRow(w io.Writer, tree *parse.Node) { fmt.Fprintf(w, sSelectRow, tree.Type, tree.Type, tree.Type) } func writeSelectRows(w io.Writer, tree *parse.Node) { plural := inflections.Pluralize(tree.Type) fmt.Fprintf(w, sSelectRows, plural, tree.Type, plural) } func writeInsertFunc(w io.Writer, tree *parse.Node) { // TODO this assumes I'm using the ID field. // we should not make that assumption fmt.Fprintf(w, sInsert, tree.Type, tree.Type, tree.Type) } func writeUpdateFunc(w io.Writer, tree *parse.Node) { fmt.Fprintf(w, sUpdate, tree.Type, tree.Type, tree.Type) } // join is a helper function that joins nodes // together by name using the seperator. func join(nodes []*parse.Node, sep string) string { var parts []string for _, node := range nodes { parts = append(parts, node.Name) } return strings.Join(parts, sep) } ================================================ FILE: gen_schema.go ================================================ package main import ( "fmt" "io" "strings" "bitbucket.org/pkg/inflect" "github.com/drone/sqlgen/schema" ) // writeSchema writes SQL statements to CREATE, INSERT, // UPDATE and DELETE values from Table t. func writeSchema(w io.Writer, d schema.Dialect, t *schema.Table) { writeConst(w, d.Table(t), "create", inflect.Singularize(t.Name), "stmt", ) writeConst(w, d.Insert(t), "insert", inflect.Singularize(t.Name), "stmt", ) writeConst(w, d.Select(t, nil), "select", inflect.Singularize(t.Name), "stmt", ) writeConst(w, d.SelectRange(t, nil), "select", inflect.Singularize(t.Name), "range", "stmt", ) writeConst(w, d.SelectCount(t, nil), "select", inflect.Singularize(t.Name), "count", "stmt", ) if len(t.Primary) != 0 { writeConst(w, d.Select(t, t.Primary), "select", inflect.Singularize(t.Name), "pkey", "stmt", ) writeConst(w, d.Update(t, t.Primary), "update", inflect.Singularize(t.Name), "pkey", "stmt", ) writeConst(w, d.Delete(t, t.Primary), "delete", inflect.Singularize(t.Name), "pkey", "stmt", ) } for _, ix := range t.Index { writeConst(w, d.Index(t, ix), "create", ix.Name, "stmt", ) writeConst(w, d.Select(t, ix.Fields), "select", ix.Name, "stmt", ) if !ix.Unique { writeConst(w, d.SelectRange(t, ix.Fields), "select", ix.Name, "range", "stmt", ) writeConst(w, d.SelectCount(t, ix.Fields), "select", ix.Name, "count", "stmt", ) } else { writeConst(w, d.Update(t, ix.Fields), "update", ix.Name, "stmt", ) writeConst(w, d.Delete(t, ix.Fields), "delete", ix.Name, "stmt", ) } } } // WritePackage writes the Go package header to // writer w with the given package name. func writePackage(w io.Writer, name string) { fmt.Fprintf(w, sPackage, name) } // writeConst is a helper function that writes the // body string to a Go const variable. func writeConst(w io.Writer, body string, label ...string) { // create a snake case variable name from // the specified labels. Then convert the // variable name to a quoted, camel case string. name := strings.Join(label, "_") name = inflect.Typeify(name) // quote the body using multi-line quotes body = fmt.Sprintf(sQuote, body) fmt.Fprintf(w, sConst, name, body) } ================================================ FILE: parse/const.go ================================================ package parse const ( Invalid = iota Bool Int Int8 Int16 Int32 Int64 Uint Uint8 Uint16 Uint32 Uint64 Float32 Float64 Complex64 Complex128 Interface Bytes Map Ptr String Slice Struct ) var Types = map[string]uint8{ "bool": Bool, "int": Int, "int8": Int8, "int16": Int16, "int32": Int32, "int64": Int64, "uint": Uint, "uint8": Uint8, "uint16": Uint16, "uint32": Uint32, "uint64": Uint64, "float32": Float32, "float64": Float64, "complex64": Complex64, "complex128": Complex128, "interface{}": Interface, "[]byte": Bytes, "string": String, } ================================================ FILE: parse/node.go ================================================ package parse type Node struct { Pkg string // source code package. Name string // source code name. Kind uint8 // source code kind. Type string // source code type. Tags *Tag Parent *Node Nodes []*Node } func (n *Node) append(node *Node) { node.Parent = n n.Nodes = append(n.Nodes, node) } // Walk traverses the node tree, invoking the callback // function for each node that is traversed. func (n *Node) Walk(fn func(*Node)) { for _, node := range n.Nodes { fn(node) node.Walk(fn) } } // WalkRev traverses the tree in reverse order, invoking // the callback function for each parent node until // the root node is reached. func (n *Node) WalkRev(fn func(*Node)) { if n.Parent != nil { n.Parent.WalkRev(fn) } fn(n) // this was previously inside the if block } // Edges returns a flattened list of all edge // nodes in the Tree. func (n *Node) Edges() []*Node { var nodes []*Node n.Walk(func(node *Node) { if len(node.Nodes) == 0 { nodes = append(nodes, node) } }) return nodes } // Path returns the absolute path of the node // in the Tree. func (n *Node) Path() []*Node { var nodes []*Node n.WalkRev(func(node *Node) { nodes = append(nodes, node) }) return nodes } ================================================ FILE: parse/parse.go ================================================ package parse import ( "errors" "fmt" "go/ast" "go/parser" "go/token" ) var ( ErrTypeNotFound = errors.New("Cannot find type in the source code.") ErrTypeInvalid = errors.New("Cannot convert type to a SQL type.") ) func Parse(path, name string) (*Node, error) { var fset = token.NewFileSet() var file, err = parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { return nil, err } for _, decl := range file.Decls { gen, ok := decl.(*ast.GenDecl) if !ok { continue } spec, ok := gen.Specs[0].(*ast.TypeSpec) if !ok { continue } if spec.Name.String() != name { continue } var node = new(Node) node.Name = spec.Name.String() node.Type = spec.Name.String() node.Pkg = file.Name.Name err = buildNodes(node, spec) return node, err } return nil, ErrTypeNotFound } func buildNodes(parent *Node, spec *ast.TypeSpec) error { ident, ok := spec.Type.(*ast.StructType) if !ok { return ErrTypeInvalid } for _, field := range ident.Fields.List { var tag string if field.Tag != nil { tag = field.Tag.Value } buildNode(parent, field.Type, field.Names[0].Name, tag) } return nil } func buildNode(parent *Node, expr ast.Expr, name, tag string) error { var err error switch ident := expr.(type) { case *ast.Ident: if ident.Obj == nil { node := &Node{ Name: name, Type: ident.Name, Kind: Types[ident.Name], } node.Tags, err = parseTag(tag) if err != nil { return err } parent.append(node) return nil } spec, ok := ident.Obj.Decl.(*ast.TypeSpec) if !ok { goto invalidType } node := &Node{ Name: name, Type: ident.Name, Kind: Struct, } node.Tags, err = parseTag(tag) if err != nil { return err } parent.append(node) return buildNodes(node, spec) case *ast.ArrayType: if ident.Len != nil { goto invalidType } node := &Node{ Name: name, Kind: Slice, Type: fmt.Sprintf("[]%s", ident.Elt), } node.Tags, err = parseTag(tag) if err != nil { return err } if node.Type == "[]byte" { node.Kind = Bytes } parent.append(node) return nil case *ast.MapType: type_ := fmt.Sprintf("map[%s]%s", ident.Key, ident.Value) node := &Node{Name: name, Type: type_, Kind: Map} node.Tags, err = parseTag(tag) if err != nil { return err } parent.append(node) return nil case *ast.StarExpr: innerIdent, ok := ident.X.(*ast.Ident) if !ok { goto invalidType } if innerIdent.Obj == nil || innerIdent.Obj.Decl == nil { goto invalidType } spec, ok := innerIdent.Obj.Decl.(*ast.TypeSpec) if !ok { goto invalidType } node := &Node{Name: name, Type: innerIdent.Name, Kind: Ptr} node.Tags, err = parseTag(tag) if err != nil { return err } if node.Tags.Skip { return nil } parent.append(node) return buildNodes(node, spec) } invalidType: return fmt.Errorf("%s is not a valid type", name) } ================================================ FILE: parse/tag.go ================================================ package parse import ( "fmt" "reflect" "strings" "gopkg.in/yaml.v2" ) const ( EncodeGzip = "gzip" EncodeJson = "json" ) // Tag stores the parsed data from the tag string in // a struct field. type Tag struct { Name string `yaml:"name"` Type string `yaml:"type"` Primary bool `yaml:"pk"` Auto bool `yaml:"auto"` Index string `yaml:"index"` Unique string `yaml:"unique"` Size int `yaml:"size"` Skip bool `yaml:"skip"` Encode string `yaml:"encode"` } // parseTag parses a tag string from the struct // field and unmarshals into a Tag struct. func parseTag(raw string) (*Tag, error) { var tag = new(Tag) raw = strings.Replace(raw, "`", "", -1) raw = reflect.StructTag(raw).Get("sql") // if the tag indicates the field should // be skipped we can exit right away. if strings.TrimSpace(raw) == "-" { tag.Skip = true return tag, nil } // otherwise wrap the string in curly braces // so that we can use the Yaml parser. raw = fmt.Sprintf("{ %s }", raw) // unmarshals the Yaml formatted string into // the Tag structure. var err = yaml.Unmarshal([]byte(raw), tag) return tag, err } ================================================ FILE: parse/tag_test.go ================================================ package parse import ( "reflect" "testing" ) var tagTests = []struct { raw string tag *Tag }{ { `sql:"-"`, &Tag{Skip: true}, }, { `sql:"pk: true, auto: true"`, &Tag{Primary: true, Auto: true}, }, { `sql:"name: foo"`, &Tag{Name: "foo"}, }, { `sql:"type: varchar"`, &Tag{Type: "varchar"}, }, { `sql:"size: 2048"`, &Tag{Size: 2048}, }, { `sql:"index: fake_index"`, &Tag{Index: "fake_index"}, }, { `sql:"unique: fake_unique_index"`, &Tag{Unique: "fake_unique_index"}, }, } func TestParseTag(t *testing.T) { for _, test := range tagTests { var want = test.tag var got, err = parseTag(test.raw) if err != nil { t.Errorf("Got Error parsing Tag %s. %s", test.raw, err) } if !reflect.DeepEqual(got, want) { t.Errorf("Wanted Tag %+v, got Tag %+v", want, got) } } } ================================================ FILE: schema/base.go ================================================ package schema import ( "bytes" "fmt" "io" "strings" "text/tabwriter" ) type base struct { Dialect Dialect } // Table returns a SQL statement to create the table. func (b *base) Table(t *Table) string { // use a large default buffer size of so that // the tabbing doesn't get prematurely flushed // resulting in un-even lines. var byt = make([]byte, 0, 100000) var buf = bytes.NewBuffer(byt) // use a tab writer to evenly space the column // names and column types. var tab = tabwriter.NewWriter(buf, 0, 8, 1, ' ', 0) b.columnw(tab, t.Fields, false, false, true) // flush the tab writer to write to the buffer tab.Flush() return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s\n);", t.Name, buf.String()) } // Index returns a SQL statement to create the index. func (b *base) Index(table *Table, index *Index) string { var obj = "INDEX" if index.Unique { obj = "UNIQUE INDEX" } return fmt.Sprintf("CREATE %s IF NOT EXISTS %s ON %s (%s)", obj, index.Name, table.Name, b.columns(index.Fields, true, false, false)) } func (b *base) Insert(t *Table) string { var fields []*Field var params []string var i int for _, field := range t.Fields { if !field.Auto { fields = append(fields, field) params = append(params, b.Dialect.Param(i)) i++ } } return fmt.Sprintf("INSERT INTO %s (%s\n) VALUES (%s)", t.Name, b.columns(fields, false, false, false), strings.Join(params, ",")) } func (b *base) Update(t *Table, fields []*Field) string { return fmt.Sprintf("UPDATE %s SET %s %s", t.Name, b.columns(t.Fields, false, true, false), b.clause(fields, len(t.Fields))) } func (b *base) Delete(t *Table, fields []*Field) string { return fmt.Sprintf("DELETE FROM %s %s", t.Name, b.clause(fields, 0)) } func (b *base) Select(t *Table, fields []*Field) string { return fmt.Sprintf("SELECT %s\nFROM %s %s", b.columns(t.Fields, false, false, false), t.Name, b.clause(fields, 0)) } func (b *base) SelectRange(t *Table, fields []*Field) string { return fmt.Sprintf("SELECT %s\nFROM %s %s\nLIMIT %s OFFSET %s", b.columns(t.Fields, false, false, false), t.Name, b.clause(fields, 0), b.Dialect.Param(len(fields)), b.Dialect.Param(len(fields)+1)) } func (b *base) SelectCount(t *Table, fields []*Field) string { return fmt.Sprintf("SELECT count(1)\nFROM %s %s", t.Name, b.clause(fields, 0)) } // Param returns the parameters symbol used in prepared // sql statements. func (b *base) Param(i int) string { return "?" } // Column returns a SQL type for the given field. // // For Mysql and Postgres see: // https://github.com/eaigner/hood/blob/master/mysql.go#L35 func (b *base) Column(f *Field) string { switch f.Type { case INTEGER: return "INTEGER" case BOOLEAN: return "BOOLEAN" case BLOB: return "BLOB" case VARCHAR: return "TEXT" default: return "TEXT" } } // Token returns the SQL string for the requested token. func (b *base) Token(v int) (_ string) { switch v { case AUTO_INCREMENT: return "AUTOINCREMENT" case PRIMARY_KEY: return "PRIMARY KEY" default: return } } // helper function to generate a block of columns. You // can optionally generate in inline list of columns, // include an assignment operator, and include column // definitions. func (b *base) columns(fields []*Field, inline, assign, def bool) string { var buf bytes.Buffer b.columnw(&buf, fields, inline, assign, def) return buf.String() } // helper function to write a block of columns to w. func (b *base) columnw(w io.Writer, fields []*Field, inline, assign, def bool) { for i, field := range fields { if !inline { io.WriteString(w, "\n") } switch { case i == 0 && !inline: io.WriteString(w, " ") case i != 0: io.WriteString(w, ",") } io.WriteString(w, field.Name) if assign { io.WriteString(w, "=") io.WriteString(w, b.Dialect.Param(i)) } if !def { continue } io.WriteString(w, "\t") io.WriteString(w, b.Dialect.Column(field)) if field.Primary { io.WriteString(w, " ") io.WriteString(w, b.Dialect.Token(PRIMARY_KEY)) } if field.Auto { io.WriteString(w, " ") io.WriteString(w, b.Dialect.Token(AUTO_INCREMENT)) } } } // helper function to generate the Where clause // section of a SQL statement func (b *base) clause(fields []*Field, pos int) string { var buf bytes.Buffer var i int for _, field := range fields { buf.WriteString("\n") switch { case i == 0: buf.WriteString("WHERE") default: buf.WriteString("AND") } buf.WriteString(" ") buf.WriteString(field.Name) buf.WriteString("=") buf.WriteString(b.Dialect.Param(i + pos)) i++ } return buf.String() } ================================================ FILE: schema/dialect.go ================================================ package schema const ( SQLITE int = iota POSTGRES MYSQL ) var Dialects = map[string]int{ "sqlite": SQLITE, "postgres": POSTGRES, "mysql": MYSQL, } type Dialect interface { Table(*Table) string Index(*Table, *Index) string Column(*Field) string Insert(*Table) string Update(*Table, []*Field) string Delete(*Table, []*Field) string Select(*Table, []*Field) string SelectCount(*Table, []*Field) string SelectRange(*Table, []*Field) string Param(int) string Token(int) string } func New(dialect int) Dialect { switch dialect { case POSTGRES: return newPosgres() case MYSQL: return newMysql() default: return newSqlite() } } ================================================ FILE: schema/dialect_mysql.go ================================================ package schema import ( "fmt" ) type mysql struct { base } func newMysql() Dialect { d := &mysql{} d.base.Dialect = d return d } func (d *mysql) Column(f *Field) (_ string) { switch f.Type { case INTEGER: return "INTEGER" case BOOLEAN: return "BOOLEAN" case BLOB: return "MEDIUMBLOB" case VARCHAR: // assigns an arbitrary size if // none is provided. size := f.Size if size == 0 { size = 512 } return fmt.Sprintf("VARCHAR(%d)", size) default: return } } func (d *mysql) Token(v int) (_ string) { switch v { case AUTO_INCREMENT: return "AUTO_INCREMENT" case PRIMARY_KEY: return "PRIMARY KEY" default: return } } ================================================ FILE: schema/dialect_postgres.go ================================================ package schema import ( "fmt" ) type posgres struct { base } func newPosgres() Dialect { d := &posgres{} d.base.Dialect = d return d } func (d *posgres) Column(f *Field) (_ string) { // posgres uses a special column type // to autoincrementing keys. if f.Auto { return "SERIAL" } switch f.Type { case INTEGER: return "INTEGER" case BOOLEAN: return "BOOLEAN" case BLOB: return "BYTEA" case VARCHAR: // assigns an arbitrary size if // none is provided. size := f.Size if size == 0 { size = 512 } return fmt.Sprintf("VARCHAR(%d)", size) default: return } } func (d *posgres) Token(v int) (_ string) { switch v { case AUTO_INCREMENT: // postgres does not support the // auto-increment keyword. return case PRIMARY_KEY: return "PRIMARY KEY" default: return } } func (d *posgres) Param(i int) string { return fmt.Sprintf("$%d", i+1) } ================================================ FILE: schema/dialect_sqlite.go ================================================ package schema type sqlite struct { base } func newSqlite() Dialect { d := &sqlite{} d.base.Dialect = d return d } ================================================ FILE: schema/helper.go ================================================ package schema import ( "strings" "github.com/acsellers/inflections" "github.com/drone/sqlgen/parse" ) func Load(tree *parse.Node) *Table { table := new(Table) // local map of indexes, used for quick // lookups and de-duping. indexs := map[string]*Index{} // pluralizes the table name and then // formats in snake case. table.Name = inflections.Underscore(tree.Type) table.Name = inflections.Pluralize(table.Name) // each edge node in the tree is a column // in the table. Convert each edge node to // a Field structure. for _, node := range tree.Edges() { field := new(Field) // Lookup the SQL column type // TODO: move this to a function t, ok := parse.Types[node.Type] if ok { tt, ok := types[t] if !ok { tt = BLOB } field.Type = tt } else { field.Type = BLOB } // substitute tag variables if node.Tags != nil { if node.Tags.Skip { continue } // default ID and int64 to primary key // with auto-increment if node.Name == "ID" && node.Kind == parse.Int64 { node.Tags.Primary = true node.Tags.Auto = true } field.Auto = node.Tags.Auto field.Primary = node.Tags.Primary field.Size = node.Tags.Size if node.Tags.Primary { table.Primary = append(table.Primary, field) } if node.Tags.Index != "" { index, ok := indexs[node.Tags.Index] if !ok { index = new(Index) index.Name = node.Tags.Index indexs[index.Name] = index table.Index = append(table.Index, index) } index.Fields = append(index.Fields, field) } if node.Tags.Unique != "" { index, ok := indexs[node.Tags.Index] if !ok { index = new(Index) index.Name = node.Tags.Unique index.Unique = true indexs[index.Name] = index table.Index = append(table.Index, index) } index.Fields = append(index.Fields, field) } if node.Tags.Type != "" { t, ok := sqlTypes[node.Tags.Type] if ok { field.Type = t } } } // get the full path name path := node.Path() var parts []string for _, part := range path { if part.Tags != nil && part.Tags.Name != "" { parts = append(parts, part.Tags.Name) continue } parts = append(parts, part.Name) } field.Name = strings.Join(parts, "_") field.Name = inflections.Underscore(field.Name) table.Fields = append(table.Fields, field) } return table } // convert Go types to SQL types. var types = map[uint8]int{ parse.Bool: BOOLEAN, parse.Int: INTEGER, parse.Int8: INTEGER, parse.Int16: INTEGER, parse.Int32: INTEGER, parse.Int64: INTEGER, parse.Uint: INTEGER, parse.Uint8: INTEGER, parse.Uint16: INTEGER, parse.Uint32: INTEGER, parse.Uint64: INTEGER, parse.Float32: INTEGER, parse.Float64: INTEGER, parse.Complex64: INTEGER, parse.Complex128: INTEGER, parse.Interface: BLOB, parse.Bytes: BLOB, parse.String: VARCHAR, parse.Map: BLOB, parse.Slice: BLOB, } var sqlTypes = map[string]int{ "text": VARCHAR, "varchar": VARCHAR, "varchar2": VARCHAR, "number": INTEGER, "integer": INTEGER, "int": INTEGER, "blob": BLOB, "bytea": BLOB, } ================================================ FILE: schema/schema.go ================================================ package schema // List of basic types const ( INTEGER int = iota VARCHAR BOOLEAN REAL BLOB ) // List of vendor-specific keywords const ( AUTO_INCREMENT = iota PRIMARY_KEY ) type Table struct { Name string Fields []*Field Index []*Index Primary []*Field } type Field struct { Name string Type int Primary bool Auto bool Size int } type Index struct { Name string Unique bool Fields []*Field } ================================================ FILE: tmpl.go ================================================ package main // template to create a constant variable. var sConst = ` const %s = %s ` // template to wrap a string in multi-line quotes. var sQuote = "`\n%s\n`" // template to declare the package name. var sPackage = ` package %s // THIS FILE WAS AUTO-GENERATED. DO NOT MODIFY. ` // template to delcare the package imports. var sImport = ` import ( %s ) ` // function template to scan a single row. const sScanRow = ` func Scan%s(row *sql.Row) (*%s, error) { %s err := row.Scan( %s ) if err != nil { return nil, err } v := &%s{} %s return v, nil } ` // function template to scan multiple rows. const sScanRows = ` func Scan%s(rows *sql.Rows) ([]*%s, error) { var err error var vv []*%s %s for rows.Next() { err = rows.Scan( %s ) if err != nil { return vv, err } v := &%s{} %s vv = append(vv, v) } return vv, rows.Err() } ` const sSliceRow = ` func Slice%s(v *%s) []interface{} { %s %s return []interface{}{ %s } } ` const sSelectRow = ` func Select%s(db *sql.DB, query string, args ...interface{}) (*%s, error) { row := db.QueryRow(query, args...) return Scan%s(row) } ` // function template to select multiple rows. const sSelectRows = ` func Select%s(db *sql.DB, query string, args ...interface{}) ([]*%s, error) { rows, err := db.Query(query, args...) if err != nil { return nil, err } defer rows.Close() return Scan%s(rows) } ` // function template to insert a single row. const sInsert = ` func Insert%s(db *sql.DB, query string, v *%s) error { res, err := db.Exec(query, Slice%s(v)[1:]...) if err != nil { return err } v.ID, err = res.LastInsertId() return err } ` // function template to update a single row. const sUpdate = ` func Update%s(db *sql.DB, query string, v *%s) error { args := Slice%s(v)[1:] args = append(args, v.ID) _, err := db.Exec(query, args...) return err } ` ================================================ FILE: util.go ================================================ package main