What can be accomplished in 7 days? A gin-like web framework? A distributed cache like groupcache? Or a simple Python interpreter? Hope this repo can give you the answer.
## Web Framework - Gee
[Gee](https://geektutu.com/post/gee.html) is a [gin](https://github.com/gin-gonic/gin)-like framework
- Day 1 - http.Handler Interface Basic [Code](gee-web/day1-http-base)
- Day 2 - Design a Flexiable Context [Code](gee-web/day2-context)
- Day 3 - Router with Trie-Tree Algorithm [Code](gee-web/day3-router)
- Day 4 - Group Control [Code](gee-web/day4-group)
- Day 5 - Middleware Mechanism [Code](gee-web/day5-middleware)
- Day 6 - Embeded Template Support [Code](gee-web/day6-template)
- Day 7 - Panic Recover & Make it Robust [Code](gee-web/day7-panic-recover)
## Distributed Cache - GeeCache
[GeeCache](https://geektutu.com/post/geecache.html) is a [groupcache](https://github.com/golang/groupcache)-like distributed cache
- Day 1 - LRU (Least Recently Used) Caching Strategy [Code](gee-cache/day1-lru)
- Day 2 - Single Machine Concurrent Cache [Code](gee-cache/day2-single-node)
- Day 3 - Launch a HTTP Server [Code](gee-cache/day3-http-server)
- Day 4 - Consistent Hash Algorithm [Code](gee-cache/day4-consistent-hash)
- Day 5 - Communication between Distributed Nodes [Code](gee-cache/day5-multi-nodes)
- Day 6 - Cache Breakdown & Single Flight | [Code](gee-cache/day6-single-flight)
- Day 7 - Use Protobuf as RPC Data Exchange Type | [Code](gee-cache/day7-proto-buf)
## Object Relational Mapping - GeeORM
[GeeORM](https://geektutu.com/post/geeorm.html) is a [gorm](https://github.com/jinzhu/gorm)-like and [xorm](https://github.com/go-xorm/xorm)-like object relational mapping library
Xorm's desgin is easier to understand than gorm-v1, so the main designs references xorm and some detailed implementions references gorm-v1.
- Day 1 - database/sql Basic | [Code](gee-orm/day1-database-sql)
- Day 2 - Object Schame Mapping | [Code](gee-orm/day2-reflect-schema)
- Day 3 - Insert and Query | [Code](gee-orm/day3-save-query)
- Day 4 - Chain, Delete and Update | [Code](gee-orm/day4-chain-operation)
- Day 5 - Support Hooks | [Code](gee-orm/day5-hooks)
- Day 6 - Support Transaction | [Code](gee-orm/day6-transaction)
- Day 7 - Migrate Database | [Code](gee-orm/day7-migrate)
## RPC Framework - GeeRPC
[GeeRPC](https://geektutu.com/post/geerpc.html) is a [net/rpc](https://github.com/golang/go/tree/master/src/net/rpc)-like RPC framework
Based on golang standard library `net/rpc`, GeeRPC implements more features. eg, protocol exchange, service registration and discovery, load balance, etc.
- Day 1 - Server Message Codec | [Code](gee-rpc/day1-codec)
- Day 2 - Concurrent Client | [Code](gee-rpc/day2-client)
- Day 3 - Service Register | [Code](gee-rpc/day3-service )
- Day 4 - Timeout Processing | [Code](gee-rpc/day4-timeout )
- Day 5 - Support HTTP Protocol | [Code](gee-rpc/day5-http-debug)
- Day 6 - Load Balance | [Code](gee-rpc/day6-load-balance)
- Day 7 - Discovery and Registry | [Code](gee-rpc/day7-registry)
## Golang WebAssembly Demo
- Demo 1 - Hello World [Code](demo-wasm/hello-world)
- Demo 2 - Register Functions [Code](demo-wasm/register-functions)
- Demo 3 - Manipulate DOM [Code](demo-wasm/manipulate-dom)
- Demo 4 - Callback [Code](demo-wasm/callback)
================================================
FILE: demo-wasm/.gitignore
================================================
*.wasm
static
================================================
FILE: demo-wasm/callback/Makefile
================================================
all: static/main.wasm static/wasm_exec.js
ifeq (, $(shell which goexec))
go get -u github.com/shurcooL/goexec
endif
goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))'
static/wasm_exec.js:
cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static
static/main.wasm: main.go
GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm .
================================================
FILE: demo-wasm/callback/index.html
================================================
================================================
FILE: demo-wasm/callback/main.go
================================================
// main.go
package main
import (
"syscall/js"
"time"
)
func fib(i int) int {
if i == 0 || i == 1 {
return 1
}
return fib(i-1) + fib(i-2)
}
func fibFunc(this js.Value, args []js.Value) interface{} {
callback := args[len(args)-1]
go func() {
time.Sleep(3 * time.Second)
v := fib(args[0].Int())
callback.Invoke(v)
}()
js.Global().Get("ans").Set("innerHTML", "Waiting 3s...")
return nil
}
func main() {
done := make(chan int, 0)
js.Global().Set("fibFunc", js.FuncOf(fibFunc))
<-done
}
================================================
FILE: demo-wasm/hello-world/Makefile
================================================
all: static/main.wasm static/wasm_exec.js
ifeq (, $(shell which goexec))
go get -u github.com/shurcooL/goexec
endif
goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))'
static/wasm_exec.js:
cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static
static/main.wasm: main.go
GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm .
================================================
FILE: demo-wasm/hello-world/index.html
================================================
================================================
FILE: demo-wasm/hello-world/main.go
================================================
// main.go
package main
import "syscall/js"
func main() {
alert := js.Global().Get("alert")
alert.Invoke("Hello World!")
}
================================================
FILE: demo-wasm/manipulate-dom/Makefile
================================================
all: static/main.wasm static/wasm_exec.js
ifeq (, $(shell which goexec))
go get -u github.com/shurcooL/goexec
endif
goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))'
static/wasm_exec.js:
cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static
static/main.wasm: main.go
GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm .
================================================
FILE: demo-wasm/manipulate-dom/index.html
================================================
1
================================================
FILE: demo-wasm/manipulate-dom/main.go
================================================
package main
import (
"strconv"
"syscall/js"
)
func fib(i int) int {
if i == 0 || i == 1 {
return 1
}
return fib(i-1) + fib(i-2)
}
var (
document = js.Global().Get("document")
numEle = document.Call("getElementById", "num")
ansEle = document.Call("getElementById", "ans")
btnEle = js.Global().Get("btn")
)
func fibFunc(this js.Value, args []js.Value) interface{} {
v := numEle.Get("value")
if num, err := strconv.Atoi(v.String()); err == nil {
ansEle.Set("innerHTML", js.ValueOf(fib(num)))
}
return nil
}
func main() {
done := make(chan int, 0)
btnEle.Call("addEventListener", "click", js.FuncOf(fibFunc))
<-done
}
================================================
FILE: demo-wasm/register-functions/Makefile
================================================
all: static/main.wasm static/wasm_exec.js
ifeq (, $(shell which goexec))
go get -u github.com/shurcooL/goexec
endif
goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))'
static/wasm_exec.js:
cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static
static/main.wasm: main.go
GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm .
================================================
FILE: demo-wasm/register-functions/index.html
================================================
1
================================================
FILE: demo-wasm/register-functions/main.go
================================================
// main.go
package main
import "syscall/js"
func fib(i int) int {
if i == 0 || i == 1 {
return 1
}
return fib(i-1) + fib(i-2)
}
func fibFunc(this js.Value, args []js.Value) interface{} {
return js.ValueOf(fib(args[0].Int()))
}
func main() {
done := make(chan int, 0)
js.Global().Set("fibFunc", js.FuncOf(fibFunc))
<-done
}
================================================
FILE: gee-bolt/day1-pages/go.mod
================================================
module geebolt
go 1.13
================================================
FILE: gee-bolt/day1-pages/meta.go
================================================
package geebolt
import (
"errors"
"hash/fnv"
"unsafe"
)
// Represent a marker value to indicate that a file is a gee-bolt DB
const magic uint32 = 0xED0CDAED
type meta struct {
magic uint32
pageSize uint32
pgid uint64
checksum uint64
}
func (m *meta) sum64() uint64 {
var h = fnv.New64a()
_, _ = h.Write((*[unsafe.Offsetof(meta{}.checksum)]byte)(unsafe.Pointer(m))[:])
return h.Sum64()
}
func (m *meta) validate() error {
if m.magic != magic {
return errors.New("invalid magic number")
}
if m.checksum != m.sum64() {
return errors.New("invalid checksum")
}
return nil
}
================================================
FILE: gee-bolt/day1-pages/page.go
================================================
package geebolt
import (
"fmt"
"reflect"
"unsafe"
)
const pageHeaderSize = unsafe.Sizeof(page{})
const branchPageElementSize = unsafe.Sizeof(branchPageElement{})
const leafPageElementSize = unsafe.Sizeof(leafPageElement{})
const maxKeysPerPage = 1024
const (
branchPageFlag uint16 = iota
leafPageFlag
metaPageFlag
freelistPageFlag
)
type page struct {
id uint64
flags uint16
count uint16
overflow uint32
}
type leafPageElement struct {
pos uint32
ksize uint32
vsize uint32
}
type branchPageElement struct {
pos uint32
ksize uint32
pgid uint64
}
func (p *page) typ() string {
switch p.flags {
case branchPageFlag:
return "branch"
case leafPageFlag:
return "leaf"
case metaPageFlag:
return "meta"
case freelistPageFlag:
return "freelist"
}
return fmt.Sprintf("unknown<%02x>", p.flags)
}
func (p *page) meta() *meta {
return (*meta)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + pageHeaderSize))
}
func (p *page) dataPtr() unsafe.Pointer {
return unsafe.Pointer(&reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(p)) + pageHeaderSize,
Len: int(p.count),
Cap: int(p.count),
})
}
func (p *page) leafPageElement(index uint16) *leafPageElement {
off := pageHeaderSize + uintptr(index)*leafPageElementSize
return (*leafPageElement)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + off))
}
func (p *page) leafPageElements() []leafPageElement {
if p.count == 0 {
return nil
}
return *(*[]leafPageElement)(p.dataPtr())
}
func (p *page) branchPageElement(index uint16) *branchPageElement {
off := pageHeaderSize + uintptr(index)*branchPageElementSize
return (*branchPageElement)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + off))
}
func (p *page) branchPageElements() []branchPageElement {
if p.count == 0 {
return nil
}
return *(*[]branchPageElement)(p.dataPtr())
}
================================================
FILE: gee-bolt/day2-mmap/db.go
================================================
package geebolt
import "os"
type DB struct {
data []byte
file *os.File
}
const maxMapSize = 1 << 31
func (db *DB) mmap(sz int) error {
b, err := syscall.Mmap()
}
func Open(path string) {
}
================================================
FILE: gee-bolt/day2-mmap/go.mod
================================================
module geebolt
go 1.13
================================================
FILE: gee-bolt/day3-tree/go.mod
================================================
module geebolt
go 1.13
================================================
FILE: gee-bolt/day3-tree/meta.go
================================================
package geebolt
import (
"errors"
"hash/fnv"
"unsafe"
)
// Represent a marker value to indicate that a file is a gee-bolt DB
const magic uint32 = 0xED0CDAED
type meta struct {
magic uint32
pageSize uint32
pgid uint64
checksum uint64
}
func (m *meta) sum64() uint64 {
var h = fnv.New64a()
_, _ = h.Write((*[unsafe.Offsetof(meta{}.checksum)]byte)(unsafe.Pointer(m))[:])
return h.Sum64()
}
func (m *meta) validate() error {
if m.magic != magic {
return errors.New("invalid magic number")
}
if m.checksum != m.sum64() {
return errors.New("invalid checksum")
}
return nil
}
================================================
FILE: gee-bolt/day3-tree/node.go
================================================
package geebolt
import (
"bytes"
"sort"
)
type kv struct {
key []byte
value []byte
}
type node struct {
isLeaf bool
key []byte
parent *node
children []*node
kvs []kv
}
func (n *node) root() *node {
if n.parent == nil {
return n
}
return n.parent.root()
}
func (n *node) index(key []byte) (index int, exact bool) {
index = sort.Search(len(n.kvs), func(i int) bool {
return bytes.Compare(n.kvs[i].key, key) != -1
})
exact = len(n.kvs) > 0 && index < len(n.kvs) && bytes.Equal(n.kvs[index].key, key)
return
}
func (n *node) put(oldKey, newKey, value []byte) {
index, exact := n.index(oldKey)
if !exact {
n.kvs = append(n.kvs, kv{})
copy(n.kvs[index+1:], n.kvs[index:])
}
kv := &n.kvs[index]
kv.key = newKey
kv.value = value
}
func (n *node) del(key []byte) {
index, exact := n.index(key)
if exact {
n.kvs = append(n.kvs[:index], n.kvs[index+1:]...)
}
}
================================================
FILE: gee-bolt/day3-tree/page.go
================================================
package geebolt
import (
"fmt"
"reflect"
"unsafe"
)
const pageHeaderSize = unsafe.Sizeof(page{})
const branchPageElementSize = unsafe.Sizeof(branchPageElement{})
const leafPageElementSize = unsafe.Sizeof(leafPageElement{})
const maxKeysPerPage = 1024
const (
branchPageFlag uint16 = iota
leafPageFlag
metaPageFlag
freelistPageFlag
)
type page struct {
id uint64
flags uint16
count uint16
overflow uint32
}
type leafPageElement struct {
pos uint32
ksize uint32
vsize uint32
}
type branchPageElement struct {
pos uint32
ksize uint32
pgid uint64
}
func (p *page) typ() string {
switch p.flags {
case branchPageFlag:
return "branch"
case leafPageFlag:
return "leaf"
case metaPageFlag:
return "meta"
case freelistPageFlag:
return "freelist"
}
return fmt.Sprintf("unknown<%02x>", p.flags)
}
func (p *page) meta() *meta {
return (*meta)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + pageHeaderSize))
}
func (p *page) dataPtr() unsafe.Pointer {
return unsafe.Pointer(&reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(p)) + pageHeaderSize,
Len: int(p.count),
Cap: int(p.count),
})
}
func (p *page) leafPageElement(index uint16) *leafPageElement {
off := pageHeaderSize + uintptr(index)*leafPageElementSize
return (*leafPageElement)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + off))
}
func (p *page) leafPageElements() []leafPageElement {
if p.count == 0 {
return nil
}
return *(*[]leafPageElement)(p.dataPtr())
}
func (p *page) branchPageElement(index uint16) *branchPageElement {
off := pageHeaderSize + uintptr(index)*branchPageElementSize
return (*branchPageElement)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + off))
}
func (p *page) branchPageElements() []branchPageElement {
if p.count == 0 {
return nil
}
return *(*[]branchPageElement)(p.dataPtr())
}
================================================
FILE: gee-cache/day1-lru/geecache/go.mod
================================================
module geecache
go 1.13
================================================
FILE: gee-cache/day1-lru/geecache/lru/lru.go
================================================
package lru
import "container/list"
// Cache is a LRU cache. It is not safe for concurrent access.
type Cache struct {
maxBytes int64
nbytes int64
ll *list.List
cache map[string]*list.Element
// optional and executed when an entry is purged.
OnEvicted func(key string, value Value)
}
type entry struct {
key string
value Value
}
// Value use Len to count how many bytes it takes
type Value interface {
Len() int
}
// New is the Constructor of Cache
func New(maxBytes int64, onEvicted func(string, Value)) *Cache {
return &Cache{
maxBytes: maxBytes,
ll: list.New(),
cache: make(map[string]*list.Element),
OnEvicted: onEvicted,
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value Value) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
c.nbytes += int64(value.Len()) - int64(kv.value.Len())
kv.value = value
} else {
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
c.nbytes += int64(len(key)) + int64(value.Len())
}
for c.maxBytes != 0 && c.maxBytes < c.nbytes {
c.RemoveOldest()
}
}
// Get look ups a key's value
func (c *Cache) Get(key string) (value Value, ok bool) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
return kv.value, true
}
return
}
// RemoveOldest removes the oldest item
func (c *Cache) RemoveOldest() {
ele := c.ll.Back()
if ele != nil {
c.ll.Remove(ele)
kv := ele.Value.(*entry)
delete(c.cache, kv.key)
c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len())
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
}
// Len the number of cache entries
func (c *Cache) Len() int {
return c.ll.Len()
}
================================================
FILE: gee-cache/day1-lru/geecache/lru/lru_test.go
================================================
package lru
import (
"reflect"
"testing"
)
type String string
func (d String) Len() int {
return len(d)
}
func TestGet(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key1", String("1234"))
if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" {
t.Fatalf("cache hit key1=1234 failed")
}
if _, ok := lru.Get("key2"); ok {
t.Fatalf("cache miss key2 failed")
}
}
func TestRemoveoldest(t *testing.T) {
k1, k2, k3 := "key1", "key2", "k3"
v1, v2, v3 := "value1", "value2", "v3"
cap := len(k1 + k2 + v1 + v2)
lru := New(int64(cap), nil)
lru.Add(k1, String(v1))
lru.Add(k2, String(v2))
lru.Add(k3, String(v3))
if _, ok := lru.Get("key1"); ok || lru.Len() != 2 {
t.Fatalf("Removeoldest key1 failed")
}
}
func TestOnEvicted(t *testing.T) {
keys := make([]string, 0)
callback := func(key string, value Value) {
keys = append(keys, key)
}
lru := New(int64(10), callback)
lru.Add("key1", String("123456"))
lru.Add("k2", String("k2"))
lru.Add("k3", String("k3"))
lru.Add("k4", String("k4"))
expect := []string{"key1", "k2"}
if !reflect.DeepEqual(expect, keys) {
t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect)
}
}
func TestAdd(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key", String("1"))
lru.Add("key", String("111"))
if lru.nbytes != int64(len("key")+len("111")) {
t.Fatal("expected 6 but got", lru.nbytes)
}
}
================================================
FILE: gee-cache/day2-single-node/geecache/byteview.go
================================================
package geecache
// A ByteView holds an immutable view of bytes.
type ByteView struct {
b []byte
}
// Len returns the view's length
func (v ByteView) Len() int {
return len(v.b)
}
// ByteSlice returns a copy of the data as a byte slice.
func (v ByteView) ByteSlice() []byte {
return cloneBytes(v.b)
}
// String returns the data as a string, making a copy if necessary.
func (v ByteView) String() string {
return string(v.b)
}
func cloneBytes(b []byte) []byte {
c := make([]byte, len(b))
copy(c, b)
return c
}
================================================
FILE: gee-cache/day2-single-node/geecache/cache.go
================================================
package geecache
import (
"geecache/lru"
"sync"
)
type cache struct {
mu sync.Mutex
lru *lru.Cache
cacheBytes int64
}
func (c *cache) add(key string, value ByteView) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
c.lru = lru.New(c.cacheBytes, nil)
}
c.lru.Add(key, value)
}
func (c *cache) get(key string) (value ByteView, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
return
}
if v, ok := c.lru.Get(key); ok {
return v.(ByteView), ok
}
return
}
================================================
FILE: gee-cache/day2-single-node/geecache/geecache.go
================================================
package geecache
import (
"fmt"
"log"
"sync"
)
// A Group is a cache namespace and associated data loaded spread over
type Group struct {
name string
getter Getter
mainCache cache
}
// A Getter loads data for a key.
type Getter interface {
Get(key string) ([]byte, error)
}
// A GetterFunc implements Getter with a function.
type GetterFunc func(key string) ([]byte, error)
// Get implements Getter interface function
func (f GetterFunc) Get(key string) ([]byte, error) {
return f(key)
}
var (
mu sync.RWMutex
groups = make(map[string]*Group)
)
// NewGroup create a new instance of Group
func NewGroup(name string, cacheBytes int64, getter Getter) *Group {
if getter == nil {
panic("nil Getter")
}
mu.Lock()
defer mu.Unlock()
g := &Group{
name: name,
getter: getter,
mainCache: cache{cacheBytes: cacheBytes},
}
groups[name] = g
return g
}
// GetGroup returns the named group previously created with NewGroup, or
// nil if there's no such group.
func GetGroup(name string) *Group {
mu.RLock()
g := groups[name]
mu.RUnlock()
return g
}
// Get value for a key from cache
func (g *Group) Get(key string) (ByteView, error) {
if key == "" {
return ByteView{}, fmt.Errorf("key is required")
}
if v, ok := g.mainCache.get(key); ok {
log.Println("[GeeCache] hit")
return v, nil
}
return g.load(key)
}
func (g *Group) load(key string) (value ByteView, err error) {
return g.getLocally(key)
}
func (g *Group) getLocally(key string) (ByteView, error) {
bytes, err := g.getter.Get(key)
if err != nil {
return ByteView{}, err
}
value := ByteView{b: cloneBytes(bytes)}
g.populateCache(key, value)
return value, nil
}
func (g *Group) populateCache(key string, value ByteView) {
g.mainCache.add(key, value)
}
================================================
FILE: gee-cache/day2-single-node/geecache/geecache_test.go
================================================
package geecache
import (
"fmt"
"log"
"reflect"
"testing"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func TestGetter(t *testing.T) {
var f Getter = GetterFunc(func(key string) ([]byte, error) {
return []byte(key), nil
})
expect := []byte("key")
if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) {
t.Fatal("callback failed")
}
}
func TestGet(t *testing.T) {
loadCounts := make(map[string]int, len(db))
gee := NewGroup("scores", 2<<10, GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
if _, ok := loadCounts[key]; !ok {
loadCounts[key] = 0
}
loadCounts[key]++
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
for k, v := range db {
if view, err := gee.Get(k); err != nil || view.String() != v {
t.Fatal("failed to get value of Tom")
}
if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 {
t.Fatalf("cache %s miss", k)
}
}
if view, err := gee.Get("unknown"); err == nil {
t.Fatalf("the value of unknow should be empty, but %s got", view)
}
}
func TestGetGroup(t *testing.T) {
groupName := "scores"
NewGroup(groupName, 2<<10, GetterFunc(
func(key string) (bytes []byte, err error) { return }))
if group := GetGroup(groupName); group == nil || group.name != groupName {
t.Fatalf("group %s not exist", groupName)
}
if group := GetGroup(groupName + "111"); group != nil {
t.Fatalf("expect nil, but %s got", group.name)
}
}
================================================
FILE: gee-cache/day2-single-node/geecache/go.mod
================================================
module geecache
go 1.13
================================================
FILE: gee-cache/day2-single-node/geecache/lru/lru.go
================================================
package lru
import "container/list"
// Cache is a LRU cache. It is not safe for concurrent access.
type Cache struct {
maxBytes int64
nbytes int64
ll *list.List
cache map[string]*list.Element
// optional and executed when an entry is purged.
OnEvicted func(key string, value Value)
}
type entry struct {
key string
value Value
}
// Value use Len to count how many bytes it takes
type Value interface {
Len() int
}
// New is the Constructor of Cache
func New(maxBytes int64, onEvicted func(string, Value)) *Cache {
return &Cache{
maxBytes: maxBytes,
ll: list.New(),
cache: make(map[string]*list.Element),
OnEvicted: onEvicted,
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value Value) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
c.nbytes += int64(value.Len()) - int64(kv.value.Len())
kv.value = value
} else {
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
c.nbytes += int64(len(key)) + int64(value.Len())
}
for c.maxBytes != 0 && c.maxBytes < c.nbytes {
c.RemoveOldest()
}
}
// Get look ups a key's value
func (c *Cache) Get(key string) (value Value, ok bool) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
return kv.value, true
}
return
}
// RemoveOldest removes the oldest item
func (c *Cache) RemoveOldest() {
ele := c.ll.Back()
if ele != nil {
c.ll.Remove(ele)
kv := ele.Value.(*entry)
delete(c.cache, kv.key)
c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len())
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
}
// Len the number of cache entries
func (c *Cache) Len() int {
return c.ll.Len()
}
================================================
FILE: gee-cache/day2-single-node/geecache/lru/lru_test.go
================================================
package lru
import (
"reflect"
"testing"
)
type String string
func (d String) Len() int {
return len(d)
}
func TestGet(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key1", String("1234"))
if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" {
t.Fatalf("cache hit key1=1234 failed")
}
if _, ok := lru.Get("key2"); ok {
t.Fatalf("cache miss key2 failed")
}
}
func TestRemoveoldest(t *testing.T) {
k1, k2, k3 := "key1", "key2", "k3"
v1, v2, v3 := "value1", "value2", "v3"
cap := len(k1 + k2 + v1 + v2)
lru := New(int64(cap), nil)
lru.Add(k1, String(v1))
lru.Add(k2, String(v2))
lru.Add(k3, String(v3))
if _, ok := lru.Get("key1"); ok || lru.Len() != 2 {
t.Fatalf("Removeoldest key1 failed")
}
}
func TestOnEvicted(t *testing.T) {
keys := make([]string, 0)
callback := func(key string, value Value) {
keys = append(keys, key)
}
lru := New(int64(10), callback)
lru.Add("key1", String("123456"))
lru.Add("k2", String("k2"))
lru.Add("k3", String("k3"))
lru.Add("k4", String("k4"))
expect := []string{"key1", "k2"}
if !reflect.DeepEqual(expect, keys) {
t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect)
}
}
func TestAdd(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key", String("1"))
lru.Add("key", String("111"))
if lru.nbytes != int64(len("key")+len("111")) {
t.Fatal("expected 6 but got", lru.nbytes)
}
}
================================================
FILE: gee-cache/day3-http-server/geecache/byteview.go
================================================
package geecache
// A ByteView holds an immutable view of bytes.
type ByteView struct {
b []byte
}
// Len returns the view's length
func (v ByteView) Len() int {
return len(v.b)
}
// ByteSlice returns a copy of the data as a byte slice.
func (v ByteView) ByteSlice() []byte {
return cloneBytes(v.b)
}
// String returns the data as a string, making a copy if necessary.
func (v ByteView) String() string {
return string(v.b)
}
func cloneBytes(b []byte) []byte {
c := make([]byte, len(b))
copy(c, b)
return c
}
================================================
FILE: gee-cache/day3-http-server/geecache/cache.go
================================================
package geecache
import (
"geecache/lru"
"sync"
)
type cache struct {
mu sync.Mutex
lru *lru.Cache
cacheBytes int64
}
func (c *cache) add(key string, value ByteView) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
c.lru = lru.New(c.cacheBytes, nil)
}
c.lru.Add(key, value)
}
func (c *cache) get(key string) (value ByteView, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
return
}
if v, ok := c.lru.Get(key); ok {
return v.(ByteView), ok
}
return
}
================================================
FILE: gee-cache/day3-http-server/geecache/geecache.go
================================================
package geecache
import (
"fmt"
"log"
"sync"
)
// A Group is a cache namespace and associated data loaded spread over
type Group struct {
name string
getter Getter
mainCache cache
}
// A Getter loads data for a key.
type Getter interface {
Get(key string) ([]byte, error)
}
// A GetterFunc implements Getter with a function.
type GetterFunc func(key string) ([]byte, error)
// Get implements Getter interface function
func (f GetterFunc) Get(key string) ([]byte, error) {
return f(key)
}
var (
mu sync.RWMutex
groups = make(map[string]*Group)
)
// NewGroup create a new instance of Group
func NewGroup(name string, cacheBytes int64, getter Getter) *Group {
if getter == nil {
panic("nil Getter")
}
mu.Lock()
defer mu.Unlock()
g := &Group{
name: name,
getter: getter,
mainCache: cache{cacheBytes: cacheBytes},
}
groups[name] = g
return g
}
// GetGroup returns the named group previously created with NewGroup, or
// nil if there's no such group.
func GetGroup(name string) *Group {
mu.RLock()
g := groups[name]
mu.RUnlock()
return g
}
// Get value for a key from cache
func (g *Group) Get(key string) (ByteView, error) {
if key == "" {
return ByteView{}, fmt.Errorf("key is required")
}
if v, ok := g.mainCache.get(key); ok {
log.Println("[GeeCache] hit")
return v, nil
}
return g.load(key)
}
func (g *Group) load(key string) (value ByteView, err error) {
return g.getLocally(key)
}
func (g *Group) getLocally(key string) (ByteView, error) {
bytes, err := g.getter.Get(key)
if err != nil {
return ByteView{}, err
}
value := ByteView{b: cloneBytes(bytes)}
g.populateCache(key, value)
return value, nil
}
func (g *Group) populateCache(key string, value ByteView) {
g.mainCache.add(key, value)
}
================================================
FILE: gee-cache/day3-http-server/geecache/geecache_test.go
================================================
package geecache
import (
"fmt"
"log"
"reflect"
"testing"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func TestGetter(t *testing.T) {
var f Getter = GetterFunc(func(key string) ([]byte, error) {
return []byte(key), nil
})
expect := []byte("key")
if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) {
t.Fatal("callback failed")
}
}
func TestGet(t *testing.T) {
loadCounts := make(map[string]int, len(db))
gee := NewGroup("scores", 2<<10, GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
if _, ok := loadCounts[key]; !ok {
loadCounts[key] = 0
}
loadCounts[key]++
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
for k, v := range db {
if view, err := gee.Get(k); err != nil || view.String() != v {
t.Fatal("failed to get value of Tom")
}
if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 {
t.Fatalf("cache %s miss", k)
}
}
if view, err := gee.Get("unknown"); err == nil {
t.Fatalf("the value of unknow should be empty, but %s got", view)
}
}
func TestGetGroup(t *testing.T) {
groupName := "scores"
NewGroup(groupName, 2<<10, GetterFunc(
func(key string) (bytes []byte, err error) { return }))
if group := GetGroup(groupName); group == nil || group.name != groupName {
t.Fatalf("group %s not exist", groupName)
}
if group := GetGroup(groupName + "111"); group != nil {
t.Fatalf("expect nil, but %s got", group.name)
}
}
================================================
FILE: gee-cache/day3-http-server/geecache/go.mod
================================================
module geecache
go 1.13
================================================
FILE: gee-cache/day3-http-server/geecache/http.go
================================================
package geecache
import (
"fmt"
"log"
"net/http"
"strings"
)
const defaultBasePath = "/_geecache/"
// HTTPPool implements PeerPicker for a pool of HTTP peers.
type HTTPPool struct {
// this peer's base URL, e.g. "https://example.net:8000"
self string
basePath string
}
// NewHTTPPool initializes an HTTP pool of peers.
func NewHTTPPool(self string) *HTTPPool {
return &HTTPPool{
self: self,
basePath: defaultBasePath,
}
}
// Log info with server name
func (p *HTTPPool) Log(format string, v ...interface{}) {
log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...))
}
// ServeHTTP handle all http requests
func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, p.basePath) {
panic("HTTPPool serving unexpected path: " + r.URL.Path)
}
p.Log("%s %s", r.Method, r.URL.Path)
// /// required
parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2)
if len(parts) != 2 {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
groupName := parts[0]
key := parts[1]
group := GetGroup(groupName)
if group == nil {
http.Error(w, "no such group: "+groupName, http.StatusNotFound)
return
}
view, err := group.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}
================================================
FILE: gee-cache/day3-http-server/geecache/lru/lru.go
================================================
package lru
import "container/list"
// Cache is a LRU cache. It is not safe for concurrent access.
type Cache struct {
maxBytes int64
nbytes int64
ll *list.List
cache map[string]*list.Element
// optional and executed when an entry is purged.
OnEvicted func(key string, value Value)
}
type entry struct {
key string
value Value
}
// Value use Len to count how many bytes it takes
type Value interface {
Len() int
}
// New is the Constructor of Cache
func New(maxBytes int64, onEvicted func(string, Value)) *Cache {
return &Cache{
maxBytes: maxBytes,
ll: list.New(),
cache: make(map[string]*list.Element),
OnEvicted: onEvicted,
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value Value) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
c.nbytes += int64(value.Len()) - int64(kv.value.Len())
kv.value = value
} else {
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
c.nbytes += int64(len(key)) + int64(value.Len())
}
for c.maxBytes != 0 && c.maxBytes < c.nbytes {
c.RemoveOldest()
}
}
// Get look ups a key's value
func (c *Cache) Get(key string) (value Value, ok bool) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
return kv.value, true
}
return
}
// RemoveOldest removes the oldest item
func (c *Cache) RemoveOldest() {
ele := c.ll.Back()
if ele != nil {
c.ll.Remove(ele)
kv := ele.Value.(*entry)
delete(c.cache, kv.key)
c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len())
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
}
// Len the number of cache entries
func (c *Cache) Len() int {
return c.ll.Len()
}
================================================
FILE: gee-cache/day3-http-server/geecache/lru/lru_test.go
================================================
package lru
import (
"reflect"
"testing"
)
type String string
func (d String) Len() int {
return len(d)
}
func TestGet(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key1", String("1234"))
if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" {
t.Fatalf("cache hit key1=1234 failed")
}
if _, ok := lru.Get("key2"); ok {
t.Fatalf("cache miss key2 failed")
}
}
func TestRemoveoldest(t *testing.T) {
k1, k2, k3 := "key1", "key2", "k3"
v1, v2, v3 := "value1", "value2", "v3"
cap := len(k1 + k2 + v1 + v2)
lru := New(int64(cap), nil)
lru.Add(k1, String(v1))
lru.Add(k2, String(v2))
lru.Add(k3, String(v3))
if _, ok := lru.Get("key1"); ok || lru.Len() != 2 {
t.Fatalf("Removeoldest key1 failed")
}
}
func TestOnEvicted(t *testing.T) {
keys := make([]string, 0)
callback := func(key string, value Value) {
keys = append(keys, key)
}
lru := New(int64(10), callback)
lru.Add("key1", String("123456"))
lru.Add("k2", String("k2"))
lru.Add("k3", String("k3"))
lru.Add("k4", String("k4"))
expect := []string{"key1", "k2"}
if !reflect.DeepEqual(expect, keys) {
t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect)
}
}
func TestAdd(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key", String("1"))
lru.Add("key", String("111"))
if lru.nbytes != int64(len("key")+len("111")) {
t.Fatal("expected 6 but got", lru.nbytes)
}
}
================================================
FILE: gee-cache/day3-http-server/go.mod
================================================
module example
go 1.13
require geecache v0.0.0
replace geecache => ./geecache
================================================
FILE: gee-cache/day3-http-server/main.go
================================================
package main
/*
$ curl http://localhost:9999/_geecache/scores/Tom
630
$ curl http://localhost:9999/_geecache/scores/kkk
kkk not exist
*/
import (
"fmt"
"geecache"
"log"
"net/http"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func main() {
geecache.NewGroup("scores", 2<<10, geecache.GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
addr := "localhost:9999"
peers := geecache.NewHTTPPool(addr)
log.Println("geecache is running at", addr)
log.Fatal(http.ListenAndServe(addr, peers))
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/byteview.go
================================================
package geecache
// A ByteView holds an immutable view of bytes.
type ByteView struct {
b []byte
}
// Len returns the view's length
func (v ByteView) Len() int {
return len(v.b)
}
// ByteSlice returns a copy of the data as a byte slice.
func (v ByteView) ByteSlice() []byte {
return cloneBytes(v.b)
}
// String returns the data as a string, making a copy if necessary.
func (v ByteView) String() string {
return string(v.b)
}
func cloneBytes(b []byte) []byte {
c := make([]byte, len(b))
copy(c, b)
return c
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/cache.go
================================================
package geecache
import (
"geecache/lru"
"sync"
)
type cache struct {
mu sync.Mutex
lru *lru.Cache
cacheBytes int64
}
func (c *cache) add(key string, value ByteView) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
c.lru = lru.New(c.cacheBytes, nil)
}
c.lru.Add(key, value)
}
func (c *cache) get(key string) (value ByteView, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
return
}
if v, ok := c.lru.Get(key); ok {
return v.(ByteView), ok
}
return
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/consistenthash/consistenthash.go
================================================
package consistenthash
import (
"hash/crc32"
"sort"
"strconv"
)
// Hash maps bytes to uint32
type Hash func(data []byte) uint32
// Map constains all hashed keys
type Map struct {
hash Hash
replicas int
keys []int // Sorted
hashMap map[int]string
}
// New creates a Map instance
func New(replicas int, fn Hash) *Map {
m := &Map{
replicas: replicas,
hash: fn,
hashMap: make(map[int]string),
}
if m.hash == nil {
m.hash = crc32.ChecksumIEEE
}
return m
}
// Add adds some keys to the hash.
func (m *Map) Add(keys ...string) {
for _, key := range keys {
for i := 0; i < m.replicas; i++ {
hash := int(m.hash([]byte(strconv.Itoa(i) + key)))
m.keys = append(m.keys, hash)
m.hashMap[hash] = key
}
}
sort.Ints(m.keys)
}
// Get gets the closest item in the hash to the provided key.
func (m *Map) Get(key string) string {
if len(m.keys) == 0 {
return ""
}
hash := int(m.hash([]byte(key)))
// Binary search for appropriate replica.
idx := sort.Search(len(m.keys), func(i int) bool {
return m.keys[i] >= hash
})
return m.hashMap[m.keys[idx%len(m.keys)]]
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/consistenthash/consistenthash_test.go
================================================
package consistenthash
import (
"strconv"
"testing"
)
func TestHashing(t *testing.T) {
hash := New(3, func(key []byte) uint32 {
i, _ := strconv.Atoi(string(key))
return uint32(i)
})
// Given the above hash function, this will give replicas with "hashes":
// 2, 4, 6, 12, 14, 16, 22, 24, 26
hash.Add("6", "4", "2")
testCases := map[string]string{
"2": "2",
"11": "2",
"23": "4",
"27": "2",
}
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
// Adds 8, 18, 28
hash.Add("8")
// 27 should now map to 8.
testCases["27"] = "8"
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/geecache.go
================================================
package geecache
import (
"fmt"
"log"
"sync"
)
// A Group is a cache namespace and associated data loaded spread over
type Group struct {
name string
getter Getter
mainCache cache
}
// A Getter loads data for a key.
type Getter interface {
Get(key string) ([]byte, error)
}
// A GetterFunc implements Getter with a function.
type GetterFunc func(key string) ([]byte, error)
// Get implements Getter interface function
func (f GetterFunc) Get(key string) ([]byte, error) {
return f(key)
}
var (
mu sync.RWMutex
groups = make(map[string]*Group)
)
// NewGroup create a new instance of Group
func NewGroup(name string, cacheBytes int64, getter Getter) *Group {
if getter == nil {
panic("nil Getter")
}
mu.Lock()
defer mu.Unlock()
g := &Group{
name: name,
getter: getter,
mainCache: cache{cacheBytes: cacheBytes},
}
groups[name] = g
return g
}
// GetGroup returns the named group previously created with NewGroup, or
// nil if there's no such group.
func GetGroup(name string) *Group {
mu.RLock()
g := groups[name]
mu.RUnlock()
return g
}
// Get value for a key from cache
func (g *Group) Get(key string) (ByteView, error) {
if key == "" {
return ByteView{}, fmt.Errorf("key is required")
}
if v, ok := g.mainCache.get(key); ok {
log.Println("[GeeCache] hit")
return v, nil
}
return g.load(key)
}
func (g *Group) load(key string) (value ByteView, err error) {
return g.getLocally(key)
}
func (g *Group) getLocally(key string) (ByteView, error) {
bytes, err := g.getter.Get(key)
if err != nil {
return ByteView{}, err
}
value := ByteView{b: cloneBytes(bytes)}
g.populateCache(key, value)
return value, nil
}
func (g *Group) populateCache(key string, value ByteView) {
g.mainCache.add(key, value)
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/geecache_test.go
================================================
package geecache
import (
"fmt"
"log"
"reflect"
"testing"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func TestGetter(t *testing.T) {
var f Getter = GetterFunc(func(key string) ([]byte, error) {
return []byte(key), nil
})
expect := []byte("key")
if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) {
t.Fatal("callback failed")
}
}
func TestGet(t *testing.T) {
loadCounts := make(map[string]int, len(db))
gee := NewGroup("scores", 2<<10, GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
if _, ok := loadCounts[key]; !ok {
loadCounts[key] = 0
}
loadCounts[key]++
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
for k, v := range db {
if view, err := gee.Get(k); err != nil || view.String() != v {
t.Fatal("failed to get value of Tom")
}
if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 {
t.Fatalf("cache %s miss", k)
}
}
if view, err := gee.Get("unknown"); err == nil {
t.Fatalf("the value of unknow should be empty, but %s got", view)
}
}
func TestGetGroup(t *testing.T) {
groupName := "scores"
NewGroup(groupName, 2<<10, GetterFunc(
func(key string) (bytes []byte, err error) { return }))
if group := GetGroup(groupName); group == nil || group.name != groupName {
t.Fatalf("group %s not exist", groupName)
}
if group := GetGroup(groupName + "111"); group != nil {
t.Fatalf("expect nil, but %s got", group.name)
}
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/go.mod
================================================
module geecache
go 1.13
================================================
FILE: gee-cache/day4-consistent-hash/geecache/http.go
================================================
package geecache
import (
"fmt"
"log"
"net/http"
"strings"
)
const defaultBasePath = "/_geecache/"
// HTTPPool implements PeerPicker for a pool of HTTP peers.
type HTTPPool struct {
// this peer's base URL, e.g. "https://example.net:8000"
self string
basePath string
}
// NewHTTPPool initializes an HTTP pool of peers.
func NewHTTPPool(self string) *HTTPPool {
return &HTTPPool{
self: self,
basePath: defaultBasePath,
}
}
// Log info with server name
func (p *HTTPPool) Log(format string, v ...interface{}) {
log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...))
}
// ServeHTTP handle all http requests
func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, p.basePath) {
panic("HTTPPool serving unexpected path: " + r.URL.Path)
}
p.Log("%s %s", r.Method, r.URL.Path)
// /// required
parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2)
if len(parts) != 2 {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
groupName := parts[0]
key := parts[1]
group := GetGroup(groupName)
if group == nil {
http.Error(w, "no such group: "+groupName, http.StatusNotFound)
return
}
view, err := group.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/lru/lru.go
================================================
package lru
import "container/list"
// Cache is a LRU cache. It is not safe for concurrent access.
type Cache struct {
maxBytes int64
nbytes int64
ll *list.List
cache map[string]*list.Element
// optional and executed when an entry is purged.
OnEvicted func(key string, value Value)
}
type entry struct {
key string
value Value
}
// Value use Len to count how many bytes it takes
type Value interface {
Len() int
}
// New is the Constructor of Cache
func New(maxBytes int64, onEvicted func(string, Value)) *Cache {
return &Cache{
maxBytes: maxBytes,
ll: list.New(),
cache: make(map[string]*list.Element),
OnEvicted: onEvicted,
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value Value) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
c.nbytes += int64(value.Len()) - int64(kv.value.Len())
kv.value = value
} else {
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
c.nbytes += int64(len(key)) + int64(value.Len())
}
for c.maxBytes != 0 && c.maxBytes < c.nbytes {
c.RemoveOldest()
}
}
// Get look ups a key's value
func (c *Cache) Get(key string) (value Value, ok bool) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
return kv.value, true
}
return
}
// RemoveOldest removes the oldest item
func (c *Cache) RemoveOldest() {
ele := c.ll.Back()
if ele != nil {
c.ll.Remove(ele)
kv := ele.Value.(*entry)
delete(c.cache, kv.key)
c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len())
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
}
// Len the number of cache entries
func (c *Cache) Len() int {
return c.ll.Len()
}
================================================
FILE: gee-cache/day4-consistent-hash/geecache/lru/lru_test.go
================================================
package lru
import (
"reflect"
"testing"
)
type String string
func (d String) Len() int {
return len(d)
}
func TestGet(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key1", String("1234"))
if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" {
t.Fatalf("cache hit key1=1234 failed")
}
if _, ok := lru.Get("key2"); ok {
t.Fatalf("cache miss key2 failed")
}
}
func TestRemoveoldest(t *testing.T) {
k1, k2, k3 := "key1", "key2", "k3"
v1, v2, v3 := "value1", "value2", "v3"
cap := len(k1 + k2 + v1 + v2)
lru := New(int64(cap), nil)
lru.Add(k1, String(v1))
lru.Add(k2, String(v2))
lru.Add(k3, String(v3))
if _, ok := lru.Get("key1"); ok || lru.Len() != 2 {
t.Fatalf("Removeoldest key1 failed")
}
}
func TestOnEvicted(t *testing.T) {
keys := make([]string, 0)
callback := func(key string, value Value) {
keys = append(keys, key)
}
lru := New(int64(10), callback)
lru.Add("key1", String("123456"))
lru.Add("k2", String("k2"))
lru.Add("k3", String("k3"))
lru.Add("k4", String("k4"))
expect := []string{"key1", "k2"}
if !reflect.DeepEqual(expect, keys) {
t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect)
}
}
func TestAdd(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key", String("1"))
lru.Add("key", String("111"))
if lru.nbytes != int64(len("key")+len("111")) {
t.Fatal("expected 6 but got", lru.nbytes)
}
}
================================================
FILE: gee-cache/day4-consistent-hash/go.mod
================================================
module example
go 1.13
require geecache v0.0.0
replace geecache => ./geecache
================================================
FILE: gee-cache/day4-consistent-hash/main.go
================================================
package main
/*
$ curl http://localhost:9999/_geecache/scores/Tom
630
$ curl http://localhost:9999/_geecache/scores/kkk
kkk not exist
*/
import (
"fmt"
"geecache"
"log"
"net/http"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func main() {
geecache.NewGroup("scores", 2<<10, geecache.GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
addr := "localhost:9999"
peers := geecache.NewHTTPPool(addr)
log.Println("geecache is running at", addr)
log.Fatal(http.ListenAndServe(addr, peers))
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/byteview.go
================================================
package geecache
// A ByteView holds an immutable view of bytes.
type ByteView struct {
b []byte
}
// Len returns the view's length
func (v ByteView) Len() int {
return len(v.b)
}
// ByteSlice returns a copy of the data as a byte slice.
func (v ByteView) ByteSlice() []byte {
return cloneBytes(v.b)
}
// String returns the data as a string, making a copy if necessary.
func (v ByteView) String() string {
return string(v.b)
}
func cloneBytes(b []byte) []byte {
c := make([]byte, len(b))
copy(c, b)
return c
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/cache.go
================================================
package geecache
import (
"geecache/lru"
"sync"
)
type cache struct {
mu sync.Mutex
lru *lru.Cache
cacheBytes int64
}
func (c *cache) add(key string, value ByteView) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
c.lru = lru.New(c.cacheBytes, nil)
}
c.lru.Add(key, value)
}
func (c *cache) get(key string) (value ByteView, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
return
}
if v, ok := c.lru.Get(key); ok {
return v.(ByteView), ok
}
return
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/consistenthash/consistenthash.go
================================================
package consistenthash
import (
"hash/crc32"
"sort"
"strconv"
)
// Hash maps bytes to uint32
type Hash func(data []byte) uint32
// Map constains all hashed keys
type Map struct {
hash Hash
replicas int
keys []int // Sorted
hashMap map[int]string
}
// New creates a Map instance
func New(replicas int, fn Hash) *Map {
m := &Map{
replicas: replicas,
hash: fn,
hashMap: make(map[int]string),
}
if m.hash == nil {
m.hash = crc32.ChecksumIEEE
}
return m
}
// Add adds some keys to the hash.
func (m *Map) Add(keys ...string) {
for _, key := range keys {
for i := 0; i < m.replicas; i++ {
hash := int(m.hash([]byte(strconv.Itoa(i) + key)))
m.keys = append(m.keys, hash)
m.hashMap[hash] = key
}
}
sort.Ints(m.keys)
}
// Get gets the closest item in the hash to the provided key.
func (m *Map) Get(key string) string {
if len(m.keys) == 0 {
return ""
}
hash := int(m.hash([]byte(key)))
// Binary search for appropriate replica.
idx := sort.Search(len(m.keys), func(i int) bool {
return m.keys[i] >= hash
})
return m.hashMap[m.keys[idx%len(m.keys)]]
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/consistenthash/consistenthash_test.go
================================================
package consistenthash
import (
"strconv"
"testing"
)
func TestHashing(t *testing.T) {
hash := New(3, func(key []byte) uint32 {
i, _ := strconv.Atoi(string(key))
return uint32(i)
})
// Given the above hash function, this will give replicas with "hashes":
// 2, 4, 6, 12, 14, 16, 22, 24, 26
hash.Add("6", "4", "2")
testCases := map[string]string{
"2": "2",
"11": "2",
"23": "4",
"27": "2",
}
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
// Adds 8, 18, 28
hash.Add("8")
// 27 should now map to 8.
testCases["27"] = "8"
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/geecache.go
================================================
package geecache
import (
"fmt"
"log"
"sync"
)
// A Group is a cache namespace and associated data loaded spread over
type Group struct {
name string
getter Getter
mainCache cache
peers PeerPicker
}
// A Getter loads data for a key.
type Getter interface {
Get(key string) ([]byte, error)
}
// A GetterFunc implements Getter with a function.
type GetterFunc func(key string) ([]byte, error)
// Get implements Getter interface function
func (f GetterFunc) Get(key string) ([]byte, error) {
return f(key)
}
var (
mu sync.RWMutex
groups = make(map[string]*Group)
)
// NewGroup create a new instance of Group
func NewGroup(name string, cacheBytes int64, getter Getter) *Group {
if getter == nil {
panic("nil Getter")
}
mu.Lock()
defer mu.Unlock()
g := &Group{
name: name,
getter: getter,
mainCache: cache{cacheBytes: cacheBytes},
}
groups[name] = g
return g
}
// GetGroup returns the named group previously created with NewGroup, or
// nil if there's no such group.
func GetGroup(name string) *Group {
mu.RLock()
g := groups[name]
mu.RUnlock()
return g
}
// Get value for a key from cache
func (g *Group) Get(key string) (ByteView, error) {
if key == "" {
return ByteView{}, fmt.Errorf("key is required")
}
if v, ok := g.mainCache.get(key); ok {
log.Println("[GeeCache] hit")
return v, nil
}
return g.load(key)
}
// RegisterPeers registers a PeerPicker for choosing remote peer
func (g *Group) RegisterPeers(peers PeerPicker) {
if g.peers != nil {
panic("RegisterPeerPicker called more than once")
}
g.peers = peers
}
func (g *Group) load(key string) (value ByteView, err error) {
if g.peers != nil {
if peer, ok := g.peers.PickPeer(key); ok {
if value, err = g.getFromPeer(peer, key); err == nil {
return value, nil
}
log.Println("[GeeCache] Failed to get from peer", err)
}
}
return g.getLocally(key)
}
func (g *Group) populateCache(key string, value ByteView) {
g.mainCache.add(key, value)
}
func (g *Group) getLocally(key string) (ByteView, error) {
bytes, err := g.getter.Get(key)
if err != nil {
return ByteView{}, err
}
value := ByteView{b: cloneBytes(bytes)}
g.populateCache(key, value)
return value, nil
}
func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) {
bytes, err := peer.Get(g.name, key)
if err != nil {
return ByteView{}, err
}
return ByteView{b: bytes}, nil
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/geecache_test.go
================================================
package geecache
import (
"fmt"
"log"
"reflect"
"testing"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func TestGetter(t *testing.T) {
var f Getter = GetterFunc(func(key string) ([]byte, error) {
return []byte(key), nil
})
expect := []byte("key")
if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) {
t.Fatal("callback failed")
}
}
func TestGet(t *testing.T) {
loadCounts := make(map[string]int, len(db))
gee := NewGroup("scores", 2<<10, GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
if _, ok := loadCounts[key]; !ok {
loadCounts[key] = 0
}
loadCounts[key]++
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
for k, v := range db {
if view, err := gee.Get(k); err != nil || view.String() != v {
t.Fatal("failed to get value of Tom")
}
if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 {
t.Fatalf("cache %s miss", k)
}
}
if view, err := gee.Get("unknown"); err == nil {
t.Fatalf("the value of unknow should be empty, but %s got", view)
}
}
func TestGetGroup(t *testing.T) {
groupName := "scores"
NewGroup(groupName, 2<<10, GetterFunc(
func(key string) (bytes []byte, err error) { return }))
if group := GetGroup(groupName); group == nil || group.name != groupName {
t.Fatalf("group %s not exist", groupName)
}
if group := GetGroup(groupName + "111"); group != nil {
t.Fatalf("expect nil, but %s got", group.name)
}
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/go.mod
================================================
module geecache
go 1.13
================================================
FILE: gee-cache/day5-multi-nodes/geecache/http.go
================================================
package geecache
import (
"fmt"
"geecache/consistenthash"
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
"sync"
)
const (
defaultBasePath = "/_geecache/"
defaultReplicas = 50
)
// HTTPPool implements PeerPicker for a pool of HTTP peers.
type HTTPPool struct {
// this peer's base URL, e.g. "https://example.net:8000"
self string
basePath string
mu sync.Mutex // guards peers and httpGetters
peers *consistenthash.Map
httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008"
}
// NewHTTPPool initializes an HTTP pool of peers.
func NewHTTPPool(self string) *HTTPPool {
return &HTTPPool{
self: self,
basePath: defaultBasePath,
}
}
// Log info with server name
func (p *HTTPPool) Log(format string, v ...interface{}) {
log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...))
}
// ServeHTTP handle all http requests
func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, p.basePath) {
panic("HTTPPool serving unexpected path: " + r.URL.Path)
}
p.Log("%s %s", r.Method, r.URL.Path)
// /// required
parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2)
if len(parts) != 2 {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
groupName := parts[0]
key := parts[1]
group := GetGroup(groupName)
if group == nil {
http.Error(w, "no such group: "+groupName, http.StatusNotFound)
return
}
view, err := group.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}
// Set updates the pool's list of peers.
func (p *HTTPPool) Set(peers ...string) {
p.mu.Lock()
defer p.mu.Unlock()
p.peers = consistenthash.New(defaultReplicas, nil)
p.peers.Add(peers...)
p.httpGetters = make(map[string]*httpGetter, len(peers))
for _, peer := range peers {
p.httpGetters[peer] = &httpGetter{baseURL: peer + p.basePath}
}
}
// PickPeer picks a peer according to key
func (p *HTTPPool) PickPeer(key string) (PeerGetter, bool) {
p.mu.Lock()
defer p.mu.Unlock()
if peer := p.peers.Get(key); peer != "" && peer != p.self {
p.Log("Pick peer %s", peer)
return p.httpGetters[peer], true
}
return nil, false
}
var _ PeerPicker = (*HTTPPool)(nil)
type httpGetter struct {
baseURL string
}
func (h *httpGetter) Get(group string, key string) ([]byte, error) {
u := fmt.Sprintf(
"%v%v/%v",
h.baseURL,
url.QueryEscape(group),
url.QueryEscape(key),
)
res, err := http.Get(u)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned: %v", res.Status)
}
bytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("reading response body: %v", err)
}
return bytes, nil
}
var _ PeerGetter = (*httpGetter)(nil)
================================================
FILE: gee-cache/day5-multi-nodes/geecache/lru/lru.go
================================================
package lru
import "container/list"
// Cache is a LRU cache. It is not safe for concurrent access.
type Cache struct {
maxBytes int64
nbytes int64
ll *list.List
cache map[string]*list.Element
// optional and executed when an entry is purged.
OnEvicted func(key string, value Value)
}
type entry struct {
key string
value Value
}
// Value use Len to count how many bytes it takes
type Value interface {
Len() int
}
// New is the Constructor of Cache
func New(maxBytes int64, onEvicted func(string, Value)) *Cache {
return &Cache{
maxBytes: maxBytes,
ll: list.New(),
cache: make(map[string]*list.Element),
OnEvicted: onEvicted,
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value Value) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
c.nbytes += int64(value.Len()) - int64(kv.value.Len())
kv.value = value
} else {
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
c.nbytes += int64(len(key)) + int64(value.Len())
}
for c.maxBytes != 0 && c.maxBytes < c.nbytes {
c.RemoveOldest()
}
}
// Get look ups a key's value
func (c *Cache) Get(key string) (value Value, ok bool) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
return kv.value, true
}
return
}
// RemoveOldest removes the oldest item
func (c *Cache) RemoveOldest() {
ele := c.ll.Back()
if ele != nil {
c.ll.Remove(ele)
kv := ele.Value.(*entry)
delete(c.cache, kv.key)
c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len())
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
}
// Len the number of cache entries
func (c *Cache) Len() int {
return c.ll.Len()
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/lru/lru_test.go
================================================
package lru
import (
"reflect"
"testing"
)
type String string
func (d String) Len() int {
return len(d)
}
func TestGet(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key1", String("1234"))
if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" {
t.Fatalf("cache hit key1=1234 failed")
}
if _, ok := lru.Get("key2"); ok {
t.Fatalf("cache miss key2 failed")
}
}
func TestRemoveoldest(t *testing.T) {
k1, k2, k3 := "key1", "key2", "k3"
v1, v2, v3 := "value1", "value2", "v3"
cap := len(k1 + k2 + v1 + v2)
lru := New(int64(cap), nil)
lru.Add(k1, String(v1))
lru.Add(k2, String(v2))
lru.Add(k3, String(v3))
if _, ok := lru.Get("key1"); ok || lru.Len() != 2 {
t.Fatalf("Removeoldest key1 failed")
}
}
func TestOnEvicted(t *testing.T) {
keys := make([]string, 0)
callback := func(key string, value Value) {
keys = append(keys, key)
}
lru := New(int64(10), callback)
lru.Add("key1", String("123456"))
lru.Add("k2", String("k2"))
lru.Add("k3", String("k3"))
lru.Add("k4", String("k4"))
expect := []string{"key1", "k2"}
if !reflect.DeepEqual(expect, keys) {
t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect)
}
}
func TestAdd(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key", String("1"))
lru.Add("key", String("111"))
if lru.nbytes != int64(len("key")+len("111")) {
t.Fatal("expected 6 but got", lru.nbytes)
}
}
================================================
FILE: gee-cache/day5-multi-nodes/geecache/peers.go
================================================
package geecache
// PeerPicker is the interface that must be implemented to locate
// the peer that owns a specific key.
type PeerPicker interface {
PickPeer(key string) (peer PeerGetter, ok bool)
}
// PeerGetter is the interface that must be implemented by a peer.
type PeerGetter interface {
Get(group string, key string) ([]byte, error)
}
================================================
FILE: gee-cache/day5-multi-nodes/go.mod
================================================
module example
go 1.13
require geecache v0.0.0
replace geecache => ./geecache
================================================
FILE: gee-cache/day5-multi-nodes/main.go
================================================
package main
/*
$ curl "http://localhost:9999/api?key=Tom"
630
$ curl "http://localhost:9999/api?key=kkk"
kkk not exist
*/
import (
"flag"
"fmt"
"geecache"
"log"
"net/http"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func createGroup() *geecache.Group {
return geecache.NewGroup("scores", 2<<10, geecache.GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
}
func startCacheServer(addr string, addrs []string, gee *geecache.Group) {
peers := geecache.NewHTTPPool(addr)
peers.Set(addrs...)
gee.RegisterPeers(peers)
log.Println("geecache is running at", addr)
log.Fatal(http.ListenAndServe(addr[7:], peers))
}
func startAPIServer(apiAddr string, gee *geecache.Group) {
http.Handle("/api", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
key := r.URL.Query().Get("key")
view, err := gee.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}))
log.Println("fontend server is running at", apiAddr)
log.Fatal(http.ListenAndServe(apiAddr[7:], nil))
}
func main() {
var port int
var api bool
flag.IntVar(&port, "port", 8001, "Geecache server port")
flag.BoolVar(&api, "api", false, "Start a api server?")
flag.Parse()
apiAddr := "http://localhost:9999"
addrMap := map[int]string{
8001: "http://localhost:8001",
8002: "http://localhost:8002",
8003: "http://localhost:8003",
}
var addrs []string
for _, v := range addrMap {
addrs = append(addrs, v)
}
gee := createGroup()
if api {
go startAPIServer(apiAddr, gee)
}
startCacheServer(addrMap[port], addrs, gee)
}
================================================
FILE: gee-cache/day5-multi-nodes/run.sh
================================================
#!/bin/bash
trap "rm server;kill 0" EXIT
go build -o server
./server -port=8001 &
./server -port=8002 &
./server -port=8003 -api=1 &
sleep 2
echo ">>> start test"
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
wait
================================================
FILE: gee-cache/day6-single-flight/geecache/byteview.go
================================================
package geecache
// A ByteView holds an immutable view of bytes.
type ByteView struct {
b []byte
}
// Len returns the view's length
func (v ByteView) Len() int {
return len(v.b)
}
// ByteSlice returns a copy of the data as a byte slice.
func (v ByteView) ByteSlice() []byte {
return cloneBytes(v.b)
}
// String returns the data as a string, making a copy if necessary.
func (v ByteView) String() string {
return string(v.b)
}
func cloneBytes(b []byte) []byte {
c := make([]byte, len(b))
copy(c, b)
return c
}
================================================
FILE: gee-cache/day6-single-flight/geecache/cache.go
================================================
package geecache
import (
"geecache/lru"
"sync"
)
type cache struct {
mu sync.Mutex
lru *lru.Cache
cacheBytes int64
}
func (c *cache) add(key string, value ByteView) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
c.lru = lru.New(c.cacheBytes, nil)
}
c.lru.Add(key, value)
}
func (c *cache) get(key string) (value ByteView, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
return
}
if v, ok := c.lru.Get(key); ok {
return v.(ByteView), ok
}
return
}
================================================
FILE: gee-cache/day6-single-flight/geecache/consistenthash/consistenthash.go
================================================
package consistenthash
import (
"hash/crc32"
"sort"
"strconv"
)
// Hash maps bytes to uint32
type Hash func(data []byte) uint32
// Map constains all hashed keys
type Map struct {
hash Hash
replicas int
keys []int // Sorted
hashMap map[int]string
}
// New creates a Map instance
func New(replicas int, fn Hash) *Map {
m := &Map{
replicas: replicas,
hash: fn,
hashMap: make(map[int]string),
}
if m.hash == nil {
m.hash = crc32.ChecksumIEEE
}
return m
}
// Add adds some keys to the hash.
func (m *Map) Add(keys ...string) {
for _, key := range keys {
for i := 0; i < m.replicas; i++ {
hash := int(m.hash([]byte(strconv.Itoa(i) + key)))
m.keys = append(m.keys, hash)
m.hashMap[hash] = key
}
}
sort.Ints(m.keys)
}
// Get gets the closest item in the hash to the provided key.
func (m *Map) Get(key string) string {
if len(m.keys) == 0 {
return ""
}
hash := int(m.hash([]byte(key)))
// Binary search for appropriate replica.
idx := sort.Search(len(m.keys), func(i int) bool {
return m.keys[i] >= hash
})
return m.hashMap[m.keys[idx%len(m.keys)]]
}
================================================
FILE: gee-cache/day6-single-flight/geecache/consistenthash/consistenthash_test.go
================================================
package consistenthash
import (
"strconv"
"testing"
)
func TestHashing(t *testing.T) {
hash := New(3, func(key []byte) uint32 {
i, _ := strconv.Atoi(string(key))
return uint32(i)
})
// Given the above hash function, this will give replicas with "hashes":
// 2, 4, 6, 12, 14, 16, 22, 24, 26
hash.Add("6", "4", "2")
testCases := map[string]string{
"2": "2",
"11": "2",
"23": "4",
"27": "2",
}
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
// Adds 8, 18, 28
hash.Add("8")
// 27 should now map to 8.
testCases["27"] = "8"
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
}
================================================
FILE: gee-cache/day6-single-flight/geecache/geecache.go
================================================
package geecache
import (
"fmt"
"geecache/singleflight"
"log"
"sync"
)
// A Group is a cache namespace and associated data loaded spread over
type Group struct {
name string
getter Getter
mainCache cache
peers PeerPicker
// use singleflight.Group to make sure that
// each key is only fetched once
loader *singleflight.Group
}
// A Getter loads data for a key.
type Getter interface {
Get(key string) ([]byte, error)
}
// A GetterFunc implements Getter with a function.
type GetterFunc func(key string) ([]byte, error)
// Get implements Getter interface function
func (f GetterFunc) Get(key string) ([]byte, error) {
return f(key)
}
var (
mu sync.RWMutex
groups = make(map[string]*Group)
)
// NewGroup create a new instance of Group
func NewGroup(name string, cacheBytes int64, getter Getter) *Group {
if getter == nil {
panic("nil Getter")
}
mu.Lock()
defer mu.Unlock()
g := &Group{
name: name,
getter: getter,
mainCache: cache{cacheBytes: cacheBytes},
loader: &singleflight.Group{},
}
groups[name] = g
return g
}
// GetGroup returns the named group previously created with NewGroup, or
// nil if there's no such group.
func GetGroup(name string) *Group {
mu.RLock()
g := groups[name]
mu.RUnlock()
return g
}
// Get value for a key from cache
func (g *Group) Get(key string) (ByteView, error) {
if key == "" {
return ByteView{}, fmt.Errorf("key is required")
}
if v, ok := g.mainCache.get(key); ok {
log.Println("[GeeCache] hit")
return v, nil
}
return g.load(key)
}
// RegisterPeers registers a PeerPicker for choosing remote peer
func (g *Group) RegisterPeers(peers PeerPicker) {
if g.peers != nil {
panic("RegisterPeerPicker called more than once")
}
g.peers = peers
}
func (g *Group) load(key string) (value ByteView, err error) {
// each key is only fetched once (either locally or remotely)
// regardless of the number of concurrent callers.
viewi, err := g.loader.Do(key, func() (interface{}, error) {
if g.peers != nil {
if peer, ok := g.peers.PickPeer(key); ok {
if value, err = g.getFromPeer(peer, key); err == nil {
return value, nil
}
log.Println("[GeeCache] Failed to get from peer", err)
}
}
return g.getLocally(key)
})
if err == nil {
return viewi.(ByteView), nil
}
return
}
func (g *Group) populateCache(key string, value ByteView) {
g.mainCache.add(key, value)
}
func (g *Group) getLocally(key string) (ByteView, error) {
bytes, err := g.getter.Get(key)
if err != nil {
return ByteView{}, err
}
value := ByteView{b: cloneBytes(bytes)}
g.populateCache(key, value)
return value, nil
}
func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) {
bytes, err := peer.Get(g.name, key)
if err != nil {
return ByteView{}, err
}
return ByteView{b: bytes}, nil
}
================================================
FILE: gee-cache/day6-single-flight/geecache/geecache_test.go
================================================
package geecache
import (
"fmt"
"log"
"reflect"
"testing"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func TestGetter(t *testing.T) {
var f Getter = GetterFunc(func(key string) ([]byte, error) {
return []byte(key), nil
})
expect := []byte("key")
if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) {
t.Fatal("callback failed")
}
}
func TestGet(t *testing.T) {
loadCounts := make(map[string]int, len(db))
gee := NewGroup("scores", 2<<10, GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
if _, ok := loadCounts[key]; !ok {
loadCounts[key] = 0
}
loadCounts[key]++
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
for k, v := range db {
if view, err := gee.Get(k); err != nil || view.String() != v {
t.Fatal("failed to get value of Tom")
}
if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 {
t.Fatalf("cache %s miss", k)
}
}
if view, err := gee.Get("unknown"); err == nil {
t.Fatalf("the value of unknow should be empty, but %s got", view)
}
}
func TestGetGroup(t *testing.T) {
groupName := "scores"
NewGroup(groupName, 2<<10, GetterFunc(
func(key string) (bytes []byte, err error) { return }))
if group := GetGroup(groupName); group == nil || group.name != groupName {
t.Fatalf("group %s not exist", groupName)
}
if group := GetGroup(groupName + "111"); group != nil {
t.Fatalf("expect nil, but %s got", group.name)
}
}
================================================
FILE: gee-cache/day6-single-flight/geecache/go.mod
================================================
module geecache
go 1.13
================================================
FILE: gee-cache/day6-single-flight/geecache/http.go
================================================
package geecache
import (
"fmt"
"geecache/consistenthash"
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
"sync"
)
const (
defaultBasePath = "/_geecache/"
defaultReplicas = 50
)
// HTTPPool implements PeerPicker for a pool of HTTP peers.
type HTTPPool struct {
// this peer's base URL, e.g. "https://example.net:8000"
self string
basePath string
mu sync.Mutex // guards peers and httpGetters
peers *consistenthash.Map
httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008"
}
// NewHTTPPool initializes an HTTP pool of peers.
func NewHTTPPool(self string) *HTTPPool {
return &HTTPPool{
self: self,
basePath: defaultBasePath,
}
}
// Log info with server name
func (p *HTTPPool) Log(format string, v ...interface{}) {
log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...))
}
// ServeHTTP handle all http requests
func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, p.basePath) {
panic("HTTPPool serving unexpected path: " + r.URL.Path)
}
p.Log("%s %s", r.Method, r.URL.Path)
// /// required
parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2)
if len(parts) != 2 {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
groupName := parts[0]
key := parts[1]
group := GetGroup(groupName)
if group == nil {
http.Error(w, "no such group: "+groupName, http.StatusNotFound)
return
}
view, err := group.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}
// Set updates the pool's list of peers.
func (p *HTTPPool) Set(peers ...string) {
p.mu.Lock()
defer p.mu.Unlock()
p.peers = consistenthash.New(defaultReplicas, nil)
p.peers.Add(peers...)
p.httpGetters = make(map[string]*httpGetter, len(peers))
for _, peer := range peers {
p.httpGetters[peer] = &httpGetter{baseURL: peer + p.basePath}
}
}
// PickPeer picks a peer according to key
func (p *HTTPPool) PickPeer(key string) (PeerGetter, bool) {
p.mu.Lock()
defer p.mu.Unlock()
if peer := p.peers.Get(key); peer != "" && peer != p.self {
p.Log("Pick peer %s", peer)
return p.httpGetters[peer], true
}
return nil, false
}
var _ PeerPicker = (*HTTPPool)(nil)
type httpGetter struct {
baseURL string
}
func (h *httpGetter) Get(group string, key string) ([]byte, error) {
u := fmt.Sprintf(
"%v%v/%v",
h.baseURL,
url.QueryEscape(group),
url.QueryEscape(key),
)
res, err := http.Get(u)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned: %v", res.Status)
}
bytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("reading response body: %v", err)
}
return bytes, nil
}
var _ PeerGetter = (*httpGetter)(nil)
================================================
FILE: gee-cache/day6-single-flight/geecache/lru/lru.go
================================================
package lru
import "container/list"
// Cache is a LRU cache. It is not safe for concurrent access.
type Cache struct {
maxBytes int64
nbytes int64
ll *list.List
cache map[string]*list.Element
// optional and executed when an entry is purged.
OnEvicted func(key string, value Value)
}
type entry struct {
key string
value Value
}
// Value use Len to count how many bytes it takes
type Value interface {
Len() int
}
// New is the Constructor of Cache
func New(maxBytes int64, onEvicted func(string, Value)) *Cache {
return &Cache{
maxBytes: maxBytes,
ll: list.New(),
cache: make(map[string]*list.Element),
OnEvicted: onEvicted,
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value Value) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
c.nbytes += int64(value.Len()) - int64(kv.value.Len())
kv.value = value
} else {
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
c.nbytes += int64(len(key)) + int64(value.Len())
}
for c.maxBytes != 0 && c.maxBytes < c.nbytes {
c.RemoveOldest()
}
}
// Get look ups a key's value
func (c *Cache) Get(key string) (value Value, ok bool) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
return kv.value, true
}
return
}
// RemoveOldest removes the oldest item
func (c *Cache) RemoveOldest() {
ele := c.ll.Back()
if ele != nil {
c.ll.Remove(ele)
kv := ele.Value.(*entry)
delete(c.cache, kv.key)
c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len())
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
}
// Len the number of cache entries
func (c *Cache) Len() int {
return c.ll.Len()
}
================================================
FILE: gee-cache/day6-single-flight/geecache/lru/lru_test.go
================================================
package lru
import (
"reflect"
"testing"
)
type String string
func (d String) Len() int {
return len(d)
}
func TestGet(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key1", String("1234"))
if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" {
t.Fatalf("cache hit key1=1234 failed")
}
if _, ok := lru.Get("key2"); ok {
t.Fatalf("cache miss key2 failed")
}
}
func TestRemoveoldest(t *testing.T) {
k1, k2, k3 := "key1", "key2", "k3"
v1, v2, v3 := "value1", "value2", "v3"
cap := len(k1 + k2 + v1 + v2)
lru := New(int64(cap), nil)
lru.Add(k1, String(v1))
lru.Add(k2, String(v2))
lru.Add(k3, String(v3))
if _, ok := lru.Get("key1"); ok || lru.Len() != 2 {
t.Fatalf("Removeoldest key1 failed")
}
}
func TestOnEvicted(t *testing.T) {
keys := make([]string, 0)
callback := func(key string, value Value) {
keys = append(keys, key)
}
lru := New(int64(10), callback)
lru.Add("key1", String("123456"))
lru.Add("k2", String("k2"))
lru.Add("k3", String("k3"))
lru.Add("k4", String("k4"))
expect := []string{"key1", "k2"}
if !reflect.DeepEqual(expect, keys) {
t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect)
}
}
func TestAdd(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key", String("1"))
lru.Add("key", String("111"))
if lru.nbytes != int64(len("key")+len("111")) {
t.Fatal("expected 6 but got", lru.nbytes)
}
}
================================================
FILE: gee-cache/day6-single-flight/geecache/peers.go
================================================
package geecache
// PeerPicker is the interface that must be implemented to locate
// the peer that owns a specific key.
type PeerPicker interface {
PickPeer(key string) (peer PeerGetter, ok bool)
}
// PeerGetter is the interface that must be implemented by a peer.
type PeerGetter interface {
Get(group string, key string) ([]byte, error)
}
================================================
FILE: gee-cache/day6-single-flight/geecache/singleflight/singleflight.go
================================================
package singleflight
import "sync"
// call is an in-flight or completed Do call
type call struct {
wg sync.WaitGroup
val interface{}
err error
}
// Group represents a class of work and forms a namespace in which
// units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.err
}
================================================
FILE: gee-cache/day6-single-flight/geecache/singleflight/singleflight_test.go
================================================
package singleflight
import (
"testing"
)
func TestDo(t *testing.T) {
var g Group
v, err := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if v != "bar" || err != nil {
t.Errorf("Do v = %v, error = %v", v, err)
}
}
================================================
FILE: gee-cache/day6-single-flight/go.mod
================================================
module example
go 1.13
require geecache v0.0.0
replace geecache => ./geecache
================================================
FILE: gee-cache/day6-single-flight/main.go
================================================
package main
/*
$ curl "http://localhost:9999/api?key=Tom"
630
$ curl "http://localhost:9999/api?key=kkk"
kkk not exist
*/
import (
"flag"
"fmt"
"geecache"
"log"
"net/http"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func createGroup() *geecache.Group {
return geecache.NewGroup("scores", 2<<10, geecache.GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
}
func startCacheServer(addr string, addrs []string, gee *geecache.Group) {
peers := geecache.NewHTTPPool(addr)
peers.Set(addrs...)
gee.RegisterPeers(peers)
log.Println("geecache is running at", addr)
log.Fatal(http.ListenAndServe(addr[7:], peers))
}
func startAPIServer(apiAddr string, gee *geecache.Group) {
http.Handle("/api", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
key := r.URL.Query().Get("key")
view, err := gee.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}))
log.Println("fontend server is running at", apiAddr)
log.Fatal(http.ListenAndServe(apiAddr[7:], nil))
}
func main() {
var port int
var api bool
flag.IntVar(&port, "port", 8001, "Geecache server port")
flag.BoolVar(&api, "api", false, "Start a api server?")
flag.Parse()
apiAddr := "http://localhost:9999"
addrMap := map[int]string{
8001: "http://localhost:8001",
8002: "http://localhost:8002",
8003: "http://localhost:8003",
}
var addrs []string
for _, v := range addrMap {
addrs = append(addrs, v)
}
gee := createGroup()
if api {
go startAPIServer(apiAddr, gee)
}
startCacheServer(addrMap[port], addrs, gee)
}
================================================
FILE: gee-cache/day6-single-flight/run.sh
================================================
#!/bin/bash
trap "rm server;kill 0" EXIT
go build -o server
./server -port=8001 &
./server -port=8002 &
./server -port=8003 -api=1 &
sleep 2
echo ">>> start test"
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
wait
================================================
FILE: gee-cache/day7-proto-buf/geecache/byteview.go
================================================
package geecache
// A ByteView holds an immutable view of bytes.
type ByteView struct {
b []byte
}
// Len returns the view's length
func (v ByteView) Len() int {
return len(v.b)
}
// ByteSlice returns a copy of the data as a byte slice.
func (v ByteView) ByteSlice() []byte {
return cloneBytes(v.b)
}
// String returns the data as a string, making a copy if necessary.
func (v ByteView) String() string {
return string(v.b)
}
func cloneBytes(b []byte) []byte {
c := make([]byte, len(b))
copy(c, b)
return c
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/cache.go
================================================
package geecache
import (
"geecache/lru"
"sync"
)
type cache struct {
mu sync.Mutex
lru *lru.Cache
cacheBytes int64
}
func (c *cache) add(key string, value ByteView) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
c.lru = lru.New(c.cacheBytes, nil)
}
c.lru.Add(key, value)
}
func (c *cache) get(key string) (value ByteView, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
return
}
if v, ok := c.lru.Get(key); ok {
return v.(ByteView), ok
}
return
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/consistenthash/consistenthash.go
================================================
package consistenthash
import (
"hash/crc32"
"sort"
"strconv"
)
// Hash maps bytes to uint32
type Hash func(data []byte) uint32
// Map constains all hashed keys
type Map struct {
hash Hash
replicas int
keys []int // Sorted
hashMap map[int]string
}
// New creates a Map instance
func New(replicas int, fn Hash) *Map {
m := &Map{
replicas: replicas,
hash: fn,
hashMap: make(map[int]string),
}
if m.hash == nil {
m.hash = crc32.ChecksumIEEE
}
return m
}
// Add adds some keys to the hash.
func (m *Map) Add(keys ...string) {
for _, key := range keys {
for i := 0; i < m.replicas; i++ {
hash := int(m.hash([]byte(strconv.Itoa(i) + key)))
m.keys = append(m.keys, hash)
m.hashMap[hash] = key
}
}
sort.Ints(m.keys)
}
// Get gets the closest item in the hash to the provided key.
func (m *Map) Get(key string) string {
if len(m.keys) == 0 {
return ""
}
hash := int(m.hash([]byte(key)))
// Binary search for appropriate replica.
idx := sort.Search(len(m.keys), func(i int) bool {
return m.keys[i] >= hash
})
return m.hashMap[m.keys[idx%len(m.keys)]]
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/consistenthash/consistenthash_test.go
================================================
package consistenthash
import (
"strconv"
"testing"
)
func TestHashing(t *testing.T) {
hash := New(3, func(key []byte) uint32 {
i, _ := strconv.Atoi(string(key))
return uint32(i)
})
// Given the above hash function, this will give replicas with "hashes":
// 2, 4, 6, 12, 14, 16, 22, 24, 26
hash.Add("6", "4", "2")
testCases := map[string]string{
"2": "2",
"11": "2",
"23": "4",
"27": "2",
}
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
// Adds 8, 18, 28
hash.Add("8")
// 27 should now map to 8.
testCases["27"] = "8"
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/geecache.go
================================================
package geecache
import (
"fmt"
pb "geecache/geecachepb"
"geecache/singleflight"
"log"
"sync"
)
// A Group is a cache namespace and associated data loaded spread over
type Group struct {
name string
getter Getter
mainCache cache
peers PeerPicker
// use singleflight.Group to make sure that
// each key is only fetched once
loader *singleflight.Group
}
// A Getter loads data for a key.
type Getter interface {
Get(key string) ([]byte, error)
}
// A GetterFunc implements Getter with a function.
type GetterFunc func(key string) ([]byte, error)
// Get implements Getter interface function
func (f GetterFunc) Get(key string) ([]byte, error) {
return f(key)
}
var (
mu sync.RWMutex
groups = make(map[string]*Group)
)
// NewGroup create a new instance of Group
func NewGroup(name string, cacheBytes int64, getter Getter) *Group {
if getter == nil {
panic("nil Getter")
}
mu.Lock()
defer mu.Unlock()
g := &Group{
name: name,
getter: getter,
mainCache: cache{cacheBytes: cacheBytes},
loader: &singleflight.Group{},
}
groups[name] = g
return g
}
// GetGroup returns the named group previously created with NewGroup, or
// nil if there's no such group.
func GetGroup(name string) *Group {
mu.RLock()
g := groups[name]
mu.RUnlock()
return g
}
// Get value for a key from cache
func (g *Group) Get(key string) (ByteView, error) {
if key == "" {
return ByteView{}, fmt.Errorf("key is required")
}
if v, ok := g.mainCache.get(key); ok {
log.Println("[GeeCache] hit")
return v, nil
}
return g.load(key)
}
// RegisterPeers registers a PeerPicker for choosing remote peer
func (g *Group) RegisterPeers(peers PeerPicker) {
if g.peers != nil {
panic("RegisterPeerPicker called more than once")
}
g.peers = peers
}
func (g *Group) load(key string) (value ByteView, err error) {
// each key is only fetched once (either locally or remotely)
// regardless of the number of concurrent callers.
viewi, err := g.loader.Do(key, func() (interface{}, error) {
if g.peers != nil {
if peer, ok := g.peers.PickPeer(key); ok {
if value, err = g.getFromPeer(peer, key); err == nil {
return value, nil
}
log.Println("[GeeCache] Failed to get from peer", err)
}
}
return g.getLocally(key)
})
if err == nil {
return viewi.(ByteView), nil
}
return
}
func (g *Group) populateCache(key string, value ByteView) {
g.mainCache.add(key, value)
}
func (g *Group) getLocally(key string) (ByteView, error) {
bytes, err := g.getter.Get(key)
if err != nil {
return ByteView{}, err
}
value := ByteView{b: cloneBytes(bytes)}
g.populateCache(key, value)
return value, nil
}
func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) {
req := &pb.Request{
Group: g.name,
Key: key,
}
res := &pb.Response{}
err := peer.Get(req, res)
if err != nil {
return ByteView{}, err
}
return ByteView{b: res.Value}, nil
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/geecache_test.go
================================================
package geecache
import (
"fmt"
"log"
"reflect"
"testing"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func TestGetter(t *testing.T) {
var f Getter = GetterFunc(func(key string) ([]byte, error) {
return []byte(key), nil
})
expect := []byte("key")
if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) {
t.Fatal("callback failed")
}
}
func TestGet(t *testing.T) {
loadCounts := make(map[string]int, len(db))
gee := NewGroup("scores", 2<<10, GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
if _, ok := loadCounts[key]; !ok {
loadCounts[key] = 0
}
loadCounts[key]++
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
for k, v := range db {
if view, err := gee.Get(k); err != nil || view.String() != v {
t.Fatal("failed to get value of Tom")
}
if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 {
t.Fatalf("cache %s miss", k)
}
}
if view, err := gee.Get("unknown"); err == nil {
t.Fatalf("the value of unknow should be empty, but %s got", view)
}
}
func TestGetGroup(t *testing.T) {
groupName := "scores"
NewGroup(groupName, 2<<10, GetterFunc(
func(key string) (bytes []byte, err error) { return }))
if group := GetGroup(groupName); group == nil || group.name != groupName {
t.Fatalf("group %s not exist", groupName)
}
if group := GetGroup(groupName + "111"); group != nil {
t.Fatalf("expect nil, but %s got", group.name)
}
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/geecachepb/geecachepb.pb.go
================================================
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: geecachepb.proto
package geecachepb
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type Request struct {
Group string `protobuf:"bytes,1,opt,name=group,proto3" json:"group,omitempty"`
Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Request) Reset() { *m = Request{} }
func (m *Request) String() string { return proto.CompactTextString(m) }
func (*Request) ProtoMessage() {}
func (*Request) Descriptor() ([]byte, []int) {
return fileDescriptor_889d0a4ad37a0d42, []int{0}
}
func (m *Request) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Request.Unmarshal(m, b)
}
func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Request.Marshal(b, m, deterministic)
}
func (m *Request) XXX_Merge(src proto.Message) {
xxx_messageInfo_Request.Merge(m, src)
}
func (m *Request) XXX_Size() int {
return xxx_messageInfo_Request.Size(m)
}
func (m *Request) XXX_DiscardUnknown() {
xxx_messageInfo_Request.DiscardUnknown(m)
}
var xxx_messageInfo_Request proto.InternalMessageInfo
func (m *Request) GetGroup() string {
if m != nil {
return m.Group
}
return ""
}
func (m *Request) GetKey() string {
if m != nil {
return m.Key
}
return ""
}
type Response struct {
Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Response) Reset() { *m = Response{} }
func (m *Response) String() string { return proto.CompactTextString(m) }
func (*Response) ProtoMessage() {}
func (*Response) Descriptor() ([]byte, []int) {
return fileDescriptor_889d0a4ad37a0d42, []int{1}
}
func (m *Response) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Response.Unmarshal(m, b)
}
func (m *Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Response.Marshal(b, m, deterministic)
}
func (m *Response) XXX_Merge(src proto.Message) {
xxx_messageInfo_Response.Merge(m, src)
}
func (m *Response) XXX_Size() int {
return xxx_messageInfo_Response.Size(m)
}
func (m *Response) XXX_DiscardUnknown() {
xxx_messageInfo_Response.DiscardUnknown(m)
}
var xxx_messageInfo_Response proto.InternalMessageInfo
func (m *Response) GetValue() []byte {
if m != nil {
return m.Value
}
return nil
}
func init() {
proto.RegisterType((*Request)(nil), "geecachepb.Request")
proto.RegisterType((*Response)(nil), "geecachepb.Response")
}
func init() { proto.RegisterFile("geecachepb.proto", fileDescriptor_889d0a4ad37a0d42) }
var fileDescriptor_889d0a4ad37a0d42 = []byte{
// 148 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x48, 0x4f, 0x4d, 0x4d,
0x4e, 0x4c, 0xce, 0x48, 0x2d, 0x48, 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x42, 0x88,
0x28, 0x19, 0x72, 0xb1, 0x07, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, 0x08, 0x89, 0x70, 0xb1, 0xa6,
0x17, 0xe5, 0x97, 0x16, 0x48, 0x30, 0x2a, 0x30, 0x6a, 0x70, 0x06, 0x41, 0x38, 0x42, 0x02, 0x5c,
0xcc, 0xd9, 0xa9, 0x95, 0x12, 0x4c, 0x60, 0x31, 0x10, 0x53, 0x49, 0x81, 0x8b, 0x23, 0x28, 0xb5,
0xb8, 0x20, 0x3f, 0xaf, 0x38, 0x15, 0xa4, 0xa7, 0x2c, 0x31, 0xa7, 0x34, 0x15, 0xac, 0x87, 0x27,
0x08, 0xc2, 0x31, 0xb2, 0xe3, 0xe2, 0x72, 0x07, 0x69, 0x76, 0x06, 0x59, 0x22, 0x64, 0xc0, 0xc5,
0xec, 0x9e, 0x5a, 0x22, 0x24, 0xac, 0x87, 0xe4, 0x10, 0xa8, 0x9d, 0x52, 0x22, 0xa8, 0x82, 0x10,
0x53, 0x93, 0xd8, 0xc0, 0xee, 0x34, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x5c, 0xd5, 0xdd, 0x09,
0xbb, 0x00, 0x00, 0x00,
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/geecachepb/geecachepb.proto
================================================
syntax = "proto3";
package geecachepb;
message Request {
string group = 1;
string key = 2;
}
message Response {
bytes value = 1;
}
service GroupCache {
rpc Get(Request) returns (Response);
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/go.mod
================================================
module geecache
go 1.13
require github.com/golang/protobuf v1.3.3
================================================
FILE: gee-cache/day7-proto-buf/geecache/http.go
================================================
package geecache
import (
"fmt"
"geecache/consistenthash"
pb "geecache/geecachepb"
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
"sync"
"github.com/golang/protobuf/proto"
)
const (
defaultBasePath = "/_geecache/"
defaultReplicas = 50
)
// HTTPPool implements PeerPicker for a pool of HTTP peers.
type HTTPPool struct {
// this peer's base URL, e.g. "https://example.net:8000"
self string
basePath string
mu sync.Mutex // guards peers and httpGetters
peers *consistenthash.Map
httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008"
}
// NewHTTPPool initializes an HTTP pool of peers.
func NewHTTPPool(self string) *HTTPPool {
return &HTTPPool{
self: self,
basePath: defaultBasePath,
}
}
// Log info with server name
func (p *HTTPPool) Log(format string, v ...interface{}) {
log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...))
}
// ServeHTTP handle all http requests
func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, p.basePath) {
panic("HTTPPool serving unexpected path: " + r.URL.Path)
}
p.Log("%s %s", r.Method, r.URL.Path)
// /// required
parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2)
if len(parts) != 2 {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
groupName := parts[0]
key := parts[1]
group := GetGroup(groupName)
if group == nil {
http.Error(w, "no such group: "+groupName, http.StatusNotFound)
return
}
view, err := group.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write the value to the response body as a proto message.
body, err := proto.Marshal(&pb.Response{Value: view.ByteSlice()})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(body)
}
// Set updates the pool's list of peers.
func (p *HTTPPool) Set(peers ...string) {
p.mu.Lock()
defer p.mu.Unlock()
p.peers = consistenthash.New(defaultReplicas, nil)
p.peers.Add(peers...)
p.httpGetters = make(map[string]*httpGetter, len(peers))
for _, peer := range peers {
p.httpGetters[peer] = &httpGetter{baseURL: peer + p.basePath}
}
}
// PickPeer picks a peer according to key
func (p *HTTPPool) PickPeer(key string) (PeerGetter, bool) {
p.mu.Lock()
defer p.mu.Unlock()
if peer := p.peers.Get(key); peer != "" && peer != p.self {
p.Log("Pick peer %s", peer)
return p.httpGetters[peer], true
}
return nil, false
}
var _ PeerPicker = (*HTTPPool)(nil)
type httpGetter struct {
baseURL string
}
func (h *httpGetter) Get(in *pb.Request, out *pb.Response) error {
u := fmt.Sprintf(
"%v%v/%v",
h.baseURL,
url.QueryEscape(in.GetGroup()),
url.QueryEscape(in.GetKey()),
)
res, err := http.Get(u)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return fmt.Errorf("server returned: %v", res.Status)
}
bytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("reading response body: %v", err)
}
if err = proto.Unmarshal(bytes, out); err != nil {
return fmt.Errorf("decoding response body: %v", err)
}
return nil
}
var _ PeerGetter = (*httpGetter)(nil)
================================================
FILE: gee-cache/day7-proto-buf/geecache/lru/lru.go
================================================
package lru
import "container/list"
// Cache is a LRU cache. It is not safe for concurrent access.
type Cache struct {
maxBytes int64
nbytes int64
ll *list.List
cache map[string]*list.Element
// optional and executed when an entry is purged.
OnEvicted func(key string, value Value)
}
type entry struct {
key string
value Value
}
// Value use Len to count how many bytes it takes
type Value interface {
Len() int
}
// New is the Constructor of Cache
func New(maxBytes int64, onEvicted func(string, Value)) *Cache {
return &Cache{
maxBytes: maxBytes,
ll: list.New(),
cache: make(map[string]*list.Element),
OnEvicted: onEvicted,
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value Value) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
c.nbytes += int64(value.Len()) - int64(kv.value.Len())
kv.value = value
} else {
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
c.nbytes += int64(len(key)) + int64(value.Len())
}
for c.maxBytes != 0 && c.maxBytes < c.nbytes {
c.RemoveOldest()
}
}
// Get look ups a key's value
func (c *Cache) Get(key string) (value Value, ok bool) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
return kv.value, true
}
return
}
// RemoveOldest removes the oldest item
func (c *Cache) RemoveOldest() {
ele := c.ll.Back()
if ele != nil {
c.ll.Remove(ele)
kv := ele.Value.(*entry)
delete(c.cache, kv.key)
c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len())
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
}
// Len the number of cache entries
func (c *Cache) Len() int {
return c.ll.Len()
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/lru/lru_test.go
================================================
package lru
import (
"reflect"
"testing"
)
type String string
func (d String) Len() int {
return len(d)
}
func TestGet(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key1", String("1234"))
if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" {
t.Fatalf("cache hit key1=1234 failed")
}
if _, ok := lru.Get("key2"); ok {
t.Fatalf("cache miss key2 failed")
}
}
func TestRemoveoldest(t *testing.T) {
k1, k2, k3 := "key1", "key2", "k3"
v1, v2, v3 := "value1", "value2", "v3"
cap := len(k1 + k2 + v1 + v2)
lru := New(int64(cap), nil)
lru.Add(k1, String(v1))
lru.Add(k2, String(v2))
lru.Add(k3, String(v3))
if _, ok := lru.Get("key1"); ok || lru.Len() != 2 {
t.Fatalf("Removeoldest key1 failed")
}
}
func TestOnEvicted(t *testing.T) {
keys := make([]string, 0)
callback := func(key string, value Value) {
keys = append(keys, key)
}
lru := New(int64(10), callback)
lru.Add("key1", String("123456"))
lru.Add("k2", String("k2"))
lru.Add("k3", String("k3"))
lru.Add("k4", String("k4"))
expect := []string{"key1", "k2"}
if !reflect.DeepEqual(expect, keys) {
t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect)
}
}
func TestAdd(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key", String("1"))
lru.Add("key", String("111"))
if lru.nbytes != int64(len("key")+len("111")) {
t.Fatal("expected 6 but got", lru.nbytes)
}
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/peers.go
================================================
package geecache
import pb "geecache/geecachepb"
// PeerPicker is the interface that must be implemented to locate
// the peer that owns a specific key.
type PeerPicker interface {
PickPeer(key string) (peer PeerGetter, ok bool)
}
// PeerGetter is the interface that must be implemented by a peer.
type PeerGetter interface {
Get(in *pb.Request, out *pb.Response) error
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/singleflight/singleflight.go
================================================
package singleflight
import "sync"
// call is an in-flight or completed Do call
type call struct {
wg sync.WaitGroup
val interface{}
err error
}
// Group represents a class of work and forms a namespace in which
// units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.err
}
================================================
FILE: gee-cache/day7-proto-buf/geecache/singleflight/singleflight_test.go
================================================
package singleflight
import (
"testing"
)
func TestDo(t *testing.T) {
var g Group
v, err := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if v != "bar" || err != nil {
t.Errorf("Do v = %v, error = %v", v, err)
}
}
================================================
FILE: gee-cache/day7-proto-buf/go.mod
================================================
module example
go 1.13
require geecache v0.0.0
replace geecache => ./geecache
================================================
FILE: gee-cache/day7-proto-buf/main.go
================================================
package main
/*
$ curl "http://localhost:9999/api?key=Tom"
630
$ curl "http://localhost:9999/api?key=kkk"
kkk not exist
*/
import (
"flag"
"fmt"
"geecache"
"log"
"net/http"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func createGroup() *geecache.Group {
return geecache.NewGroup("scores", 2<<10, geecache.GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
}
func startCacheServer(addr string, addrs []string, gee *geecache.Group) {
peers := geecache.NewHTTPPool(addr)
peers.Set(addrs...)
gee.RegisterPeers(peers)
log.Println("geecache is running at", addr)
log.Fatal(http.ListenAndServe(addr[7:], peers))
}
func startAPIServer(apiAddr string, gee *geecache.Group) {
http.Handle("/api", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
key := r.URL.Query().Get("key")
view, err := gee.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}))
log.Println("fontend server is running at", apiAddr)
log.Fatal(http.ListenAndServe(apiAddr[7:], nil))
}
func main() {
var port int
var api bool
flag.IntVar(&port, "port", 8001, "Geecache server port")
flag.BoolVar(&api, "api", false, "Start a api server?")
flag.Parse()
apiAddr := "http://localhost:9999"
addrMap := map[int]string{
8001: "http://localhost:8001",
8002: "http://localhost:8002",
8003: "http://localhost:8003",
}
var addrs []string
for _, v := range addrMap {
addrs = append(addrs, v)
}
gee := createGroup()
if api {
go startAPIServer(apiAddr, gee)
}
startCacheServer(addrMap[port], addrs, gee)
}
================================================
FILE: gee-cache/day7-proto-buf/run.sh
================================================
#!/bin/bash
trap "rm server;kill 0" EXIT
go build -o server
./server -port=8001 &
./server -port=8002 &
./server -port=8003 -api=1 &
sleep 2
echo ">>> start test"
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
wait
================================================
FILE: gee-cache/doc/geecache-day1.md
================================================
---
title: 动手写分布式缓存 - GeeCache第一天 LRU 缓存淘汰策略
date: 2020-02-11 22:00:00
description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了常用的三种缓存淘汰(失效)算法:先进先出(FIFO),最少使用(LFU) 和 最近最少使用(LRU),并实现 LRU 算法和相应的测试代码。
tags:
- Go
nav: 从零实现
categories:
- 分布式缓存 - GeeCache
keywords:
- Go语言
- 从零实现
- 分布式缓存
- LRU
- 缓存失效
image: post/geecache-day1/lru_logo.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day1 LRU 缓存淘汰策略
---
本文是[7天用Go从零实现分布式缓存GeeCache教程系列](https://geektutu.com/post/geecache.html)的第一篇。
- 介绍常用的三种缓存淘汰(失效)算法:FIFO,LFU 和 LRU
- 实现 LRU 缓存淘汰算法,**代码约80行**
## 1 FIFO/LFU/LRU 算法简介
GeeCache 的缓存全部存储在内存中,内存是有限的,因此不可能无限制地添加数据。假定我们设置缓存能够使用的内存大小为 N,那么在某一个时间点,添加了某一条缓存记录之后,占用内存超过了 N,这个时候就需要从缓存中移除一条或多条数据了。那移除谁呢?我们肯定希望尽可能移除“没用”的数据,那如何判定数据“有用”还是“没用”呢?
### 1.1 FIFO(First In First Out)
先进先出,也就是淘汰缓存中最老(最早添加)的记录。FIFO 认为,最早添加的记录,其不再被使用的可能性比刚添加的可能性大。这种算法的实现也非常简单,创建一个队列,新增记录添加到队尾,每次内存不够时,淘汰队首。但是很多场景下,部分记录虽然是最早添加但也最常被访问,而不得不因为呆的时间太长而被淘汰。这类数据会被频繁地添加进缓存,又被淘汰出去,导致缓存命中率降低。
### 1.2 LFU(Least Frequently Used)
最少使用,也就是淘汰缓存中访问频率最低的记录。LFU 认为,如果数据过去被访问多次,那么将来被访问的频率也更高。LFU 的实现需要维护一个按照访问次数排序的队列,每次访问,访问次数加1,队列重新排序,淘汰时选择访问次数最少的即可。LFU 算法的命中率是比较高的,但缺点也非常明显,维护每个记录的访问次数,对内存的消耗是很高的;另外,如果数据的访问模式发生变化,LFU 需要较长的时间去适应,也就是说 LFU 算法受历史数据的影响比较大。例如某个数据历史上访问次数奇高,但在某个时间点之后几乎不再被访问,但因为历史访问次数过高,而迟迟不能被淘汰。
### 1.3 LRU(Least Recently Used)
最近最少使用,相对于仅考虑时间因素的 FIFO 和仅考虑访问频率的 LFU,LRU 算法可以认为是相对平衡的一种淘汰算法。LRU 认为,如果数据最近被访问过,那么将来被访问的概率也会更高。LRU 算法的实现非常简单,维护一个队列,如果某条记录被访问了,则移动到队尾,那么队首则是最近最少访问的数据,淘汰该条记录即可。
## 2 LRU 算法实现
### 2.1 核心数据结构

这张图很好地表示了 LRU 算法最核心的 2 个数据结构
- 绿色的是字典(map),存储键和值的映射关系。这样根据某个键(key)查找对应的值(value)的复杂是`O(1)`,在字典中插入一条记录的复杂度也是`O(1)`。
- 红色的是双向链表(double linked list)实现的队列。将所有的值放到双向链表中,这样,当访问到某个值时,将其移动到队尾的复杂度是`O(1)`,在队尾新增一条记录以及删除一条记录的复杂度均为`O(1)`。
接下来我们创建一个包含字典和双向链表的结构体类型 Cache,方便实现后续的增删查改操作。
[day1-lru/geecache/lru/lru.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day1-lru/geecache/lru)
```go
package lru
import "container/list"
// Cache is a LRU cache. It is not safe for concurrent access.
type Cache struct {
maxBytes int64
nbytes int64
ll *list.List
cache map[string]*list.Element
// optional and executed when an entry is purged.
OnEvicted func(key string, value Value)
}
type entry struct {
key string
value Value
}
// Value use Len to count how many bytes it takes
type Value interface {
Len() int
}
```
- 在这里我们直接使用 Go 语言标准库实现的双向链表`list.List`。
- 字典的定义是 `map[string]*list.Element`,键是字符串,值是双向链表中对应节点的指针。
- `maxBytes` 是允许使用的最大内存,`nbytes` 是当前已使用的内存,`OnEvicted` 是某条记录被移除时的回调函数,可以为 nil。
- 键值对 `entry` 是双向链表节点的数据类型,在链表中仍保存每个值对应的 key 的好处在于,淘汰队首节点时,需要用 key 从字典中删除对应的映射。
- 为了通用性,我们允许值是实现了 `Value` 接口的任意类型,该接口只包含了一个方法 `Len() int`,用于返回值所占用的内存大小。
方便实例化 `Cache`,实现 `New()` 函数:
```go
// New is the Constructor of Cache
func New(maxBytes int64, onEvicted func(string, Value)) *Cache {
return &Cache{
maxBytes: maxBytes,
ll: list.New(),
cache: make(map[string]*list.Element),
OnEvicted: onEvicted,
}
}
```
### 2.2 查找功能
查找主要有 2 个步骤,第一步是从字典中找到对应的双向链表的节点,第二步,将该节点移动到队尾。
```go
// Get look ups a key's value
func (c *Cache) Get(key string) (value Value, ok bool) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
return kv.value, true
}
return
}
```
- 如果键对应的链表节点存在,则将对应节点移动到队尾,并返回查找到的值。
- `c.ll.MoveToFront(ele)`,即将链表中的节点 `ele` 移动到队尾(双向链表作为队列,队首队尾是相对的,在这里约定 front 为队尾)
### 2.3 删除
这里的删除,实际上是缓存淘汰。即移除最近最少访问的节点(队首)
```go
// RemoveOldest removes the oldest item
func (c *Cache) RemoveOldest() {
ele := c.ll.Back()
if ele != nil {
c.ll.Remove(ele)
kv := ele.Value.(*entry)
delete(c.cache, kv.key)
c.nbytes -= int64(len(kv.key)) + int64(kv.value.Len())
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
}
```
- `c.ll.Back()` 取到队首节点,从链表中删除。
- `delete(c.cache, kv.key)`,从字典中 `c.cache` 删除该节点的映射关系。
- 更新当前所用的内存 `c.nbytes`。
- 如果回调函数 `OnEvicted` 不为 nil,则调用回调函数。
### 2.4 新增/修改
```go
// Add adds a value to the cache.
func (c *Cache) Add(key string, value Value) {
if ele, ok := c.cache[key]; ok {
c.ll.MoveToFront(ele)
kv := ele.Value.(*entry)
c.nbytes += int64(value.Len()) - int64(kv.value.Len())
kv.value = value
} else {
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
c.nbytes += int64(len(key)) + int64(value.Len())
}
for c.maxBytes != 0 && c.maxBytes < c.nbytes {
c.RemoveOldest()
}
}
```
- 如果键存在,则更新对应节点的值,并将该节点移到队尾。
- 不存在则是新增场景,首先队尾添加新节点 `&entry{key, value}`, 并字典中添加 key 和节点的映射关系。
- 更新 `c.nbytes`,如果超过了设定的最大值 `c.maxBytes`,则移除最少访问的节点。
最后,为了方便测试,我们实现 `Len()` 用来获取添加了多少条数据。
```go
// Len the number of cache entries
func (c *Cache) Len() int {
return c.ll.Len()
}
```
## 3 测试
例如,我们可以尝试添加几条数据,测试 `Get` 方法
[day1-lru/geecache/lru/lru_test.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day1-lru/geecache/lru)
```go
type String string
func (d String) Len() int {
return len(d)
}
func TestGet(t *testing.T) {
lru := New(int64(0), nil)
lru.Add("key1", String("1234"))
if v, ok := lru.Get("key1"); !ok || string(v.(String)) != "1234" {
t.Fatalf("cache hit key1=1234 failed")
}
if _, ok := lru.Get("key2"); ok {
t.Fatalf("cache miss key2 failed")
}
}
```
测试,当使用内存超过了设定值时,是否会触发“无用”节点的移除:
```go
func TestRemoveoldest(t *testing.T) {
k1, k2, k3 := "key1", "key2", "k3"
v1, v2, v3 := "value1", "value2", "v3"
cap := len(k1 + k2 + v1 + v2)
lru := New(int64(cap), nil)
lru.Add(k1, String(v1))
lru.Add(k2, String(v2))
lru.Add(k3, String(v3))
if _, ok := lru.Get("key1"); ok || lru.Len() != 2 {
t.Fatalf("Removeoldest key1 failed")
}
}
```
测试回调函数能否被调用:
```go
func TestOnEvicted(t *testing.T) {
keys := make([]string, 0)
callback := func(key string, value Value) {
keys = append(keys, key)
}
lru := New(int64(10), callback)
lru.Add("key1", String("123456"))
lru.Add("k2", String("k2"))
lru.Add("k3", String("k3"))
lru.Add("k4", String("k4"))
expect := []string{"key1", "k2"}
if !reflect.DeepEqual(expect, keys) {
t.Fatalf("Call OnEvicted failed, expect keys equals to %s", expect)
}
}
```
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [list 官方文档 - golang.org](https://golang.org/pkg/container/list/)
================================================
FILE: gee-cache/doc/geecache-day2.md
================================================
---
title: 动手写分布式缓存 - GeeCache第二天 单机并发缓存
date: 2020-02-12 22:00:00
description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了 sync.Mutex 互斥锁的使用,并发控制 LRU 缓存。实现 GeeCache 核心数据结构 Group,缓存不存在时,调用回调函数(callback)获取源数据。
tags:
- Go
nav: 从零实现
categories:
- 分布式缓存 - GeeCache
keywords:
- Go语言
- 从零实现
- 分布式缓存
- 互斥锁
- sync.Mutex
image: post/geecache-day2/concurrent_cache_logo.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day2 单机并发缓存
---

本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第二篇。
- 介绍 sync.Mutex 互斥锁的使用,并实现 LRU 缓存的并发控制。
- 实现 GeeCache 核心数据结构 Group,缓存不存在时,调用回调函数获取源数据,**代码约150行**
## 1 sync.Mutex
多个协程(goroutine)同时读写同一个变量,在并发度较高的情况下,会发生冲突。确保一次只有一个协程(goroutine)可以访问该变量以避免冲突,这称之为`互斥`,互斥锁可以解决这个问题。
> sync.Mutex 是一个互斥锁,可以由不同的协程加锁和解锁。
`sync.Mutex` 是 Go 语言标准库提供的一个互斥锁,当一个协程(goroutine)获得了这个锁的拥有权后,其它请求锁的协程(goroutine) 就会阻塞在 `Lock()` 方法的调用上,直到调用 `Unlock()` 锁被释放。
接下来举一个简单的例子,假设有10个并发的协程打印了同一个数字`100`,为了避免重复打印,实现了`printOnce(num int)` 函数,使用集合 set 记录已打印过的数字,如果数字已打印过,则不再打印。
```go
var set = make(map[int]bool, 0)
func printOnce(num int) {
if _, exist := set[num]; !exist {
fmt.Println(num)
}
set[num] = true
}
func main() {
for i := 0; i < 10; i++ {
go printOnce(100)
}
time.Sleep(time.Second)
}
```
我们运行 `go run .` 会发生什么情况呢?
```bash
$ go run .
100
100
```
有时候打印 2 次,有时候打印 4 次,有时候还会触发 panic,因为对同一个数据结构`set`的访问冲突了。接下来用互斥锁的`Lock()`和`Unlock()` 方法将冲突的部分包裹起来:
```go
var m sync.Mutex
var set = make(map[int]bool, 0)
func printOnce(num int) {
m.Lock()
if _, exist := set[num]; !exist {
fmt.Println(num)
}
set[num] = true
m.Unlock()
}
func main() {
for i := 0; i < 10; i++ {
go printOnce(100)
}
time.Sleep(time.Second)
}
```
```bash
$ go run .
100
```
相同的数字只会被打印一次。当一个协程调用了 `Lock()` 方法时,其他协程被阻塞了,直到`Unlock()`调用将锁释放。因此被包裹部分的代码就能够避免冲突,实现互斥。
`Unlock()`释放锁还有另外一种写法:
```go
func printOnce(num int) {
m.Lock()
defer m.Unlock()
if _, exist := set[num]; !exist {
fmt.Println(num)
}
set[num] = true
}
```
## 2 支持并发读写
上一篇文章 [GeeCache 第一天](https://geektutu.com/post/geecache-day1.html) 实现了 LRU 缓存淘汰策略。接下来我们使用 `sync.Mutex` 封装 LRU 的几个方法,使之支持并发的读写。在这之前,我们抽象了一个只读数据结构 `ByteView` 用来表示缓存值,是 GeeCache 主要的数据结构之一。
[day2-single-node/geecache/byteview.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day2-single-node/geecache)
```go
package geecache
// A ByteView holds an immutable view of bytes.
type ByteView struct {
b []byte
}
// Len returns the view's length
func (v ByteView) Len() int {
return len(v.b)
}
// ByteSlice returns a copy of the data as a byte slice.
func (v ByteView) ByteSlice() []byte {
return cloneBytes(v.b)
}
// String returns the data as a string, making a copy if necessary.
func (v ByteView) String() string {
return string(v.b)
}
func cloneBytes(b []byte) []byte {
c := make([]byte, len(b))
copy(c, b)
return c
}
```
- ByteView 只有一个数据成员,`b []byte`,b 将会存储真实的缓存值。选择 byte 类型是为了能够支持任意的数据类型的存储,例如字符串、图片等。
- 实现 `Len() int` 方法,我们在 lru.Cache 的实现中,要求被缓存对象必须实现 Value 接口,即 `Len() int` 方法,返回其所占的内存大小。
- `b` 是只读的,使用 `ByteSlice()` 方法返回一个拷贝,防止缓存值被外部程序修改。
接下来就可以为 lru.Cache 添加并发特性了。
[day2-single-node/geecache/cache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day2-single-node/geecache)
```go
package geecache
import (
"geecache/lru"
"sync"
)
type cache struct {
mu sync.Mutex
lru *lru.Cache
cacheBytes int64
}
func (c *cache) add(key string, value ByteView) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
c.lru = lru.New(c.cacheBytes, nil)
}
c.lru.Add(key, value)
}
func (c *cache) get(key string) (value ByteView, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
return
}
if v, ok := c.lru.Get(key); ok {
return v.(ByteView), ok
}
return
}
```
- `cache.go` 的实现非常简单,实例化 lru,封装 get 和 add 方法,并添加互斥锁 mu。
- 在 `add` 方法中,判断了 `c.lru` 是否为 nil,如果等于 nil 再创建实例。这种方法称之为延迟初始化(Lazy Initialization),一个对象的延迟初始化意味着该对象的创建将会延迟至第一次使用该对象时。主要用于提高性能,并减少程序内存要求。
## 3 主体结构 Group
Group 是 GeeCache 最核心的数据结构,负责与用户的交互,并且控制缓存值存储和获取的流程。
```bash
是
接收 key --> 检查是否被缓存 -----> 返回缓存值 ⑴
| 否 是
|-----> 是否应当从远程节点获取 -----> 与远程节点交互 --> 返回缓存值 ⑵
| 否
|-----> 调用`回调函数`,获取值并添加到缓存 --> 返回缓存值 ⑶
```
我们将在 `geecache.go` 中实现主体结构 Group,那么 GeeCache 的代码结构的雏形已经形成了。
```bash
geecache/
|--lru/
|--lru.go // lru 缓存淘汰策略
|--byteview.go // 缓存值的抽象与封装
|--cache.go // 并发控制
|--geecache.go // 负责与外部交互,控制缓存存储和获取的主流程
```
接下来我们将实现流程 ⑴ 和 ⑶,远程交互的部分后续再实现。
### 3.1 回调 Getter
我们思考一下,如果缓存不存在,应从数据源(文件,数据库等)获取数据并添加到缓存中。GeeCache 是否应该支持多种数据源的配置呢?不应该,一是数据源的种类太多,没办法一一实现;二是扩展性不好。如何从源头获取数据,应该是用户决定的事情,我们就把这件事交给用户好了。因此,我们设计了一个回调函数(callback),在缓存不存在时,调用这个函数,得到源数据。
[day2-single-node/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day2-single-node/geecache)
```go
// A Getter loads data for a key.
type Getter interface {
Get(key string) ([]byte, error)
}
// A GetterFunc implements Getter with a function.
type GetterFunc func(key string) ([]byte, error)
// Get implements Getter interface function
func (f GetterFunc) Get(key string) ([]byte, error) {
return f(key)
}
```
- 定义接口 Getter 和 回调函数 `Get(key string)([]byte, error)`,参数是 key,返回值是 []byte。
- 定义函数类型 GetterFunc,并实现 Getter 接口的 `Get` 方法。
- 函数类型实现某一个接口,称之为接口型函数,方便使用者在调用时既能够传入函数作为参数,也能够传入实现了该接口的结构体作为参数。
> 了解接口型函数的使用场景,可以参考 [Go 接口型函数的使用场景 - 7days-golang Q & A](https://geektutu.com/post/7days-golang-q1.html)
我们可以写一个测试用例来保证回调函数能够正常工作。
```go
func TestGetter(t *testing.T) {
var f Getter = GetterFunc(func(key string) ([]byte, error) {
return []byte(key), nil
})
expect := []byte("key")
if v, _ := f.Get("key"); !reflect.DeepEqual(v, expect) {
t.Errorf("callback failed")
}
}
```
- 在这个测试用例中,我们借助 GetterFunc 的类型转换,将一个匿名回调函数转换成了接口 `f Getter`。
- 调用该接口的方法 `f.Get(key string)`,实际上就是在调用匿名回调函数。
> 定义一个函数类型 F,并且实现接口 A 的方法,然后在这个方法中调用自己。这是 Go 语言中将其他函数(参数返回值定义与 F 一致)转换为接口 A 的常用技巧。
### 3.2 Group 的定义
接下来是最核心数据结构 Group 的定义:
[day2-single-node/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day2-single-node/geecache)
```go
// A Group is a cache namespace and associated data loaded spread over
type Group struct {
name string
getter Getter
mainCache cache
}
var (
mu sync.RWMutex
groups = make(map[string]*Group)
)
// NewGroup create a new instance of Group
func NewGroup(name string, cacheBytes int64, getter Getter) *Group {
if getter == nil {
panic("nil Getter")
}
mu.Lock()
defer mu.Unlock()
g := &Group{
name: name,
getter: getter,
mainCache: cache{cacheBytes: cacheBytes},
}
groups[name] = g
return g
}
// GetGroup returns the named group previously created with NewGroup, or
// nil if there's no such group.
func GetGroup(name string) *Group {
mu.RLock()
g := groups[name]
mu.RUnlock()
return g
}
```
- 一个 Group 可以认为是一个缓存的命名空间,每个 Group 拥有一个唯一的名称 `name`。比如可以创建三个 Group,缓存学生的成绩命名为 scores,缓存学生信息的命名为 info,缓存学生课程的命名为 courses。
- 第二个属性是 `getter Getter`,即缓存未命中时获取源数据的回调(callback)。
- 第三个属性是 `mainCache cache`,即一开始实现的并发缓存。
- 构建函数 `NewGroup` 用来实例化 Group,并且将 group 存储在全局变量 `groups` 中。
- `GetGroup` 用来特定名称的 Group,这里使用了只读锁 `RLock()`,因为不涉及任何冲突变量的写操作。
### 3.3 Group 的 Get 方法
接下来是 GeeCache 最为核心的方法 `Get`:
```go
// Get value for a key from cache
func (g *Group) Get(key string) (ByteView, error) {
if key == "" {
return ByteView{}, fmt.Errorf("key is required")
}
if v, ok := g.mainCache.get(key); ok {
log.Println("[GeeCache] hit")
return v, nil
}
return g.load(key)
}
func (g *Group) load(key string) (value ByteView, err error) {
return g.getLocally(key)
}
func (g *Group) getLocally(key string) (ByteView, error) {
bytes, err := g.getter.Get(key)
if err != nil {
return ByteView{}, err
}
value := ByteView{b: cloneBytes(bytes)}
g.populateCache(key, value)
return value, nil
}
func (g *Group) populateCache(key string, value ByteView) {
g.mainCache.add(key, value)
}
```
- Get 方法实现了上述所说的流程 ⑴ 和 ⑶。
- 流程 ⑴ :从 mainCache 中查找缓存,如果存在则返回缓存值。
- 流程 ⑶ :缓存不存在,则调用 load 方法,load 调用 getLocally(分布式场景下会调用 getFromPeer 从其他节点获取),getLocally 调用用户回调函数 `g.getter.Get()` 获取源数据,并且将源数据添加到缓存 mainCache 中(通过 populateCache 方法)
至此,这一章节的单机并发缓存就已经完成了。
## 4 测试
可以写测试用例,也可以写 main 函数来测试这一章节实现的功能。那我们通过测试用例来看一下,如何使用我们实现的单机并发缓存吧。
首先,用一个 map 模拟耗时的数据库。
```go
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
```
创建 group 实例,并测试 `Get` 方法
```go
func TestGet(t *testing.T) {
loadCounts := make(map[string]int, len(db))
gee := NewGroup("scores", 2<<10, GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
if _, ok := loadCounts[key]; !ok {
loadCounts[key] = 0
}
loadCounts[key] += 1
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
for k, v := range db {
if view, err := gee.Get(k); err != nil || view.String() != v {
t.Fatal("failed to get value of Tom")
} // load from callback function
if _, err := gee.Get(k); err != nil || loadCounts[k] > 1 {
t.Fatalf("cache %s miss", k)
} // cache hit
}
if view, err := gee.Get("unknown"); err == nil {
t.Fatalf("the value of unknow should be empty, but %s got", view)
}
}
```
- 在这个测试用例中,我们主要测试了 2 种情况
- 1)在缓存为空的情况下,能够通过回调函数获取到源数据。
- 2)在缓存已经存在的情况下,是否直接从缓存中获取,为了实现这一点,使用 `loadCounts` 统计某个键调用回调函数的次数,如果次数大于1,则表示调用了多次回调函数,没有缓存。
测试结果如下:
```bash
$ go test -run TestGet
2020/02/11 22:07:31 [SlowDB] search key Sam
2020/02/11 22:07:31 [GeeCache] hit
2020/02/11 22:07:31 [SlowDB] search key Tom
2020/02/11 22:07:31 [GeeCache] hit
2020/02/11 22:07:31 [SlowDB] search key Jack
2020/02/11 22:07:31 [GeeCache] hit
2020/02/11 22:07:31 [SlowDB] search key unknown
PASS
ok geecache 0.008s
```
可以很清晰地看到,缓存为空时,调用了回调函数,第二次访问时,则直接从缓存中读取。
## 附 推荐阅读
- [Go 语言简明教程 - 并发编程](https://geektutu.com/post/quick-golang.html#7-并发编程-goroutine)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [sync 官方文档 - golang.org](https://golang.org/pkg/sync/)
================================================
FILE: gee-cache/doc/geecache-day3.md
================================================
---
title: 动手写分布式缓存 - GeeCache第三天 HTTP 服务端
date: 2020-02-12 23:00:00
description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了如何使用标准库 http 搭建 HTTP Server,为 GeeCache 单机节点搭建 HTTP 服务,并进行相关的测试。
tags:
- Go
nav: 从零实现
categories:
- 分布式缓存 - GeeCache
keywords:
- Go语言
- 从零实现
- 分布式缓存
- HTTP Server
image: post/geecache-day3/http_logo.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day3 HTTP 服务端
---

本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第三篇。
- 介绍如何使用 Go 语言标准库 `http` 搭建 HTTP Server
- 并实现 main 函数启动 HTTP Server 测试 API,**代码约60行**
## 1 http 标准库
Go 语言提供了 `http` 标准库,可以非常方便地搭建 HTTP 服务端和客户端。比如我们可以实现一个服务端,无论接收到什么请求,都返回字符串 "Hello World!"
```go
package main
import (
"log"
"net/http"
)
type server int
func (h *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Println(r.URL.Path)
w.Write([]byte("Hello World!"))
}
func main() {
var s server
http.ListenAndServe("localhost:9999", &s)
}
```
- 创建任意类型 server,并实现 `ServeHTTP` 方法。
- 调用 `http.ListenAndServe` 在 9999 端口启动 http 服务,处理请求的对象为 `s server`。
接下来我们执行 `go run .` 启动服务,借助 curl 来测试效果:
```bash
$ curl http://localhost:9999
Hello World!
$ curl http://localhost:9999/abc
Hello World!
```
Go 程序日志输出
```bash
2020/02/11 22:56:32 /
2020/02/11 22:56:34 /abc
```
> `http.ListenAndServe` 接收 2 个参数,第一个参数是服务启动的地址,第二个参数是 Handler,任何实现了 `ServeHTTP` 方法的对象都可以作为 HTTP 的 Handler。
在标准库中,http.Handler 接口的定义如下:
```go
package http
type Handler interface {
ServeHTTP(w ResponseWriter, r *Request)
}
```
## 2 GeeCache HTTP 服务端
分布式缓存需要实现节点间通信,建立基于 HTTP 的通信机制是比较常见和简单的做法。如果一个节点启动了 HTTP 服务,那么这个节点就可以被其他节点访问。今天我们就为单机节点搭建 HTTP Server。
不与其他部分耦合,我们将这部分代码放在新的 `http.go` 文件中,当前的代码结构如下:
```bash
geecache/
|--lru/
|--lru.go // lru 缓存淘汰策略
|--byteview.go // 缓存值的抽象与封装
|--cache.go // 并发控制
|--geecache.go // 负责与外部交互,控制缓存存储和获取的主流程
|--http.go // 提供被其他节点访问的能力(基于http)
```
首先我们创建一个结构体 `HTTPPool`,作为承载节点间 HTTP 通信的核心数据结构(包括服务端和客户端,今天只实现服务端)。
[day3-http-server/geecache/http.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day3-http-server/geecache)
```go
package geecache
import (
"fmt"
"log"
"net/http"
"strings"
)
const defaultBasePath = "/_geecache/"
// HTTPPool implements PeerPicker for a pool of HTTP peers.
type HTTPPool struct {
// this peer's base URL, e.g. "https://example.net:8000"
self string
basePath string
}
// NewHTTPPool initializes an HTTP pool of peers.
func NewHTTPPool(self string) *HTTPPool {
return &HTTPPool{
self: self,
basePath: defaultBasePath,
}
}
```
- `HTTPPool` 只有 2 个参数,一个是 self,用来记录自己的地址,包括主机名/IP 和端口。
- 另一个是 basePath,作为节点间通讯地址的前缀,默认是 `/_geecache/`,那么 http://example.com/_geecache/ 开头的请求,就用于节点间的访问。因为一个主机上还可能承载其他的服务,加一段 Path 是一个好习惯。比如,大部分网站的 API 接口,一般以 `/api` 作为前缀。
接下来,实现最为核心的 `ServeHTTP` 方法。
```go
// Log info with server name
func (p *HTTPPool) Log(format string, v ...interface{}) {
log.Printf("[Server %s] %s", p.self, fmt.Sprintf(format, v...))
}
// ServeHTTP handle all http requests
func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, p.basePath) {
panic("HTTPPool serving unexpected path: " + r.URL.Path)
}
p.Log("%s %s", r.Method, r.URL.Path)
// /// required
parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2)
if len(parts) != 2 {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
groupName := parts[0]
key := parts[1]
group := GetGroup(groupName)
if group == nil {
http.Error(w, "no such group: "+groupName, http.StatusNotFound)
return
}
view, err := group.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}
```
- ServeHTTP 的实现逻辑是比较简单的,首先判断访问路径的前缀是否是 `basePath`,不是返回错误。
- 我们约定访问路径格式为 `///`,通过 groupname 得到 group 实例,再使用 `group.Get(key)` 获取缓存数据。
- 最终使用 `w.Write()` 将缓存值作为 httpResponse 的 body 返回。
到这里,HTTP 服务端已经完整地实现了。接下来,我们将在单机上启动 HTTP 服务,使用 curl 进行测试。
## 3 测试
实现 main 函数,实例化 group,并启动 HTTP 服务。
[day3-http-server/main.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day3-http-server)
```go
package main
import (
"fmt"
"geecache"
"log"
"net/http"
)
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func main() {
geecache.NewGroup("scores", 2<<10, geecache.GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
addr := "localhost:9999"
peers := geecache.NewHTTPPool(addr)
log.Println("geecache is running at", addr)
log.Fatal(http.ListenAndServe(addr, peers))
}
```
- 同样地,我们使用 map 模拟了数据源 db。
- 创建一个名为 scores 的 Group,若缓存为空,回调函数会从 db 中获取数据并返回。
- 使用 http.ListenAndServe 在 9999 端口启动了 HTTP 服务。
> 需要注意的点:
> main.go 和 geecache/ 在同级目录,但 go modules 不再支持 import <相对路径>,相对路径需要在 go.mod 中声明:
> require geecache v0.0.0
> replace geecache => ./geecache
接下来,运行 main 函数,使用 curl 做一些简单测试:
```bash
$ curl http://localhost:9999/_geecache/scores/Tom
630
$ curl http://localhost:9999/_geecache/scores/kkk
kkk not exist
```
GeeCache 的日志输出如下:
```bash
2020/02/11 23:28:39 geecache is running at localhost:9999
2020/02/11 23:29:08 [Server localhost:9999] GET /_geecache/scores/Tom
2020/02/11 23:29:08 [SlowDB] search key Tom
2020/02/11 23:29:16 [Server localhost:9999] GET /_geecache/scores/kkk
2020/02/11 23:29:16 [SlowDB] search key kkk
```
节点间的相互通信不仅需要 HTTP 服务端,还需要 HTTP 客户端,这就是我们下一步需要做的事情。
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [Go http.Handler 基础](https://geektutu.com/post/gee-day1.html)
- [http 官方文档 - golang.org](https://golang.org/pkg/http)
================================================
FILE: gee-cache/doc/geecache-day4.md
================================================
---
title: 动手写分布式缓存 - GeeCache第四天 一致性哈希(hash)
date: 2020-02-16 20:00:00
description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了一致性哈希(consistent hashing)的原理、实现以及相关测试用例,一致性哈希为什么能避免缓存雪崩,虚拟节点为什么能解决数据倾斜的问题。
tags:
- Go
nav: 从零实现
categories:
- 分布式缓存 - GeeCache
keywords:
- Go语言
- 从零实现
- 一致性hash
- consistent hash
image: post/geecache-day4/hash_logo.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day4 一致性哈希
---

本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第四篇。
- 一致性哈希(consistent hashing)的原理以及为什么要使用一致性哈希。
- 实现一致性哈希代码,添加相应的测试用例,**代码约60行**
## 1 为什么使用一致性哈希
今天我们要实现的是一致性哈希算法,一致性哈希算法是 GeeCache 从单节点走向分布式节点的一个重要的环节。那你可能要问了,
> 童鞋,一致性哈希算法是啥?为什么要使用一致性哈希算法?这和分布式有什么关系?
### 1.1 我该访问谁?
对于分布式缓存来说,当一个节点接收到请求,如果该节点并没有存储缓存值,那么它面临的难题是,从谁那获取数据?自己,还是节点1, 2, 3, 4... 。假设包括自己在内一共有 10 个节点,当一个节点接收到请求时,随机选择一个节点,由该节点从数据源获取数据。
假设第一次随机选取了节点 1 ,节点 1 从数据源获取到数据的同时缓存该数据;那第二次,只有 1/10 的可能性再次选择节点 1, 有 9/10 的概率选择了其他节点,如果选择了其他节点,就意味着需要再一次从数据源获取数据,一般来说,这个操作是很耗时的。这样做,一是缓存效率低,二是各个节点上存储着相同的数据,浪费了大量的存储空间。
那有什么办法,对于给定的 key,每一次都选择同一个节点呢?使用 hash 算法也能够做到这一点。那把 key 的每一个字符的 ASCII 码加起来,再除以 10 取余数可以吗?当然可以,这可以认为是自定义的 hash 算法。

从上面的图可以看到,任意一个节点任意时刻请求查找键 `Tom` 对应的值,都会分配给节点 2,有效地解决了上述的问题。
### 1.2 节点数量变化了怎么办?
简单求取 Hash 值解决了缓存性能的问题,但是没有考虑节点数量变化的场景。假设,移除了其中一台节点,只剩下 9 个,那么之前 `hash(key) % 10` 变成了 `hash(key) % 9`,也就意味着几乎缓存值对应的节点都发生了改变。即几乎所有的缓存值都失效了。节点在接收到对应的请求时,均需要重新去数据源获取数据,容易引起 `缓存雪崩`。
> 缓存雪崩:缓存在同一时刻全部失效,造成瞬时DB请求量大、压力骤增,引起雪崩。常因为缓存服务器宕机,或缓存设置了相同的过期时间引起。
那如何解决这个问题呢?一致性哈希算法可以。
## 2 算法原理
### 2.1 步骤
一致性哈希算法将 key 映射到 2^32 的空间中,将这个数字首尾相连,形成一个环。
- 计算节点/机器(通常使用节点的名称、编号和 IP 地址)的哈希值,放置在环上。
- 计算 key 的哈希值,放置在环上,顺时针寻找到的第一个节点,就是应选取的节点/机器。

环上有 peer2,peer4,peer6 三个节点,`key11`,`key2`,`key27` 均映射到 peer2,`key23` 映射到 peer4。此时,如果新增节点/机器 peer8,假设它新增位置如图所示,那么只有 `key27` 从 peer2 调整到 peer8,其余的映射均没有发生改变。
也就是说,一致性哈希算法,在新增/删除节点时,只需要重新定位该节点附近的一小部分数据,而不需要重新定位所有的节点,这就解决了上述的问题。
### 2.2 数据倾斜问题
如果服务器的节点过少,容易引起 key 的倾斜。例如上面例子中的 peer2,peer4,peer6 分布在环的上半部分,下半部分是空的。那么映射到环下半部分的 key 都会被分配给 peer2,key 过度向 peer2 倾斜,缓存节点间负载不均。
为了解决这个问题,引入了虚拟节点的概念,一个真实节点对应多个虚拟节点。
假设 1 个真实节点对应 3 个虚拟节点,那么 peer1 对应的虚拟节点是 peer1-1、 peer1-2、 peer1-3(通常以添加编号的方式实现),其余节点也以相同的方式操作。
- 第一步,计算虚拟节点的 Hash 值,放置在环上。
- 第二步,计算 key 的 Hash 值,在环上顺时针寻找到应选取的虚拟节点,例如是 peer2-1,那么就对应真实节点 peer2。
虚拟节点扩充了节点的数量,解决了节点较少的情况下数据容易倾斜的问题。而且代价非常小,只需要增加一个字典(map)维护真实节点与虚拟节点的映射关系即可。
## 3 Go语言实现
我们在 geecache 目录下新建 package `consistenthash`,用来实现一致性哈希算法。
[day4-consistent-hash/geecache/consistenthash/consistenthash.go](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day4-consistent-hash/geecache/consistenthash)
```go
package consistenthash
import (
"hash/crc32"
"sort"
"strconv"
)
// Hash maps bytes to uint32
type Hash func(data []byte) uint32
// Map constains all hashed keys
type Map struct {
hash Hash
replicas int
keys []int // Sorted
hashMap map[int]string
}
// New creates a Map instance
func New(replicas int, fn Hash) *Map {
m := &Map{
replicas: replicas,
hash: fn,
hashMap: make(map[int]string),
}
if m.hash == nil {
m.hash = crc32.ChecksumIEEE
}
return m
}
```
- 定义了函数类型 `Hash`,采取依赖注入的方式,允许用于替换成自定义的 Hash 函数,也方便测试时替换,默认为 `crc32.ChecksumIEEE` 算法。
- `Map` 是一致性哈希算法的主数据结构,包含 4 个成员变量:Hash 函数 `hash`;虚拟节点倍数 `replicas`;哈希环 `keys`;虚拟节点与真实节点的映射表 `hashMap`,键是虚拟节点的哈希值,值是真实节点的名称。
- 构造函数 `New()` 允许自定义虚拟节点倍数和 Hash 函数。
接下来,实现添加真实节点/机器的 `Add()` 方法。
```go
// Add adds some keys to the hash.
func (m *Map) Add(keys ...string) {
for _, key := range keys {
for i := 0; i < m.replicas; i++ {
hash := int(m.hash([]byte(strconv.Itoa(i) + key)))
m.keys = append(m.keys, hash)
m.hashMap[hash] = key
}
}
sort.Ints(m.keys)
}
```
- `Add` 函数允许传入 0 或 多个真实节点的名称。
- 对每一个真实节点 `key`,对应创建 `m.replicas` 个虚拟节点,虚拟节点的名称是:`strconv.Itoa(i) + key`,即通过添加编号的方式区分不同虚拟节点。
- 使用 `m.hash()` 计算虚拟节点的哈希值,使用 `append(m.keys, hash)` 添加到环上。
- 在 `hashMap` 中增加虚拟节点和真实节点的映射关系。
- 最后一步,环上的哈希值排序。
最后一步,实现选择节点的 `Get()` 方法。
```go
// Get gets the closest item in the hash to the provided key.
func (m *Map) Get(key string) string {
if len(m.keys) == 0 {
return ""
}
hash := int(m.hash([]byte(key)))
// Binary search for appropriate replica.
idx := sort.Search(len(m.keys), func(i int) bool {
return m.keys[i] >= hash
})
return m.hashMap[m.keys[idx%len(m.keys)]]
}
```
- 选择节点就非常简单了,第一步,计算 key 的哈希值。
- 第二步,顺时针找到第一个匹配的虚拟节点的下标 `idx`,从 m.keys 中获取到对应的哈希值。如果 `idx == len(m.keys)`,说明应选择 `m.keys[0]`,因为 `m.keys` 是一个环状结构,所以用取余数的方式来处理这种情况。
- 第三步,通过 `hashMap` 映射得到真实的节点。
至此,整个一致性哈希算法就实现完成了。
## 4 测试
最后呢,需要测试用例来验证我们的实现是否有问题。
[day4-consistent-hash/geecache/consistenthash/consistenthash_test.go](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day4-consistent-hash/geecache/consistenthash)
```go
package consistenthash
import (
"strconv"
"testing"
)
func TestHashing(t *testing.T) {
hash := New(3, func(key []byte) uint32 {
i, _ := strconv.Atoi(string(key))
return uint32(i)
})
// Given the above hash function, this will give replicas with "hashes":
// 2, 4, 6, 12, 14, 16, 22, 24, 26
hash.Add("6", "4", "2")
testCases := map[string]string{
"2": "2",
"11": "2",
"23": "4",
"27": "2",
}
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
// Adds 8, 18, 28
hash.Add("8")
// 27 should now map to 8.
testCases["27"] = "8"
for k, v := range testCases {
if hash.Get(k) != v {
t.Errorf("Asking for %s, should have yielded %s", k, v)
}
}
}
```
如果要进行测试,那么我们需要明确地知道每一个传入的 key 的哈希值,那使用默认的 `crc32.ChecksumIEEE` 算法显然达不到目的。所以在这里使用了自定义的 Hash 算法。自定义的 Hash 算法只处理数字,传入字符串表示的数字,返回对应的数字即可。
- 一开始,有 2/4/6 三个真实节点,对应的虚拟节点的哈希值是 02/12/22、04/14/24、06/16/26。
- 那么用例 2/11/23/27 选择的虚拟节点分别是 02/12/24/02,也就是真实节点 2/2/4/2。
- 添加一个真实节点 8,对应虚拟节点的哈希值是 08/18/28,此时,用例 27 对应的虚拟节点从 `02` 变更为 `28`,即真实节点 8。
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
================================================
FILE: gee-cache/doc/geecache-day5.md
================================================
---
title: 动手写分布式缓存 - GeeCache第五天 分布式节点
date: 2020-02-16 21:30:00
description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了为 GeeCache 添加了注册节点与选择节点的功能,并实现了 HTTP 客户端,与远程节点的服务端通信。
tags:
- Go
nav: 从零实现
categories:
- 分布式缓存 - GeeCache
keywords:
- Go语言
- 从零实现
- HTTP客户端
- 分布式节点
image: post/geecache-day5/dist_nodes_logo.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day5 分布式节点
---

本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第五篇。
- 注册节点(Register Peers),借助一致性哈希算法选择节点。
- 实现 HTTP 客户端,与远程节点的服务端通信,**代码约90行**
## 1 流程回顾
```bash
是
接收 key --> 检查是否被缓存 -----> 返回缓存值 ⑴
| 否 是
|-----> 是否应当从远程节点获取 -----> 与远程节点交互 --> 返回缓存值 ⑵
| 否
|-----> 调用`回调函数`,获取值并添加到缓存 --> 返回缓存值 ⑶
```
我们在[GeeCache 第二天](https://geektutu.com/post/geecache-day2.html) 中描述了 geecache 的流程。在这之前已经实现了流程 ⑴ 和 ⑶,今天实现流程 ⑵,从远程节点获取缓存值。
我们进一步细化流程 ⑵:
```bash
使用一致性哈希选择节点 是 是
|-----> 是否是远程节点 -----> HTTP 客户端访问远程节点 --> 成功?-----> 服务端返回返回值
| 否 ↓ 否
|----------------------------> 回退到本地节点处理。
```
## 2 抽象 PeerPicker
[day5-multi-nodes/geecache/peers.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day5-multi-nodes/geecache)
```go
package geecache
// PeerPicker is the interface that must be implemented to locate
// the peer that owns a specific key.
type PeerPicker interface {
PickPeer(key string) (peer PeerGetter, ok bool)
}
// PeerGetter is the interface that must be implemented by a peer.
type PeerGetter interface {
Get(group string, key string) ([]byte, error)
}
```
- 在这里,抽象出 2 个接口,PeerPicker 的 `PickPeer()` 方法用于根据传入的 key 选择相应节点 PeerGetter。
- 接口 PeerGetter 的 `Get()` 方法用于从对应 group 查找缓存值。PeerGetter 就对应于上述流程中的 HTTP 客户端。
## 3 节点选择与 HTTP 客户端
在 [GeeCache 第三天](https://geektutu.com/post/geecache-day3.html) 中我们为 `HTTPPool` 实现了服务端功能,通信不仅需要服务端还需要客户端,因此,我们接下来要为 `HTTPPool` 实现客户端的功能。
首先创建具体的 HTTP 客户端类 `httpGetter`,实现 PeerGetter 接口。
[day5-multi-nodes/geecache/http.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day5-multi-nodes/geecache)
```go
type httpGetter struct {
baseURL string
}
func (h *httpGetter) Get(group string, key string) ([]byte, error) {
u := fmt.Sprintf(
"%v%v/%v",
h.baseURL,
url.QueryEscape(group),
url.QueryEscape(key),
)
res, err := http.Get(u)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned: %v", res.Status)
}
bytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("reading response body: %v", err)
}
return bytes, nil
}
var _ PeerGetter = (*httpGetter)(nil)
```
- baseURL 表示将要访问的远程节点的地址,例如 `http://example.com/_geecache/`。
- 使用 `http.Get()` 方式获取返回值,并转换为 `[]bytes` 类型。
第二步,为 HTTPPool 添加节点选择的功能。
```go
const (
defaultBasePath = "/_geecache/"
defaultReplicas = 50
)
// HTTPPool implements PeerPicker for a pool of HTTP peers.
type HTTPPool struct {
// this peer's base URL, e.g. "https://example.net:8000"
self string
basePath string
mu sync.Mutex // guards peers and httpGetters
peers *consistenthash.Map
httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008"
}
```
- 新增成员变量 `peers`,类型是一致性哈希算法的 `Map`,用来根据具体的 key 选择节点。
- 新增成员变量 `httpGetters`,映射远程节点与对应的 httpGetter。每一个远程节点对应一个 httpGetter,因为 httpGetter 与远程节点的地址 `baseURL` 有关。
第三步,实现 PeerPicker 接口。
```go
// Set updates the pool's list of peers.
func (p *HTTPPool) Set(peers ...string) {
p.mu.Lock()
defer p.mu.Unlock()
p.peers = consistenthash.New(defaultReplicas, nil)
p.peers.Add(peers...)
p.httpGetters = make(map[string]*httpGetter, len(peers))
for _, peer := range peers {
p.httpGetters[peer] = &httpGetter{baseURL: peer + p.basePath}
}
}
// PickPeer picks a peer according to key
func (p *HTTPPool) PickPeer(key string) (PeerGetter, bool) {
p.mu.Lock()
defer p.mu.Unlock()
if peer := p.peers.Get(key); peer != "" && peer != p.self {
p.Log("Pick peer %s", peer)
return p.httpGetters[peer], true
}
return nil, false
}
var _ PeerPicker = (*HTTPPool)(nil)
```
- `Set()` 方法实例化了一致性哈希算法,并且添加了传入的节点。
- 并为每一个节点创建了一个 HTTP 客户端 `httpGetter`。
- `PickerPeer()` 包装了一致性哈希算法的 `Get()` 方法,根据具体的 key,选择节点,返回节点对应的 HTTP 客户端。
至此,HTTPPool 既具备了提供 HTTP 服务的能力,也具备了根据具体的 key,创建 HTTP 客户端从远程节点获取缓存值的能力。
## 4 实现主流程
最后,我们需要将上述新增的功能集成在主流程(geecache.go)中。
[day5-multi-nodes/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day5-multi-nodes/geecache)
```go
// A Group is a cache namespace and associated data loaded spread over
type Group struct {
name string
getter Getter
mainCache cache
peers PeerPicker
}
// RegisterPeers registers a PeerPicker for choosing remote peer
func (g *Group) RegisterPeers(peers PeerPicker) {
if g.peers != nil {
panic("RegisterPeerPicker called more than once")
}
g.peers = peers
}
func (g *Group) load(key string) (value ByteView, err error) {
if g.peers != nil {
if peer, ok := g.peers.PickPeer(key); ok {
if value, err = g.getFromPeer(peer, key); err == nil {
return value, nil
}
log.Println("[GeeCache] Failed to get from peer", err)
}
}
return g.getLocally(key)
}
func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) {
bytes, err := peer.Get(g.name, key)
if err != nil {
return ByteView{}, err
}
return ByteView{b: bytes}, nil
}
```
- 新增 `RegisterPeers()` 方法,将 实现了 PeerPicker 接口的 HTTPPool 注入到 Group 中。
- 新增 `getFromPeer()` 方法,使用实现了 PeerGetter 接口的 httpGetter 从访问远程节点,获取缓存值。
- 修改 load 方法,使用 `PickPeer()` 方法选择节点,若非本机节点,则调用 `getFromPeer()` 从远程获取。若是本机节点或失败,则回退到 `getLocally()`。
## 5 main 函数测试。
[day5-multi-nodes/main.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day5-multi-nodes)
```go
var db = map[string]string{
"Tom": "630",
"Jack": "589",
"Sam": "567",
}
func createGroup() *geecache.Group {
return geecache.NewGroup("scores", 2<<10, geecache.GetterFunc(
func(key string) ([]byte, error) {
log.Println("[SlowDB] search key", key)
if v, ok := db[key]; ok {
return []byte(v), nil
}
return nil, fmt.Errorf("%s not exist", key)
}))
}
func startCacheServer(addr string, addrs []string, gee *geecache.Group) {
peers := geecache.NewHTTPPool(addr)
peers.Set(addrs...)
gee.RegisterPeers(peers)
log.Println("geecache is running at", addr)
log.Fatal(http.ListenAndServe(addr[7:], peers))
}
func startAPIServer(apiAddr string, gee *geecache.Group) {
http.Handle("/api", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
key := r.URL.Query().Get("key")
view, err := gee.Get(key)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(view.ByteSlice())
}))
log.Println("fontend server is running at", apiAddr)
log.Fatal(http.ListenAndServe(apiAddr[7:], nil))
}
func main() {
var port int
var api bool
flag.IntVar(&port, "port", 8001, "Geecache server port")
flag.BoolVar(&api, "api", false, "Start a api server?")
flag.Parse()
apiAddr := "http://localhost:9999"
addrMap := map[int]string{
8001: "http://localhost:8001",
8002: "http://localhost:8002",
8003: "http://localhost:8003",
}
var addrs []string
for _, v := range addrMap {
addrs = append(addrs, v)
}
gee := createGroup()
if api {
go startAPIServer(apiAddr, gee)
}
startCacheServer(addrMap[port], []string(addrs), gee)
}
```
main 函数的代码比较多,但是逻辑是非常简单的。
- `startCacheServer()` 用来启动缓存服务器:创建 HTTPPool,添加节点信息,注册到 gee 中,启动 HTTP 服务(共3个端口,8001/8002/8003),用户不感知。
- `startAPIServer()` 用来启动一个 API 服务(端口 9999),与用户进行交互,用户感知。
- `main()` 函数需要命令行传入 `port` 和 `api` 2 个参数,用来在指定端口启动 HTTP 服务。
为了方便,我们将启动的命令封装为一个 `shell` 脚本:
```bash
#!/bin/bash
trap "rm server;kill 0" EXIT
go build -o server
./server -port=8001 &
./server -port=8002 &
./server -port=8003 -api=1 &
sleep 2
echo ">>> start test"
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
curl "http://localhost:9999/api?key=Tom" &
wait
```
- `trap` 命令用于在 shell 脚本退出时,删掉临时文件,结束子进程。
```bash
$ ./run.sh
2020/02/16 21:17:43 geecache is running at http://localhost:8001
2020/02/16 21:17:43 geecache is running at http://localhost:8002
2020/02/16 21:17:43 geecache is running at http://localhost:8003
2020/02/16 21:17:43 fontend server is running at http://localhost:9999
>>> start test
2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001
2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001
2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001
...
630630630
```
此时,我们可以打开一个新的 shell,进行测试:
```bash
$ curl "http://localhost:9999/api?key=Tom"
630
$ curl "http://localhost:9999/api?key=kkk"
kkk not exist
```
测试的时候,我们并发了 3 个请求 `?key=Tom`,从日志中可以看到,三次均选择了节点 `8001`,这是一致性哈希算法的功劳。但是有一个问题在于,同时向 `8001` 发起了 3 次请求。试想,假如有 10 万个在并发请求该数据呢?那就会向 `8001` 同时发起 10 万次请求,如果 `8001` 又同时向数据库发起 10 万次查询请求,很容易导致缓存被击穿。
三次请求的结果是一致的,对于相同的 key,能不能只向 `8001` 发起一次请求?这个问题下一次解决。
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
================================================
FILE: gee-cache/doc/geecache-day6.md
================================================
---
title: 动手写分布式缓存 - GeeCache第六天 防止缓存击穿
date: 2020-02-16 23:00:00
description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了缓存雪崩、缓存击穿与缓存穿透的概念,使用 singleflight 防止缓存击穿,实现与测试。
tags:
- Go
nav: 从零实现
categories:
- 分布式缓存 - GeeCache
keywords:
- Go语言
- 从零实现
- HTTP客户端
- 分布式节点
image: post/geecache-day6/singleflight_logo.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day6 防止缓存击穿
---

本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第六篇。
- 缓存雪崩、缓存击穿与缓存穿透的概念简介。
- 使用 singleflight 防止缓存击穿,实现与测试。**代码约70行**
## 1 缓存雪崩、缓存击穿与缓存穿透
[GeeCache 第五天](https://geektutu.com/post/geecache-day5.html) 提到了缓存雪崩和缓存击穿,在这里做下总结:
> **缓存雪崩**:缓存在同一时刻全部失效,造成瞬时DB请求量大、压力骤增,引起雪崩。缓存雪崩通常因为缓存服务器宕机、缓存的 key 设置了相同的过期时间等引起。
> **缓存击穿**:一个存在的key,在缓存过期的一刻,同时有大量的请求,这些请求都会击穿到 DB ,造成瞬时DB请求量大、压力骤增。
> **缓存穿透**:查询一个不存在的数据,因为不存在则不会写到缓存中,所以每次都会去请求 DB,如果瞬间流量过大,穿透到 DB,导致宕机。
## 2 singleflight 的实现
还记得 [GeeCache 第五天](https://geektutu.com/post/geecache-day5.html) 最后的测试结果吗?
```bash
2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001
2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001
2020/02/16 21:17:45 [Server http://localhost:8003] Pick peer http://localhost:8001
```
我们并发了 N 个请求 `?key=Tom`,8003 节点向 8001 同时发起了 N 次请求。假设对数据库的访问没有做任何限制的,很可能向数据库也发起 N 次请求,容易导致缓存击穿和穿透。即使对数据库做了防护,HTTP 请求是非常耗费资源的操作,针对相同的 key,8003 节点向 8001 发起三次请求也是没有必要的。那这种情况下,我们如何做到只向远端节点发起一次请求呢?
geecache 实现了一个名为 singleflight 的 package 来解决这个问题。
[day6-single-flight/geecache/singleflight/singleflight.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day6-single-flight/geecache/singleflight)
首先创建 `call` 和 `Group` 类型。
```go
package singleflight
import "sync"
type call struct {
wg sync.WaitGroup
val interface{}
err error
}
type Group struct {
mu sync.Mutex // protects m
m map[string]*call
}
```
- `call` 代表正在进行中,或已经结束的请求。使用 `sync.WaitGroup` 锁避免重入。
- `Group` 是 singleflight 的主数据结构,管理不同 key 的请求(call)。
实现 `Do` 方法
```go
func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.err
}
```
- Do 方法,接收 2 个参数,第一个参数是 `key`,第二个参数是一个函数 `fn`。Do 的作用就是,针对相同的 key,无论 Do 被调用多少次,函数 `fn` 都只会被调用一次,等待 fn 调用结束了,返回返回值或错误。
`g.mu` 是保护 Group 的成员变量 `m` 不被并发读写而加上的锁。为了便于理解 `Do` 函数,我们将 `g.mu` 暂时去掉。并且把 `g.m` 延迟初始化的部分去掉,延迟初始化的目的很简单,提高内存使用效率。
剩下的逻辑就很清晰了:
```go
func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
if c, ok := g.m[key]; ok {
c.wg.Wait() // 如果请求正在进行中,则等待
return c.val, c.err // 请求结束,返回结果
}
c := new(call)
c.wg.Add(1) // 发起请求前加锁
g.m[key] = c // 添加到 g.m,表明 key 已经有对应的请求在处理
c.val, c.err = fn() // 调用 fn,发起请求
c.wg.Done() // 请求结束
delete(g.m, key) // 更新 g.m
return c.val, c.err // 返回结果
}
```
并发协程之间不需要消息传递,非常适合 `sync.WaitGroup`。
- wg.Add(1) 锁加1。
- wg.Wait() 阻塞,直到锁被释放。
- wg.Done() 锁减1。
## 3 singleflight 的使用
[day6-single-flight/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day6-single-flight/geecache)
```go
type Group struct {
name string
getter Getter
mainCache cache
peers PeerPicker
// use singleflight.Group to make sure that
// each key is only fetched once
loader *singleflight.Group
}
func NewGroup(name string, cacheBytes int64, getter Getter) *Group {
// ...
g := &Group{
// ...
loader: &singleflight.Group{},
}
return g
}
func (g *Group) load(key string) (value ByteView, err error) {
// each key is only fetched once (either locally or remotely)
// regardless of the number of concurrent callers.
viewi, err := g.loader.Do(key, func() (interface{}, error) {
if g.peers != nil {
if peer, ok := g.peers.PickPeer(key); ok {
if value, err = g.getFromPeer(peer, key); err == nil {
return value, nil
}
log.Println("[GeeCache] Failed to get from peer", err)
}
}
return g.getLocally(key)
})
if err == nil {
return viewi.(ByteView), nil
}
return
}
```
- 修改 `geecache.go` 中的 `Group`,添加成员变量 loader,并更新构建函数 `NewGroup`。
- 修改 `load` 函数,将原来的 load 的逻辑,使用 `g.loader.Do` 包裹起来即可,这样确保了并发场景下针对相同的 key,`load` 过程只会调用一次。
## 4 测试
执行 `run.sh` 就可以看到效果了。
```bash
$ ./run.sh
2020/02/16 22:36:00 [Server http://localhost:8003] Pick peer http://localhost:8001
2020/02/16 22:36:00 [Server http://localhost:8001] GET /_geecache/scores/Tom
2020/02/16 22:36:00 [SlowDB] search key Tom
630630630
```
可以看到,向 API 发起了三次并发请求,但8003 只向 8001 发起了一次请求,就搞定了。
如果并发度不够高,可能仍会看到向 8001 请求三次的场景。这种情况下三次请求是串行执行的,并没有触发 `singleflight` 的锁机制工作,可以加大并发数量再测试。即,将 `run.sh` 中的 `curl` 命令复制 N 次。
## 附 推荐
- [Go 语言简明教程#并发编程](https://geektutu.com/post/quick-golang.html#7-%E5%B9%B6%E5%8F%91%E7%BC%96%E7%A8%8B-goroutine)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
================================================
FILE: gee-cache/doc/geecache-day7.md
================================================
---
title: 动手写分布式缓存 - GeeCache第七天 使用 Protobuf 通信
date: 2020-02-17 00:30:00
description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。本文介绍了使用 protobuf(protocol buffer) 进行节点间通信,编码报文,提高效率
tags:
- Go
nav: 从零实现
categories:
- 分布式缓存 - GeeCache
keywords:
- Go语言
- 从零实现
- HTTP客户端
- 分布式节点
image: post/geecache-day7/protobuf_logo.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day7 使用 Protobuf 通信
---

本文是[7天用Go从零实现分布式缓存GeeCache](https://geektutu.com/post/geecache.html)的第七篇。
- 为什么要使用 protobuf?
- 使用 protobuf 进行节点间通信,编码报文,提高效率。**代码约50行**
## 1 为什么要使用 protobuf
> protobuf 即 Protocol Buffers,Google 开发的一种数据描述语言,是一种轻便高效的结构化数据存储格式,与语言、平台无关,可扩展可序列化。protobuf 以二进制方式存储,占用空间小。
protobuf 的安装和使用教程请移步 [Go Protobuf 简明教程](https://geektutu.com/post/quick-go-protobuf.html),这篇文章就不再赘述了。protobuf 广泛地应用于远程过程调用(RPC) 的二进制传输,使用 protobuf 的目的非常简单,为了获得更高的性能。传输前使用 protobuf 编码,接收方再进行解码,可以显著地降低二进制传输的大小。另外一方面,protobuf 可非常适合传输结构化数据,便于通信字段的扩展。
使用 protobuf 一般分为以下 2 步:
- 按照 protobuf 的语法,在 `.proto` 文件中定义数据结构,并使用 `protoc` 生成 Go 代码(`.proto` 文件是跨平台的,还可以生成 C、Java 等其他源码文件)。
- 在项目代码中引用生成的 Go 代码。
## 2 使用 protobuf 通信
新建 package `geecachepb`,定义 `geecachepb.proto`
[day7-proto-buf/geecache/geecachepb/geecachepb.proto - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day7-proto-buf/geecache/geecachepb)
```go
syntax = "proto3";
package geecachepb;
message Request {
string group = 1;
string key = 2;
}
message Response {
bytes value = 1;
}
service GroupCache {
rpc Get(Request) returns (Response);
}
```
- `Request` 包含 2 个字段, group 和 cache,这与我们之前定义的接口 `/_geecache//` 所需的参数吻合。
- `Response` 包含 1 个字段,bytes,类型为 byte 数组,与之前吻合。
生成 `geecache.pb.go`
```bash
$ protoc --go_out=. *.proto
$ ls
geecachepb.pb.go geecachepb.proto
```
可以看到 `geecachepb.pb.go` 中有如下数据类型:
```go
type Request struct {
Group string `protobuf:"bytes,1,opt,name=group,proto3" json:"group,omitempty"`
Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"`
...
}
type Response struct {
Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"`
}
```
接下来,修改 `peers.go` 中的 `PeerGetter` 接口,参数使用 `geecachepb.pb.go` 中的数据类型。
[day7-proto-buf/geecache/peers.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day7-proto-buf/geecache)
```go
import pb "geecache/geecachepb"
type PeerGetter interface {
Get(in *pb.Request, out *pb.Response) error
}
```
最后,修改 `geecache.go` 和 `http.go` 中使用了 `PeerGetter` 接口的地方。
[day7-proto-buf/geecache/geecache.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day7-proto-buf/geecache)
```go
import (
// ...
pb "geecache/geecachepb"
)
func (g *Group) getFromPeer(peer PeerGetter, key string) (ByteView, error) {
req := &pb.Request{
Group: g.name,
Key: key,
}
res := &pb.Response{}
err := peer.Get(req, res)
if err != nil {
return ByteView{}, err
}
return ByteView{b: res.Value}, nil
}
```
[day7-proto-buf/geecache/http.go - github](https://github.com/geektutu/7days-golang/tree/master/gee-cache/day7-proto-buf/geecache)
```go
import (
// ...
pb "geecache/geecachepb"
"github.com/golang/protobuf/proto"
)
func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// ...
// Write the value to the response body as a proto message.
body, err := proto.Marshal(&pb.Response{Value: view.ByteSlice()})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(body)
}
func (h *httpGetter) Get(in *pb.Request, out *pb.Response) error {
u := fmt.Sprintf(
"%v%v/%v",
h.baseURL,
url.QueryEscape(in.GetGroup()),
url.QueryEscape(in.GetKey()),
)
res, err := http.Get(u)
// ...
if err = proto.Unmarshal(bytes, out); err != nil {
return fmt.Errorf("decoding response body: %v", err)
}
return nil
}
```
- `ServeHTTP()` 中使用 `proto.Marshal()` 编码 HTTP 响应。
- `Get()` 中使用 `proto.Unmarshal()` 解码 HTTP 响应。
至此,我们已经将 HTTP 通信的中间载体替换成了 protobuf。运行 `run.sh` 即可以测试 GeeCache 能否正常工作。
## 总结
到这一篇为止,7 天用 Go 动手写/从零实现分布式缓存 GeeCache 这个系列就完成了。简单回顾下。第一天,为了解决资源限制的问题,实现了 LRU 缓存淘汰算法;第二天实现了单机并发,并给用户提供了自定义数据源的回调函数;第三天实现了 HTTP 服务端;第四天实现了一致性哈希算法,解决远程节点的挑选问题;第五天创建 HTTP 客户端,实现了多节点间的通信;第六天实现了 singleflight 解决缓存击穿的问题;第七天,使用 protobuf 库,优化了节点间通信的性能。如果看到这里,还没有动手写的话呢,赶紧动手写起来吧。一天差不多只需要实现 100 行代码呢。
## 附 推荐
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [Go Protobuf 简明教程](https://geektutu.com/post/quick-go-protobuf.html)
================================================
FILE: gee-cache/doc/geecache.md
================================================
---
title: 7天用Go从零实现分布式缓存GeeCache
date: 2020-02-08 01:00:00
description: 7天用 Go语言/golang 从零实现分布式缓存 GeeCache 教程(7 days implement golang distributed cache from scratch tutorial),动手写分布式缓存,参照 groupcache 的实现。功能包括单机/分布式缓存,LRU (Least Recently Used) 缓存策略,防止缓存击穿、一致性哈希(Consistent Hash),protobuf 通信等。
tags:
- Go
nav: 从零实现
categories:
- 分布式缓存 - GeeCache
keywords:
- Go语言
- 从零实现分布式缓存
- 动手写分布式缓存
image: post/geecache/geecache_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day0 序言
---

## 1 谈谈分布式缓存
第一次请求时将一些耗时操作的结果暂存,以后遇到相同的请求,直接返回暂存的数据。我想这是大部分童鞋对于缓存的理解。在计算机系统中,缓存无处不在,比如我们访问一个网页,网页和引用的 JS/CSS 等静态文件,根据不同的策略,会缓存在浏览器本地或是 CDN 服务器,那在第二次访问的时候,就会觉得网页加载的速度快了不少;比如微博的点赞的数量,不可能每个人每次访问,都从数据库中查找所有点赞的记录再统计,数据库的操作是很耗时的,很难支持那么大的流量,所以一般点赞这类数据是缓存在 Redis 服务集群中的。
> 商业世界里,现金为王;架构世界里,缓存为王。
缓存中最简单的莫过于存储在内存中的键值对缓存了。说到键值对,很容易想到的是字典(dict)类型,Go 语言中称之为 map。那直接创建一个 map,每次有新数据就往 map 中插入不就好了,这不就是键值对缓存么?这样做有什么问题呢?
1)内存不够了怎么办?
那就随机删掉几条数据好了。随机删掉好呢?还是按照时间顺序好呢?或者是有没有其他更好的淘汰策略呢?不同数据的访问频率是不一样的,优先删除访问频率低的数据是不是更好呢?数据的访问频率可能随着时间变化,那优先删除最近最少访问的数据可能是一个更好的选择。我们需要实现一个合理的淘汰策略。
2)并发写入冲突了怎么办?
对缓存的访问,一般不可能是串行的。map 是没有并发保护的,应对并发的场景,修改操作(包括新增,更新和删除)需要加锁。
3)单机性能不够怎么办?
单台计算机的资源是有限的,计算、存储等都是有限的。随着业务量和访问量的增加,单台机器很容易遇到瓶颈。如果利用多台计算机的资源,并行处理提高性能就要缓存应用能够支持分布式,这称为水平扩展(scale horizontally)。与水平扩展相对应的是垂直扩展(scale vertically),即通过增加单个节点的计算、存储、带宽等,来提高系统的性能,硬件的成本和性能并非呈线性关系,大部分情况下,分布式系统是一个更优的选择。
4)...
## 2 关于 GeeCache
设计一个分布式缓存系统,需要考虑资源控制、淘汰策略、并发、分布式节点通信等各个方面的问题。而且,针对不同的应用场景,还需要在不同的特性之间权衡,例如,是否需要支持缓存更新?还是假定缓存在淘汰之前是不允许改变的。不同的权衡对应着不同的实现。
[groupcache](https://github.com/golang/groupcache) 是 Go 语言版的 memcached,目的是在某些特定场合替代 memcached。groupcache 的作者也是 memcached 的作者。无论是了解单机缓存还是分布式缓存,深入学习这个库的实现都是非常有意义的。
`GeeCache` 基本上模仿了 [groupcache](https://github.com/golang/groupcache) 的实现,为了将代码量限制在 500 行左右(groupcache 约 3000 行),裁剪了部分功能。但总体实现上,还是与 groupcache 非常接近的。支持特性有:
- 单机缓存和基于 HTTP 的分布式缓存
- 最近最少访问(Least Recently Used, LRU) 缓存策略
- 使用 Go 锁机制防止缓存击穿
- 使用一致性哈希选择节点,实现负载均衡
- 使用 protobuf 优化节点间二进制通信
- ...
`GeeCache` 分7天实现,每天完成的部分都是可以独立运行和测试的,就像搭积木一样,每天实现的特性组合在一起就是最终的分布式缓存系统。每天的代码在 100 行左右。
## 3 目录
- 第一天:[LRU 缓存淘汰策略](https://geektutu.com/post/geecache-day1.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day1-lru)
- 第二天:[单机并发缓存](https://geektutu.com/post/geecache-day2.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day2-single-node)
- 第三天:[HTTP 服务端](https://geektutu.com/post/geecache-day3.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day3-http-server)
- 第四天:[一致性哈希(Hash)](https://geektutu.com/post/geecache-day4.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day4-consistent-hash)
- 第五天:[分布式节点](https://geektutu.com/post/geecache-day5.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day5-multi-nodes)
- 第六天:[防止缓存击穿](https://geektutu.com/post/geecache-day6.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day6-single-flight)
- 第七天:[使用 Protobuf 通信](https://geektutu.com/post/geecache-day7.html) | [Code - Github](https://github.com/geektutu/7days-golang/blob/master/gee-cache/day7-proto-buf)
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [Go Protobuf 简明教程](https://geektutu.com/post/quick-go-protobuf.html)
================================================
FILE: gee-orm/day1-database-sql/cmd_test/main.go
================================================
package main
import (
"fmt"
"geeorm"
_ "github.com/mattn/go-sqlite3"
)
func main() {
engine, _ := geeorm.NewEngine("sqlite3", "gee.db")
defer engine.Close()
s := engine.NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
count, _ := result.RowsAffected()
fmt.Printf("Exec success, %d affected\n", count)
}
================================================
FILE: gee-orm/day1-database-sql/geeorm.go
================================================
package geeorm
import (
"database/sql"
"geeorm/log"
"geeorm/session"
)
// Engine is the main struct of geeorm, manages all db sessions and transactions.
type Engine struct {
db *sql.DB
}
// NewEngine create a instance of Engine
// connect database and ping it to test whether it's alive
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
e = &Engine{db: db}
log.Info("Connect database success")
return
}
// Close database connection
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
}
log.Info("Close database success")
}
// NewSession creates a new session for next operations
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db)
}
================================================
FILE: gee-orm/day1-database-sql/geeorm_test.go
================================================
package geeorm
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func OpenDB(t *testing.T) *Engine {
t.Helper()
engine, err := NewEngine("sqlite3", "gee.db")
if err != nil {
t.Fatal("failed to connect", err)
}
return engine
}
func TestNewEngine(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
}
================================================
FILE: gee-orm/day1-database-sql/go.mod
================================================
module geeorm
go 1.13
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
================================================
FILE: gee-orm/day1-database-sql/log/log.go
================================================
package log
import (
"io/ioutil"
"log"
"os"
"sync"
)
var (
errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile)
loggers = []*log.Logger{errorLog, infoLog}
mu sync.Mutex
)
// log methods
var (
Error = errorLog.Println
Errorf = errorLog.Printf
Info = infoLog.Println
Infof = infoLog.Printf
)
// log levels
const (
InfoLevel = iota
ErrorLevel
Disabled
)
// SetLevel controls log level
func SetLevel(level int) {
mu.Lock()
defer mu.Unlock()
for _, logger := range loggers {
logger.SetOutput(os.Stdout)
}
if ErrorLevel < level {
errorLog.SetOutput(ioutil.Discard)
}
if InfoLevel < level {
infoLog.SetOutput(ioutil.Discard)
}
}
================================================
FILE: gee-orm/day1-database-sql/log/log_test.go
================================================
package log
import (
"os"
"testing"
)
func TestSetLevel(t *testing.T) {
SetLevel(ErrorLevel)
if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout {
t.Fatal("failed to set log level")
}
SetLevel(Disabled)
if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout {
t.Fatal("failed to set log level")
}
}
================================================
FILE: gee-orm/day1-database-sql/session/raw.go
================================================
package session
import (
"database/sql"
"geeorm/log"
"strings"
)
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
db *sql.DB
sql strings.Builder
sqlVars []interface{}
}
// New creates a instance of Session
func New(db *sql.DB) *Session {
return &Session{db: db}
}
// Clear initialize the state of a session
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
}
// DB returns *sql.DB
func (s *Session) DB() *sql.DB {
return s.db
}
// Exec raw sql with sqlVars
func (s *Session) Exec() (result sql.Result, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// QueryRow gets a record from db
func (s *Session) QueryRow() *sql.Row {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
return s.DB().QueryRow(s.sql.String(), s.sqlVars...)
}
// QueryRows gets a list of records from db
func (s *Session) QueryRows() (rows *sql.Rows, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// Raw appends sql and sqlVars
func (s *Session) Raw(sql string, values ...interface{}) *Session {
s.sql.WriteString(sql)
s.sql.WriteString(" ")
s.sqlVars = append(s.sqlVars, values...)
return s
}
================================================
FILE: gee-orm/day1-database-sql/session/raw_test.go
================================================
package session
import (
"database/sql"
"os"
"testing"
_ "github.com/mattn/go-sqlite3"
)
var TestDB *sql.DB
func TestMain(m *testing.M) {
TestDB, _ = sql.Open("sqlite3", "../gee.db")
code := m.Run()
_ = TestDB.Close()
os.Exit(code)
}
func NewSession() *Session {
return New(TestDB)
}
func TestSession_Exec(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
if count, err := result.RowsAffected(); err != nil || count != 2 {
t.Fatal("expect 2, but got", count)
}
}
func TestSession_QueryRows(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
row := s.Raw("SELECT count(*) FROM User").QueryRow()
var count int
if err := row.Scan(&count); err != nil || count != 0 {
t.Fatal("failed to query db", err)
}
}
================================================
FILE: gee-orm/day2-reflect-schema/dialect/dialect.go
================================================
package dialect
import "reflect"
var dialectsMap = map[string]Dialect{}
// Dialect is an interface contains methods that a dialect has to implement
type Dialect interface {
DataTypeOf(typ reflect.Value) string
TableExistSQL(tableName string) (string, []interface{})
}
// RegisterDialect register a dialect to the global variable
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// Get the dialect from global variable if it exists
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
================================================
FILE: gee-orm/day2-reflect-schema/dialect/sqlite3.go
================================================
package dialect
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct{}
var _ Dialect = (*sqlite3)(nil)
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
// Get Data Type for sqlite3 Dialect
func (s *sqlite3) DataTypeOf(typ reflect.Value) string {
switch typ.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
return "text"
case reflect.Array, reflect.Slice:
return "blob"
case reflect.Struct:
if _, ok := typ.Interface().(time.Time); ok {
return "datetime"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind()))
}
// TableExistSQL returns SQL that judge whether the table exists in database
func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}
================================================
FILE: gee-orm/day2-reflect-schema/dialect/sqlite3_test.go
================================================
package dialect
import (
"reflect"
"testing"
)
func TestDataTypeOf(t *testing.T) {
dial := &sqlite3{}
cases := []struct {
Value interface{}
Type string
}{
{"Tom", "text"},
{123, "integer"},
{1.2, "real"},
{[]int{1, 2, 3}, "blob"},
}
for _, c := range cases {
if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type {
t.Fatalf("expect %s, but got %s", c.Type, typ)
}
}
}
================================================
FILE: gee-orm/day2-reflect-schema/geeorm.go
================================================
package geeorm
import (
"database/sql"
"geeorm/dialect"
"geeorm/log"
"geeorm/session"
)
// Engine is the main struct of geeorm, manages all db sessions and transactions.
type Engine struct {
db *sql.DB
dialect dialect.Dialect
}
// NewEngine create a instance of Engine
// connect database and ping it to test whether it's alive
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
// make sure the specific dialect exists
dial, ok := dialect.GetDialect(driver)
if !ok {
log.Errorf("dialect %s Not Found", driver)
return
}
e = &Engine{db: db, dialect: dial}
log.Info("Connect database success")
return
}
// Close database connection
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
}
log.Info("Close database success")
}
// NewSession creates a new session for next operations
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db, engine.dialect)
}
================================================
FILE: gee-orm/day2-reflect-schema/geeorm_test.go
================================================
package geeorm
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func OpenDB(t *testing.T) *Engine {
t.Helper()
engine, err := NewEngine("sqlite3", "gee.db")
if err != nil {
t.Fatal("failed to connect", err)
}
return engine
}
func TestNewEngine(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
}
================================================
FILE: gee-orm/day2-reflect-schema/go.mod
================================================
module geeorm
go 1.13
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
================================================
FILE: gee-orm/day2-reflect-schema/log/log.go
================================================
package log
import (
"io/ioutil"
"log"
"os"
"sync"
)
var (
errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile)
loggers = []*log.Logger{errorLog, infoLog}
mu sync.Mutex
)
// log methods
var (
Error = errorLog.Println
Errorf = errorLog.Printf
Info = infoLog.Println
Infof = infoLog.Printf
)
// log levels
const (
InfoLevel = iota
ErrorLevel
Disabled
)
// SetLevel controls log level
func SetLevel(level int) {
mu.Lock()
defer mu.Unlock()
for _, logger := range loggers {
logger.SetOutput(os.Stdout)
}
if ErrorLevel < level {
errorLog.SetOutput(ioutil.Discard)
}
if InfoLevel < level {
infoLog.SetOutput(ioutil.Discard)
}
}
================================================
FILE: gee-orm/day2-reflect-schema/log/log_test.go
================================================
package log
import (
"os"
"testing"
)
func TestSetLevel(t *testing.T) {
SetLevel(ErrorLevel)
if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout {
t.Fatal("failed to set log level")
}
SetLevel(Disabled)
if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout {
t.Fatal("failed to set log level")
}
}
================================================
FILE: gee-orm/day2-reflect-schema/schema/schema.go
================================================
package schema
import (
"geeorm/dialect"
"go/ast"
"reflect"
)
// Field represents a column of database
type Field struct {
Name string
Type string
Tag string
}
// Schema represents a table of database
type Schema struct {
Model interface{}
Name string
Fields []*Field
FieldNames []string
fieldMap map[string]*Field
}
// GetField returns field by name
func (schema *Schema) GetField(name string) *Field {
return schema.fieldMap[name]
}
// Values return the values of dest's member variables
func (schema *Schema) RecordValues(dest interface{}) []interface{} {
destValue := reflect.Indirect(reflect.ValueOf(dest))
var fieldValues []interface{}
for _, field := range schema.Fields {
fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface())
}
return fieldValues
}
type ITableName interface {
TableName() string
}
// Parse a struct to a Schema instance
func Parse(dest interface{}, d dialect.Dialect) *Schema {
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
var tableName string
t, ok := dest.(ITableName)
if !ok {
tableName = modelType.Name()
} else {
tableName = t.TableName()
}
schema := &Schema{
Model: dest,
Name: tableName,
fieldMap: make(map[string]*Field),
}
for i := 0; i < modelType.NumField(); i++ {
p := modelType.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) {
field := &Field{
Name: p.Name,
Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))),
}
if v, ok := p.Tag.Lookup("geeorm"); ok {
field.Tag = v
}
schema.Fields = append(schema.Fields, field)
schema.FieldNames = append(schema.FieldNames, p.Name)
schema.fieldMap[p.Name] = field
}
}
return schema
}
================================================
FILE: gee-orm/day2-reflect-schema/schema/schema_test.go
================================================
package schema
import (
"geeorm/dialect"
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
var TestDial, _ = dialect.GetDialect("sqlite3")
func TestParse(t *testing.T) {
schema := Parse(&User{}, TestDial)
if schema.Name != "User" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
if schema.GetField("Name").Tag != "PRIMARY KEY" {
t.Fatal("failed to parse primary key")
}
}
func TestSchema_RecordValues(t *testing.T) {
schema := Parse(&User{}, TestDial)
values := schema.RecordValues(&User{"Tom", 18})
name := values[0].(string)
age := values[1].(int)
if name != "Tom" || age != 18 {
t.Fatal("failed to get values")
}
}
type UserTest struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func (u *UserTest) TableName() string {
return "ns_user_test"
}
func TestSchema_TableName(t *testing.T) {
schema := Parse(&UserTest{}, TestDial)
if schema.Name != "ns_user_test" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
}
================================================
FILE: gee-orm/day2-reflect-schema/session/raw.go
================================================
package session
import (
"database/sql"
"geeorm/dialect"
"geeorm/log"
"geeorm/schema"
"strings"
)
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
db *sql.DB
dialect dialect.Dialect
refTable *schema.Schema
sql strings.Builder
sqlVars []interface{}
}
// New creates a instance of Session
func New(db *sql.DB, dialect dialect.Dialect) *Session {
return &Session{
db: db,
dialect: dialect,
}
}
// Clear initialize the state of a session
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
}
// DB returns *sql.DB
func (s *Session) DB() *sql.DB {
return s.db
}
// Exec raw sql with sqlVars
func (s *Session) Exec() (result sql.Result, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// QueryRow gets a record from db
func (s *Session) QueryRow() *sql.Row {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
return s.DB().QueryRow(s.sql.String(), s.sqlVars...)
}
// QueryRows gets a list of records from db
func (s *Session) QueryRows() (rows *sql.Rows, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// Raw appends sql and sqlVars
func (s *Session) Raw(sql string, values ...interface{}) *Session {
s.sql.WriteString(sql)
s.sql.WriteString(" ")
s.sqlVars = append(s.sqlVars, values...)
return s
}
================================================
FILE: gee-orm/day2-reflect-schema/session/raw_test.go
================================================
package session
import (
"database/sql"
"os"
"testing"
"geeorm/dialect"
_ "github.com/mattn/go-sqlite3"
)
var (
TestDB *sql.DB
TestDial, _ = dialect.GetDialect("sqlite3")
)
func TestMain(m *testing.M) {
TestDB, _ = sql.Open("sqlite3", "../gee.db")
code := m.Run()
_ = TestDB.Close()
os.Exit(code)
}
func NewSession() *Session {
return New(TestDB, TestDial)
}
func TestSession_Exec(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
if count, err := result.RowsAffected(); err != nil || count != 2 {
t.Fatal("expect 2, but got", count)
}
}
func TestSession_QueryRows(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
row := s.Raw("SELECT count(*) FROM User").QueryRow()
var count int
if err := row.Scan(&count); err != nil || count != 0 {
t.Fatal("failed to query db", err)
}
}
================================================
FILE: gee-orm/day2-reflect-schema/session/table.go
================================================
package session
import (
"fmt"
"geeorm/log"
"reflect"
"strings"
"geeorm/schema"
)
// Model assigns refTable
func (s *Session) Model(value interface{}) *Session {
// nil or different model, update refTable
if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) {
s.refTable = schema.Parse(value, s.dialect)
}
return s
}
// RefTable returns a Schema instance that contains all parsed fields
func (s *Session) RefTable() *schema.Schema {
if s.refTable == nil {
log.Error("Model is not set")
}
return s.refTable
}
// CreateTable create a table in database with a model
func (s *Session) CreateTable() error {
table := s.RefTable()
var columns []string
for _, field := range table.Fields {
columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag))
}
desc := strings.Join(columns, ",")
_, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec()
return err
}
// DropTable drops a table with the name of model
func (s *Session) DropTable() error {
_, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec()
return err
}
// HasTable returns true of the table exists
func (s *Session) HasTable() bool {
sql, values := s.dialect.TableExistSQL(s.RefTable().Name)
row := s.Raw(sql, values...).QueryRow()
var tmp string
_ = row.Scan(&tmp)
return tmp == s.RefTable().Name
}
================================================
FILE: gee-orm/day2-reflect-schema/session/table_test.go
================================================
package session
import (
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestSession_CreateTable(t *testing.T) {
s := NewSession().Model(&User{})
_ = s.DropTable()
_ = s.CreateTable()
if !s.HasTable() {
t.Fatal("Failed to create table User")
}
}
func TestSession_Model(t *testing.T) {
s := NewSession().Model(&User{})
table := s.RefTable()
s.Model(&Session{})
if table.Name != "User" || s.RefTable().Name != "Session" {
t.Fatal("Failed to change model")
}
}
================================================
FILE: gee-orm/day3-save-query/clause/clause.go
================================================
package clause
import (
"strings"
)
// Clause contains SQL conditions
type Clause struct {
sql map[Type]string
sqlVars map[Type][]interface{}
}
// Type is the type of Clause
type Type int
// Support types for Clause
const (
INSERT Type = iota
VALUES
SELECT
LIMIT
WHERE
ORDERBY
)
// Set adds a sub clause of specific type
func (c *Clause) Set(name Type, vars ...interface{}) {
if c.sql == nil {
c.sql = make(map[Type]string)
c.sqlVars = make(map[Type][]interface{})
}
sql, vars := generators[name](vars...)
c.sql[name] = sql
c.sqlVars[name] = vars
}
// Build generate the final SQL and SQLVars
func (c *Clause) Build(orders ...Type) (string, []interface{}) {
var sqls []string
var vars []interface{}
for _, order := range orders {
if sql, ok := c.sql[order]; ok {
sqls = append(sqls, sql)
vars = append(vars, c.sqlVars[order]...)
}
}
return strings.Join(sqls, " "), vars
}
================================================
FILE: gee-orm/day3-save-query/clause/clause_test.go
================================================
package clause
import (
"reflect"
"testing"
)
func TestClause_Set(t *testing.T) {
var clause Clause
clause.Set(INSERT, "User", []string{"Name", "Age"})
sql := clause.sql[INSERT]
vars := clause.sqlVars[INSERT]
t.Log(sql, vars)
if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 {
t.Fatal("failed to get clause")
}
}
func testSelect(t *testing.T) {
var clause Clause
clause.Set(LIMIT, 3)
clause.Set(SELECT, "User", []string{"*"})
clause.Set(WHERE, "Name = ?", "Tom")
clause.Set(ORDERBY, "Age ASC")
sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT)
t.Log(sql, vars)
if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) {
t.Fatal("failed to build SQLVars")
}
}
func TestClause_Build(t *testing.T) {
t.Run("select", func(t *testing.T) {
testSelect(t)
})
}
================================================
FILE: gee-orm/day3-save-query/clause/generator.go
================================================
package clause
import (
"fmt"
"strings"
)
type generator func(values ...interface{}) (string, []interface{})
var generators map[Type]generator
func init() {
generators = make(map[Type]generator)
generators[INSERT] = _insert
generators[VALUES] = _values
generators[SELECT] = _select
generators[LIMIT] = _limit
generators[WHERE] = _where
generators[ORDERBY] = _orderBy
}
func genBindVars(num int) string {
var vars []string
for i := 0; i < num; i++ {
vars = append(vars, "?")
}
return strings.Join(vars, ", ")
}
func _insert(values ...interface{}) (string, []interface{}) {
// INSERT INTO $tableName ($fields)
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{}
}
func _values(values ...interface{}) (string, []interface{}) {
// VALUES ($v1), ($v2), ...
var bindStr string
var sql strings.Builder
var vars []interface{}
sql.WriteString("VALUES ")
for i, value := range values {
v := value.([]interface{})
if bindStr == "" {
bindStr = genBindVars(len(v))
}
sql.WriteString(fmt.Sprintf("(%v)", bindStr))
if i+1 != len(values) {
sql.WriteString(", ")
}
vars = append(vars, v...)
}
return sql.String(), vars
}
func _select(values ...interface{}) (string, []interface{}) {
// SELECT $fields FROM $tableName
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{}
}
func _limit(values ...interface{}) (string, []interface{}) {
// LIMIT $num
return "LIMIT ?", values
}
func _where(values ...interface{}) (string, []interface{}) {
// WHERE $desc
desc, vars := values[0], values[1:]
return fmt.Sprintf("WHERE %s", desc), vars
}
func _orderBy(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{}
}
================================================
FILE: gee-orm/day3-save-query/dialect/dialect.go
================================================
package dialect
import "reflect"
var dialectsMap = map[string]Dialect{}
// Dialect is an interface contains methods that a dialect has to implement
type Dialect interface {
DataTypeOf(typ reflect.Value) string
TableExistSQL(tableName string) (string, []interface{})
}
// RegisterDialect register a dialect to the global variable
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// Get the dialect from global variable if it exists
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
================================================
FILE: gee-orm/day3-save-query/dialect/sqlite3.go
================================================
package dialect
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct{}
var _ Dialect = (*sqlite3)(nil)
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
// Get Data Type for sqlite3 Dialect
func (s *sqlite3) DataTypeOf(typ reflect.Value) string {
switch typ.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
return "text"
case reflect.Array, reflect.Slice:
return "blob"
case reflect.Struct:
if _, ok := typ.Interface().(time.Time); ok {
return "datetime"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind()))
}
// TableExistSQL returns SQL that judge whether the table exists in database
func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}
================================================
FILE: gee-orm/day3-save-query/dialect/sqlite3_test.go
================================================
package dialect
import (
"reflect"
"testing"
)
func TestDataTypeOf(t *testing.T) {
dial := &sqlite3{}
cases := []struct {
Value interface{}
Type string
}{
{"Tom", "text"},
{123, "integer"},
{1.2, "real"},
{[]int{1, 2, 3}, "blob"},
}
for _, c := range cases {
if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type {
t.Fatalf("expect %s, but got %s", c.Type, typ)
}
}
}
================================================
FILE: gee-orm/day3-save-query/geeorm.go
================================================
package geeorm
import (
"database/sql"
"geeorm/dialect"
"geeorm/log"
"geeorm/session"
)
// Engine is the main struct of geeorm, manages all db sessions and transactions.
type Engine struct {
db *sql.DB
dialect dialect.Dialect
}
// NewEngine create a instance of Engine
// connect database and ping it to test whether it's alive
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
// make sure the specific dialect exists
dial, ok := dialect.GetDialect(driver)
if !ok {
log.Errorf("dialect %s Not Found", driver)
return
}
e = &Engine{db: db, dialect: dial}
log.Info("Connect database success")
return
}
// Close database connection
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
}
log.Info("Close database success")
}
// NewSession creates a new session for next operations
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db, engine.dialect)
}
================================================
FILE: gee-orm/day3-save-query/geeorm_test.go
================================================
package geeorm
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func OpenDB(t *testing.T) *Engine {
t.Helper()
engine, err := NewEngine("sqlite3", "gee.db")
if err != nil {
t.Fatal("failed to connect", err)
}
return engine
}
func TestNewEngine(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
}
================================================
FILE: gee-orm/day3-save-query/go.mod
================================================
module geeorm
go 1.13
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
================================================
FILE: gee-orm/day3-save-query/log/log.go
================================================
package log
import (
"io/ioutil"
"log"
"os"
"sync"
)
var (
errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile)
loggers = []*log.Logger{errorLog, infoLog}
mu sync.Mutex
)
// log methods
var (
Error = errorLog.Println
Errorf = errorLog.Printf
Info = infoLog.Println
Infof = infoLog.Printf
)
// log levels
const (
InfoLevel = iota
ErrorLevel
Disabled
)
// SetLevel controls log level
func SetLevel(level int) {
mu.Lock()
defer mu.Unlock()
for _, logger := range loggers {
logger.SetOutput(os.Stdout)
}
if ErrorLevel < level {
errorLog.SetOutput(ioutil.Discard)
}
if InfoLevel < level {
infoLog.SetOutput(ioutil.Discard)
}
}
================================================
FILE: gee-orm/day3-save-query/log/log_test.go
================================================
package log
import (
"os"
"testing"
)
func TestSetLevel(t *testing.T) {
SetLevel(ErrorLevel)
if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout {
t.Fatal("failed to set log level")
}
SetLevel(Disabled)
if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout {
t.Fatal("failed to set log level")
}
}
================================================
FILE: gee-orm/day3-save-query/schema/schema.go
================================================
package schema
import (
"geeorm/dialect"
"go/ast"
"reflect"
)
// Field represents a column of database
type Field struct {
Name string
Type string
Tag string
}
// Schema represents a table of database
type Schema struct {
Model interface{}
Name string
Fields []*Field
FieldNames []string
fieldMap map[string]*Field
}
// GetField returns field by name
func (schema *Schema) GetField(name string) *Field {
return schema.fieldMap[name]
}
// Values return the values of dest's member variables
func (schema *Schema) RecordValues(dest interface{}) []interface{} {
destValue := reflect.Indirect(reflect.ValueOf(dest))
var fieldValues []interface{}
for _, field := range schema.Fields {
fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface())
}
return fieldValues
}
type ITableName interface {
TableName() string
}
// Parse a struct to a Schema instance
func Parse(dest interface{}, d dialect.Dialect) *Schema {
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
var tableName string
t, ok := dest.(ITableName)
if !ok {
tableName = modelType.Name()
} else {
tableName = t.TableName()
}
schema := &Schema{
Model: dest,
Name: tableName,
fieldMap: make(map[string]*Field),
}
for i := 0; i < modelType.NumField(); i++ {
p := modelType.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) {
field := &Field{
Name: p.Name,
Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))),
}
if v, ok := p.Tag.Lookup("geeorm"); ok {
field.Tag = v
}
schema.Fields = append(schema.Fields, field)
schema.FieldNames = append(schema.FieldNames, p.Name)
schema.fieldMap[p.Name] = field
}
}
return schema
}
================================================
FILE: gee-orm/day3-save-query/schema/schema_test.go
================================================
package schema
import (
"geeorm/dialect"
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
var TestDial, _ = dialect.GetDialect("sqlite3")
func TestParse(t *testing.T) {
schema := Parse(&User{}, TestDial)
if schema.Name != "User" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
if schema.GetField("Name").Tag != "PRIMARY KEY" {
t.Fatal("failed to parse primary key")
}
}
func TestSchema_RecordValues(t *testing.T) {
schema := Parse(&User{}, TestDial)
values := schema.RecordValues(&User{"Tom", 18})
name := values[0].(string)
age := values[1].(int)
if name != "Tom" || age != 18 {
t.Fatal("failed to get values")
}
}
type UserTest struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func (u *UserTest) TableName() string {
return "ns_user_test"
}
func TestSchema_TableName(t *testing.T) {
schema := Parse(&UserTest{}, TestDial)
if schema.Name != "ns_user_test" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
}
================================================
FILE: gee-orm/day3-save-query/session/raw.go
================================================
package session
import (
"database/sql"
"geeorm/clause"
"geeorm/dialect"
"geeorm/log"
"geeorm/schema"
"strings"
)
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
db *sql.DB
dialect dialect.Dialect
refTable *schema.Schema
clause clause.Clause
sql strings.Builder
sqlVars []interface{}
}
// New creates a instance of Session
func New(db *sql.DB, dialect dialect.Dialect) *Session {
return &Session{
db: db,
dialect: dialect,
}
}
// Clear initialize the state of a session
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
s.clause = clause.Clause{}
}
// DB returns *sql.DB
func (s *Session) DB() *sql.DB {
return s.db
}
// Exec raw sql with sqlVars
func (s *Session) Exec() (result sql.Result, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// QueryRow gets a record from db
func (s *Session) QueryRow() *sql.Row {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
return s.DB().QueryRow(s.sql.String(), s.sqlVars...)
}
// QueryRows gets a list of records from db
func (s *Session) QueryRows() (rows *sql.Rows, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// Raw appends sql and sqlVars
func (s *Session) Raw(sql string, values ...interface{}) *Session {
s.sql.WriteString(sql)
s.sql.WriteString(" ")
s.sqlVars = append(s.sqlVars, values...)
return s
}
================================================
FILE: gee-orm/day3-save-query/session/raw_test.go
================================================
package session
import (
"database/sql"
"os"
"testing"
"geeorm/dialect"
_ "github.com/mattn/go-sqlite3"
)
var (
TestDB *sql.DB
TestDial, _ = dialect.GetDialect("sqlite3")
)
func TestMain(m *testing.M) {
TestDB, _ = sql.Open("sqlite3", "../gee.db")
code := m.Run()
_ = TestDB.Close()
os.Exit(code)
}
func NewSession() *Session {
return New(TestDB, TestDial)
}
func TestSession_Exec(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
if count, err := result.RowsAffected(); err != nil || count != 2 {
t.Fatal("expect 2, but got", count)
}
}
func TestSession_QueryRows(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
row := s.Raw("SELECT count(*) FROM User").QueryRow()
var count int
if err := row.Scan(&count); err != nil || count != 0 {
t.Fatal("failed to query db", err)
}
}
================================================
FILE: gee-orm/day3-save-query/session/record.go
================================================
package session
import (
"geeorm/clause"
"reflect"
)
// Insert one or more records in database
func (s *Session) Insert(values ...interface{}) (int64, error) {
recordValues := make([]interface{}, 0)
for _, value := range values {
table := s.Model(value).RefTable()
s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
recordValues = append(recordValues, table.RecordValues(value))
}
s.clause.Set(clause.VALUES, recordValues...)
sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// Find gets all eligible records
func (s *Session) Find(values interface{}) error {
destSlice := reflect.Indirect(reflect.ValueOf(values))
destType := destSlice.Type().Elem()
table := s.Model(reflect.New(destType).Elem().Interface()).RefTable()
s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
rows, err := s.Raw(sql, vars...).QueryRows()
if err != nil {
return err
}
for rows.Next() {
dest := reflect.New(destType).Elem()
var values []interface{}
for _, name := range table.FieldNames {
values = append(values, dest.FieldByName(name).Addr().Interface())
}
if err := rows.Scan(values...); err != nil {
return err
}
destSlice.Set(reflect.Append(destSlice, dest))
}
return rows.Close()
}
================================================
FILE: gee-orm/day3-save-query/session/record_test.go
================================================
package session
import "testing"
var (
user1 = &User{"Tom", 18}
user2 = &User{"Sam", 25}
user3 = &User{"Jack", 25}
)
func testRecordInit(t *testing.T) *Session {
t.Helper()
s := NewSession().Model(&User{})
err1 := s.DropTable()
err2 := s.CreateTable()
_, err3 := s.Insert(user1, user2)
if err1 != nil || err2 != nil || err3 != nil {
t.Fatal("failed init test records")
}
return s
}
func TestSession_Insert(t *testing.T) {
s := testRecordInit(t)
affected, err := s.Insert(user3)
if err != nil || affected != 1 {
t.Fatal("failed to create record")
}
}
func TestSession_Find(t *testing.T) {
s := testRecordInit(t)
var users []User
if err := s.Find(&users); err != nil || len(users) != 2 {
t.Fatal("failed to query all")
}
}
================================================
FILE: gee-orm/day3-save-query/session/table.go
================================================
package session
import (
"fmt"
"geeorm/log"
"reflect"
"strings"
"geeorm/schema"
)
// Model assigns refTable
func (s *Session) Model(value interface{}) *Session {
// nil or different model, update refTable
if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) {
s.refTable = schema.Parse(value, s.dialect)
}
return s
}
// RefTable returns a Schema instance that contains all parsed fields
func (s *Session) RefTable() *schema.Schema {
if s.refTable == nil {
log.Error("Model is not set")
}
return s.refTable
}
// CreateTable create a table in database with a model
func (s *Session) CreateTable() error {
table := s.RefTable()
var columns []string
for _, field := range table.Fields {
columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag))
}
desc := strings.Join(columns, ",")
_, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec()
return err
}
// DropTable drops a table with the name of model
func (s *Session) DropTable() error {
_, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec()
return err
}
// HasTable returns true of the table exists
func (s *Session) HasTable() bool {
sql, values := s.dialect.TableExistSQL(s.RefTable().Name)
row := s.Raw(sql, values...).QueryRow()
var tmp string
_ = row.Scan(&tmp)
return tmp == s.RefTable().Name
}
================================================
FILE: gee-orm/day3-save-query/session/table_test.go
================================================
package session
import (
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestSession_CreateTable(t *testing.T) {
s := NewSession().Model(&User{})
_ = s.DropTable()
_ = s.CreateTable()
if !s.HasTable() {
t.Fatal("Failed to create table User")
}
}
func TestSession_Model(t *testing.T) {
s := NewSession().Model(&User{})
table := s.RefTable()
s.Model(&Session{})
if table.Name != "User" || s.RefTable().Name != "Session" {
t.Fatal("Failed to change model")
}
}
================================================
FILE: gee-orm/day4-chain-operation/clause/clause.go
================================================
package clause
import (
"strings"
)
// Clause contains SQL conditions
type Clause struct {
sql map[Type]string
sqlVars map[Type][]interface{}
}
// Type is the type of Clause
type Type int
// Support types for Clause
const (
INSERT Type = iota
VALUES
SELECT
LIMIT
WHERE
ORDERBY
UPDATE
DELETE
COUNT
)
// Set adds a sub clause of specific type
func (c *Clause) Set(name Type, vars ...interface{}) {
if c.sql == nil {
c.sql = make(map[Type]string)
c.sqlVars = make(map[Type][]interface{})
}
sql, vars := generators[name](vars...)
c.sql[name] = sql
c.sqlVars[name] = vars
}
// Build generate the final SQL and SQLVars
func (c *Clause) Build(orders ...Type) (string, []interface{}) {
var sqls []string
var vars []interface{}
for _, order := range orders {
if sql, ok := c.sql[order]; ok {
sqls = append(sqls, sql)
vars = append(vars, c.sqlVars[order]...)
}
}
return strings.Join(sqls, " "), vars
}
================================================
FILE: gee-orm/day4-chain-operation/clause/clause_test.go
================================================
package clause
import (
"reflect"
"testing"
)
func TestClause_Set(t *testing.T) {
var clause Clause
clause.Set(INSERT, "User", []string{"Name", "Age"})
sql := clause.sql[INSERT]
vars := clause.sqlVars[INSERT]
t.Log(sql, vars)
if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 {
t.Fatal("failed to get clause")
}
}
func testSelect(t *testing.T) {
var clause Clause
clause.Set(LIMIT, 3)
clause.Set(SELECT, "User", []string{"*"})
clause.Set(WHERE, "Name = ?", "Tom")
clause.Set(ORDERBY, "Age ASC")
sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT)
t.Log(sql, vars)
if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) {
t.Fatal("failed to build SQLVars")
}
}
func testUpdate(t *testing.T) {
var clause Clause
clause.Set(UPDATE, "User", map[string]interface{}{"Age": 30})
clause.Set(WHERE, "Name = ?", "Tom")
sql, vars := clause.Build(UPDATE, WHERE)
t.Log(sql, vars)
if sql != "UPDATE User SET Age = ? WHERE Name = ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{30, "Tom"}) {
t.Fatal("failed to build SQLVars")
}
}
func testDelete(t *testing.T) {
var clause Clause
clause.Set(DELETE, "User")
clause.Set(WHERE, "Name = ?", "Tom")
sql, vars := clause.Build(DELETE, WHERE)
t.Log(sql, vars)
if sql != "DELETE FROM User WHERE Name = ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom"}) {
t.Fatal("failed to build SQLVars")
}
}
func TestClause_Build(t *testing.T) {
t.Run("select", func(t *testing.T) {
testSelect(t)
})
t.Run("update", func(t *testing.T) {
testUpdate(t)
})
t.Run("delete", func(t *testing.T) {
testDelete(t)
})
}
================================================
FILE: gee-orm/day4-chain-operation/clause/generator.go
================================================
package clause
import (
"fmt"
"strings"
)
type generator func(values ...interface{}) (string, []interface{})
var generators map[Type]generator
func init() {
generators = make(map[Type]generator)
generators[INSERT] = _insert
generators[VALUES] = _values
generators[SELECT] = _select
generators[LIMIT] = _limit
generators[WHERE] = _where
generators[ORDERBY] = _orderBy
generators[UPDATE] = _update
generators[DELETE] = _delete
generators[COUNT] = _count
}
func genBindVars(num int) string {
var vars []string
for i := 0; i < num; i++ {
vars = append(vars, "?")
}
return strings.Join(vars, ", ")
}
func _insert(values ...interface{}) (string, []interface{}) {
// INSERT INTO $tableName ($fields)
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{}
}
func _values(values ...interface{}) (string, []interface{}) {
// VALUES ($v1), ($v2), ...
var bindStr string
var sql strings.Builder
var vars []interface{}
sql.WriteString("VALUES ")
for i, value := range values {
v := value.([]interface{})
if bindStr == "" {
bindStr = genBindVars(len(v))
}
sql.WriteString(fmt.Sprintf("(%v)", bindStr))
if i+1 != len(values) {
sql.WriteString(", ")
}
vars = append(vars, v...)
}
return sql.String(), vars
}
func _select(values ...interface{}) (string, []interface{}) {
// SELECT $fields FROM $tableName
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{}
}
func _limit(values ...interface{}) (string, []interface{}) {
// LIMIT $num
return "LIMIT ?", values
}
func _where(values ...interface{}) (string, []interface{}) {
// WHERE $desc
desc, vars := values[0], values[1:]
return fmt.Sprintf("WHERE %s", desc), vars
}
func _orderBy(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{}
}
func _update(values ...interface{}) (string, []interface{}) {
tableName := values[0]
m := values[1].(map[string]interface{})
var keys []string
var vars []interface{}
for k, v := range m {
keys = append(keys, k+" = ?")
vars = append(vars, v)
}
return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars
}
func _delete(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{}
}
func _count(values ...interface{}) (string, []interface{}) {
return _select(values[0], []string{"count(*)"})
}
================================================
FILE: gee-orm/day4-chain-operation/dialect/dialect.go
================================================
package dialect
import "reflect"
var dialectsMap = map[string]Dialect{}
// Dialect is an interface contains methods that a dialect has to implement
type Dialect interface {
DataTypeOf(typ reflect.Value) string
TableExistSQL(tableName string) (string, []interface{})
}
// RegisterDialect register a dialect to the global variable
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// Get the dialect from global variable if it exists
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
================================================
FILE: gee-orm/day4-chain-operation/dialect/sqlite3.go
================================================
package dialect
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct{}
var _ Dialect = (*sqlite3)(nil)
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
// Get Data Type for sqlite3 Dialect
func (s *sqlite3) DataTypeOf(typ reflect.Value) string {
switch typ.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
return "text"
case reflect.Array, reflect.Slice:
return "blob"
case reflect.Struct:
if _, ok := typ.Interface().(time.Time); ok {
return "datetime"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind()))
}
// TableExistSQL returns SQL that judge whether the table exists in database
func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}
================================================
FILE: gee-orm/day4-chain-operation/dialect/sqlite3_test.go
================================================
package dialect
import (
"reflect"
"testing"
)
func TestDataTypeOf(t *testing.T) {
dial := &sqlite3{}
cases := []struct {
Value interface{}
Type string
}{
{"Tom", "text"},
{123, "integer"},
{1.2, "real"},
{[]int{1, 2, 3}, "blob"},
}
for _, c := range cases {
if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type {
t.Fatalf("expect %s, but got %s", c.Type, typ)
}
}
}
================================================
FILE: gee-orm/day4-chain-operation/geeorm.go
================================================
package geeorm
import (
"database/sql"
"geeorm/dialect"
"geeorm/log"
"geeorm/session"
)
// Engine is the main struct of geeorm, manages all db sessions and transactions.
type Engine struct {
db *sql.DB
dialect dialect.Dialect
}
// NewEngine create a instance of Engine
// connect database and ping it to test whether it's alive
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
// make sure the specific dialect exists
dial, ok := dialect.GetDialect(driver)
if !ok {
log.Errorf("dialect %s Not Found", driver)
return
}
e = &Engine{db: db, dialect: dial}
log.Info("Connect database success")
return
}
// Close database connection
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
}
log.Info("Close database success")
}
// NewSession creates a new session for next operations
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db, engine.dialect)
}
================================================
FILE: gee-orm/day4-chain-operation/geeorm_test.go
================================================
package geeorm
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func OpenDB(t *testing.T) *Engine {
t.Helper()
engine, err := NewEngine("sqlite3", "gee.db")
if err != nil {
t.Fatal("failed to connect", err)
}
return engine
}
func TestNewEngine(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
}
================================================
FILE: gee-orm/day4-chain-operation/go.mod
================================================
module geeorm
go 1.13
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
================================================
FILE: gee-orm/day4-chain-operation/log/log.go
================================================
package log
import (
"io/ioutil"
"log"
"os"
"sync"
)
var (
errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile)
loggers = []*log.Logger{errorLog, infoLog}
mu sync.Mutex
)
// log methods
var (
Error = errorLog.Println
Errorf = errorLog.Printf
Info = infoLog.Println
Infof = infoLog.Printf
)
// log levels
const (
InfoLevel = iota
ErrorLevel
Disabled
)
// SetLevel controls log level
func SetLevel(level int) {
mu.Lock()
defer mu.Unlock()
for _, logger := range loggers {
logger.SetOutput(os.Stdout)
}
if ErrorLevel < level {
errorLog.SetOutput(ioutil.Discard)
}
if InfoLevel < level {
infoLog.SetOutput(ioutil.Discard)
}
}
================================================
FILE: gee-orm/day4-chain-operation/log/log_test.go
================================================
package log
import (
"os"
"testing"
)
func TestSetLevel(t *testing.T) {
SetLevel(ErrorLevel)
if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout {
t.Fatal("failed to set log level")
}
SetLevel(Disabled)
if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout {
t.Fatal("failed to set log level")
}
}
================================================
FILE: gee-orm/day4-chain-operation/schema/schema.go
================================================
package schema
import (
"geeorm/dialect"
"go/ast"
"reflect"
)
// Field represents a column of database
type Field struct {
Name string
Type string
Tag string
}
// Schema represents a table of database
type Schema struct {
Model interface{}
Name string
Fields []*Field
FieldNames []string
fieldMap map[string]*Field
}
// GetField returns field by name
func (schema *Schema) GetField(name string) *Field {
return schema.fieldMap[name]
}
// Values return the values of dest's member variables
func (schema *Schema) RecordValues(dest interface{}) []interface{} {
destValue := reflect.Indirect(reflect.ValueOf(dest))
var fieldValues []interface{}
for _, field := range schema.Fields {
fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface())
}
return fieldValues
}
type ITableName interface {
TableName() string
}
// Parse a struct to a Schema instance
func Parse(dest interface{}, d dialect.Dialect) *Schema {
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
var tableName string
t, ok := dest.(ITableName)
if !ok {
tableName = modelType.Name()
} else {
tableName = t.TableName()
}
schema := &Schema{
Model: dest,
Name: tableName,
fieldMap: make(map[string]*Field),
}
for i := 0; i < modelType.NumField(); i++ {
p := modelType.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) {
field := &Field{
Name: p.Name,
Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))),
}
if v, ok := p.Tag.Lookup("geeorm"); ok {
field.Tag = v
}
schema.Fields = append(schema.Fields, field)
schema.FieldNames = append(schema.FieldNames, p.Name)
schema.fieldMap[p.Name] = field
}
}
return schema
}
================================================
FILE: gee-orm/day4-chain-operation/schema/schema_test.go
================================================
package schema
import (
"geeorm/dialect"
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
var TestDial, _ = dialect.GetDialect("sqlite3")
func TestParse(t *testing.T) {
schema := Parse(&User{}, TestDial)
if schema.Name != "User" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
if schema.GetField("Name").Tag != "PRIMARY KEY" {
t.Fatal("failed to parse primary key")
}
}
func TestSchema_RecordValues(t *testing.T) {
schema := Parse(&User{}, TestDial)
values := schema.RecordValues(&User{"Tom", 18})
name := values[0].(string)
age := values[1].(int)
if name != "Tom" || age != 18 {
t.Fatal("failed to get values")
}
}
type UserTest struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func (u *UserTest) TableName() string {
return "ns_user_test"
}
func TestSchema_TableName(t *testing.T) {
schema := Parse(&UserTest{}, TestDial)
if schema.Name != "ns_user_test" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
}
================================================
FILE: gee-orm/day4-chain-operation/session/raw.go
================================================
package session
import (
"database/sql"
"geeorm/clause"
"geeorm/dialect"
"geeorm/log"
"geeorm/schema"
"strings"
)
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
db *sql.DB
dialect dialect.Dialect
refTable *schema.Schema
clause clause.Clause
sql strings.Builder
sqlVars []interface{}
}
// New creates a instance of Session
func New(db *sql.DB, dialect dialect.Dialect) *Session {
return &Session{
db: db,
dialect: dialect,
}
}
// Clear initialize the state of a session
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
s.clause = clause.Clause{}
}
// DB returns *sql.DB
func (s *Session) DB() *sql.DB {
return s.db
}
// Exec raw sql with sqlVars
func (s *Session) Exec() (result sql.Result, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// QueryRow gets a record from db
func (s *Session) QueryRow() *sql.Row {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
return s.DB().QueryRow(s.sql.String(), s.sqlVars...)
}
// QueryRows gets a list of records from db
func (s *Session) QueryRows() (rows *sql.Rows, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// Raw appends sql and sqlVars
func (s *Session) Raw(sql string, values ...interface{}) *Session {
s.sql.WriteString(sql)
s.sql.WriteString(" ")
s.sqlVars = append(s.sqlVars, values...)
return s
}
================================================
FILE: gee-orm/day4-chain-operation/session/raw_test.go
================================================
package session
import (
"database/sql"
"os"
"testing"
"geeorm/dialect"
_ "github.com/mattn/go-sqlite3"
)
var (
TestDB *sql.DB
TestDial, _ = dialect.GetDialect("sqlite3")
)
func TestMain(m *testing.M) {
TestDB, _ = sql.Open("sqlite3", "../gee.db")
code := m.Run()
_ = TestDB.Close()
os.Exit(code)
}
func NewSession() *Session {
return New(TestDB, TestDial)
}
func TestSession_Exec(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
if count, err := result.RowsAffected(); err != nil || count != 2 {
t.Fatal("expect 2, but got", count)
}
}
func TestSession_QueryRows(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
row := s.Raw("SELECT count(*) FROM User").QueryRow()
var count int
if err := row.Scan(&count); err != nil || count != 0 {
t.Fatal("failed to query db", err)
}
}
================================================
FILE: gee-orm/day4-chain-operation/session/record.go
================================================
package session
import (
"errors"
"geeorm/clause"
"reflect"
)
// Insert one or more records in database
func (s *Session) Insert(values ...interface{}) (int64, error) {
recordValues := make([]interface{}, 0)
for _, value := range values {
table := s.Model(value).RefTable()
s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
recordValues = append(recordValues, table.RecordValues(value))
}
s.clause.Set(clause.VALUES, recordValues...)
sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// Find gets all eligible records
func (s *Session) Find(values interface{}) error {
destSlice := reflect.Indirect(reflect.ValueOf(values))
destType := destSlice.Type().Elem()
table := s.Model(reflect.New(destType).Elem().Interface()).RefTable()
s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
rows, err := s.Raw(sql, vars...).QueryRows()
if err != nil {
return err
}
for rows.Next() {
dest := reflect.New(destType).Elem()
var values []interface{}
for _, name := range table.FieldNames {
values = append(values, dest.FieldByName(name).Addr().Interface())
}
if err := rows.Scan(values...); err != nil {
return err
}
destSlice.Set(reflect.Append(destSlice, dest))
}
return rows.Close()
}
// First gets the 1st row
func (s *Session) First(value interface{}) error {
dest := reflect.Indirect(reflect.ValueOf(value))
destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem()
if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil {
return err
}
if destSlice.Len() == 0 {
return errors.New("NOT FOUND")
}
dest.Set(destSlice.Index(0))
return nil
}
// Limit adds limit condition to clause
func (s *Session) Limit(num int) *Session {
s.clause.Set(clause.LIMIT, num)
return s
}
// Where adds limit condition to clause
func (s *Session) Where(desc string, args ...interface{}) *Session {
var vars []interface{}
s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...)
return s
}
// OrderBy adds order by condition to clause
func (s *Session) OrderBy(desc string) *Session {
s.clause.Set(clause.ORDERBY, desc)
return s
}
// Update records with where clause
// support map[string]interface{}
// also support kv list: "Name", "Tom", "Age", 18, ....
func (s *Session) Update(kv ...interface{}) (int64, error) {
m, ok := kv[0].(map[string]interface{})
if !ok {
m = make(map[string]interface{})
for i := 0; i < len(kv); i += 2 {
m[kv[i].(string)] = kv[i+1]
}
}
s.clause.Set(clause.UPDATE, s.RefTable().Name, m)
sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// Delete records with where clause
func (s *Session) Delete() (int64, error) {
s.clause.Set(clause.DELETE, s.RefTable().Name)
sql, vars := s.clause.Build(clause.DELETE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// Count records with where clause
func (s *Session) Count() (int64, error) {
s.clause.Set(clause.COUNT, s.RefTable().Name)
sql, vars := s.clause.Build(clause.COUNT, clause.WHERE)
row := s.Raw(sql, vars...).QueryRow()
var tmp int64
if err := row.Scan(&tmp); err != nil {
return 0, err
}
return tmp, nil
}
================================================
FILE: gee-orm/day4-chain-operation/session/record_test.go
================================================
package session
import "testing"
var (
user1 = &User{"Tom", 18}
user2 = &User{"Sam", 25}
user3 = &User{"Jack", 25}
)
func testRecordInit(t *testing.T) *Session {
t.Helper()
s := NewSession().Model(&User{})
err1 := s.DropTable()
err2 := s.CreateTable()
_, err3 := s.Insert(user1, user2)
if err1 != nil || err2 != nil || err3 != nil {
t.Fatal("failed init test records")
}
return s
}
func TestSession_Insert(t *testing.T) {
s := testRecordInit(t)
affected, err := s.Insert(user3)
if err != nil || affected != 1 {
t.Fatal("failed to create record")
}
}
func TestSession_Find(t *testing.T) {
s := testRecordInit(t)
var users []User
if err := s.Find(&users); err != nil || len(users) != 2 {
t.Fatal("failed to query all")
}
}
func TestSession_First(t *testing.T) {
s := testRecordInit(t)
u := &User{}
err := s.First(u)
if err != nil || u.Name != "Tom" || u.Age != 18 {
t.Fatal("failed to query first")
}
}
func TestSession_Limit(t *testing.T) {
s := testRecordInit(t)
var users []User
err := s.Limit(1).Find(&users)
if err != nil || len(users) != 1 {
t.Fatal("failed to query with limit condition")
}
}
func TestSession_Where(t *testing.T) {
s := testRecordInit(t)
var users []User
_, err1 := s.Insert(user3)
err2 := s.Where("Age = ?", 25).Find(&users)
if err1 != nil || err2 != nil || len(users) != 2 {
t.Fatal("failed to query with where condition")
}
}
func TestSession_OrderBy(t *testing.T) {
s := testRecordInit(t)
u := &User{}
err := s.OrderBy("Age DESC").First(u)
if err != nil || u.Age != 25 {
t.Fatal("failed to query with order by condition")
}
}
func TestSession_Update(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30)
u := &User{}
_ = s.OrderBy("Age DESC").First(u)
if affected != 1 || u.Age != 30 {
t.Fatal("failed to update")
}
}
func TestSession_DeleteAndCount(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Delete()
count, _ := s.Count()
if affected != 1 || count != 1 {
t.Fatal("failed to delete or count")
}
}
================================================
FILE: gee-orm/day4-chain-operation/session/table.go
================================================
package session
import (
"fmt"
"geeorm/log"
"reflect"
"strings"
"geeorm/schema"
)
// Model assigns refTable
func (s *Session) Model(value interface{}) *Session {
// nil or different model, update refTable
if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) {
s.refTable = schema.Parse(value, s.dialect)
}
return s
}
// RefTable returns a Schema instance that contains all parsed fields
func (s *Session) RefTable() *schema.Schema {
if s.refTable == nil {
log.Error("Model is not set")
}
return s.refTable
}
// CreateTable create a table in database with a model
func (s *Session) CreateTable() error {
table := s.RefTable()
var columns []string
for _, field := range table.Fields {
columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag))
}
desc := strings.Join(columns, ",")
_, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec()
return err
}
// DropTable drops a table with the name of model
func (s *Session) DropTable() error {
_, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec()
return err
}
// HasTable returns true of the table exists
func (s *Session) HasTable() bool {
sql, values := s.dialect.TableExistSQL(s.RefTable().Name)
row := s.Raw(sql, values...).QueryRow()
var tmp string
_ = row.Scan(&tmp)
return tmp == s.RefTable().Name
}
================================================
FILE: gee-orm/day4-chain-operation/session/table_test.go
================================================
package session
import (
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestSession_CreateTable(t *testing.T) {
s := NewSession().Model(&User{})
_ = s.DropTable()
_ = s.CreateTable()
if !s.HasTable() {
t.Fatal("Failed to create table User")
}
}
func TestSession_Model(t *testing.T) {
s := NewSession().Model(&User{})
table := s.RefTable()
s.Model(&Session{})
if table.Name != "User" || s.RefTable().Name != "Session" {
t.Fatal("Failed to change model")
}
}
================================================
FILE: gee-orm/day5-hooks/clause/clause.go
================================================
package clause
import (
"strings"
)
// Clause contains SQL conditions
type Clause struct {
sql map[Type]string
sqlVars map[Type][]interface{}
}
// Type is the type of Clause
type Type int
// Support types for Clause
const (
INSERT Type = iota
VALUES
SELECT
LIMIT
WHERE
ORDERBY
UPDATE
DELETE
COUNT
)
// Set adds a sub clause of specific type
func (c *Clause) Set(name Type, vars ...interface{}) {
if c.sql == nil {
c.sql = make(map[Type]string)
c.sqlVars = make(map[Type][]interface{})
}
sql, vars := generators[name](vars...)
c.sql[name] = sql
c.sqlVars[name] = vars
}
// Build generate the final SQL and SQLVars
func (c *Clause) Build(orders ...Type) (string, []interface{}) {
var sqls []string
var vars []interface{}
for _, order := range orders {
if sql, ok := c.sql[order]; ok {
sqls = append(sqls, sql)
vars = append(vars, c.sqlVars[order]...)
}
}
return strings.Join(sqls, " "), vars
}
================================================
FILE: gee-orm/day5-hooks/clause/clause_test.go
================================================
package clause
import (
"reflect"
"testing"
)
func TestClause_Set(t *testing.T) {
var clause Clause
clause.Set(INSERT, "User", []string{"Name", "Age"})
sql := clause.sql[INSERT]
vars := clause.sqlVars[INSERT]
t.Log(sql, vars)
if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 {
t.Fatal("failed to get clause")
}
}
func testSelect(t *testing.T) {
var clause Clause
clause.Set(LIMIT, 3)
clause.Set(SELECT, "User", []string{"*"})
clause.Set(WHERE, "Name = ?", "Tom")
clause.Set(ORDERBY, "Age ASC")
sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT)
t.Log(sql, vars)
if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) {
t.Fatal("failed to build SQLVars")
}
}
func testUpdate(t *testing.T) {
var clause Clause
clause.Set(UPDATE, "User", map[string]interface{}{"Age": 30})
clause.Set(WHERE, "Name = ?", "Tom")
sql, vars := clause.Build(UPDATE, WHERE)
t.Log(sql, vars)
if sql != "UPDATE User SET Age = ? WHERE Name = ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{30, "Tom"}) {
t.Fatal("failed to build SQLVars")
}
}
func testDelete(t *testing.T) {
var clause Clause
clause.Set(DELETE, "User")
clause.Set(WHERE, "Name = ?", "Tom")
sql, vars := clause.Build(DELETE, WHERE)
t.Log(sql, vars)
if sql != "DELETE FROM User WHERE Name = ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom"}) {
t.Fatal("failed to build SQLVars")
}
}
func TestClause_Build(t *testing.T) {
t.Run("select", func(t *testing.T) {
testSelect(t)
})
t.Run("update", func(t *testing.T) {
testUpdate(t)
})
t.Run("delete", func(t *testing.T) {
testDelete(t)
})
}
================================================
FILE: gee-orm/day5-hooks/clause/generator.go
================================================
package clause
import (
"fmt"
"strings"
)
type generator func(values ...interface{}) (string, []interface{})
var generators map[Type]generator
func init() {
generators = make(map[Type]generator)
generators[INSERT] = _insert
generators[VALUES] = _values
generators[SELECT] = _select
generators[LIMIT] = _limit
generators[WHERE] = _where
generators[ORDERBY] = _orderBy
generators[UPDATE] = _update
generators[DELETE] = _delete
generators[COUNT] = _count
}
func genBindVars(num int) string {
var vars []string
for i := 0; i < num; i++ {
vars = append(vars, "?")
}
return strings.Join(vars, ", ")
}
func _insert(values ...interface{}) (string, []interface{}) {
// INSERT INTO $tableName ($fields)
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{}
}
func _values(values ...interface{}) (string, []interface{}) {
// VALUES ($v1), ($v2), ...
var bindStr string
var sql strings.Builder
var vars []interface{}
sql.WriteString("VALUES ")
for i, value := range values {
v := value.([]interface{})
if bindStr == "" {
bindStr = genBindVars(len(v))
}
sql.WriteString(fmt.Sprintf("(%v)", bindStr))
if i+1 != len(values) {
sql.WriteString(", ")
}
vars = append(vars, v...)
}
return sql.String(), vars
}
func _select(values ...interface{}) (string, []interface{}) {
// SELECT $fields FROM $tableName
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{}
}
func _limit(values ...interface{}) (string, []interface{}) {
// LIMIT $num
return "LIMIT ?", values
}
func _where(values ...interface{}) (string, []interface{}) {
// WHERE $desc
desc, vars := values[0], values[1:]
return fmt.Sprintf("WHERE %s", desc), vars
}
func _orderBy(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{}
}
func _update(values ...interface{}) (string, []interface{}) {
tableName := values[0]
m := values[1].(map[string]interface{})
var keys []string
var vars []interface{}
for k, v := range m {
keys = append(keys, k+" = ?")
vars = append(vars, v)
}
return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars
}
func _delete(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{}
}
func _count(values ...interface{}) (string, []interface{}) {
return _select(values[0], []string{"count(*)"})
}
================================================
FILE: gee-orm/day5-hooks/dialect/dialect.go
================================================
package dialect
import "reflect"
var dialectsMap = map[string]Dialect{}
// Dialect is an interface contains methods that a dialect has to implement
type Dialect interface {
DataTypeOf(typ reflect.Value) string
TableExistSQL(tableName string) (string, []interface{})
}
// RegisterDialect register a dialect to the global variable
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// Get the dialect from global variable if it exists
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
================================================
FILE: gee-orm/day5-hooks/dialect/sqlite3.go
================================================
package dialect
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct{}
var _ Dialect = (*sqlite3)(nil)
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
// Get Data Type for sqlite3 Dialect
func (s *sqlite3) DataTypeOf(typ reflect.Value) string {
switch typ.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
return "text"
case reflect.Array, reflect.Slice:
return "blob"
case reflect.Struct:
if _, ok := typ.Interface().(time.Time); ok {
return "datetime"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind()))
}
// TableExistSQL returns SQL that judge whether the table exists in database
func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}
================================================
FILE: gee-orm/day5-hooks/dialect/sqlite3_test.go
================================================
package dialect
import (
"reflect"
"testing"
)
func TestDataTypeOf(t *testing.T) {
dial := &sqlite3{}
cases := []struct {
Value interface{}
Type string
}{
{"Tom", "text"},
{123, "integer"},
{1.2, "real"},
{[]int{1, 2, 3}, "blob"},
}
for _, c := range cases {
if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type {
t.Fatalf("expect %s, but got %s", c.Type, typ)
}
}
}
================================================
FILE: gee-orm/day5-hooks/geeorm.go
================================================
package geeorm
import (
"database/sql"
"geeorm/dialect"
"geeorm/log"
"geeorm/session"
)
// Engine is the main struct of geeorm, manages all db sessions and transactions.
type Engine struct {
db *sql.DB
dialect dialect.Dialect
}
// NewEngine create a instance of Engine
// connect database and ping it to test whether it's alive
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
// make sure the specific dialect exists
dial, ok := dialect.GetDialect(driver)
if !ok {
log.Errorf("dialect %s Not Found", driver)
return
}
e = &Engine{db: db, dialect: dial}
log.Info("Connect database success")
return
}
// Close database connection
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
}
log.Info("Close database success")
}
// NewSession creates a new session for next operations
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db, engine.dialect)
}
================================================
FILE: gee-orm/day5-hooks/geeorm_test.go
================================================
package geeorm
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func OpenDB(t *testing.T) *Engine {
t.Helper()
engine, err := NewEngine("sqlite3", "gee.db")
if err != nil {
t.Fatal("failed to connect", err)
}
return engine
}
func TestNewEngine(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
}
================================================
FILE: gee-orm/day5-hooks/go.mod
================================================
module geeorm
go 1.13
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
================================================
FILE: gee-orm/day5-hooks/log/log.go
================================================
package log
import (
"io/ioutil"
"log"
"os"
"sync"
)
var (
errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile)
loggers = []*log.Logger{errorLog, infoLog}
mu sync.Mutex
)
// log methods
var (
Error = errorLog.Println
Errorf = errorLog.Printf
Info = infoLog.Println
Infof = infoLog.Printf
)
// log levels
const (
InfoLevel = iota
ErrorLevel
Disabled
)
// SetLevel controls log level
func SetLevel(level int) {
mu.Lock()
defer mu.Unlock()
for _, logger := range loggers {
logger.SetOutput(os.Stdout)
}
if ErrorLevel < level {
errorLog.SetOutput(ioutil.Discard)
}
if InfoLevel < level {
infoLog.SetOutput(ioutil.Discard)
}
}
================================================
FILE: gee-orm/day5-hooks/log/log_test.go
================================================
package log
import (
"os"
"testing"
)
func TestSetLevel(t *testing.T) {
SetLevel(ErrorLevel)
if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout {
t.Fatal("failed to set log level")
}
SetLevel(Disabled)
if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout {
t.Fatal("failed to set log level")
}
}
================================================
FILE: gee-orm/day5-hooks/schema/schema.go
================================================
package schema
import (
"geeorm/dialect"
"go/ast"
"reflect"
)
// Field represents a column of database
type Field struct {
Name string
Type string
Tag string
}
// Schema represents a table of database
type Schema struct {
Model interface{}
Name string
Fields []*Field
FieldNames []string
fieldMap map[string]*Field
}
// GetField returns field by name
func (schema *Schema) GetField(name string) *Field {
return schema.fieldMap[name]
}
// Values return the values of dest's member variables
func (schema *Schema) RecordValues(dest interface{}) []interface{} {
destValue := reflect.Indirect(reflect.ValueOf(dest))
var fieldValues []interface{}
for _, field := range schema.Fields {
fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface())
}
return fieldValues
}
type ITableName interface {
TableName() string
}
// Parse a struct to a Schema instance
func Parse(dest interface{}, d dialect.Dialect) *Schema {
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
var tableName string
t, ok := dest.(ITableName)
if !ok {
tableName = modelType.Name()
} else {
tableName = t.TableName()
}
schema := &Schema{
Model: dest,
Name: tableName,
fieldMap: make(map[string]*Field),
}
for i := 0; i < modelType.NumField(); i++ {
p := modelType.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) {
field := &Field{
Name: p.Name,
Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))),
}
if v, ok := p.Tag.Lookup("geeorm"); ok {
field.Tag = v
}
schema.Fields = append(schema.Fields, field)
schema.FieldNames = append(schema.FieldNames, p.Name)
schema.fieldMap[p.Name] = field
}
}
return schema
}
================================================
FILE: gee-orm/day5-hooks/schema/schema_test.go
================================================
package schema
import (
"geeorm/dialect"
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
var TestDial, _ = dialect.GetDialect("sqlite3")
func TestParse(t *testing.T) {
schema := Parse(&User{}, TestDial)
if schema.Name != "User" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
if schema.GetField("Name").Tag != "PRIMARY KEY" {
t.Fatal("failed to parse primary key")
}
}
func TestSchema_RecordValues(t *testing.T) {
schema := Parse(&User{}, TestDial)
values := schema.RecordValues(&User{"Tom", 18})
name := values[0].(string)
age := values[1].(int)
if name != "Tom" || age != 18 {
t.Fatal("failed to get values")
}
}
type UserTest struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func (u *UserTest) TableName() string {
return "ns_user_test"
}
func TestSchema_TableName(t *testing.T) {
schema := Parse(&UserTest{}, TestDial)
if schema.Name != "ns_user_test" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
}
================================================
FILE: gee-orm/day5-hooks/session/hooks.go
================================================
package session
import (
"geeorm/log"
"reflect"
)
// Hooks constants
const (
BeforeQuery = "BeforeQuery"
AfterQuery = "AfterQuery"
BeforeUpdate = "BeforeUpdate"
AfterUpdate = "AfterUpdate"
BeforeDelete = "BeforeDelete"
AfterDelete = "AfterDelete"
BeforeInsert = "BeforeInsert"
AfterInsert = "AfterInsert"
)
// CallMethod calls the registered hooks
func (s *Session) CallMethod(method string, value interface{}) {
fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method)
if value != nil {
fm = reflect.ValueOf(value).MethodByName(method)
}
param := []reflect.Value{reflect.ValueOf(s)}
if fm.IsValid() {
if v := fm.Call(param); len(v) > 0 {
if err, ok := v[0].Interface().(error); ok {
log.Error(err)
}
}
}
return
}
================================================
FILE: gee-orm/day5-hooks/session/hooks_test.go
================================================
package session
import (
"geeorm/log"
"testing"
)
type Account struct {
ID int `geeorm:"PRIMARY KEY"`
Password string
}
func (account *Account) BeforeInsert(s *Session) error {
log.Info("before inert", account)
account.ID += 1000
return nil
}
func (account *Account) AfterQuery(s *Session) error {
log.Info("after query", account)
account.Password = "******"
return nil
}
func TestSession_CallMethod(t *testing.T) {
s := NewSession().Model(&Account{})
_ = s.DropTable()
_ = s.CreateTable()
_, _ = s.Insert(&Account{1, "123456"}, &Account{2, "qwerty"})
u := &Account{}
err := s.First(u)
if err != nil || u.ID != 1001 || u.Password != "******" {
t.Fatal("Failed to call hooks after query, got", u)
}
}
================================================
FILE: gee-orm/day5-hooks/session/raw.go
================================================
package session
import (
"database/sql"
"geeorm/clause"
"geeorm/dialect"
"geeorm/log"
"geeorm/schema"
"strings"
)
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
db *sql.DB
dialect dialect.Dialect
refTable *schema.Schema
clause clause.Clause
sql strings.Builder
sqlVars []interface{}
}
// New creates a instance of Session
func New(db *sql.DB, dialect dialect.Dialect) *Session {
return &Session{
db: db,
dialect: dialect,
}
}
// Clear initialize the state of a session
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
s.clause = clause.Clause{}
}
// DB returns *sql.DB
func (s *Session) DB() *sql.DB {
return s.db
}
// Exec raw sql with sqlVars
func (s *Session) Exec() (result sql.Result, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// QueryRow gets a record from db
func (s *Session) QueryRow() *sql.Row {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
return s.DB().QueryRow(s.sql.String(), s.sqlVars...)
}
// QueryRows gets a list of records from db
func (s *Session) QueryRows() (rows *sql.Rows, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// Raw appends sql and sqlVars
func (s *Session) Raw(sql string, values ...interface{}) *Session {
s.sql.WriteString(sql)
s.sql.WriteString(" ")
s.sqlVars = append(s.sqlVars, values...)
return s
}
================================================
FILE: gee-orm/day5-hooks/session/raw_test.go
================================================
package session
import (
"database/sql"
"os"
"testing"
"geeorm/dialect"
_ "github.com/mattn/go-sqlite3"
)
var (
TestDB *sql.DB
TestDial, _ = dialect.GetDialect("sqlite3")
)
func TestMain(m *testing.M) {
TestDB, _ = sql.Open("sqlite3", "../gee.db")
code := m.Run()
_ = TestDB.Close()
os.Exit(code)
}
func NewSession() *Session {
return New(TestDB, TestDial)
}
func TestSession_Exec(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
if count, err := result.RowsAffected(); err != nil || count != 2 {
t.Fatal("expect 2, but got", count)
}
}
func TestSession_QueryRows(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
row := s.Raw("SELECT count(*) FROM User").QueryRow()
var count int
if err := row.Scan(&count); err != nil || count != 0 {
t.Fatal("failed to query db", err)
}
}
================================================
FILE: gee-orm/day5-hooks/session/record.go
================================================
package session
import (
"errors"
"geeorm/clause"
"reflect"
)
// Insert one or more records in database
func (s *Session) Insert(values ...interface{}) (int64, error) {
recordValues := make([]interface{}, 0)
for _, value := range values {
s.CallMethod(BeforeInsert, value)
table := s.Model(value).RefTable()
s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
recordValues = append(recordValues, table.RecordValues(value))
}
s.clause.Set(clause.VALUES, recordValues...)
sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterInsert, nil)
return result.RowsAffected()
}
// Find gets all eligible records
func (s *Session) Find(values interface{}) error {
s.CallMethod(BeforeQuery, nil)
destSlice := reflect.Indirect(reflect.ValueOf(values))
destType := destSlice.Type().Elem()
table := s.Model(reflect.New(destType).Elem().Interface()).RefTable()
s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
rows, err := s.Raw(sql, vars...).QueryRows()
if err != nil {
return err
}
for rows.Next() {
dest := reflect.New(destType).Elem()
var values []interface{}
for _, name := range table.FieldNames {
values = append(values, dest.FieldByName(name).Addr().Interface())
}
if err := rows.Scan(values...); err != nil {
return err
}
s.CallMethod(AfterQuery, dest.Addr().Interface())
destSlice.Set(reflect.Append(destSlice, dest))
}
return rows.Close()
}
// First gets the 1st row
func (s *Session) First(value interface{}) error {
dest := reflect.Indirect(reflect.ValueOf(value))
destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem()
if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil {
return err
}
if destSlice.Len() == 0 {
return errors.New("NOT FOUND")
}
dest.Set(destSlice.Index(0))
return nil
}
// Limit adds limit condition to clause
func (s *Session) Limit(num int) *Session {
s.clause.Set(clause.LIMIT, num)
return s
}
// Where adds limit condition to clause
func (s *Session) Where(desc string, args ...interface{}) *Session {
var vars []interface{}
s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...)
return s
}
// OrderBy adds order by condition to clause
func (s *Session) OrderBy(desc string) *Session {
s.clause.Set(clause.ORDERBY, desc)
return s
}
// Update records with where clause
// support map[string]interface{}
// also support kv list: "Name", "Tom", "Age", 18, ....
func (s *Session) Update(kv ...interface{}) (int64, error) {
s.CallMethod(BeforeUpdate, nil)
m, ok := kv[0].(map[string]interface{})
if !ok {
m = make(map[string]interface{})
for i := 0; i < len(kv); i += 2 {
m[kv[i].(string)] = kv[i+1]
}
}
s.clause.Set(clause.UPDATE, s.RefTable().Name, m)
sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterUpdate, nil)
return result.RowsAffected()
}
// Delete records with where clause
func (s *Session) Delete() (int64, error) {
s.CallMethod(BeforeDelete, nil)
s.clause.Set(clause.DELETE, s.RefTable().Name)
sql, vars := s.clause.Build(clause.DELETE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterDelete, nil)
return result.RowsAffected()
}
// Count records with where clause
func (s *Session) Count() (int64, error) {
s.clause.Set(clause.COUNT, s.RefTable().Name)
sql, vars := s.clause.Build(clause.COUNT, clause.WHERE)
row := s.Raw(sql, vars...).QueryRow()
var tmp int64
if err := row.Scan(&tmp); err != nil {
return 0, err
}
return tmp, nil
}
================================================
FILE: gee-orm/day5-hooks/session/record_test.go
================================================
package session
import "testing"
var (
user1 = &User{"Tom", 18}
user2 = &User{"Sam", 25}
user3 = &User{"Jack", 25}
)
func testRecordInit(t *testing.T) *Session {
t.Helper()
s := NewSession().Model(&User{})
err1 := s.DropTable()
err2 := s.CreateTable()
_, err3 := s.Insert(user1, user2)
if err1 != nil || err2 != nil || err3 != nil {
t.Fatal("failed init test records")
}
return s
}
func TestSession_Insert(t *testing.T) {
s := testRecordInit(t)
affected, err := s.Insert(user3)
if err != nil || affected != 1 {
t.Fatal("failed to create record")
}
}
func TestSession_Find(t *testing.T) {
s := testRecordInit(t)
var users []User
if err := s.Find(&users); err != nil || len(users) != 2 {
t.Fatal("failed to query all")
}
}
func TestSession_First(t *testing.T) {
s := testRecordInit(t)
u := &User{}
err := s.First(u)
if err != nil || u.Name != "Tom" || u.Age != 18 {
t.Fatal("failed to query first")
}
}
func TestSession_Limit(t *testing.T) {
s := testRecordInit(t)
var users []User
err := s.Limit(1).Find(&users)
if err != nil || len(users) != 1 {
t.Fatal("failed to query with limit condition")
}
}
func TestSession_Where(t *testing.T) {
s := testRecordInit(t)
var users []User
_, err1 := s.Insert(user3)
err2 := s.Where("Age = ?", 25).Find(&users)
if err1 != nil || err2 != nil || len(users) != 2 {
t.Fatal("failed to query with where condition")
}
}
func TestSession_OrderBy(t *testing.T) {
s := testRecordInit(t)
u := &User{}
err := s.OrderBy("Age DESC").First(u)
if err != nil || u.Age != 25 {
t.Fatal("failed to query with order by condition")
}
}
func TestSession_Update(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30)
u := &User{}
_ = s.OrderBy("Age DESC").First(u)
if affected != 1 || u.Age != 30 {
t.Fatal("failed to update")
}
}
func TestSession_DeleteAndCount(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Delete()
count, _ := s.Count()
if affected != 1 || count != 1 {
t.Fatal("failed to delete or count")
}
}
================================================
FILE: gee-orm/day5-hooks/session/table.go
================================================
package session
import (
"fmt"
"geeorm/log"
"reflect"
"strings"
"geeorm/schema"
)
// Model assigns refTable
func (s *Session) Model(value interface{}) *Session {
// nil or different model, update refTable
if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) {
s.refTable = schema.Parse(value, s.dialect)
}
return s
}
// RefTable returns a Schema instance that contains all parsed fields
func (s *Session) RefTable() *schema.Schema {
if s.refTable == nil {
log.Error("Model is not set")
}
return s.refTable
}
// CreateTable create a table in database with a model
func (s *Session) CreateTable() error {
table := s.RefTable()
var columns []string
for _, field := range table.Fields {
columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag))
}
desc := strings.Join(columns, ",")
_, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec()
return err
}
// DropTable drops a table with the name of model
func (s *Session) DropTable() error {
_, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec()
return err
}
// HasTable returns true of the table exists
func (s *Session) HasTable() bool {
sql, values := s.dialect.TableExistSQL(s.RefTable().Name)
row := s.Raw(sql, values...).QueryRow()
var tmp string
_ = row.Scan(&tmp)
return tmp == s.RefTable().Name
}
================================================
FILE: gee-orm/day5-hooks/session/table_test.go
================================================
package session
import (
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestSession_CreateTable(t *testing.T) {
s := NewSession().Model(&User{})
_ = s.DropTable()
_ = s.CreateTable()
if !s.HasTable() {
t.Fatal("Failed to create table User")
}
}
func TestSession_Model(t *testing.T) {
s := NewSession().Model(&User{})
table := s.RefTable()
s.Model(&Session{})
if table.Name != "User" || s.RefTable().Name != "Session" {
t.Fatal("Failed to change model")
}
}
================================================
FILE: gee-orm/day6-transaction/clause/clause.go
================================================
package clause
import (
"strings"
)
// Clause contains SQL conditions
type Clause struct {
sql map[Type]string
sqlVars map[Type][]interface{}
}
// Type is the type of Clause
type Type int
// Support types for Clause
const (
INSERT Type = iota
VALUES
SELECT
LIMIT
WHERE
ORDERBY
UPDATE
DELETE
COUNT
)
// Set adds a sub clause of specific type
func (c *Clause) Set(name Type, vars ...interface{}) {
if c.sql == nil {
c.sql = make(map[Type]string)
c.sqlVars = make(map[Type][]interface{})
}
sql, vars := generators[name](vars...)
c.sql[name] = sql
c.sqlVars[name] = vars
}
// Build generate the final SQL and SQLVars
func (c *Clause) Build(orders ...Type) (string, []interface{}) {
var sqls []string
var vars []interface{}
for _, order := range orders {
if sql, ok := c.sql[order]; ok {
sqls = append(sqls, sql)
vars = append(vars, c.sqlVars[order]...)
}
}
return strings.Join(sqls, " "), vars
}
================================================
FILE: gee-orm/day6-transaction/clause/clause_test.go
================================================
package clause
import (
"reflect"
"testing"
)
func TestClause_Set(t *testing.T) {
var clause Clause
clause.Set(INSERT, "User", []string{"Name", "Age"})
sql := clause.sql[INSERT]
vars := clause.sqlVars[INSERT]
t.Log(sql, vars)
if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 {
t.Fatal("failed to get clause")
}
}
func testSelect(t *testing.T) {
var clause Clause
clause.Set(LIMIT, 3)
clause.Set(SELECT, "User", []string{"*"})
clause.Set(WHERE, "Name = ?", "Tom")
clause.Set(ORDERBY, "Age ASC")
sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT)
t.Log(sql, vars)
if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) {
t.Fatal("failed to build SQLVars")
}
}
func testUpdate(t *testing.T) {
var clause Clause
clause.Set(UPDATE, "User", map[string]interface{}{"Age": 30})
clause.Set(WHERE, "Name = ?", "Tom")
sql, vars := clause.Build(UPDATE, WHERE)
t.Log(sql, vars)
if sql != "UPDATE User SET Age = ? WHERE Name = ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{30, "Tom"}) {
t.Fatal("failed to build SQLVars")
}
}
func testDelete(t *testing.T) {
var clause Clause
clause.Set(DELETE, "User")
clause.Set(WHERE, "Name = ?", "Tom")
sql, vars := clause.Build(DELETE, WHERE)
t.Log(sql, vars)
if sql != "DELETE FROM User WHERE Name = ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom"}) {
t.Fatal("failed to build SQLVars")
}
}
func TestClause_Build(t *testing.T) {
t.Run("select", func(t *testing.T) {
testSelect(t)
})
t.Run("update", func(t *testing.T) {
testUpdate(t)
})
t.Run("delete", func(t *testing.T) {
testDelete(t)
})
}
================================================
FILE: gee-orm/day6-transaction/clause/generator.go
================================================
package clause
import (
"fmt"
"strings"
)
type generator func(values ...interface{}) (string, []interface{})
var generators map[Type]generator
func init() {
generators = make(map[Type]generator)
generators[INSERT] = _insert
generators[VALUES] = _values
generators[SELECT] = _select
generators[LIMIT] = _limit
generators[WHERE] = _where
generators[ORDERBY] = _orderBy
generators[UPDATE] = _update
generators[DELETE] = _delete
generators[COUNT] = _count
}
func genBindVars(num int) string {
var vars []string
for i := 0; i < num; i++ {
vars = append(vars, "?")
}
return strings.Join(vars, ", ")
}
func _insert(values ...interface{}) (string, []interface{}) {
// INSERT INTO $tableName ($fields)
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{}
}
func _values(values ...interface{}) (string, []interface{}) {
// VALUES ($v1), ($v2), ...
var bindStr string
var sql strings.Builder
var vars []interface{}
sql.WriteString("VALUES ")
for i, value := range values {
v := value.([]interface{})
if bindStr == "" {
bindStr = genBindVars(len(v))
}
sql.WriteString(fmt.Sprintf("(%v)", bindStr))
if i+1 != len(values) {
sql.WriteString(", ")
}
vars = append(vars, v...)
}
return sql.String(), vars
}
func _select(values ...interface{}) (string, []interface{}) {
// SELECT $fields FROM $tableName
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{}
}
func _limit(values ...interface{}) (string, []interface{}) {
// LIMIT $num
return "LIMIT ?", values
}
func _where(values ...interface{}) (string, []interface{}) {
// WHERE $desc
desc, vars := values[0], values[1:]
return fmt.Sprintf("WHERE %s", desc), vars
}
func _orderBy(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{}
}
func _update(values ...interface{}) (string, []interface{}) {
tableName := values[0]
m := values[1].(map[string]interface{})
var keys []string
var vars []interface{}
for k, v := range m {
keys = append(keys, k+" = ?")
vars = append(vars, v)
}
return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars
}
func _delete(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{}
}
func _count(values ...interface{}) (string, []interface{}) {
return _select(values[0], []string{"count(*)"})
}
================================================
FILE: gee-orm/day6-transaction/dialect/dialect.go
================================================
package dialect
import "reflect"
var dialectsMap = map[string]Dialect{}
// Dialect is an interface contains methods that a dialect has to implement
type Dialect interface {
DataTypeOf(typ reflect.Value) string
TableExistSQL(tableName string) (string, []interface{})
}
// RegisterDialect register a dialect to the global variable
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// Get the dialect from global variable if it exists
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
================================================
FILE: gee-orm/day6-transaction/dialect/sqlite3.go
================================================
package dialect
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct{}
var _ Dialect = (*sqlite3)(nil)
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
// Get Data Type for sqlite3 Dialect
func (s *sqlite3) DataTypeOf(typ reflect.Value) string {
switch typ.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
return "text"
case reflect.Array, reflect.Slice:
return "blob"
case reflect.Struct:
if _, ok := typ.Interface().(time.Time); ok {
return "datetime"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind()))
}
// TableExistSQL returns SQL that judge whether the table exists in database
func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}
================================================
FILE: gee-orm/day6-transaction/dialect/sqlite3_test.go
================================================
package dialect
import (
"reflect"
"testing"
)
func TestDataTypeOf(t *testing.T) {
dial := &sqlite3{}
cases := []struct {
Value interface{}
Type string
}{
{"Tom", "text"},
{123, "integer"},
{1.2, "real"},
{[]int{1, 2, 3}, "blob"},
}
for _, c := range cases {
if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type {
t.Fatalf("expect %s, but got %s", c.Type, typ)
}
}
}
================================================
FILE: gee-orm/day6-transaction/geeorm.go
================================================
package geeorm
import (
"database/sql"
"geeorm/dialect"
"geeorm/log"
"geeorm/session"
)
// Engine is the main struct of geeorm, manages all db sessions and transactions.
type Engine struct {
db *sql.DB
dialect dialect.Dialect
}
// NewEngine create a instance of Engine
// connect database and ping it to test whether it's alive
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
// make sure the specific dialect exists
dial, ok := dialect.GetDialect(driver)
if !ok {
log.Errorf("dialect %s Not Found", driver)
return
}
e = &Engine{db: db, dialect: dial}
log.Info("Connect database success")
return
}
// Close database connection
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
}
log.Info("Close database success")
}
// NewSession creates a new session for next operations
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db, engine.dialect)
}
// TxFunc will be called between tx.Begin() and tx.Commit()
// https://stackoverflow.com/questions/16184238/database-sql-tx-detecting-commit-or-rollback
type TxFunc func(*session.Session) (interface{}, error)
// Transaction executes sql wrapped in a transaction, then automatically commit if no error occurs
func (engine *Engine) Transaction(f TxFunc) (result interface{}, err error) {
s := engine.NewSession()
if err := s.Begin(); err != nil {
return nil, err
}
defer func() {
if p := recover(); p != nil {
_ = s.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
_ = s.Rollback() // err is non-nil; don't change it
} else {
err = s.Commit() // err is nil; if Commit returns error update err
}
}()
return f(s)
}
================================================
FILE: gee-orm/day6-transaction/geeorm_test.go
================================================
package geeorm
import (
"errors"
"geeorm/session"
"testing"
_ "github.com/mattn/go-sqlite3"
)
func OpenDB(t *testing.T) *Engine {
t.Helper()
engine, err := NewEngine("sqlite3", "gee.db")
if err != nil {
t.Fatal("failed to connect", err)
}
return engine
}
func TestNewEngine(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
}
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func transactionRollback(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
s := engine.NewSession()
_ = s.Model(&User{}).DropTable()
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
_ = s.Model(&User{}).CreateTable()
_, err = s.Insert(&User{"Tom", 18})
return nil, errors.New("Error")
})
if err == nil || s.HasTable() {
t.Fatal("failed to rollback")
}
}
func transactionCommit(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
s := engine.NewSession()
_ = s.Model(&User{}).DropTable()
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
_ = s.Model(&User{}).CreateTable()
_, err = s.Insert(&User{"Tom", 18})
return
})
u := &User{}
_ = s.First(u)
if err != nil || u.Name != "Tom" {
t.Fatal("failed to commit")
}
}
func TestEngine_Transaction(t *testing.T) {
t.Run("rollback", func(t *testing.T) {
transactionRollback(t)
})
t.Run("commit", func(t *testing.T) {
transactionCommit(t)
})
}
================================================
FILE: gee-orm/day6-transaction/go.mod
================================================
module geeorm
go 1.13
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
================================================
FILE: gee-orm/day6-transaction/log/log.go
================================================
package log
import (
"io/ioutil"
"log"
"os"
"sync"
)
var (
errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile)
loggers = []*log.Logger{errorLog, infoLog}
mu sync.Mutex
)
// log methods
var (
Error = errorLog.Println
Errorf = errorLog.Printf
Info = infoLog.Println
Infof = infoLog.Printf
)
// log levels
const (
InfoLevel = iota
ErrorLevel
Disabled
)
// SetLevel controls log level
func SetLevel(level int) {
mu.Lock()
defer mu.Unlock()
for _, logger := range loggers {
logger.SetOutput(os.Stdout)
}
if ErrorLevel < level {
errorLog.SetOutput(ioutil.Discard)
}
if InfoLevel < level {
infoLog.SetOutput(ioutil.Discard)
}
}
================================================
FILE: gee-orm/day6-transaction/log/log_test.go
================================================
package log
import (
"os"
"testing"
)
func TestSetLevel(t *testing.T) {
SetLevel(ErrorLevel)
if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout {
t.Fatal("failed to set log level")
}
SetLevel(Disabled)
if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout {
t.Fatal("failed to set log level")
}
}
================================================
FILE: gee-orm/day6-transaction/schema/schema.go
================================================
package schema
import (
"geeorm/dialect"
"go/ast"
"reflect"
)
// Field represents a column of database
type Field struct {
Name string
Type string
Tag string
}
// Schema represents a table of database
type Schema struct {
Model interface{}
Name string
Fields []*Field
FieldNames []string
fieldMap map[string]*Field
}
// GetField returns field by name
func (schema *Schema) GetField(name string) *Field {
return schema.fieldMap[name]
}
// Values return the values of dest's member variables
func (schema *Schema) RecordValues(dest interface{}) []interface{} {
destValue := reflect.Indirect(reflect.ValueOf(dest))
var fieldValues []interface{}
for _, field := range schema.Fields {
fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface())
}
return fieldValues
}
type ITableName interface {
TableName() string
}
// Parse a struct to a Schema instance
func Parse(dest interface{}, d dialect.Dialect) *Schema {
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
var tableName string
t, ok := dest.(ITableName)
if !ok {
tableName = modelType.Name()
} else {
tableName = t.TableName()
}
schema := &Schema{
Model: dest,
Name: tableName,
fieldMap: make(map[string]*Field),
}
for i := 0; i < modelType.NumField(); i++ {
p := modelType.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) {
field := &Field{
Name: p.Name,
Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))),
}
if v, ok := p.Tag.Lookup("geeorm"); ok {
field.Tag = v
}
schema.Fields = append(schema.Fields, field)
schema.FieldNames = append(schema.FieldNames, p.Name)
schema.fieldMap[p.Name] = field
}
}
return schema
}
================================================
FILE: gee-orm/day6-transaction/schema/schema_test.go
================================================
package schema
import (
"geeorm/dialect"
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
var TestDial, _ = dialect.GetDialect("sqlite3")
func TestParse(t *testing.T) {
schema := Parse(&User{}, TestDial)
if schema.Name != "User" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
if schema.GetField("Name").Tag != "PRIMARY KEY" {
t.Fatal("failed to parse primary key")
}
}
func TestSchema_RecordValues(t *testing.T) {
schema := Parse(&User{}, TestDial)
values := schema.RecordValues(&User{"Tom", 18})
name := values[0].(string)
age := values[1].(int)
if name != "Tom" || age != 18 {
t.Fatal("failed to get values")
}
}
type UserTest struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func (u *UserTest) TableName() string {
return "ns_user_test"
}
func TestSchema_TableName(t *testing.T) {
schema := Parse(&UserTest{}, TestDial)
if schema.Name != "ns_user_test" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
}
================================================
FILE: gee-orm/day6-transaction/session/hooks.go
================================================
package session
import (
"geeorm/log"
"reflect"
)
// Hooks constants
const (
BeforeQuery = "BeforeQuery"
AfterQuery = "AfterQuery"
BeforeUpdate = "BeforeUpdate"
AfterUpdate = "AfterUpdate"
BeforeDelete = "BeforeDelete"
AfterDelete = "AfterDelete"
BeforeInsert = "BeforeInsert"
AfterInsert = "AfterInsert"
)
// CallMethod calls the registered hooks
func (s *Session) CallMethod(method string, value interface{}) {
fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method)
if value != nil {
fm = reflect.ValueOf(value).MethodByName(method)
}
param := []reflect.Value{reflect.ValueOf(s)}
if fm.IsValid() {
if v := fm.Call(param); len(v) > 0 {
if err, ok := v[0].Interface().(error); ok {
log.Error(err)
}
}
}
return
}
================================================
FILE: gee-orm/day6-transaction/session/hooks_test.go
================================================
package session
import (
"geeorm/log"
"testing"
)
type Account struct {
ID int `geeorm:"PRIMARY KEY"`
Password string
}
func (account *Account) BeforeInsert(s *Session) error {
log.Info("before inert", account)
account.ID += 1000
return nil
}
func (account *Account) AfterQuery(s *Session) error {
log.Info("after query", account)
account.Password = "******"
return nil
}
func TestSession_CallMethod(t *testing.T) {
s := NewSession().Model(&Account{})
_ = s.DropTable()
_ = s.CreateTable()
_, _ = s.Insert(&Account{1, "123456"}, &Account{2, "qwerty"})
u := &Account{}
err := s.First(u)
if err != nil || u.ID != 1001 || u.Password != "******" {
t.Fatal("Failed to call hooks after query, got", u)
}
}
================================================
FILE: gee-orm/day6-transaction/session/raw.go
================================================
package session
import (
"database/sql"
"geeorm/clause"
"geeorm/dialect"
"geeorm/log"
"geeorm/schema"
"strings"
)
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
db *sql.DB
dialect dialect.Dialect
tx *sql.Tx
refTable *schema.Schema
clause clause.Clause
sql strings.Builder
sqlVars []interface{}
}
// New creates a instance of Session
func New(db *sql.DB, dialect dialect.Dialect) *Session {
return &Session{
db: db,
dialect: dialect,
}
}
// Clear initialize the state of a session
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
s.clause = clause.Clause{}
}
// CommonDB is a minimal function set of db
type CommonDB interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
Exec(query string, args ...interface{}) (sql.Result, error)
}
var _ CommonDB = (*sql.DB)(nil)
var _ CommonDB = (*sql.Tx)(nil)
// DB returns tx if a tx begins. otherwise return *sql.DB
func (s *Session) DB() CommonDB {
if s.tx != nil {
return s.tx
}
return s.db
}
// Exec raw sql with sqlVars
func (s *Session) Exec() (result sql.Result, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// QueryRow gets a record from db
func (s *Session) QueryRow() *sql.Row {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
return s.DB().QueryRow(s.sql.String(), s.sqlVars...)
}
// QueryRows gets a list of records from db
func (s *Session) QueryRows() (rows *sql.Rows, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// Raw appends sql and sqlVars
func (s *Session) Raw(sql string, values ...interface{}) *Session {
s.sql.WriteString(sql)
s.sql.WriteString(" ")
s.sqlVars = append(s.sqlVars, values...)
return s
}
================================================
FILE: gee-orm/day6-transaction/session/raw_test.go
================================================
package session
import (
"database/sql"
"os"
"testing"
"geeorm/dialect"
_ "github.com/mattn/go-sqlite3"
)
var (
TestDB *sql.DB
TestDial, _ = dialect.GetDialect("sqlite3")
)
func TestMain(m *testing.M) {
TestDB, _ = sql.Open("sqlite3", "../gee.db")
code := m.Run()
_ = TestDB.Close()
os.Exit(code)
}
func NewSession() *Session {
return New(TestDB, TestDial)
}
func TestSession_Exec(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
if count, err := result.RowsAffected(); err != nil || count != 2 {
t.Fatal("expect 2, but got", count)
}
}
func TestSession_QueryRows(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
row := s.Raw("SELECT count(*) FROM User").QueryRow()
var count int
if err := row.Scan(&count); err != nil || count != 0 {
t.Fatal("failed to query db", err)
}
}
================================================
FILE: gee-orm/day6-transaction/session/record.go
================================================
package session
import (
"errors"
"geeorm/clause"
"reflect"
)
// Insert one or more records in database
func (s *Session) Insert(values ...interface{}) (int64, error) {
recordValues := make([]interface{}, 0)
for _, value := range values {
s.CallMethod(BeforeInsert, value)
table := s.Model(value).RefTable()
s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
recordValues = append(recordValues, table.RecordValues(value))
}
s.clause.Set(clause.VALUES, recordValues...)
sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterInsert, nil)
return result.RowsAffected()
}
// Find gets all eligible records
func (s *Session) Find(values interface{}) error {
s.CallMethod(BeforeQuery, nil)
destSlice := reflect.Indirect(reflect.ValueOf(values))
destType := destSlice.Type().Elem()
table := s.Model(reflect.New(destType).Elem().Interface()).RefTable()
s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
rows, err := s.Raw(sql, vars...).QueryRows()
if err != nil {
return err
}
for rows.Next() {
dest := reflect.New(destType).Elem()
var values []interface{}
for _, name := range table.FieldNames {
values = append(values, dest.FieldByName(name).Addr().Interface())
}
if err := rows.Scan(values...); err != nil {
return err
}
s.CallMethod(AfterQuery, dest.Addr().Interface())
destSlice.Set(reflect.Append(destSlice, dest))
}
return rows.Close()
}
// First gets the 1st row
func (s *Session) First(value interface{}) error {
dest := reflect.Indirect(reflect.ValueOf(value))
destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem()
if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil {
return err
}
if destSlice.Len() == 0 {
return errors.New("NOT FOUND")
}
dest.Set(destSlice.Index(0))
return nil
}
// Limit adds limit condition to clause
func (s *Session) Limit(num int) *Session {
s.clause.Set(clause.LIMIT, num)
return s
}
// Where adds limit condition to clause
func (s *Session) Where(desc string, args ...interface{}) *Session {
var vars []interface{}
s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...)
return s
}
// OrderBy adds order by condition to clause
func (s *Session) OrderBy(desc string) *Session {
s.clause.Set(clause.ORDERBY, desc)
return s
}
// Update records with where clause
// support map[string]interface{}
// also support kv list: "Name", "Tom", "Age", 18, ....
func (s *Session) Update(kv ...interface{}) (int64, error) {
s.CallMethod(BeforeUpdate, nil)
m, ok := kv[0].(map[string]interface{})
if !ok {
m = make(map[string]interface{})
for i := 0; i < len(kv); i += 2 {
m[kv[i].(string)] = kv[i+1]
}
}
s.clause.Set(clause.UPDATE, s.RefTable().Name, m)
sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterUpdate, nil)
return result.RowsAffected()
}
// Delete records with where clause
func (s *Session) Delete() (int64, error) {
s.CallMethod(BeforeDelete, nil)
s.clause.Set(clause.DELETE, s.RefTable().Name)
sql, vars := s.clause.Build(clause.DELETE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterDelete, nil)
return result.RowsAffected()
}
// Count records with where clause
func (s *Session) Count() (int64, error) {
s.clause.Set(clause.COUNT, s.RefTable().Name)
sql, vars := s.clause.Build(clause.COUNT, clause.WHERE)
row := s.Raw(sql, vars...).QueryRow()
var tmp int64
if err := row.Scan(&tmp); err != nil {
return 0, err
}
return tmp, nil
}
================================================
FILE: gee-orm/day6-transaction/session/record_test.go
================================================
package session
import "testing"
var (
user1 = &User{"Tom", 18}
user2 = &User{"Sam", 25}
user3 = &User{"Jack", 25}
)
func testRecordInit(t *testing.T) *Session {
t.Helper()
s := NewSession().Model(&User{})
err1 := s.DropTable()
err2 := s.CreateTable()
_, err3 := s.Insert(user1, user2)
if err1 != nil || err2 != nil || err3 != nil {
t.Fatal("failed init test records")
}
return s
}
func TestSession_Insert(t *testing.T) {
s := testRecordInit(t)
affected, err := s.Insert(user3)
if err != nil || affected != 1 {
t.Fatal("failed to create record")
}
}
func TestSession_Find(t *testing.T) {
s := testRecordInit(t)
var users []User
if err := s.Find(&users); err != nil || len(users) != 2 {
t.Fatal("failed to query all")
}
}
func TestSession_First(t *testing.T) {
s := testRecordInit(t)
u := &User{}
err := s.First(u)
if err != nil || u.Name != "Tom" || u.Age != 18 {
t.Fatal("failed to query first")
}
}
func TestSession_Limit(t *testing.T) {
s := testRecordInit(t)
var users []User
err := s.Limit(1).Find(&users)
if err != nil || len(users) != 1 {
t.Fatal("failed to query with limit condition")
}
}
func TestSession_Where(t *testing.T) {
s := testRecordInit(t)
var users []User
_, err1 := s.Insert(user3)
err2 := s.Where("Age = ?", 25).Find(&users)
if err1 != nil || err2 != nil || len(users) != 2 {
t.Fatal("failed to query with where condition")
}
}
func TestSession_OrderBy(t *testing.T) {
s := testRecordInit(t)
u := &User{}
err := s.OrderBy("Age DESC").First(u)
if err != nil || u.Age != 25 {
t.Fatal("failed to query with order by condition")
}
}
func TestSession_Update(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30)
u := &User{}
_ = s.OrderBy("Age DESC").First(u)
if affected != 1 || u.Age != 30 {
t.Fatal("failed to update")
}
}
func TestSession_DeleteAndCount(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Delete()
count, _ := s.Count()
if affected != 1 || count != 1 {
t.Fatal("failed to delete or count")
}
}
================================================
FILE: gee-orm/day6-transaction/session/table.go
================================================
package session
import (
"fmt"
"geeorm/log"
"reflect"
"strings"
"geeorm/schema"
)
// Model assigns refTable
func (s *Session) Model(value interface{}) *Session {
// nil or different model, update refTable
if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) {
s.refTable = schema.Parse(value, s.dialect)
}
return s
}
// RefTable returns a Schema instance that contains all parsed fields
func (s *Session) RefTable() *schema.Schema {
if s.refTable == nil {
log.Error("Model is not set")
}
return s.refTable
}
// CreateTable create a table in database with a model
func (s *Session) CreateTable() error {
table := s.RefTable()
var columns []string
for _, field := range table.Fields {
columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag))
}
desc := strings.Join(columns, ",")
_, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec()
return err
}
// DropTable drops a table with the name of model
func (s *Session) DropTable() error {
_, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec()
return err
}
// HasTable returns true of the table exists
func (s *Session) HasTable() bool {
sql, values := s.dialect.TableExistSQL(s.RefTable().Name)
row := s.Raw(sql, values...).QueryRow()
var tmp string
_ = row.Scan(&tmp)
return tmp == s.RefTable().Name
}
================================================
FILE: gee-orm/day6-transaction/session/table_test.go
================================================
package session
import (
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestSession_CreateTable(t *testing.T) {
s := NewSession().Model(&User{})
_ = s.DropTable()
_ = s.CreateTable()
if !s.HasTable() {
t.Fatal("Failed to create table User")
}
}
func TestSession_Model(t *testing.T) {
s := NewSession().Model(&User{})
table := s.RefTable()
s.Model(&Session{})
if table.Name != "User" || s.RefTable().Name != "Session" {
t.Fatal("Failed to change model")
}
}
================================================
FILE: gee-orm/day6-transaction/session/transaction.go
================================================
package session
import "geeorm/log"
// Begin a transaction
func (s *Session) Begin() (err error) {
log.Info("transaction begin")
if s.tx, err = s.db.Begin(); err != nil {
log.Error(err)
return
}
return
}
// Commit a transaction
func (s *Session) Commit() (err error) {
log.Info("transaction commit")
if err = s.tx.Commit(); err != nil {
log.Error(err)
}
return
}
// Rollback a transaction
func (s *Session) Rollback() (err error) {
log.Info("transaction rollback")
if err = s.tx.Rollback(); err != nil {
log.Error(err)
}
return
}
================================================
FILE: gee-orm/day7-migrate/clause/clause.go
================================================
package clause
import (
"strings"
)
// Clause contains SQL conditions
type Clause struct {
sql map[Type]string
sqlVars map[Type][]interface{}
}
// Type is the type of Clause
type Type int
// Support types for Clause
const (
INSERT Type = iota
VALUES
SELECT
LIMIT
WHERE
ORDERBY
UPDATE
DELETE
COUNT
)
// Set adds a sub clause of specific type
func (c *Clause) Set(name Type, vars ...interface{}) {
if c.sql == nil {
c.sql = make(map[Type]string)
c.sqlVars = make(map[Type][]interface{})
}
sql, vars := generators[name](vars...)
c.sql[name] = sql
c.sqlVars[name] = vars
}
// Build generate the final SQL and SQLVars
func (c *Clause) Build(orders ...Type) (string, []interface{}) {
var sqls []string
var vars []interface{}
for _, order := range orders {
if sql, ok := c.sql[order]; ok {
sqls = append(sqls, sql)
vars = append(vars, c.sqlVars[order]...)
}
}
return strings.Join(sqls, " "), vars
}
================================================
FILE: gee-orm/day7-migrate/clause/clause_test.go
================================================
package clause
import (
"reflect"
"testing"
)
func TestClause_Set(t *testing.T) {
var clause Clause
clause.Set(INSERT, "User", []string{"Name", "Age"})
sql := clause.sql[INSERT]
vars := clause.sqlVars[INSERT]
t.Log(sql, vars)
if sql != "INSERT INTO User (Name,Age)" || len(vars) != 0 {
t.Fatal("failed to get clause")
}
}
func testSelect(t *testing.T) {
var clause Clause
clause.Set(LIMIT, 3)
clause.Set(SELECT, "User", []string{"*"})
clause.Set(WHERE, "Name = ?", "Tom")
clause.Set(ORDERBY, "Age ASC")
sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT)
t.Log(sql, vars)
if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) {
t.Fatal("failed to build SQLVars")
}
}
func testUpdate(t *testing.T) {
var clause Clause
clause.Set(UPDATE, "User", map[string]interface{}{"Age": 30})
clause.Set(WHERE, "Name = ?", "Tom")
sql, vars := clause.Build(UPDATE, WHERE)
t.Log(sql, vars)
if sql != "UPDATE User SET Age = ? WHERE Name = ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{30, "Tom"}) {
t.Fatal("failed to build SQLVars")
}
}
func testDelete(t *testing.T) {
var clause Clause
clause.Set(DELETE, "User")
clause.Set(WHERE, "Name = ?", "Tom")
sql, vars := clause.Build(DELETE, WHERE)
t.Log(sql, vars)
if sql != "DELETE FROM User WHERE Name = ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom"}) {
t.Fatal("failed to build SQLVars")
}
}
func TestClause_Build(t *testing.T) {
t.Run("select", func(t *testing.T) {
testSelect(t)
})
t.Run("update", func(t *testing.T) {
testUpdate(t)
})
t.Run("delete", func(t *testing.T) {
testDelete(t)
})
}
================================================
FILE: gee-orm/day7-migrate/clause/generator.go
================================================
package clause
import (
"fmt"
"strings"
)
type generator func(values ...interface{}) (string, []interface{})
var generators map[Type]generator
func init() {
generators = make(map[Type]generator)
generators[INSERT] = _insert
generators[VALUES] = _values
generators[SELECT] = _select
generators[LIMIT] = _limit
generators[WHERE] = _where
generators[ORDERBY] = _orderBy
generators[UPDATE] = _update
generators[DELETE] = _delete
generators[COUNT] = _count
}
func genBindVars(num int) string {
var vars []string
for i := 0; i < num; i++ {
vars = append(vars, "?")
}
return strings.Join(vars, ", ")
}
func _insert(values ...interface{}) (string, []interface{}) {
// INSERT INTO $tableName ($fields)
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{}
}
func _values(values ...interface{}) (string, []interface{}) {
// VALUES ($v1), ($v2), ...
var bindStr string
var sql strings.Builder
var vars []interface{}
sql.WriteString("VALUES ")
for i, value := range values {
v := value.([]interface{})
if bindStr == "" {
bindStr = genBindVars(len(v))
}
sql.WriteString(fmt.Sprintf("(%v)", bindStr))
if i+1 != len(values) {
sql.WriteString(", ")
}
vars = append(vars, v...)
}
return sql.String(), vars
}
func _select(values ...interface{}) (string, []interface{}) {
// SELECT $fields FROM $tableName
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{}
}
func _limit(values ...interface{}) (string, []interface{}) {
// LIMIT $num
return "LIMIT ?", values
}
func _where(values ...interface{}) (string, []interface{}) {
// WHERE $desc
desc, vars := values[0], values[1:]
return fmt.Sprintf("WHERE %s", desc), vars
}
func _orderBy(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{}
}
func _update(values ...interface{}) (string, []interface{}) {
tableName := values[0]
m := values[1].(map[string]interface{})
var keys []string
var vars []interface{}
for k, v := range m {
keys = append(keys, k+" = ?")
vars = append(vars, v)
}
return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars
}
func _delete(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{}
}
func _count(values ...interface{}) (string, []interface{}) {
return _select(values[0], []string{"count(*)"})
}
================================================
FILE: gee-orm/day7-migrate/dialect/dialect.go
================================================
package dialect
import "reflect"
var dialectsMap = map[string]Dialect{}
// Dialect is an interface contains methods that a dialect has to implement
type Dialect interface {
DataTypeOf(typ reflect.Value) string
TableExistSQL(tableName string) (string, []interface{})
}
// RegisterDialect register a dialect to the global variable
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// Get the dialect from global variable if it exists
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
================================================
FILE: gee-orm/day7-migrate/dialect/sqlite3.go
================================================
package dialect
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct{}
var _ Dialect = (*sqlite3)(nil)
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
// Get Data Type for sqlite3 Dialect
func (s *sqlite3) DataTypeOf(typ reflect.Value) string {
switch typ.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
return "text"
case reflect.Array, reflect.Slice:
return "blob"
case reflect.Struct:
if _, ok := typ.Interface().(time.Time); ok {
return "datetime"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind()))
}
// TableExistSQL returns SQL that judge whether the table exists in database
func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}
================================================
FILE: gee-orm/day7-migrate/dialect/sqlite3_test.go
================================================
package dialect
import (
"reflect"
"testing"
)
func TestDataTypeOf(t *testing.T) {
dial := &sqlite3{}
cases := []struct {
Value interface{}
Type string
}{
{"Tom", "text"},
{123, "integer"},
{1.2, "real"},
{[]int{1, 2, 3}, "blob"},
}
for _, c := range cases {
if typ := dial.DataTypeOf(reflect.ValueOf(c.Value)); typ != c.Type {
t.Fatalf("expect %s, but got %s", c.Type, typ)
}
}
}
================================================
FILE: gee-orm/day7-migrate/geeorm.go
================================================
package geeorm
import (
"database/sql"
"fmt"
"geeorm/dialect"
"geeorm/log"
"geeorm/session"
"strings"
)
// Engine is the main struct of geeorm, manages all db sessions and transactions.
type Engine struct {
db *sql.DB
dialect dialect.Dialect
}
// NewEngine create a instance of Engine
// connect database and ping it to test whether it's alive
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
// make sure the specific dialect exists
dial, ok := dialect.GetDialect(driver)
if !ok {
log.Errorf("dialect %s Not Found", driver)
return
}
e = &Engine{db: db, dialect: dial}
log.Info("Connect database success")
return
}
// Close database connection
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
return
}
log.Info("Close database success")
}
// NewSession creates a new session for next operations
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db, engine.dialect)
}
// TxFunc will be called between tx.Begin() and tx.Commit()
// https://stackoverflow.com/questions/16184238/database-sql-tx-detecting-commit-or-rollback
type TxFunc func(*session.Session) (interface{}, error)
// Transaction executes sql wrapped in a transaction, then automatically commit if no error occurs
func (engine *Engine) Transaction(f TxFunc) (result interface{}, err error) {
s := engine.NewSession()
if err := s.Begin(); err != nil {
return nil, err
}
defer func() {
if p := recover(); p != nil {
_ = s.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
_ = s.Rollback() // err is non-nil; don't change it
} else {
err = s.Commit() // err is nil; if Commit returns error update err
}
}()
return f(s)
}
// difference returns a - b
func difference(a []string, b []string) (diff []string) {
mapB := make(map[string]bool)
for _, v := range b {
mapB[v] = true
}
for _, v := range a {
if _, ok := mapB[v]; !ok {
diff = append(diff, v)
}
}
return
}
// Migrate table
func (engine *Engine) Migrate(value interface{}) error {
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
if !s.Model(value).HasTable() {
log.Infof("table %s doesn't exist", s.RefTable().Name)
return nil, s.CreateTable()
}
table := s.RefTable()
rows, _ := s.Raw(fmt.Sprintf("SELECT * FROM %s LIMIT 1", table.Name)).QueryRows()
columns, _ := rows.Columns()
addCols := difference(table.FieldNames, columns)
delCols := difference(columns, table.FieldNames)
log.Infof("added cols %v, deleted cols %v", addCols, delCols)
for _, col := range addCols {
f := table.GetField(col)
sqlStr := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table.Name, f.Name, f.Type)
if _, err = s.Raw(sqlStr).Exec(); err != nil {
return
}
}
if len(delCols) == 0 {
return
}
tmp := "tmp_" + table.Name
fieldStr := strings.Join(table.FieldNames, ", ")
s.Raw(fmt.Sprintf("CREATE TABLE %s AS SELECT %s from %s;", tmp, fieldStr, table.Name))
s.Raw(fmt.Sprintf("DROP TABLE %s;", table.Name))
s.Raw(fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", tmp, table.Name))
_, err = s.Exec()
return
})
return err
}
================================================
FILE: gee-orm/day7-migrate/geeorm_test.go
================================================
package geeorm
import (
"errors"
"geeorm/session"
"reflect"
"testing"
_ "github.com/mattn/go-sqlite3"
)
func OpenDB(t *testing.T) *Engine {
t.Helper()
engine, err := NewEngine("sqlite3", "gee.db")
if err != nil {
t.Fatal("failed to connect", err)
}
return engine
}
func TestNewEngine(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
}
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func transactionRollback(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
s := engine.NewSession()
_ = s.Model(&User{}).DropTable()
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
_ = s.Model(&User{}).CreateTable()
_, err = s.Insert(&User{"Tom", 18})
return nil, errors.New("Error")
})
if err == nil || s.HasTable() {
t.Fatal("failed to rollback")
}
}
func transactionCommit(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
s := engine.NewSession()
_ = s.Model(&User{}).DropTable()
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
_ = s.Model(&User{}).CreateTable()
_, err = s.Insert(&User{"Tom", 18})
return
})
u := &User{}
_ = s.First(u)
if err != nil || u.Name != "Tom" {
t.Fatal("failed to commit")
}
}
func TestEngine_Transaction(t *testing.T) {
t.Run("rollback", func(t *testing.T) {
transactionRollback(t)
})
t.Run("commit", func(t *testing.T) {
transactionCommit(t)
})
}
func TestEngine_Migrate(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
s := engine.NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text PRIMARY KEY, XXX integer);").Exec()
_, _ = s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
engine.Migrate(&User{})
rows, _ := s.Raw("SELECT * FROM User").QueryRows()
columns, _ := rows.Columns()
if !reflect.DeepEqual(columns, []string{"Name", "Age"}) {
t.Fatal("Failed to migrate table User, got columns", columns)
}
}
================================================
FILE: gee-orm/day7-migrate/go.mod
================================================
module geeorm
go 1.13
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
================================================
FILE: gee-orm/day7-migrate/log/log.go
================================================
package log
import (
"io/ioutil"
"log"
"os"
"sync"
)
var (
errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile)
loggers = []*log.Logger{errorLog, infoLog}
mu sync.Mutex
)
// log methods
var (
Error = errorLog.Println
Errorf = errorLog.Printf
Info = infoLog.Println
Infof = infoLog.Printf
)
// log levels
const (
InfoLevel = iota
ErrorLevel
Disabled
)
// SetLevel controls log level
func SetLevel(level int) {
mu.Lock()
defer mu.Unlock()
for _, logger := range loggers {
logger.SetOutput(os.Stdout)
}
if ErrorLevel < level {
errorLog.SetOutput(ioutil.Discard)
}
if InfoLevel < level {
infoLog.SetOutput(ioutil.Discard)
}
}
================================================
FILE: gee-orm/day7-migrate/log/log_test.go
================================================
package log
import (
"os"
"testing"
)
func TestSetLevel(t *testing.T) {
SetLevel(ErrorLevel)
if infoLog.Writer() == os.Stdout || errorLog.Writer() != os.Stdout {
t.Fatal("failed to set log level")
}
SetLevel(Disabled)
if infoLog.Writer() == os.Stdout || errorLog.Writer() == os.Stdout {
t.Fatal("failed to set log level")
}
}
================================================
FILE: gee-orm/day7-migrate/schema/schema.go
================================================
package schema
import (
"geeorm/dialect"
"go/ast"
"reflect"
)
// Field represents a column of database
type Field struct {
Name string
Type string
Tag string
}
// Schema represents a table of database
type Schema struct {
Model interface{}
Name string
Fields []*Field
FieldNames []string
fieldMap map[string]*Field
}
// GetField returns field by name
func (schema *Schema) GetField(name string) *Field {
return schema.fieldMap[name]
}
// Values return the values of dest's member variables
func (schema *Schema) RecordValues(dest interface{}) []interface{} {
destValue := reflect.Indirect(reflect.ValueOf(dest))
var fieldValues []interface{}
for _, field := range schema.Fields {
fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface())
}
return fieldValues
}
type ITableName interface {
TableName() string
}
// Parse a struct to a Schema instance
func Parse(dest interface{}, d dialect.Dialect) *Schema {
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
var tableName string
t, ok := dest.(ITableName)
if !ok {
tableName = modelType.Name()
} else {
tableName = t.TableName()
}
schema := &Schema{
Model: dest,
Name: tableName,
fieldMap: make(map[string]*Field),
}
for i := 0; i < modelType.NumField(); i++ {
p := modelType.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) {
field := &Field{
Name: p.Name,
Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))),
}
if v, ok := p.Tag.Lookup("geeorm"); ok {
field.Tag = v
}
schema.Fields = append(schema.Fields, field)
schema.FieldNames = append(schema.FieldNames, p.Name)
schema.fieldMap[p.Name] = field
}
}
return schema
}
================================================
FILE: gee-orm/day7-migrate/schema/schema_test.go
================================================
package schema
import (
"geeorm/dialect"
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
var TestDial, _ = dialect.GetDialect("sqlite3")
func TestParse(t *testing.T) {
schema := Parse(&User{}, TestDial)
if schema.Name != "User" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
if schema.GetField("Name").Tag != "PRIMARY KEY" {
t.Fatal("failed to parse primary key")
}
}
func TestSchema_RecordValues(t *testing.T) {
schema := Parse(&User{}, TestDial)
values := schema.RecordValues(&User{"Tom", 18})
name := values[0].(string)
age := values[1].(int)
if name != "Tom" || age != 18 {
t.Fatal("failed to get values")
}
}
type UserTest struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func (u *UserTest) TableName() string {
return "ns_user_test"
}
func TestSchema_TableName(t *testing.T) {
schema := Parse(&UserTest{}, TestDial)
if schema.Name != "ns_user_test" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
}
================================================
FILE: gee-orm/day7-migrate/session/hooks.go
================================================
package session
import (
"geeorm/log"
"reflect"
)
// Hooks constants
const (
BeforeQuery = "BeforeQuery"
AfterQuery = "AfterQuery"
BeforeUpdate = "BeforeUpdate"
AfterUpdate = "AfterUpdate"
BeforeDelete = "BeforeDelete"
AfterDelete = "AfterDelete"
BeforeInsert = "BeforeInsert"
AfterInsert = "AfterInsert"
)
// CallMethod calls the registered hooks
func (s *Session) CallMethod(method string, value interface{}) {
fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method)
if value != nil {
fm = reflect.ValueOf(value).MethodByName(method)
}
param := []reflect.Value{reflect.ValueOf(s)}
if fm.IsValid() {
if v := fm.Call(param); len(v) > 0 {
if err, ok := v[0].Interface().(error); ok {
log.Error(err)
}
}
}
return
}
================================================
FILE: gee-orm/day7-migrate/session/hooks_test.go
================================================
package session
import (
"geeorm/log"
"testing"
)
type Account struct {
ID int `geeorm:"PRIMARY KEY"`
Password string
}
func (account *Account) BeforeInsert(s *Session) error {
log.Info("before inert", account)
account.ID += 1000
return nil
}
func (account *Account) AfterQuery(s *Session) error {
log.Info("after query", account)
account.Password = "******"
return nil
}
func TestSession_CallMethod(t *testing.T) {
s := NewSession().Model(&Account{})
_ = s.DropTable()
_ = s.CreateTable()
_, _ = s.Insert(&Account{1, "123456"}, &Account{2, "qwerty"})
u := &Account{}
err := s.First(u)
if err != nil || u.ID != 1001 || u.Password != "******" {
t.Fatal("Failed to call hooks after query, got", u)
}
}
================================================
FILE: gee-orm/day7-migrate/session/raw.go
================================================
package session
import (
"database/sql"
"geeorm/clause"
"geeorm/dialect"
"geeorm/log"
"geeorm/schema"
"strings"
)
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
db *sql.DB
dialect dialect.Dialect
tx *sql.Tx
refTable *schema.Schema
clause clause.Clause
sql strings.Builder
sqlVars []interface{}
}
// New creates a instance of Session
func New(db *sql.DB, dialect dialect.Dialect) *Session {
return &Session{
db: db,
dialect: dialect,
}
}
// Clear initialize the state of a session
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
s.clause = clause.Clause{}
}
// CommonDB is a minimal function set of db
type CommonDB interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
Exec(query string, args ...interface{}) (sql.Result, error)
}
var _ CommonDB = (*sql.DB)(nil)
var _ CommonDB = (*sql.Tx)(nil)
// DB returns tx if a tx begins. otherwise return *sql.DB
func (s *Session) DB() CommonDB {
if s.tx != nil {
return s.tx
}
return s.db
}
// Exec raw sql with sqlVars
func (s *Session) Exec() (result sql.Result, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// QueryRow gets a record from db
func (s *Session) QueryRow() *sql.Row {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
return s.DB().QueryRow(s.sql.String(), s.sqlVars...)
}
// QueryRows gets a list of records from db
func (s *Session) QueryRows() (rows *sql.Rows, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// Raw appends sql and sqlVars
func (s *Session) Raw(sql string, values ...interface{}) *Session {
s.sql.WriteString(sql)
s.sql.WriteString(" ")
s.sqlVars = append(s.sqlVars, values...)
return s
}
================================================
FILE: gee-orm/day7-migrate/session/raw_test.go
================================================
package session
import (
"database/sql"
"os"
"testing"
"geeorm/dialect"
_ "github.com/mattn/go-sqlite3"
)
var (
TestDB *sql.DB
TestDial, _ = dialect.GetDialect("sqlite3")
)
func TestMain(m *testing.M) {
TestDB, _ = sql.Open("sqlite3", "../gee.db")
code := m.Run()
_ = TestDB.Close()
os.Exit(code)
}
func NewSession() *Session {
return New(TestDB, TestDial)
}
func TestSession_Exec(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
if count, err := result.RowsAffected(); err != nil || count != 2 {
t.Fatal("expect 2, but got", count)
}
}
func TestSession_QueryRows(t *testing.T) {
s := NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
row := s.Raw("SELECT count(*) FROM User").QueryRow()
var count int
if err := row.Scan(&count); err != nil || count != 0 {
t.Fatal("failed to query db", err)
}
}
================================================
FILE: gee-orm/day7-migrate/session/record.go
================================================
package session
import (
"errors"
"geeorm/clause"
"reflect"
)
// Insert one or more records in database
func (s *Session) Insert(values ...interface{}) (int64, error) {
recordValues := make([]interface{}, 0)
for _, value := range values {
s.CallMethod(BeforeInsert, value)
table := s.Model(value).RefTable()
s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
recordValues = append(recordValues, table.RecordValues(value))
}
s.clause.Set(clause.VALUES, recordValues...)
sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterInsert, nil)
return result.RowsAffected()
}
// Find gets all eligible records
func (s *Session) Find(values interface{}) error {
s.CallMethod(BeforeQuery, nil)
destSlice := reflect.Indirect(reflect.ValueOf(values))
destType := destSlice.Type().Elem()
table := s.Model(reflect.New(destType).Elem().Interface()).RefTable()
s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
rows, err := s.Raw(sql, vars...).QueryRows()
if err != nil {
return err
}
for rows.Next() {
dest := reflect.New(destType).Elem()
var values []interface{}
for _, name := range table.FieldNames {
values = append(values, dest.FieldByName(name).Addr().Interface())
}
if err := rows.Scan(values...); err != nil {
return err
}
s.CallMethod(AfterQuery, dest.Addr().Interface())
destSlice.Set(reflect.Append(destSlice, dest))
}
return rows.Close()
}
// First gets the 1st row
func (s *Session) First(value interface{}) error {
dest := reflect.Indirect(reflect.ValueOf(value))
destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem()
if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil {
return err
}
if destSlice.Len() == 0 {
return errors.New("NOT FOUND")
}
dest.Set(destSlice.Index(0))
return nil
}
// Limit adds limit condition to clause
func (s *Session) Limit(num int) *Session {
s.clause.Set(clause.LIMIT, num)
return s
}
// Where adds limit condition to clause
func (s *Session) Where(desc string, args ...interface{}) *Session {
var vars []interface{}
s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...)
return s
}
// OrderBy adds order by condition to clause
func (s *Session) OrderBy(desc string) *Session {
s.clause.Set(clause.ORDERBY, desc)
return s
}
// Update records with where clause
// support map[string]interface{}
// also support kv list: "Name", "Tom", "Age", 18, ....
func (s *Session) Update(kv ...interface{}) (int64, error) {
s.CallMethod(BeforeUpdate, nil)
m, ok := kv[0].(map[string]interface{})
if !ok {
m = make(map[string]interface{})
for i := 0; i < len(kv); i += 2 {
m[kv[i].(string)] = kv[i+1]
}
}
s.clause.Set(clause.UPDATE, s.RefTable().Name, m)
sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterUpdate, nil)
return result.RowsAffected()
}
// Delete records with where clause
func (s *Session) Delete() (int64, error) {
s.CallMethod(BeforeDelete, nil)
s.clause.Set(clause.DELETE, s.RefTable().Name)
sql, vars := s.clause.Build(clause.DELETE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
s.CallMethod(AfterDelete, nil)
return result.RowsAffected()
}
// Count records with where clause
func (s *Session) Count() (int64, error) {
s.clause.Set(clause.COUNT, s.RefTable().Name)
sql, vars := s.clause.Build(clause.COUNT, clause.WHERE)
row := s.Raw(sql, vars...).QueryRow()
var tmp int64
if err := row.Scan(&tmp); err != nil {
return 0, err
}
return tmp, nil
}
================================================
FILE: gee-orm/day7-migrate/session/record_test.go
================================================
package session
import "testing"
var (
user1 = &User{"Tom", 18}
user2 = &User{"Sam", 25}
user3 = &User{"Jack", 25}
)
func testRecordInit(t *testing.T) *Session {
t.Helper()
s := NewSession().Model(&User{})
err1 := s.DropTable()
err2 := s.CreateTable()
_, err3 := s.Insert(user1, user2)
if err1 != nil || err2 != nil || err3 != nil {
t.Fatal("failed init test records")
}
return s
}
func TestSession_Insert(t *testing.T) {
s := testRecordInit(t)
affected, err := s.Insert(user3)
if err != nil || affected != 1 {
t.Fatal("failed to create record")
}
}
func TestSession_Find(t *testing.T) {
s := testRecordInit(t)
var users []User
if err := s.Find(&users); err != nil || len(users) != 2 {
t.Fatal("failed to query all")
}
}
func TestSession_First(t *testing.T) {
s := testRecordInit(t)
u := &User{}
err := s.First(u)
if err != nil || u.Name != "Tom" || u.Age != 18 {
t.Fatal("failed to query first")
}
}
func TestSession_Limit(t *testing.T) {
s := testRecordInit(t)
var users []User
err := s.Limit(1).Find(&users)
if err != nil || len(users) != 1 {
t.Fatal("failed to query with limit condition")
}
}
func TestSession_Where(t *testing.T) {
s := testRecordInit(t)
var users []User
_, err1 := s.Insert(user3)
err2 := s.Where("Age = ?", 25).Find(&users)
if err1 != nil || err2 != nil || len(users) != 2 {
t.Fatal("failed to query with where condition")
}
}
func TestSession_OrderBy(t *testing.T) {
s := testRecordInit(t)
u := &User{}
err := s.OrderBy("Age DESC").First(u)
if err != nil || u.Age != 25 {
t.Fatal("failed to query with order by condition")
}
}
func TestSession_Update(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30)
u := &User{}
_ = s.OrderBy("Age DESC").First(u)
if affected != 1 || u.Age != 30 {
t.Fatal("failed to update")
}
}
func TestSession_DeleteAndCount(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Delete()
count, _ := s.Count()
if affected != 1 || count != 1 {
t.Fatal("failed to delete or count")
}
}
================================================
FILE: gee-orm/day7-migrate/session/table.go
================================================
package session
import (
"fmt"
"geeorm/log"
"reflect"
"strings"
"geeorm/schema"
)
// Model assigns refTable
func (s *Session) Model(value interface{}) *Session {
// nil or different model, update refTable
if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) {
s.refTable = schema.Parse(value, s.dialect)
}
return s
}
// RefTable returns a Schema instance that contains all parsed fields
func (s *Session) RefTable() *schema.Schema {
if s.refTable == nil {
log.Error("Model is not set")
}
return s.refTable
}
// CreateTable create a table in database with a model
func (s *Session) CreateTable() error {
table := s.RefTable()
var columns []string
for _, field := range table.Fields {
columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag))
}
desc := strings.Join(columns, ",")
_, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec()
return err
}
// DropTable drops a table with the name of model
func (s *Session) DropTable() error {
_, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec()
return err
}
// HasTable returns true of the table exists
func (s *Session) HasTable() bool {
sql, values := s.dialect.TableExistSQL(s.RefTable().Name)
row := s.Raw(sql, values...).QueryRow()
var tmp string
_ = row.Scan(&tmp)
return tmp == s.RefTable().Name
}
================================================
FILE: gee-orm/day7-migrate/session/table_test.go
================================================
package session
import (
"testing"
)
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestSession_CreateTable(t *testing.T) {
s := NewSession().Model(&User{})
_ = s.DropTable()
_ = s.CreateTable()
if !s.HasTable() {
t.Fatal("Failed to create table User")
}
}
func TestSession_Model(t *testing.T) {
s := NewSession().Model(&User{})
table := s.RefTable()
s.Model(&Session{})
if table.Name != "User" || s.RefTable().Name != "Session" {
t.Fatal("Failed to change model")
}
}
================================================
FILE: gee-orm/day7-migrate/session/transaction.go
================================================
package session
import "geeorm/log"
// Begin a transaction
func (s *Session) Begin() (err error) {
log.Info("transaction begin")
if s.tx, err = s.db.Begin(); err != nil {
log.Error(err)
return
}
return
}
// Commit a transaction
func (s *Session) Commit() (err error) {
log.Info("transaction commit")
if err = s.tx.Commit(); err != nil {
log.Error(err)
}
return
}
// Rollback a transaction
func (s *Session) Rollback() (err error) {
log.Info("transaction rollback")
if err = s.tx.Rollback(); err != nil {
log.Error(err)
}
return
}
================================================
FILE: gee-orm/doc/geeorm-day1.md
================================================
---
title: 动手写ORM框架 - GeeORM第一天 database/sql 基础
date: 2020-03-07 23:00:00
description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。介绍了 SQLite 的基础操作(连接数据库,创建表、增删记录等),使用 Go 标准库 database/sql 操作 SQLite 数据库,包括执行(Exec),查询(Query, QueryRow)。
tags:
- Go
nav: 从零实现
categories:
- ORM框架 - GeeORM
keywords:
- Go语言
- 从零实现ORM框架
- database/sql
- sqlite
image: post/geeorm/geeorm_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day1 database/sql 基础
---
本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第一篇。介绍了
- SQLite 的基础操作(连接数据库,创建表、增删记录等)。
- 使用 Go 语言标准库 database/sql 连接并操作 SQLite 数据库,并简单封装。**代码约150行**
## 1 初识 SQLite
> SQLite is a C-language library that implements a small, fast, self-contained, high-reliability, full-featured, SQL database engine.
> -- [SQLite 官网](https://sqlite.org/index.html)
SQLite 是一款轻量级的,遵守 ACID 事务原则的关系型数据库。SQLite 可以直接嵌入到代码中,不需要像 MySQL、PostgreSQL 需要启动独立的服务才能使用。SQLite 将数据存储在单一的磁盘文件中,使用起来非常方便。也非常适合初学者用来学习关系型数据的使用。GeeORM 的所有的开发和测试均基于 SQLite。
在 Ubuntu 上,安装 SQLite 只需要一行命令,无需配置即可使用。
```bash
apt-get install sqlite3
```
接下来,连接数据库(gee.db),如若 gee.db 不存在,则会新建。如果连接成功,就进入到了 SQLite 的命令行模式,执行 `.help` 可以看到所有的帮助命令。
```bash
> sqlite3 gee.db
SQLite version 3.22.0 2018-01-22 18:45:57
Enter ".help" for usage hints.
sqlite>
```
使用 SQL 语句新建一张表 `User`,包含两个字段,字符串 Name 和 整型 Age。
```bash
sqlite> CREATE TABLE User(Name text, Age integer);
```
插入两条数据
```bash
sqlite> INSERT INTO User(Name, Age) VALUES ("Tom", 18), ("Jack", 25);
```
执行简单的查询操作,在执行之前使用 `.head on` 打开显示列名的开关,这样查询结果看上去更直观。
```bash
sqlite> .head on
# 查找 `Age > 20` 的记录;
sqlite> SELECT * FROM User WHERE Age > 20;
Name|Age
Jack|25
# 统计记录个数。
sqlite> SELECT COUNT(*) FROM User;
COUNT(*)
2
```
使用 `.table` 查看当前数据库中所有的表(table),执行 `.schema
` 查看建表的 SQL 语句。
```bash
sqlite> .table
User
sqlite> .schema User
CREATE TABLE User(Name text, Age integer);
```
SQLite 的使用暂时介绍这么多,了解了以上使用方法已经足够我们完成今天的任务了。如果想了解更多用法,可参考 [SQLite 常用命令](https://geektutu.com/post/cheat-sheet-sqlite.html)。
## 2 database/sql 标准库
Go 语言提供了标准库 `database/sql` 用于和数据库的交互,接下来我们写一个 Demo,看一看这个库的用法。
```go
package main
import (
"database/sql"
"log"
_ "github.com/mattn/go-sqlite3"
)
func main() {
db, _ := sql.Open("sqlite3", "gee.db")
defer func() { _ = db.Close() }()
_, _ = db.Exec("DROP TABLE IF EXISTS User;")
_, _ = db.Exec("CREATE TABLE User(Name text);")
result, err := db.Exec("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam")
if err == nil {
affected, _ := result.RowsAffected()
log.Println(affected)
}
row := db.QueryRow("SELECT Name FROM User LIMIT 1")
var name string
if err := row.Scan(&name); err == nil {
log.Println(name)
}
}
```
> go-sqlite3 依赖于 gcc,如果这份代码在 Windows 上运行的话,需要安装 [mingw](http://mingw.org/) 或其他包含有 gcc 编译器的工具包。
执行 `go run .`,输出如下。
```bash
> go run .
2020/03/07 20:28:37 2
2020/03/07 20:28:37 Tom
```
- 使用 `sql.Open()` 连接数据库,第一个参数是驱动名称,import 语句 `_ "github.com/mattn/go-sqlite3"` 包导入时会注册 sqlite3 的驱动,第二个参数是数据库的名称,对于 SQLite 来说,也就是文件名,不存在会新建。返回一个 `sql.DB` 实例的指针。
- `Exec()` 用于执行 SQL 语句,如果是查询语句,不会返回相关的记录。所以查询语句通常使用 `Query()` 和 `QueryRow()`,前者可以返回多条记录,后者只返回一条记录。
- `Exec()`、`Query()`、`QueryRow()` 接受1或多个入参,第一个入参是 SQL 语句,后面的入参是 SQL 语句中的占位符 `?` 对应的值,占位符一般用来防 SQL 注入。
- `QueryRow()` 的返回值类型是 `*sql.Row`,`row.Scan()` 接受1或多个指针作为参数,可以获取对应列(column)的值,在这个示例中,只有 `Name` 一列,因此传入字符串指针 `&name` 即可获取到查询的结果。
掌握了基础的 SQL 语句和 Go 标准库 `database/sql` 的使用,可以开始实现 ORM 框架的雏形了。
## 3 实现一个简单的 log 库
开发一个框架/库并不容易,详细的日志能够帮助我们快速地定位问题。因此,在写核心代码之前,我们先用几十行代码实现一个简单的 log 库。
> 为什么不直接使用原生的 log 库呢?log 标准库没有日志分级,不打印文件和行号,这就意味着我们很难快速知道是哪个地方发生了错误。
这个简易的 log 库具备以下特性:
- 支持日志分级(Info、Error、Disabled 三级)。
- 不同层级日志显示时使用不同的颜色区分。
- 显示打印日志代码对应的文件名和行号。
```bash
go mod init geeorm
```
首先创建一个名为 geeorm 的 module,并新建文件 log/log.go,用于放置和日志相关的代码。GeeORM 现在长这个样子:
```bash
day1-database-sql/
|--log/
|--log.go
|--go.mod
```
第一步,创建 2 个日志实例分别用于打印 Info 和 Error 日志。
[day1-database-sql/log/log.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day1-database-sql/log)
```go
package log
import (
"io/ioutil"
"log"
"os"
"sync"
)
var (
errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
infoLog = log.New(os.Stdout, "\033[34m[info ]\033[0m ", log.LstdFlags|log.Lshortfile)
loggers = []*log.Logger{errorLog, infoLog}
mu sync.Mutex
)
// log methods
var (
Error = errorLog.Println
Errorf = errorLog.Printf
Info = infoLog.Println
Infof = infoLog.Printf
)
```
- `[info ]` 颜色为蓝色,`[error]` 为红色。
- 使用 `log.Lshortfile` 支持显示文件名和代码行号。
- 暴露 `Error`,`Errorf`,`Info`,`Infof` 4个方法。
第二步呢,支持设置日志的层级(InfoLevel, ErrorLevel, Disabled)。
```go
// log levels
const (
InfoLevel = iota
ErrorLevel
Disabled
)
// SetLevel controls log level
func SetLevel(level int) {
mu.Lock()
defer mu.Unlock()
for _, logger := range loggers {
logger.SetOutput(os.Stdout)
}
if ErrorLevel < level {
errorLog.SetOutput(ioutil.Discard)
}
if InfoLevel < level {
infoLog.SetOutput(ioutil.Discard)
}
}
```
- 这一部分的实现非常简单,三个层级声明为三个常量,通过控制 `Output`,来控制日志是否打印。
- 如果设置为 ErrorLevel,infoLog 的输出会被定向到 `ioutil.Discard`,即不打印该日志。
至此呢,一个简单的支持分级的 log 库就实现完成了。
## 4 核心结构 Session
我们在根目录下新建一个文件夹 session,用于实现与数据库的交互。今天我们只实现直接调用 SQL 语句进行原生交互的部分,这部分代码实现在 `session/raw.go` 中。
[day1-database-sql/session/raw.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day1-database-sql/session)
```go
package session
import (
"database/sql"
"geeorm/log"
"strings"
)
type Session struct {
db *sql.DB
sql strings.Builder
sqlVars []interface{}
}
func New(db *sql.DB) *Session {
return &Session{db: db}
}
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
}
func (s *Session) DB() *sql.DB {
return s.db
}
func (s *Session) Raw(sql string, values ...interface{}) *Session {
s.sql.WriteString(sql)
s.sql.WriteString(" ")
s.sqlVars = append(s.sqlVars, values...)
return s
}
```
- Session 结构体目前只包含三个成员变量,第一个是 `db *sql.DB`,即使用 `sql.Open()` 方法连接数据库成功之后返回的指针。
- 第二个和第三个成员变量用来拼接 SQL 语句和 SQL 语句中占位符的对应值。用户调用 `Raw()` 方法即可改变这两个变量的值。
接下来呢,封装 `Exec()`、`Query()` 和 `QueryRow()` 三个原生方法。
```go
// Exec raw sql with sqlVars
func (s *Session) Exec() (result sql.Result, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
// QueryRow gets a record from db
func (s *Session) QueryRow() *sql.Row {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
return s.DB().QueryRow(s.sql.String(), s.sqlVars...)
}
// QueryRows gets a list of records from db
func (s *Session) QueryRows() (rows *sql.Rows, err error) {
defer s.Clear()
log.Info(s.sql.String(), s.sqlVars)
if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
log.Error(err)
}
return
}
```
- 封装有 2 个目的,一是统一打印日志(包括 执行的SQL 语句和错误日志)。
- 二是执行完成后,清空 `(s *Session).sql` 和 `(s *Session).sqlVars` 两个变量。这样 Session 可以复用,开启一次会话,可以执行多次 SQL。
## 5 核心结构 Engine
Session 负责与数据库的交互,那交互前的准备工作(比如连接/测试数据库),交互后的收尾工作(关闭连接)等就交给 Engine 来负责了。Engine 是 GeeORM 与用户交互的入口。代码位于根目录的 `geeorm.go`。
[day1-database-sql/geeorm.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day1-database-sql)
```go
package geeorm
import (
"database/sql"
"geeorm/log"
"geeorm/session"
)
type Engine struct {
db *sql.DB
}
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
e = &Engine{db: db}
log.Info("Connect database success")
return
}
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
}
log.Info("Close database success")
}
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db)
}
```
Engine 的逻辑非常简单,最重要的方法是 `NewEngine`,`NewEngine` 主要做了两件事。
- 连接数据库,返回 `*sql.DB`。
- 调用 `db.Ping()`,检查数据库是否能够正常连接。
另外呢,提供了 Engine 提供了 `NewSession()` 方法,这样可以通过 `Engine` 实例创建会话,进而与数据库进行交互了。到这一步,整个 GeeORM 的框架雏形已经出来了。
```bash
day1-database-sql/
|--log/ # 日志
|--log.go
|--session/ # 数据库交互
|--raw.go
|--geeorm.go # 用户交互
|--go.mod
```
## 6 测试
GeeORM 的单元测试是比较完备的,可以参考 `log_test.go`、`raw_test.go` 和 `geeorm_test.go` 等几个测试文件,在这里呢,就不一一讲解了。接下来呢,我们将 geeorm 视为第三方库来使用。
在根目录下新建 cmd_test 目录放置测试代码,新建文件 main.go。
[day1-database-sql/cmd_test/main.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day1-database-sql/cmd_test)
```go
package main
import (
"geeorm"
"geeorm/log"
_ "github.com/mattn/go-sqlite3"
)
func main() {
engine, _ := geeorm.NewEngine("sqlite3", "gee.db")
defer engine.Close()
s := engine.NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text);").Exec()
result, _ := s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
count, _ := result.RowsAffected()
fmt.Printf("Exec success, %d affected\n", count)
}
```
执行 `go run main.go`,将会看到如下的输出:

日志中出现了一行报错信息,*table User already exists*,因为我们在 main 函数中执行了两次创建表 `User` 的语句。可以看到,每一行日志均标明了报错的文件和行号,而且不同层级日志的颜色是不同的。
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html)
================================================
FILE: gee-orm/doc/geeorm-day2.md
================================================
---
title: 动手写ORM框架 - GeeORM第二天 对象表结构映射
date: 2020-03-08 00:20:00
description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。使用反射(reflect)获取任意 struct 对象的名称和字段,映射为数据中的表;使用 dialect 隔离不同数据库之间的差异,便于扩展;数据库表的创建(create)、删除(drop)。
tags:
- Go
nav: 从零实现
categories:
- ORM框架 - GeeORM
keywords:
- Go语言
- 从零实现ORM框架
- database/sql
- sqlite
- reflect
- table mapping
image: post/geeorm/geeorm_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day2 对象表结构映射
---
本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第二篇。
- 使用 dialect 隔离不同数据库之间的差异,便于扩展。
- 使用反射(reflect)获取任意 struct 对象的名称和字段,映射为数据中的表。
- 数据库表的创建(create)、删除(drop)。**代码约150行**
## 1 Dialect
SQL 语句中的类型和 Go 语言中的类型是不同的,例如Go 语言中的 `int`、`int8`、`int16` 等类型均对应 SQLite 中的 `integer` 类型。因此实现 ORM 映射的第一步,需要思考如何将 Go 语言的类型映射为数据库中的类型。
同时,不同数据库支持的数据类型也是有差异的,即使功能相同,在 SQL 语句的表达上也可能有差异。ORM 框架往往需要兼容多种数据库,因此我们需要将差异的这一部分提取出来,每一种数据库分别实现,实现最大程度的复用和解耦。这部分代码称之为 `dialect`。
在根目录下新建文件夹 dialect,并在 dialect 文件夹下新建文件 `dialect.go`,抽象出各个数据库差异的部分。
[day2-reflect-schema/dialect/dialect.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema/dialect)
```go
package dialect
import "reflect"
var dialectsMap = map[string]Dialect{}
type Dialect interface {
DataTypeOf(typ reflect.Value) string
TableExistSQL(tableName string) (string, []interface{})
}
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
```
`Dialect` 接口包含 2 个方法:
- `DataTypeOf` 用于将 Go 语言的类型转换为该数据库的数据类型。
- `TableExistSQL` 返回某个表是否存在的 SQL 语句,参数是表名(table)。
当然,不同数据库之间的差异远远不止这两个地方,随着 ORM 框架功能的增多,dialect 的实现也会逐渐丰富起来,同时框架的其他部分不会受到影响。
同时,声明了 `RegisterDialect` 和 `GetDialect` 两个方法用于注册和获取 dialect 实例。如果新增加对某个数据库的支持,那么调用 `RegisterDialect` 即可注册到全局。
接下来,在`dialect` 目录下新建文件 `sqlite3.go` 增加对 SQLite 的支持。
[day2-reflect-schema/dialect/sqlite3.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema/dialect)
```go
package dialect
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct{}
var _ Dialect = (*sqlite3)(nil)
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
func (s *sqlite3) DataTypeOf(typ reflect.Value) string {
switch typ.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
return "text"
case reflect.Array, reflect.Slice:
return "blob"
case reflect.Struct:
if _, ok := typ.Interface().(time.Time); ok {
return "datetime"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind()))
}
func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}
```
- `sqlite3.go` 的实现虽然比较繁琐,但是整体逻辑还是非常清晰的。`DataTypeOf` 将 Go 语言的类型映射为 SQLite 的数据类型。`TableExistSQL` 返回了在 SQLite 中判断表 `tableName` 是否存在的 SQL 语句。
- 实现了 `init()` 函数,包在第一次加载时,会将 sqlite3 的 dialect 自动注册到全局。
## 2 Schema
Dialect 实现了一些特定的 SQL 语句的转换,接下来我们将要实现 ORM 框架中最为核心的转换——对象(object)和表(table)的转换。给定一个任意的对象,转换为关系型数据库中的表结构。
在数据库中创建一张表需要哪些要素呢?
- 表名(table name) —— 结构体名(struct name)
- 字段名和字段类型 —— 成员变量和类型。
- 额外的约束条件(例如非空、主键等) —— 成员变量的Tag(Go 语言通过 Tag 实现,Java、Python 等语言通过注解实现)
举一个实际的例子:
```go
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
```
期望对应的 schema 语句:
```sql
CREATE TABLE `User` (`Name` text PRIMARY KEY, `Age` integer);
```
我们将这部分代码的实现放置在一个子包 `schema/schema.go` 中。
[day2-reflect-schema/schema/schema.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema/schema)
```go
package schema
import (
"geeorm/dialect"
"go/ast"
"reflect"
)
// Field represents a column of database
type Field struct {
Name string
Type string
Tag string
}
// Schema represents a table of database
type Schema struct {
Model interface{}
Name string
Fields []*Field
FieldNames []string
fieldMap map[string]*Field
}
func (schema *Schema) GetField(name string) *Field {
return schema.fieldMap[name]
}
```
- Field 包含 3 个成员变量,字段名 Name、类型 Type、和约束条件 Tag
- Schema 主要包含被映射的对象 Model、表名 Name 和字段 Fields。
- FieldNames 包含所有的字段名(列名),fieldMap 记录字段名和 Field 的映射关系,方便之后直接使用,无需遍历 Fields。
接下来实现 Parse 函数,将任意的对象解析为 Schema 实例。
```go
func Parse(dest interface{}, d dialect.Dialect) *Schema {
modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
schema := &Schema{
Model: dest,
Name: modelType.Name(),
fieldMap: make(map[string]*Field),
}
for i := 0; i < modelType.NumField(); i++ {
p := modelType.Field(i)
if !p.Anonymous && ast.IsExported(p.Name) {
field := &Field{
Name: p.Name,
Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))),
}
if v, ok := p.Tag.Lookup("geeorm"); ok {
field.Tag = v
}
schema.Fields = append(schema.Fields, field)
schema.FieldNames = append(schema.FieldNames, p.Name)
schema.fieldMap[p.Name] = field
}
}
return schema
}
```
- `TypeOf()` 和 `ValueOf()` 是 reflect 包最为基本也是最重要的 2 个方法,分别用来返回入参的类型和值。因为设计的入参是一个对象的指针,因此需要 `reflect.Indirect()` 获取指针指向的实例。
- `modelType.Name()` 获取到结构体的名称作为表名。
- `NumField()` 获取实例的字段的个数,然后通过下标获取到特定字段 `p := modelType.Field(i)`。
- `p.Name` 即字段名,`p.Type` 即字段类型,通过 `(Dialect).DataTypeOf()` 转换为数据库的字段类型,`p.Tag` 即额外的约束条件。
写一个测试用例来验证 Parse 函数。
```go
// schema_test.go
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
var TestDial, _ = dialect.GetDialect("sqlite3")
func TestParse(t *testing.T) {
schema := Parse(&User{}, TestDial)
if schema.Name != "User" || len(schema.Fields) != 2 {
t.Fatal("failed to parse User struct")
}
if schema.GetField("Name").Tag != "PRIMARY KEY" {
t.Fatal("failed to parse primary key")
}
}
```
## 3 Session
Session 的核心功能是与数据库进行交互。因此,我们将数据库表的增/删操作实现在子包 session 中。在此之前,Session 的结构需要做一些调整。
```go
type Session struct {
db *sql.DB
dialect dialect.Dialect
refTable *schema.Schema
sql strings.Builder
sqlVars []interface{}
}
func New(db *sql.DB, dialect dialect.Dialect) *Session {
return &Session{
db: db,
dialect: dialect,
}
}
```
- `Session` 成员变量新增 dialect 和 refTable
- 构造函数 `New` 的参数改为 2 个,db 和 dialect。
在文件夹 `session` 下新建 `table.go` 用于放置操作数据库表相关的代码。
[day2-reflect-schema/session/table.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema/session)
```go
func (s *Session) Model(value interface{}) *Session {
// nil or different model, update refTable
if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) {
s.refTable = schema.Parse(value, s.dialect)
}
return s
}
func (s *Session) RefTable() *schema.Schema {
if s.refTable == nil {
log.Error("Model is not set")
}
return s.refTable
}
```
- `Model()` 方法用于给 refTable 赋值。解析操作是比较耗时的,因此将解析的结果保存在成员变量 refTable 中,即使 `Model()` 被调用多次,如果传入的结构体名称不发生变化,则不会更新 refTable 的值。
- `RefTable()` 方法返回 refTable 的值,如果 refTable 未被赋值,则打印错误日志。
接下来实现数据库表的创建、删除和判断是否存在的功能。三个方法的实现逻辑是相似的,利用 `RefTable()` 返回的数据库表和字段的信息,拼接出 SQL 语句,调用原生 SQL 接口执行。
```go
func (s *Session) CreateTable() error {
table := s.RefTable()
var columns []string
for _, field := range table.Fields {
columns = append(columns, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag))
}
desc := strings.Join(columns, ",")
_, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec()
return err
}
func (s *Session) DropTable() error {
_, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec()
return err
}
func (s *Session) HasTable() bool {
sql, values := s.dialect.TableExistSQL(s.RefTable().Name)
row := s.Raw(sql, values...).QueryRow()
var tmp string
_ = row.Scan(&tmp)
return tmp == s.RefTable().Name
}
```
在 `table_test.go` 中实现对应的测试用例:
```go
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestSession_CreateTable(t *testing.T) {
s := NewSession().Model(&User{})
_ = s.DropTable()
_ = s.CreateTable()
if !s.HasTable() {
t.Fatal("Failed to create table User")
}
}
```
## 4 Engine
因为 Session 构造函数增加了对 dialect 的依赖,Engine 需要作一些细微的调整。
[day2-reflect-schema/geeorm.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day2-reflect-schema)
```go
type Engine struct {
db *sql.DB
dialect dialect.Dialect
}
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
// make sure the specific dialect exists
dial, ok := dialect.GetDialect(driver)
if !ok {
log.Errorf("dialect %s Not Found", driver)
return
}
e = &Engine{db: db, dialect: dial}
log.Info("Connect database success")
return
}
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db, engine.dialect)
}
```
- `NewEngine` 创建 Engine 实例时,获取 driver 对应的 dialect。
- `NewSession` 创建 Session 实例时,传递 dialect 给构造函数 New。
至此,第二天的内容已经完成了,总结一下今天的成果:
- 1)为适配不同的数据库,映射数据类型和特定的 SQL 语句,创建 Dialect 层屏蔽数据库差异。
- 2)设计 Schema,利用反射(reflect)完成结构体和数据库表结构的映射,包括表名、字段名、字段类型、字段 tag 等。
- 3)构造创建(create)、删除(drop)、存在性(table exists) 的 SQL 语句完成数据库表的基本操作。
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [Go Reflect 提高反射性能](https://geektutu.com/post/hpg-reflect.html)
- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html)
================================================
FILE: gee-orm/doc/geeorm-day3.md
================================================
---
title: 动手写ORM框架 - GeeORM第三天 记录新增和查询
date: 2020-03-08 01:00:00
description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。实现新增(insert)记录的功能;使用反射(reflect)将数据库的记录转换为对应的结构体实例,实现查询(select)功能。
tags:
- Go
nav: 从零实现
categories:
- ORM框架 - GeeORM
keywords:
- Go语言
- 从零实现ORM框架
- database/sql
- sqlite
- insert into
- select from
image: post/geeorm/geeorm_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day3 记录新增和查询
---
本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第三篇。
- 实现新增(insert)记录的功能。
- 使用反射(reflect)将数据库的记录转换为对应的结构体实例,实现查询(select)功能。**代码约150行**
## 1 Clause 构造 SQL 语句
从第三天开始,GeeORM 需要涉及一些较为复杂的操作,例如查询操作。查询语句一般由很多个子句(clause) 构成。SELECT 语句的构成通常是这样的:
```sql
SELECT col1, col2, ...
FROM table_name
WHERE [ conditions ]
GROUP BY col1
HAVING [ conditions ]
```
也就是说,如果想一次构造出完整的 SQL 语句是比较困难的,因此我们将构造 SQL 语句这一部分独立出来,放在子package clause 中实现。
首先在 `clause/generator.go` 中实现各个子句的生成规则。
[day3-save-query/clause/generator.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/clause)
```go
package clause
import (
"fmt"
"strings"
)
type generator func(values ...interface{}) (string, []interface{})
var generators map[Type]generator
func init() {
generators = make(map[Type]generator)
generators[INSERT] = _insert
generators[VALUES] = _values
generators[SELECT] = _select
generators[LIMIT] = _limit
generators[WHERE] = _where
generators[ORDERBY] = _orderBy
}
func genBindVars(num int) string {
var vars []string
for i := 0; i < num; i++ {
vars = append(vars, "?")
}
return strings.Join(vars, ", ")
}
func _insert(values ...interface{}) (string, []interface{}) {
// INSERT INTO $tableName ($fields)
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{}
}
func _values(values ...interface{}) (string, []interface{}) {
// VALUES ($v1), ($v2), ...
var bindStr string
var sql strings.Builder
var vars []interface{}
sql.WriteString("VALUES ")
for i, value := range values {
v := value.([]interface{})
if bindStr == "" {
bindStr = genBindVars(len(v))
}
sql.WriteString(fmt.Sprintf("(%v)", bindStr))
if i+1 != len(values) {
sql.WriteString(", ")
}
vars = append(vars, v...)
}
return sql.String(), vars
}
func _select(values ...interface{}) (string, []interface{}) {
// SELECT $fields FROM $tableName
tableName := values[0]
fields := strings.Join(values[1].([]string), ",")
return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{}
}
func _limit(values ...interface{}) (string, []interface{}) {
// LIMIT $num
return "LIMIT ?", values
}
func _where(values ...interface{}) (string, []interface{}) {
// WHERE $desc
desc, vars := values[0], values[1:]
return fmt.Sprintf("WHERE %s", desc), vars
}
func _orderBy(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{}
}
```
然后在 `clause/clause.go` 中实现结构体 `Clause` 拼接各个独立的子句。
[day3-save-query/clause/clause.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/clause)
```go
package clause
import "strings"
type Clause struct {
sql map[Type]string
sqlVars map[Type][]interface{}
}
type Type int
const (
INSERT Type = iota
VALUES
SELECT
LIMIT
WHERE
ORDERBY
)
func (c *Clause) Set(name Type, vars ...interface{}) {
if c.sql == nil {
c.sql = make(map[Type]string)
c.sqlVars = make(map[Type][]interface{})
}
sql, vars := generators[name](vars...)
c.sql[name] = sql
c.sqlVars[name] = vars
}
func (c *Clause) Build(orders ...Type) (string, []interface{}) {
var sqls []string
var vars []interface{}
for _, order := range orders {
if sql, ok := c.sql[order]; ok {
sqls = append(sqls, sql)
vars = append(vars, c.sqlVars[order]...)
}
}
return strings.Join(sqls, " "), vars
}
```
- `Set` 方法根据 `Type` 调用对应的 generator,生成该子句对应的 SQL 语句。
- `Build` 方法根据传入的 `Type` 的顺序,构造出最终的 SQL 语句。
在 `clause_test.go` 实现对应的测试用例:
```go
func testSelect(t *testing.T) {
var clause Clause
clause.Set(LIMIT, 3)
clause.Set(SELECT, "User", []string{"*"})
clause.Set(WHERE, "Name = ?", "Tom")
clause.Set(ORDERBY, "Age ASC")
sql, vars := clause.Build(SELECT, WHERE, ORDERBY, LIMIT)
t.Log(sql, vars)
if sql != "SELECT * FROM User WHERE Name = ? ORDER BY Age ASC LIMIT ?" {
t.Fatal("failed to build SQL")
}
if !reflect.DeepEqual(vars, []interface{}{"Tom", 3}) {
t.Fatal("failed to build SQLVars")
}
}
func TestClause_Build(t *testing.T) {
t.Run("select", func(t *testing.T) {
testSelect(t)
})
}
```
## 2 实现 Insert 功能
首先为 Session 添加成员变量 clause
```go
// session/raw.go
type Session struct {
db *sql.DB
dialect dialect.Dialect
refTable *schema.Schema
clause clause.Clause
sql strings.Builder
sqlVars []interface{}
}
func (s *Session) Clear() {
s.sql.Reset()
s.sqlVars = nil
s.clause = clause.Clause{}
}
```
clause 已经支持生成简单的插入(INSERT) 和 查询(SELECT) 的 SQL 语句,那么紧接着我们就可以在 session 中实现对应的功能了。
INSERT 对应的 SQL 语句一般是这样的:
```sql
INSERT INTO table_name(col1, col2, col3, ...) VALUES
(A1, A2, A3, ...),
(B1, B2, B3, ...),
...
```
在 ORM 框架中期望 Insert 的调用方式如下:
```go
s := geeorm.NewEngine("sqlite3", "gee.db").NewSession()
u1 := &User{Name: "Tom", Age: 18}
u2 := &User{Name: "Sam", Age: 25}
s.Insert(u1, u2, ...)
```
也就是说,我们还需要一个步骤,根据数据库中列的顺序,从对象中找到对应的值,按顺序平铺。即 `u1`、`u2` 转换为 `("Tom", 18), ("Same", 25)` 这样的格式。
因此在实现 Insert 功能之前,还需要给 `Schema` 新增一个函数 `RecordValues` 完成上述的转换。
[day3-save-query/schema/schema.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/schema)
```go
func (schema *Schema) RecordValues(dest interface{}) []interface{} {
destValue := reflect.Indirect(reflect.ValueOf(dest))
var fieldValues []interface{}
for _, field := range schema.Fields {
fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface())
}
return fieldValues
}
```
在 session 文件夹下新建 record.go,用于实现记录增删查改相关的代码。
[day3-save-query/session/record.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/session)
```go
package session
import (
"geeorm/clause"
"reflect"
)
func (s *Session) Insert(values ...interface{}) (int64, error) {
recordValues := make([]interface{}, 0)
for _, value := range values {
table := s.Model(value).RefTable()
s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
recordValues = append(recordValues, table.RecordValues(value))
}
s.clause.Set(clause.VALUES, recordValues...)
sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return result.RowsAffected()
}
```
后续所有构造 SQL 语句的方式都将与 `Insert` 中构造 SQL 语句的方式一致。分两步:
- 1)多次调用 `clause.Set()` 构造好每一个子句。
- 2)调用一次 `clause.Build()` 按照传入的顺序构造出最终的 SQL 语句。
构造完成后,调用 `Raw().Exec()` 方法执行。
## 3 实现 Find 功能
期望的调用方式是这样的:传入一个切片指针,查询的结果保存在切片中。
```go
s := geeorm.NewEngine("sqlite3", "gee.db").NewSession()
var users []User
s.Find(&users);
```
Find 功能的难点和 Insert 恰好反了过来。Insert 需要将已经存在的对象的每一个字段的值平铺开来,而 Find 则是需要根据平铺开的字段的值构造出对象。同样,也需要用到反射(reflect)。
```go
func (s *Session) Find(values interface{}) error {
destSlice := reflect.Indirect(reflect.ValueOf(values))
destType := destSlice.Type().Elem()
table := s.Model(reflect.New(destType).Elem().Interface()).RefTable()
s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
rows, err := s.Raw(sql, vars...).QueryRows()
if err != nil {
return err
}
for rows.Next() {
dest := reflect.New(destType).Elem()
var values []interface{}
for _, name := range table.FieldNames {
values = append(values, dest.FieldByName(name).Addr().Interface())
}
if err := rows.Scan(values...); err != nil {
return err
}
destSlice.Set(reflect.Append(destSlice, dest))
}
return rows.Close()
}
```
Find 的代码实现比较复杂,主要分为以下几步:
- 1) `destSlice.Type().Elem()` 获取切片的单个元素的类型 `destType`,使用 `reflect.New()` 方法创建一个 `destType` 的实例,作为 `Model()` 的入参,映射出表结构 `RefTable()`。
- 2)根据表结构,使用 clause 构造出 SELECT 语句,查询到所有符合条件的记录 `rows`。
- 3)遍历每一行记录,利用反射创建 `destType` 的实例 `dest`,将 `dest` 的所有字段平铺开,构造切片 `values`。
- 4)调用 `rows.Scan()` 将该行记录每一列的值依次赋值给 values 中的每一个字段。
- 5)将 `dest` 添加到切片 `destSlice` 中。循环直到所有的记录都添加到切片 `destSlice` 中。
## 4 测试
在 session 文件夹下新建 `record_test.go`,创建测试用例。
> `User` 和 `NewSession()` 的定义位于 raw_test.go 中。
[day3-save-query/session/record_test.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day3-save-query/session)
```go
package session
import "testing"
var (
user1 = &User{"Tom", 18}
user2 = &User{"Sam", 25}
user3 = &User{"Jack", 25}
)
func testRecordInit(t *testing.T) *Session {
t.Helper()
s := NewSession().Model(&User{})
err1 := s.DropTable()
err2 := s.CreateTable()
_, err3 := s.Insert(user1, user2)
if err1 != nil || err2 != nil || err3 != nil {
t.Fatal("failed init test records")
}
return s
}
func TestSession_Insert(t *testing.T) {
s := testRecordInit(t)
affected, err := s.Insert(user3)
if err != nil || affected != 1 {
t.Fatal("failed to create record")
}
}
func TestSession_Find(t *testing.T) {
s := testRecordInit(t)
var users []User
if err := s.Find(&users); err != nil || len(users) != 2 {
t.Fatal("failed to query all")
}
}
```
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html)
- [Laws Of Reflection - golang.org](https://blog.golang.org/laws-of-reflection)
================================================
FILE: gee-orm/doc/geeorm-day4.md
================================================
---
title: 动手写ORM框架 - GeeORM第四天 链式操作与更新删除
date: 2020-03-08 16:00:00
description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。通过链式(chain)操作,支持查询条件(where, order by, limit 等)的叠加;实现记录的更新(update)、删除(delete)和统计(count)功能。
tags:
- Go
nav: 从零实现
categories:
- ORM框架 - GeeORM
keywords:
- Go语言
- 从零实现ORM框架
- database/sql
- sqlite
- chain operation
- delete from
image: post/geeorm/geeorm_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day4 链式操作与更新删除
---
本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第四篇。
- 通过链式(chain)操作,支持查询条件(where, order by, limit 等)的叠加。
- 实现记录的更新(update)、删除(delete)和统计(count)功能。**代码约100行**
## 1 支持 Update、Delete 和 Count
### 1.1 子句生成器
clause 负责构造 SQL 语句,如果需要增加对更新(update)、删除(delete)和统计(count)功能的支持,第一步自然是在 clause 中实现 update、delete 和 count 子句的生成器。
第一步:在原来的基础上,新增 UPDATE、DELETE、COUNT 三个 `Type` 类型的枚举值。
[day4-chain-operation/clause/clause.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day4-chain-operation/clause)
```go
// Support types for Clause
const (
INSERT Type = iota
VALUES
SELECT
LIMIT
WHERE
ORDERBY
UPDATE
DELETE
COUNT
)
```
第二步:实现对应字句的 generator,并注册到全局变量 `generators` 中
[day4-chain-operation/clause/generator.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day4-chain-operation/clause)
```go
func init() {
generators = make(map[Type]generator)
generators[INSERT] = _insert
generators[VALUES] = _values
generators[SELECT] = _select
generators[LIMIT] = _limit
generators[WHERE] = _where
generators[ORDERBY] = _orderBy
generators[UPDATE] = _update
generators[DELETE] = _delete
generators[COUNT] = _count
}
func _update(values ...interface{}) (string, []interface{}) {
tableName := values[0]
m := values[1].(map[string]interface{})
var keys []string
var vars []interface{}
for k, v := range m {
keys = append(keys, k+" = ?")
vars = append(vars, v)
}
return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars
}
func _delete(values ...interface{}) (string, []interface{}) {
return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{}
}
func _count(values ...interface{}) (string, []interface{}) {
return _select(values[0], []string{"count(*)"})
}
```
- `_update` 设计入参是2个,第一个参数是表名(table),第二个参数是 map 类型,表示待更新的键值对。
- `_delete` 只有一个入参,即表名。
- `_count` 只有一个入参,即表名,并复用了 `_select` 生成器。
### 1.2 Update 方法
子句的 generator 已经准备好了,接下来和 Insert、Find 等方法一样,在 `session/record.go` 中按照一定顺序拼接 SQL 语句并调用就可以了。
[day4-chain-operation/session/record.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day4-chain-operation/session)
```go
// support map[string]interface{}
// also support kv list: "Name", "Tom", "Age", 18, ....
func (s *Session) Update(kv ...interface{}) (int64, error) {
m, ok := kv[0].(map[string]interface{})
if !ok {
m = make(map[string]interface{})
for i := 0; i < len(kv); i += 2 {
m[kv[i].(string)] = kv[i+1]
}
}
s.clause.Set(clause.UPDATE, s.RefTable().Name, m)
sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return result.RowsAffected()
}
```
Update 方法比较特别的一点在于,Update 接受 2 种入参,平铺开来的键值对和 map 类型的键值对。因为 generator 接受的参数是 map 类型的键值对,因此 `Update` 方法会动态地判断传入参数的类型,如果是不是 map 类型,则会自动转换。
### 1.3 Delete 方法
```go
// Delete records with where clause
func (s *Session) Delete() (int64, error) {
s.clause.Set(clause.DELETE, s.RefTable().Name)
sql, vars := s.clause.Build(clause.DELETE, clause.WHERE)
result, err := s.Raw(sql, vars...).Exec()
if err != nil {
return 0, err
}
return result.RowsAffected()
}
```
### 1.4 Count 方法
```go
// Count records with where clause
func (s *Session) Count() (int64, error) {
s.clause.Set(clause.COUNT, s.RefTable().Name)
sql, vars := s.clause.Build(clause.COUNT, clause.WHERE)
row := s.Raw(sql, vars...).QueryRow()
var tmp int64
if err := row.Scan(&tmp); err != nil {
return 0, err
}
return tmp, nil
}
```
## 2 链式调用(chain)
链式调用是一种简化代码的编程方式,能够使代码更简洁、易读。链式调用的原理也非常简单,某个对象调用某个方法后,将该对象的引用/指针返回,即可以继续调用该对象的其他方法。通常来说,当某个对象需要一次调用多个方法来设置其属性时,就非常适合改造为链式调用了。
SQL 语句的构造过程就非常符合这个条件。SQL 语句由多个子句构成,典型的例如 SELECT 语句,往往需要设置查询条件(WHERE)、限制返回行数(LIMIT)等。理想的调用方式应该是这样的:
```go
s := geeorm.NewEngine("sqlite3", "gee.db").NewSession()
var users []User
s.Where("Age > 18").Limit(3).Find(&users)
```
从上面的示例中,可以看出,`WHERE`、`LIMIT`、`ORDER BY` 等查询条件语句非常适合链式调用。这几个子句的 generator 在之前就已经实现了,那我们接下来在 `session/record.go` 中添加对应的方法即可。
[day4-chain-operation/session/record.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day4-chain-operation/session)
```go
// Limit adds limit condition to clause
func (s *Session) Limit(num int) *Session {
s.clause.Set(clause.LIMIT, num)
return s
}
// Where adds limit condition to clause
func (s *Session) Where(desc string, args ...interface{}) *Session {
var vars []interface{}
s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...)
return s
}
// OrderBy adds order by condition to clause
func (s *Session) OrderBy(desc string) *Session {
s.clause.Set(clause.ORDERBY, desc)
return s
}
```
## 3 First 只返回一条记录
很多时候,我们期望 SQL 语句只返回一条记录,比如根据某个童鞋的学号查询他的信息,返回结果有且只有一条。结合链式调用,我们可以非常容易地实现 First 方法。
```go
func (s *Session) First(value interface{}) error {
dest := reflect.Indirect(reflect.ValueOf(value))
destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem()
if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil {
return err
}
if destSlice.Len() == 0 {
return errors.New("NOT FOUND")
}
dest.Set(destSlice.Index(0))
return nil
}
```
First 方法可以这么使用:
```go
u := &User{}
_ = s.OrderBy("Age DESC").First(u)
```
> 实现原理:根据传入的类型,利用反射构造切片,调用 `Limit(1)` 限制返回的行数,调用 `Find` 方法获取到查询结果。
## 4 测试
接下来呢,我们在 `record_test.go` 中添加几个测试用例,检测功能是否正常。
```go
package session
import "testing"
var (
user1 = &User{"Tom", 18}
user2 = &User{"Sam", 25}
user3 = &User{"Jack", 25}
)
func testRecordInit(t *testing.T) *Session {
t.Helper()
s := NewSession().Model(&User{})
err1 := s.DropTable()
err2 := s.CreateTable()
_, err3 := s.Insert(user1, user2)
if err1 != nil || err2 != nil || err3 != nil {
t.Fatal("failed init test records")
}
return s
}
func TestSession_Limit(t *testing.T) {
s := testRecordInit(t)
var users []User
err := s.Limit(1).Find(&users)
if err != nil || len(users) != 1 {
t.Fatal("failed to query with limit condition")
}
}
func TestSession_Update(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Update("Age", 30)
u := &User{}
_ = s.OrderBy("Age DESC").First(u)
if affected != 1 || u.Age != 30 {
t.Fatal("failed to update")
}
}
func TestSession_DeleteAndCount(t *testing.T) {
s := testRecordInit(t)
affected, _ := s.Where("Name = ?", "Tom").Delete()
count, _ := s.Count()
if affected != 1 || count != 1 {
t.Fatal("failed to delete or count")
}
}
```
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html)
================================================
FILE: gee-orm/doc/geeorm-day5.md
================================================
---
title: 动手写ORM框架 - GeeORM第五天 实现钩子(Hooks)
date: 2020-03-08 18:00:00
description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。通过反射(reflect)获取结构体绑定的钩子(hooks),并调用;支持增删查改(CRUD)前后调用钩子。
tags:
- Go
nav: 从零实现
categories:
- ORM框架 - GeeORM
keywords:
- Go语言
- 从零实现ORM框架
- database/sql
- sqlite
- hooks
- BeforeUpdate
image: post/geeorm/geeorm_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day5 实现钩子
---
本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第五篇。
- 通过反射(reflect)获取结构体绑定的钩子(hooks),并调用。
- 支持增删查改(CRUD)前后调用钩子。**代码约50行**
## 1 Hook 机制
Hook,翻译为钩子,其主要思想是提前在可能增加功能的地方埋好(预设)一个钩子,当我们需要重新修改或者增加这个地方的逻辑的时候,把扩展的类或者方法挂载到这个点即可。钩子的应用非常广泛,例如 Github 支持的 travis 持续集成服务,当有 `git push` 事件发生时,会触发 travis 拉取新的代码进行构建。IDE 中钩子也非常常见,比如,当按下 `Ctrl + s` 后,自动格式化代码。再比如前端常用的 `hot reload` 机制,前端代码发生变更时,自动编译打包,通知浏览器自动刷新页面,实现所写即所得。
钩子机制设计的好坏,取决于扩展点选择的是否合适。例如对于持续集成来说,代码如果不发生变更,反复构建是没有意义的,因此钩子应设计在代码可能发生变更的地方,比如 MR、PR 合并前后。
那对于 ORM 框架来说,合适的扩展点在哪里呢?很显然,记录的增删查改前后都是非常合适的。
比如,我们设计一个 `Account` 类,`Account` 包含有一个隐私字段 `Password`,那么每次查询后都需要做脱敏处理,才能继续使用。如果提供了 `AfterQuery` 的钩子,查询后,自动地将 `Password` 字段的值脱敏,是不是能省去很多冗余的代码呢?
## 2 实现钩子
GeeORM 的钩子与结构体绑定,即每个结构体需要实现各自的钩子。hook 相关的代码实现在 `session/hooks.go` 中。
[day5-hooks/session/hooks.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day5-hooks/session)
```go
package session
import (
"geeorm/log"
"reflect"
)
// Hooks constants
const (
BeforeQuery = "BeforeQuery"
AfterQuery = "AfterQuery"
BeforeUpdate = "BeforeUpdate"
AfterUpdate = "AfterUpdate"
BeforeDelete = "BeforeDelete"
AfterDelete = "AfterDelete"
BeforeInsert = "BeforeInsert"
AfterInsert = "AfterInsert"
)
// CallMethod calls the registered hooks
func (s *Session) CallMethod(method string, value interface{}) {
fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method)
if value != nil {
fm = reflect.ValueOf(value).MethodByName(method)
}
param := []reflect.Value{reflect.ValueOf(s)}
if fm.IsValid() {
if v := fm.Call(param); len(v) > 0 {
if err, ok := v[0].Interface().(error); ok {
log.Error(err)
}
}
}
return
}
```
- 钩子机制同样是通过反射来实现的,`s.RefTable().Model` 或 `value` 即当前会话正在操作的对象,使用 `MethodByName` 方法反射得到该对象的方法。
- 将 `s *Session` 作为入参调用。每一个钩子的入参类型均是 `*Session`。
接下来,将 `CallMethod()` 方法在 Find、Insert、Update、Delete 方法内部调用即可。例如,`Find` 方法修改为:
```go
// Find gets all eligible records
func (s *Session) Find(values interface{}) error {
s.CallMethod(BeforeQuery, nil)
// ...
for rows.Next() {
dest := reflect.New(destType).Elem()
// ...
s.CallMethod(AfterQuery, dest.Addr().Interface())
// ...
}
return rows.Close()
}
```
- `AfterQuery` 钩子可以操作每一行记录。
## 3 测试
新建 `session/hooks.go` 文件添加对应的测试用例。
```go
package session
import (
"geeorm/log"
"testing"
)
type Account struct {
ID int `geeorm:"PRIMARY KEY"`
Password string
}
func (account *Account) BeforeInsert(s *Session) error {
log.Info("before inert", account)
account.ID += 1000
return nil
}
func (account *Account) AfterQuery(s *Session) error {
log.Info("after query", account)
account.Password = "******"
return nil
}
func TestSession_CallMethod(t *testing.T) {
s := NewSession().Model(&Account{})
_ = s.DropTable()
_ = s.CreateTable()
_, _ = s.Insert(&Account{1, "123456"}, &Account{2, "qwerty"})
u := &Account{}
err := s.First(u)
if err != nil || u.ID != 1001 || u.Password != "******" {
t.Fatal("Failed to call hooks after query, got", u)
}
}
```
在这个测试用例中,测试了 `BeforeInsert` 和 `AfterQuery` 2 个钩子。
- `BeforeInsert` 将 account.ID 的值增加 1000
- `AfterQuery` 将密码脱敏,显示为 6 个 `*`。
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html)
================================================
FILE: gee-orm/doc/geeorm-day6.md
================================================
---
title: 动手写ORM框架 - GeeORM第六天 支持事务(Transaction)
date: 2020-03-08 21:00:00
description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。介绍数据库中的事务(transaction);封装事务,用户自定义回调函数实现原子操作。
tags:
- Go
nav: 从零实现
categories:
- ORM框架 - GeeORM
keywords:
- Go语言
- 从零实现ORM框架
- database/sql
- sqlite
- transaction
image: post/geeorm/geeorm_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day6 支持事务
---
本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第六篇。
- 介绍数据库中的事务(transaction)。
- 封装事务,用户自定义回调函数实现原子操作。**代码约100行**
## 1 事务的 ACID 属性
> 数据库事务(transaction)是访问并可能操作各种数据项的一个数据库操作序列,这些操作要么全部执行,要么全部不执行,是一个不可分割的工作单位。事务由事务开始与事务结束之间执行的全部数据库操作组成。
举一个简单的例子,转账。A 转账给 B 一万元,那么数据库至少需要执行 2 个操作:
- 1)A 的账户减掉一万元。
- 2)B 的账户增加一万元。
这两个操作要么全部执行,代表转账成功。任意一个操作失败了,之前的操作都必须回退,代表转账失败。一个操作完成,另一个操作失败,这种结果是不能够接受的。这种场景就非常适合利用数据库事务的特性来解决。
如果一个数据库支持事务,那么必须具备 ACID 四个属性。
- 1)原子性(Atomicity):事务中的全部操作在数据库中是不可分割的,要么全部完成,要么全部不执行。
- 2)一致性(Consistency): 几个并行执行的事务,其执行结果必须与按某一顺序 串行执行的结果相一致。
- 3)隔离性(Isolation):事务的执行不受其他事务的干扰,事务执行的中间结果对其他事务必须是透明的。
- 4)持久性(Durability):对于任意已提交事务,系统必须保证该事务对数据库的改变不被丢失,即使数据库出现故障。
## 2 SQLite 和 Go 标准库中的事务
SQLite 中创建一个事务的原生 SQL 长什么样子呢?
```sql
sqlite> BEGIN;
sqlite> DELETE FROM User WHERE Age > 25;
sqlite> INSERT INTO User VALUES ("Tom", 25), ("Jack", 18);
sqlite> COMMIT;
```
`BEGIN` 开启事务,`COMMIT` 提交事务,`ROLLBACK` 回滚事务。任何一个事务,均以 `BEGIN` 开始,`COMMIT` 或 `ROLLBACK` 结束。
Go 语言标准库 database/sql 提供了支持事务的接口。用一个简单的例子,看一看 Go 语言标准是如何支持事务的。
```go
package main
import (
"database/sql"
_ "github.com/mattn/go-sqlite3"
"log"
)
func main() {
db, _ := sql.Open("sqlite3", "gee.db")
defer func() { _ = db.Close() }()
_, _ = db.Exec("CREATE TABLE IF NOT EXISTS User(`Name` text);")
tx, _ := db.Begin()
_, err1 := tx.Exec("INSERT INTO User(`Name`) VALUES (?)", "Tom")
_, err2 := tx.Exec("INSERT INTO User(`Name`) VALUES (?)", "Jack")
if err1 != nil || err2 != nil {
_ = tx.Rollback()
log.Println("Rollback", err1, err2)
} else {
_ = tx.Commit()
log.Println("Commit")
}
}
```
Go 语言中实现事务和 SQL 原生语句其实是非常接近的。调用 `db.Begin()` 得到 `*sql.Tx` 对象,使用 `tx.Exec()` 执行一系列操作,如果发生错误,通过 `tx.Rollback()` 回滚,如果没有发生错误,则通过 `tx.Commit()` 提交。
## 3 GeeORM 支持事务
GeeORM 之前的操作均是执行完即自动提交的,每个操作是相互独立的。之前直接使用 `sql.DB` 对象执行 SQL 语句,如果要支持事务,需要更改为 `sql.Tx` 执行。在 Session 结构体中新增成员变量 `tx *sql.Tx`,当 `tx` 不为空时,则使用 `tx` 执行 SQL 语句,否则使用 `db` 执行 SQL 语句。这样既兼容了原有的执行方式,又提供了对事务的支持。
[day6-transaction/session/raw.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day6-transaction/session)
```go
type Session struct {
db *sql.DB
dialect dialect.Dialect
tx *sql.Tx
refTable *schema.Schema
clause clause.Clause
sql strings.Builder
sqlVars []interface{}
}
// CommonDB is a minimal function set of db
type CommonDB interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
Exec(query string, args ...interface{}) (sql.Result, error)
}
var _ CommonDB = (*sql.DB)(nil)
var _ CommonDB = (*sql.Tx)(nil)
// DB returns tx if a tx begins. otherwise return *sql.DB
func (s *Session) DB() CommonDB {
if s.tx != nil {
return s.tx
}
return s.db
}
```
新建文件 `session/transaction.go` 封装事务的 Begin、Commit 和 Rollback 三个接口。
[day6-transaction/session/transaction.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day6-transaction/session)
```go
package session
import "geeorm/log"
func (s *Session) Begin() (err error) {
log.Info("transaction begin")
if s.tx, err = s.db.Begin(); err != nil {
log.Error(err)
return
}
return
}
func (s *Session) Commit() (err error) {
log.Info("transaction commit")
if err = s.tx.Commit(); err != nil {
log.Error(err)
}
return
}
func (s *Session) Rollback() (err error) {
log.Info("transaction rollback")
if err = s.tx.Rollback(); err != nil {
log.Error(err)
}
return
}
```
- 调用 `s.db.Begin()` 得到 `*sql.Tx` 对象,赋值给 s.tx。
- 封装的另一个目的是统一打印日志,方便定位问题。
最后一步,在 `geeorm.go` 中为用户提供傻瓜式/一键式使用的接口。
[day6-transaction/geeorm.go](https://github.com/geektutu/7days-golang/tree/master/gee-orm/day6-transaction)
```go
type TxFunc func(*session.Session) (interface{}, error)
func (engine *Engine) Transaction(f TxFunc) (result interface{}, err error) {
s := engine.NewSession()
if err := s.Begin(); err != nil {
return nil, err
}
defer func() {
if p := recover(); p != nil {
_ = s.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
_ = s.Rollback() // err is non-nil; don't change it
} else {
err = s.Commit() // err is nil; if Commit returns error update err
}
}()
return f(s)
}
```
> Transaction 的实现参考了 [stackoverflow](https://stackoverflow.com/questions/16184238/database-sql-tx-detecting-commit-or-rollback)
用户只需要将所有的操作放到一个回调函数中,作为入参传递给 `engine.Transaction()`,发生任何错误,自动回滚,如果没有错误发生,则提交。
## 4 测试
在 `geeorm_test.go` 中添加测试用例看看 Transaction 如何工作的吧。
```go
func OpenDB(t *testing.T) *Engine {
t.Helper()
engine, err := NewEngine("sqlite3", "gee.db")
if err != nil {
t.Fatal("failed to connect", err)
}
return engine
}
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestEngine_Transaction(t *testing.T) {
t.Run("rollback", func(t *testing.T) {
transactionRollback(t)
})
t.Run("commit", func(t *testing.T) {
transactionCommit(t)
})
}
```
首先是 rollback 的用例:
```go
func transactionRollback(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
s := engine.NewSession()
_ = s.Model(&User{}).DropTable()
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
_ = s.Model(&User{}).CreateTable()
_, err = s.Insert(&User{"Tom", 18})
return nil, errors.New("Error")
})
if err == nil || s.HasTable() {
t.Fatal("failed to rollback")
}
}
```
- 在这个用例中,如何执行成功,则会创建一张表 `User`,并插入一条记录。
- 故意返回了一个自定义 error,最终事务回滚,表创建失败。
接下来是 commit 的用例:
```go
func transactionCommit(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
s := engine.NewSession()
_ = s.Model(&User{}).DropTable()
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
_ = s.Model(&User{}).CreateTable()
_, err = s.Insert(&User{"Tom", 18})
return
})
u := &User{}
_ = s.First(u)
if err != nil || u.Name != "Tom" {
t.Fatal("failed to commit")
}
}
```
- 创建表和插入记录均成功执行,最终通过 `s.First()` 方法查询到插入的记录。
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html)
================================================
FILE: gee-orm/doc/geeorm-day7.md
================================================
---
title: 动手写ORM框架 - GeeORM第七天 数据库迁移(Migrate)
date: 2020-03-08 23:00:00
description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。结构体(struct)变更时,数据库表的字段(field)自动迁移(migrate);仅支持字段新增与删除,不支持字段类型变更。
tags:
- Go
nav: 从零实现
categories:
- ORM框架 - GeeORM
keywords:
- Go语言
- 从零实现ORM框架
- database/sql
- sqlite
- migrate
image: post/geeorm/geeorm_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day7 数据库迁移
---
本文是[7天用Go从零实现ORM框架GeeORM](https://geektutu.com/post/geeorm.html)的第七篇。
- 结构体(struct)变更时,数据库表的字段(field)自动迁移(migrate)。
- 仅支持字段新增与删除,不支持字段类型变更。**代码约70行**
## 1 使用 SQL 语句 Migrate
数据库 Migrate 一直是数据库运维人员最为头痛的问题,如果仅仅是一张表增删字段还比较容易,那如果涉及到外键等复杂的关联关系,数据库的迁移就会变得非常困难。
GeeORM 的 Migrate 操作仅针对最为简单的场景,即支持字段的新增与删除,不支持字段类型变更。
在实现 Migrate 之前,我们先看看如何使用原生的 SQL 语句增删字段。
### 1.1 新增字段
```sql
ALTER TABLE table_name ADD COLUMN col_name, col_type;
```
大部分数据支持使用 `ALTER` 关键字新增字段,或者重命名字段。
### 1.2 删除字段
> 参考 [sqlite delete or add column - stackoverflow](https://stackoverflow.com/questions/8442147/how-to-delete-or-add-column-in-sqlite)
对于 SQLite 来说,删除字段并不像新增字段那么容易,一个比较可行的方法需要执行下列几个步骤:
```sql
CREATE TABLE new_table AS SELECT col1, col2, ... from old_table
DROP TABLE old_table
ALTER TABLE new_table RENAME TO old_table;
```
- 第一步:从 `old_table` 中挑选需要保留的字段到 `new_table` 中。
- 第二步:删除 `old_table`。
- 第三步:重命名 `new_table` 为 `old_table`。
## 2 GeeORM 实现 Migrate
按照原生的 SQL 命令,利用之前实现的事务,在 `geeorm.go` 中实现 Migrate 方法。
```go
// difference returns a - b
func difference(a []string, b []string) (diff []string) {
mapB := make(map[string]bool)
for _, v := range b {
mapB[v] = true
}
for _, v := range a {
if _, ok := mapB[v]; !ok {
diff = append(diff, v)
}
}
return
}
// Migrate table
func (engine *Engine) Migrate(value interface{}) error {
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
if !s.Model(value).HasTable() {
log.Infof("table %s doesn't exist", s.RefTable().Name)
return nil, s.CreateTable()
}
table := s.RefTable()
rows, _ := s.Raw(fmt.Sprintf("SELECT * FROM %s LIMIT 1", table.Name)).QueryRows()
columns, _ := rows.Columns()
addCols := difference(table.FieldNames, columns)
delCols := difference(columns, table.FieldNames)
log.Infof("added cols %v, deleted cols %v", addCols, delCols)
for _, col := range addCols {
f := table.GetField(col)
sqlStr := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table.Name, f.Name, f.Type)
if _, err = s.Raw(sqlStr).Exec(); err != nil {
return
}
}
if len(delCols) == 0 {
return
}
tmp := "tmp_" + table.Name
fieldStr := strings.Join(table.FieldNames, ", ")
s.Raw(fmt.Sprintf("CREATE TABLE %s AS SELECT %s from %s;", tmp, fieldStr, table.Name))
s.Raw(fmt.Sprintf("DROP TABLE %s;", table.Name))
s.Raw(fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", tmp, table.Name))
_, err = s.Exec()
return
})
return err
}
```
- `difference` 用来计算前后两个字段切片的差集。新表 - 旧表 = 新增字段,旧表 - 新表 = 删除字段。
- 使用 `ALTER` 语句新增字段。
- 使用创建新表并重命名的方式删除字段。
## 3 测试
在 `geeorm_test.go` 中添加 Migrate 的测试用例:
```go
type User struct {
Name string `geeorm:"PRIMARY KEY"`
Age int
}
func TestEngine_Migrate(t *testing.T) {
engine := OpenDB(t)
defer engine.Close()
s := engine.NewSession()
_, _ = s.Raw("DROP TABLE IF EXISTS User;").Exec()
_, _ = s.Raw("CREATE TABLE User(Name text PRIMARY KEY, XXX integer);").Exec()
_, _ = s.Raw("INSERT INTO User(`Name`) values (?), (?)", "Tom", "Sam").Exec()
engine.Migrate(&User{})
rows, _ := s.Raw("SELECT * FROM User").QueryRows()
columns, _ := rows.Columns()
if !reflect.DeepEqual(columns, []string{"Name", "Age"}) {
t.Fatal("Failed to migrate table User, got columns", columns)
}
}
```
- 首先假设原有的 `User` 包含两个字段 `Name` 和 `XXX`,在一次业务变更之后,`User` 结构体的字段变更为 `Name` 和 `Age`。
- 即需要删除原有字段 `XXX`,并新增字段 `Age`。
- 调用 `Migrate(&User{})` 之后,新表的结构为 `Name`,`Age`
## 4 总结
GeeORM 的整体实现比较粗糙,比如数据库的迁移仅仅考虑了最简单的场景。实现的特性也比较少,比如结构体嵌套的场景,外键的场景,复合主键的场景都没有覆盖。ORM 框架的代码规模一般都比较大,如果想尽可能地逼近数据库,就需要大量的代码来实现相关的特性;二是数据库之间的差异也是比较大的,实现的功能越多,数据库之间的差异就会越突出,有时候为了达到较好的性能,就不得不为每个数据做特殊处理;还有些 ORM 框架同时支持关系型数据库和非关系型数据库,这就要求框架本身有更高层次的抽象,不能局限在 SQL 这一层。
GeeORM 仅 800 左右的代码是不可能做到这一点的。不过,GeeORM 的目的并不是实现一个可以在生产使用的 ORM 框架,而是希望尽可能多地介绍 ORM 框架大致的实现原理,例如
- 在框架中如何屏蔽不同数据库之间的差异;
- 数据库中表结构和编程语言中的对象是如何映射的;
- 如何优雅地模拟查询条件,链式调用是个不错的选择;
- 为什么 ORM 框架通常会提供 hooks 扩展的能力;
- 事务的原理和 ORM 框架如何集成对事务的支持;
- 一些难点问题,例如数据库迁移。
- ...
基于这几点,我觉得 GeeORM 的目的达到了。
## 附 推荐阅读
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html)
- [sqlite delete or add column - stackoverflow](https://stackoverflow.com/questions/8442147/how-to-delete-or-add-column-in-sqlite)
================================================
FILE: gee-orm/doc/geeorm.md
================================================
---
title: 7天用Go从零实现ORM框架GeeORM
date: 2020-03-01 01:00:00
description: 7天用 Go语言/golang 从零实现 ORM 框架 GeeORM 教程(7 days implement golang object relational mapping framework from scratch tutorial),动手写 ORM 框架,参照 gorm, xorm 的实现。功能包括对象和表结构的相互映射,表的创建删除(table),记录的增删查改,事务支持(transaction),数据库迁移(migrate),钩子(hooks)等。
tags:
- Go
nav: 从零实现
categories:
- ORM框架 - GeeORM
keywords:
- Go语言
- 从零实现ORM框架
- 动手写ORM框架
- database/sql
- sqlite3
image: post/geeorm/geeorm_sm.jpg
github: https://github.com/geektutu/7days-golang
book: 七天用Go从零实现系列
book_title: Day0 序言
---

## 1 谈谈 ORM 框架
> 对象关系映射(Object Relational Mapping,简称ORM)是通过使用描述对象和数据库之间映射的元数据,将面向对象语言程序中的对象自动持久化到关系数据库中。
那对象和数据库是如何映射的呢?
| 数据库 | 面向对象的编程语言 |
|:---:|:---:|
| 表(table) | 类(class/struct) |
| 记录(record, row) | 对象 (object) |
| 字段(field, column) | 对象属性(attribute) |
举一个具体的例子,来理解 ORM。
```sql
CREATE TABLE `User` (`Name` text, `Age` integer);
INSERT INTO `User` (`Name`, `Age`) VALUES ("Tom", 18);
SELECT * FROM `User`;
```
第一条 SQL 语句,在数据库中创建了表 `User`,并且定义了 2 个字段 `Name` 和 `Age`;第二条 SQL 语句往表中添加了一条记录;最后一条语句返回表中的所有记录。
假如我们使用了 ORM 框架,可以这么写:
```go
type User struct {
Name string
Age int
}
orm.CreateTable(&User{})
orm.Save(&User{"Tom", 18})
var users []User
orm.Find(&users)
```
ORM 框架相当于对象和数据库中间的一个桥梁,借助 ORM 可以避免写繁琐的 SQL 语言,仅仅通过操作具体的对象,就能够完成对关系型数据库的操作。
那如何实现一个 ORM 框架呢?
- `CreateTable` 方法需要从参数 `&User{}` 得到对应的结构体的名称 User 作为表名,成员变量 Name, Age 作为列名,同时还需要知道成员变量对应的类型。
- `Save` 方法则需要知道每个成员变量的值。
- `Find` 方法仅从传入的空切片 `&[]User`,得到对应的结构体名也就是表名 User,并从数据库中取到所有的记录,将其转换成 User 对象,添加到切片中。
如果这些方法只接受 User 类型的参数,那是很容易实现的。但是 ORM 框架是通用的,也就是说可以将任意合法的对象转换成数据库中的表和记录。例如:
```go
type Account struct {
Username string
Password string
}
orm.CreateTable(&Account{})
```
这就面临了一个很重要的问题:如何根据任意类型的指针,得到其对应的结构体的信息。这涉及到了 Go 语言的反射机制(reflect),通过反射,可以获取到对象对应的结构体名称,成员变量、方法等信息,例如:
```go
typ := reflect.Indirect(reflect.ValueOf(&Account{})).Type()
fmt.Println(typ.Name()) // Account
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fmt.Println(field.Name) // Username Password
}
```
- `reflect.ValueOf()` 获取指针对应的反射值。
- `reflect.Indirect()` 获取指针指向的对象的反射值。
- `(reflect.Type).Name()` 返回类名(字符串)。
- `(reflect.Type).Field(i)` 获取第 i 个成员变量。
除了对象和表结构/记录的映射以外,设计 ORM 框架还需要关注什么问题呢?
1)MySQL,PostgreSQL,SQLite 等数据库的 SQL 语句是有区别的,ORM 框架如何在开发者不感知的情况下适配多种数据库?
2)如何对象的字段发生改变,数据库表结构能够自动更新,即是否支持数据库自动迁移(migrate)?
3)数据库支持的功能很多,例如事务(transaction),ORM 框架能实现哪些?
4)...
## 2 关于 GeeORM
数据库的特性非常多,简单的增删查改使用 ORM 替代 SQL 语句是没有问题的,但是也有很多特性难以用 ORM 替代,比如复杂的多表关联查询,ORM 也可能支持,但是基于性能的考虑,开发者自己写 SQL 语句很可能更高效。
因此,设计实现一个 ORM 框架,就需要给功能特性排优先级了。
Go 语言中使用比较广泛 ORM 框架是 [gorm](https://github.com/jinzhu/gorm) 和 [xorm](https://github.com/go-xorm/xorm)。除了基础的功能,比如表的操作,记录的增删查改,gorm 还实现了关联关系(一对一、一对多等),回调插件等;xorm 实现了读写分离(支持配置多个数据库),数据同步,导入导出等。
gorm 正在彻底重构 v1 版本,短期内看不到发布 v2 的可能。相比于 gorm-v1,xorm 在设计上更清晰。GeeORM 的设计主要参考了 xorm,一些细节上的实现参考了 gorm。GeeORM 的目的主要是了解 ORM 框架设计的原理,具体实现上鲁棒性做得不够,一些复杂的特性,例如 gorm 的关联关系,xorm 的读写分离没有实现。目前支持的特性有:
- 表的创建、删除、迁移。
- 记录的增删查改,查询条件的链式操作。
- 单一主键的设置(primary key)。
- 钩子(在创建/更新/删除/查找之前或之后)
- 事务(transaction)。
- ...
`GeeORM` 分7天实现,每天完成的部分都是可以独立运行和测试的,就像搭积木一样,一个个独立的特性组合在一起就是最终的 ORM 框架。每天的代码在 100 行左右,同时配有较为完备的单元测试用例。
## 3 目录
- 第一天:[database/sql 基础](https://geektutu.com/post/geeorm-day1.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day1-database-sql)
- 第二天:[对象表结构映射](https://geektutu.com/post/geeorm-day2.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day2-reflect-schema)
- 第三天:[记录新增和查询](https://geektutu.com/post/geeorm-day3.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day3-save-query)
- 第四天:[链式操作与更新删除](https://geektutu.com/post/geeorm-day4.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day4-chain-operation)
- 第五天:[实现钩子(Hooks)](https://geektutu.com/post/geeorm-day5.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day5-hooks)
- 第六天:[支持事务(Transaction)](https://geektutu.com/post/geeorm-day6.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day6-transaction)
- 第七天:[数据库迁移(Migrate)](https://geektutu.com/post/geeorm-day7.html) | [Code](https://github.com/geektutu/7days-golang/blob/master/gee-orm/day7-migrate)
## 附 推荐阅读
- [Go 语言简明教程](https://geektutu.com/post/quick-golang.html)
- [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html)
- [Go Reflect 提高反射性能](https://geektutu.com/post/hpg-reflect.html)
- [SQLite 常用命令速查表](https://geektutu.com/post/cheat-sheet-sqlite.html)
================================================
FILE: gee-orm/run_test.sh
================================================
#!/bin/bash
set -eou pipefail
cur=$PWD
for item in "$cur"/day*/
do
echo "$item"
cd "$item"
go test geeorm/... 2>&1 | grep -v warning
done
================================================
FILE: gee-rpc/day1-codec/codec/codec.go
================================================
package codec
import (
"io"
)
type Header struct {
ServiceMethod string // format "Service.Method"
Seq uint64 // sequence number chosen by client
Error string
}
type Codec interface {
io.Closer
ReadHeader(*Header) error
ReadBody(interface{}) error
Write(*Header, interface{}) error
}
type NewCodecFunc func(io.ReadWriteCloser) Codec
type Type string
const (
GobType Type = "application/gob"
JsonType Type = "application/json" // not implemented
)
var NewCodecFuncMap map[Type]NewCodecFunc
func init() {
NewCodecFuncMap = make(map[Type]NewCodecFunc)
NewCodecFuncMap[GobType] = NewGobCodec
}
================================================
FILE: gee-rpc/day1-codec/codec/gob.go
================================================
package codec
import (
"bufio"
"encoding/gob"
"io"
"log"
)
type GobCodec struct {
conn io.ReadWriteCloser
buf *bufio.Writer
dec *gob.Decoder
enc *gob.Encoder
}
var _ Codec = (*GobCodec)(nil)
func NewGobCodec(conn io.ReadWriteCloser) Codec {
buf := bufio.NewWriter(conn)
return &GobCodec{
conn: conn,
buf: buf,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
}
}
func (c *GobCodec) ReadHeader(h *Header) error {
return c.dec.Decode(h)
}
func (c *GobCodec) ReadBody(body interface{}) error {
return c.dec.Decode(body)
}
func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
defer func() {
_ = c.buf.Flush()
if err != nil {
_ = c.Close()
}
}()
if err = c.enc.Encode(h); err != nil {
log.Println("rpc: gob error encoding header:", err)
return
}
if err = c.enc.Encode(body); err != nil {
log.Println("rpc: gob error encoding body:", err)
return
}
return
}
func (c *GobCodec) Close() error {
return c.conn.Close()
}
================================================
FILE: gee-rpc/day1-codec/go.mod
================================================
module geerpc
go 1.13
================================================
FILE: gee-rpc/day1-codec/main/main.go
================================================
package main
import (
"encoding/json"
"fmt"
"geerpc"
"geerpc/codec"
"log"
"net"
"time"
)
func startServer(addr chan string) {
// pick a free port
l, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatal("network error:", err)
}
log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String()
geerpc.Accept(l)
}
func main() {
log.SetFlags(0)
addr := make(chan string)
go startServer(addr)
// in fact, following code is like a simple geerpc client
conn, _ := net.Dial("tcp", <-addr)
defer func() { _ = conn.Close() }()
time.Sleep(time.Second)
// send options
_ = json.NewEncoder(conn).Encode(geerpc.DefaultOption)
cc := codec.NewGobCodec(conn)
// send request & receive response
for i := 0; i < 5; i++ {
h := &codec.Header{
ServiceMethod: "Foo.Sum",
Seq: uint64(i),
}
_ = cc.Write(h, fmt.Sprintf("geerpc req %d", h.Seq))
_ = cc.ReadHeader(h)
var reply string
_ = cc.ReadBody(&reply)
log.Println("reply:", reply)
}
}
================================================
FILE: gee-rpc/day1-codec/server.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"encoding/json"
"fmt"
"geerpc/codec"
"io"
"log"
"net"
"reflect"
"sync"
)
const MagicNumber = 0x3bef5c
type Option struct {
MagicNumber int // MagicNumber marks this's a geerpc request
CodecType codec.Type // client may choose different Codec to encode body
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
}
// Server represents an RPC Server.
type Server struct{}
// NewServer returns a new Server.
func NewServer() *Server {
return &Server{}
}
// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer()
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: options error: ", err)
return
}
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
return
}
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
return
}
server.serveCodec(f(conn))
}
// invalidRequest is a placeholder for response argv when error occurs
var invalidRequest = struct{}{}
func (server *Server) serveCodec(cc codec.Codec) {
sending := new(sync.Mutex) // make sure to send a complete response
wg := new(sync.WaitGroup) // wait until all request are handled
for {
req, err := server.readRequest(cc)
if err != nil {
if req == nil {
break // it's not possible to recover, so close the connection
}
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
go server.handleRequest(cc, req, sending, wg)
}
wg.Wait()
_ = cc.Close()
}
// request stores all information of a call
type request struct {
h *codec.Header // header of request
argv, replyv reflect.Value // argv and replyv of request
}
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error:", err)
}
return nil, err
}
return &h, nil
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
// TODO: now we don't know the type of request argv
// day 1, just suppose it's string
req.argv = reflect.New(reflect.TypeOf(""))
if err = cc.ReadBody(req.argv.Interface()); err != nil {
log.Println("rpc server: read argv err:", err)
}
return req, nil
}
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response error:", err)
}
}
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
// TODO, should call registered rpc methods to get the right replyv
// day 1, just print argv and send a hello message
defer wg.Done()
log.Println(req.h, req.argv.Elem())
req.replyv = reflect.ValueOf(fmt.Sprintf("geerpc resp %d", req.h.Seq))
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error:", err)
return
}
go server.ServeConn(conn)
}
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
================================================
FILE: gee-rpc/day2-client/client.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"encoding/json"
"errors"
"fmt"
"geerpc/codec"
"io"
"log"
"net"
"sync"
)
// Call represents an active RPC.
type Call struct {
Seq uint64
ServiceMethod string // format "."
Args interface{} // arguments to the function
Reply interface{} // reply from the function
Error error // if error occurs, it will be set
Done chan *Call // Strobes when call is complete.
}
func (call *Call) done() {
call.Done <- call
}
// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
cc codec.Codec
opt *Option
sending sync.Mutex // protect following
header codec.Header
mu sync.Mutex // protect following
seq uint64
pending map[uint64]*Call
closing bool // user has called Close
shutdown bool // server has told us to stop
}
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shut down")
// Close the connection
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}
// IsAvailable return true if the client does work
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
call.Seq = client.seq
client.pending[call.Seq] = call
client.seq++
return call.Seq, nil
}
func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending, seq)
return call
}
func (client *Client) terminateCalls(err error) {
client.sending.Lock()
defer client.sending.Unlock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
}
}
func (client *Client) send(call *Call) {
// make sure that the client will send a complete request
client.sending.Lock()
defer client.sending.Unlock()
// register this call.
seq, err := client.registerCall(call)
if err != nil {
call.Error = err
call.done()
return
}
// prepare request header
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
// encode and send the request
if err := client.cc.Write(&client.header, call.Args); err != nil {
call := client.removeCall(seq)
// call may be nil, it usually means that Write partially failed,
// client has received the response and handled
if call != nil {
call.Error = err
call.done()
}
}
}
func (client *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = client.cc.ReadHeader(&h); err != nil {
break
}
call := client.removeCall(h.Seq)
switch {
case call == nil:
// it usually means that Write partially failed
// and call was already removed.
err = client.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
default:
err = client.cc.ReadBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
}
// error occurs, so terminateCalls pending calls
client.terminateCalls(err)
}
// Go invokes the function asynchronously.
// It returns the Call structure representing the invocation.
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
client.send(call)
return call
}
// Call invokes the named function, waits for it to complete,
// and returns its error status.
func (client *Client) Call(serviceMethod string, args, reply interface{}) error {
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
return call.Error
}
func parseOptions(opts ...*Option) (*Option, error) {
// if opts is nil or pass nil as parameter
if len(opts) == 0 || opts[0] == nil {
return DefaultOption, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = DefaultOption.MagicNumber
if opt.CodecType == "" {
opt.CodecType = DefaultOption.CodecType
}
return opt, nil
}
func NewClient(conn net.Conn, opt *Option) (*Client, error) {
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", opt.CodecType)
log.Println("rpc client: codec error:", err)
return nil, err
}
// send options with server
if err := json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn), opt), nil
}
func newClientCodec(cc codec.Codec, opt *Option) *Client {
client := &Client{
seq: 1, // seq starts with 1, 0 means invalid call
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
return client
}
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if err != nil {
_ = conn.Close()
}
}()
return NewClient(conn, opt)
}
================================================
FILE: gee-rpc/day2-client/codec/codec.go
================================================
package codec
import (
"io"
)
type Header struct {
ServiceMethod string // format "Service.Method"
Seq uint64 // sequence number chosen by client
Error string
}
type Codec interface {
io.Closer
ReadHeader(*Header) error
ReadBody(interface{}) error
Write(*Header, interface{}) error
}
type NewCodecFunc func(io.ReadWriteCloser) Codec
type Type string
const (
GobType Type = "application/gob"
JsonType Type = "application/json" // not implemented
)
var NewCodecFuncMap map[Type]NewCodecFunc
func init() {
NewCodecFuncMap = make(map[Type]NewCodecFunc)
NewCodecFuncMap[GobType] = NewGobCodec
}
================================================
FILE: gee-rpc/day2-client/codec/gob.go
================================================
package codec
import (
"bufio"
"encoding/gob"
"io"
"log"
)
type GobCodec struct {
conn io.ReadWriteCloser
buf *bufio.Writer
dec *gob.Decoder
enc *gob.Encoder
}
var _ Codec = (*GobCodec)(nil)
func NewGobCodec(conn io.ReadWriteCloser) Codec {
buf := bufio.NewWriter(conn)
return &GobCodec{
conn: conn,
buf: buf,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
}
}
func (c *GobCodec) ReadHeader(h *Header) error {
return c.dec.Decode(h)
}
func (c *GobCodec) ReadBody(body interface{}) error {
return c.dec.Decode(body)
}
func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
defer func() {
_ = c.buf.Flush()
if err != nil {
_ = c.Close()
}
}()
if err = c.enc.Encode(h); err != nil {
log.Println("rpc: gob error encoding header:", err)
return
}
if err = c.enc.Encode(body); err != nil {
log.Println("rpc: gob error encoding body:", err)
return
}
return
}
func (c *GobCodec) Close() error {
return c.conn.Close()
}
================================================
FILE: gee-rpc/day2-client/go.mod
================================================
module geerpc
go 1.13
================================================
FILE: gee-rpc/day2-client/main/main.go
================================================
package main
import (
"fmt"
"geerpc"
"log"
"net"
"sync"
"time"
)
func startServer(addr chan string) {
// pick a free port
l, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatal("network error:", err)
}
log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String()
geerpc.Accept(l)
}
func main() {
log.SetFlags(0)
addr := make(chan string)
go startServer(addr)
client, _ := geerpc.Dial("tcp", <-addr)
defer func() { _ = client.Close() }()
time.Sleep(time.Second)
// send request & receive response
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
args := fmt.Sprintf("geerpc req %d", i)
var reply string
if err := client.Call("Foo.Sum", args, &reply); err != nil {
log.Fatal("call Foo.Sum error:", err)
}
log.Println("reply:", reply)
}(i)
}
wg.Wait()
}
================================================
FILE: gee-rpc/day2-client/server.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"encoding/json"
"fmt"
"geerpc/codec"
"io"
"log"
"net"
"reflect"
"sync"
)
const MagicNumber = 0x3bef5c
type Option struct {
MagicNumber int // MagicNumber marks this's a geerpc request
CodecType codec.Type // client may choose different Codec to encode body
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
}
// Server represents an RPC Server.
type Server struct{}
// NewServer returns a new Server.
func NewServer() *Server {
return &Server{}
}
// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer()
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: options error: ", err)
return
}
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
return
}
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
return
}
server.serveCodec(f(conn))
}
// invalidRequest is a placeholder for response argv when error occurs
var invalidRequest = struct{}{}
func (server *Server) serveCodec(cc codec.Codec) {
sending := new(sync.Mutex) // make sure to send a complete response
wg := new(sync.WaitGroup) // wait until all request are handled
for {
req, err := server.readRequest(cc)
if err != nil {
if req == nil {
break // it's not possible to recover, so close the connection
}
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
go server.handleRequest(cc, req, sending, wg)
}
wg.Wait()
_ = cc.Close()
}
// request stores all information of a call
type request struct {
h *codec.Header // header of request
argv, replyv reflect.Value // argv and replyv of request
}
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error:", err)
}
return nil, err
}
return &h, nil
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
// TODO: now we don't know the type of request argv
// day 1, just suppose it's string
req.argv = reflect.New(reflect.TypeOf(""))
if err = cc.ReadBody(req.argv.Interface()); err != nil {
log.Println("rpc server: read argv err:", err)
}
return req, nil
}
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response error:", err)
}
}
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
// TODO, should call registered rpc methods to get the right replyv
// day 1, just print argv and send a hello message
defer wg.Done()
log.Println(req.h, req.argv.Elem())
req.replyv = reflect.ValueOf(fmt.Sprintf("geerpc resp %d", req.h.Seq))
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error:", err)
return
}
go server.ServeConn(conn)
}
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
================================================
FILE: gee-rpc/day3-service/client.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"encoding/json"
"errors"
"fmt"
"geerpc/codec"
"io"
"log"
"net"
"sync"
)
// Call represents an active RPC.
type Call struct {
Seq uint64
ServiceMethod string // format "."
Args interface{} // arguments to the function
Reply interface{} // reply from the function
Error error // if error occurs, it will be set
Done chan *Call // Strobes when call is complete.
}
func (call *Call) done() {
call.Done <- call
}
// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
cc codec.Codec
opt *Option
sending sync.Mutex // protect following
header codec.Header
mu sync.Mutex // protect following
seq uint64
pending map[uint64]*Call
closing bool // user has called Close
shutdown bool // server has told us to stop
}
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shut down")
// Close the connection
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}
// IsAvailable return true if the client does work
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
call.Seq = client.seq
client.pending[call.Seq] = call
client.seq++
return call.Seq, nil
}
func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending, seq)
return call
}
func (client *Client) terminateCalls(err error) {
client.sending.Lock()
defer client.sending.Unlock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
}
}
func (client *Client) send(call *Call) {
// make sure that the client will send a complete request
client.sending.Lock()
defer client.sending.Unlock()
// register this call.
seq, err := client.registerCall(call)
if err != nil {
call.Error = err
call.done()
return
}
// prepare request header
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
// encode and send the request
if err := client.cc.Write(&client.header, call.Args); err != nil {
call := client.removeCall(seq)
// call may be nil, it usually means that Write partially failed,
// client has received the response and handled
if call != nil {
call.Error = err
call.done()
}
}
}
func (client *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = client.cc.ReadHeader(&h); err != nil {
break
}
call := client.removeCall(h.Seq)
switch {
case call == nil:
// it usually means that Write partially failed
// and call was already removed.
err = client.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
default:
err = client.cc.ReadBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
}
// error occurs, so terminateCalls pending calls
client.terminateCalls(err)
}
// Go invokes the function asynchronously.
// It returns the Call structure representing the invocation.
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
client.send(call)
return call
}
// Call invokes the named function, waits for it to complete,
// and returns its error status.
func (client *Client) Call(serviceMethod string, args, reply interface{}) error {
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
return call.Error
}
func parseOptions(opts ...*Option) (*Option, error) {
// if opts is nil or pass nil as parameter
if len(opts) == 0 || opts[0] == nil {
return DefaultOption, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = DefaultOption.MagicNumber
if opt.CodecType == "" {
opt.CodecType = DefaultOption.CodecType
}
return opt, nil
}
func NewClient(conn net.Conn, opt *Option) (*Client, error) {
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", opt.CodecType)
log.Println("rpc client: codec error:", err)
return nil, err
}
// send options with server
if err := json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
return nil, err
}
return newClientCodec(f(conn), opt), nil
}
func newClientCodec(cc codec.Codec, opt *Option) *Client {
client := &Client{
seq: 1, // seq starts with 1, 0 means invalid call
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
return client
}
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if err != nil {
_ = conn.Close()
}
}()
return NewClient(conn, opt)
}
================================================
FILE: gee-rpc/day3-service/codec/codec.go
================================================
package codec
import (
"io"
)
type Header struct {
ServiceMethod string // format "Service.Method"
Seq uint64 // sequence number chosen by client
Error string
}
type Codec interface {
io.Closer
ReadHeader(*Header) error
ReadBody(interface{}) error
Write(*Header, interface{}) error
}
type NewCodecFunc func(io.ReadWriteCloser) Codec
type Type string
const (
GobType Type = "application/gob"
JsonType Type = "application/json" // not implemented
)
var NewCodecFuncMap map[Type]NewCodecFunc
func init() {
NewCodecFuncMap = make(map[Type]NewCodecFunc)
NewCodecFuncMap[GobType] = NewGobCodec
}
================================================
FILE: gee-rpc/day3-service/codec/gob.go
================================================
package codec
import (
"bufio"
"encoding/gob"
"io"
"log"
)
type GobCodec struct {
conn io.ReadWriteCloser
buf *bufio.Writer
dec *gob.Decoder
enc *gob.Encoder
}
var _ Codec = (*GobCodec)(nil)
func NewGobCodec(conn io.ReadWriteCloser) Codec {
buf := bufio.NewWriter(conn)
return &GobCodec{
conn: conn,
buf: buf,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
}
}
func (c *GobCodec) ReadHeader(h *Header) error {
return c.dec.Decode(h)
}
func (c *GobCodec) ReadBody(body interface{}) error {
return c.dec.Decode(body)
}
func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
defer func() {
_ = c.buf.Flush()
if err != nil {
_ = c.Close()
}
}()
if err = c.enc.Encode(h); err != nil {
log.Println("rpc: gob error encoding header:", err)
return
}
if err = c.enc.Encode(body); err != nil {
log.Println("rpc: gob error encoding body:", err)
return
}
return
}
func (c *GobCodec) Close() error {
return c.conn.Close()
}
================================================
FILE: gee-rpc/day3-service/go.mod
================================================
module geerpc
go 1.13
================================================
FILE: gee-rpc/day3-service/main/main.go
================================================
package main
import (
"geerpc"
"log"
"net"
"sync"
"time"
)
type Foo int
type Args struct{ Num1, Num2 int }
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func startServer(addr chan string) {
var foo Foo
if err := geerpc.Register(&foo); err != nil {
log.Fatal("register error:", err)
}
// pick a free port
l, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatal("network error:", err)
}
log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String()
geerpc.Accept(l)
}
func main() {
log.SetFlags(0)
addr := make(chan string)
go startServer(addr)
client, _ := geerpc.Dial("tcp", <-addr)
defer func() { _ = client.Close() }()
time.Sleep(time.Second)
// send request & receive response
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
args := &Args{Num1: i, Num2: i * i}
var reply int
if err := client.Call("Foo.Sum", args, &reply); err != nil {
log.Fatal("call Foo.Sum error:", err)
}
log.Printf("%d + %d = %d", args.Num1, args.Num2, reply)
}(i)
}
wg.Wait()
}
================================================
FILE: gee-rpc/day3-service/server.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"encoding/json"
"errors"
"geerpc/codec"
"io"
"log"
"net"
"reflect"
"strings"
"sync"
)
const MagicNumber = 0x3bef5c
type Option struct {
MagicNumber int // MagicNumber marks this's a geerpc request
CodecType codec.Type // client may choose different Codec to encode body
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
}
// Server represents an RPC Server.
type Server struct {
serviceMap sync.Map
}
// NewServer returns a new Server.
func NewServer() *Server {
return &Server{}
}
// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer()
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: options error: ", err)
return
}
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
return
}
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
return
}
server.serveCodec(f(conn))
}
// invalidRequest is a placeholder for response argv when error occurs
var invalidRequest = struct{}{}
func (server *Server) serveCodec(cc codec.Codec) {
sending := new(sync.Mutex) // make sure to send a complete response
wg := new(sync.WaitGroup) // wait until all request are handled
for {
req, err := server.readRequest(cc)
if err != nil {
if req == nil {
break // it's not possible to recover, so close the connection
}
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
go server.handleRequest(cc, req, sending, wg)
}
wg.Wait()
_ = cc.Close()
}
// request stores all information of a call
type request struct {
h *codec.Header // header of request
argv, replyv reflect.Value // argv and replyv of request
mtype *methodType
svc *service
}
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error:", err)
}
return nil, err
}
return &h, nil
}
func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod)
return
}
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc server: can't find service " + serviceName)
return
}
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc server: can't find method " + methodName)
}
return
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()
// make sure that argvi is a pointer, ReadBody need a pointer as parameter
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read body err:", err)
return req, err
}
return req, nil
}
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response error:", err)
}
}
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
defer wg.Done()
err := req.svc.call(req.mtype, req.argv, req.replyv)
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error:", err)
return
}
go server.ServeConn(conn)
}
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
// Register publishes in the server the set of methods of the
// receiver value that satisfy the following conditions:
// - exported method of exported type
// - two arguments, both of exported type
// - the second argument is a pointer
// - one return value, of type error
func (server *Server) Register(rcvr interface{}) error {
s := newService(rcvr)
if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup {
return errors.New("rpc: service already defined: " + s.name)
}
return nil
}
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
================================================
FILE: gee-rpc/day3-service/service.go
================================================
package geerpc
import (
"go/ast"
"log"
"reflect"
"sync/atomic"
)
type methodType struct {
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
numCalls uint64
}
func (m *methodType) NumCalls() uint64 {
return atomic.LoadUint64(&m.numCalls)
}
func (m *methodType) newArgv() reflect.Value {
var argv reflect.Value
// arg may be a pointer type, or a value type
if m.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(m.ArgType.Elem())
} else {
argv = reflect.New(m.ArgType).Elem()
}
return argv
}
func (m *methodType) newReplyv() reflect.Value {
// reply must be a pointer type
replyv := reflect.New(m.ReplyType.Elem())
switch m.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
}
return replyv
}
type service struct {
name string
typ reflect.Type
rcvr reflect.Value
method map[string]*methodType
}
func newService(rcvr interface{}) *service {
s := new(service)
s.rcvr = reflect.ValueOf(rcvr)
s.name = reflect.Indirect(s.rcvr).Type().Name()
s.typ = reflect.TypeOf(rcvr)
if !ast.IsExported(s.name) {
log.Fatalf("rpc server: %s is not a valid service name", s.name)
}
s.registerMethods()
return s
}
func (s *service) registerMethods() {
s.method = make(map[string]*methodType)
for i := 0; i < s.typ.NumMethod(); i++ {
method := s.typ.Method(i)
mType := method.Type
if mType.NumIn() != 3 || mType.NumOut() != 1 {
continue
}
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
continue
}
argType, replyType := mType.In(1), mType.In(2)
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
continue
}
s.method[method.Name] = &methodType{
method: method,
ArgType: argType,
ReplyType: replyType,
}
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
}
}
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
atomic.AddUint64(&m.numCalls, 1)
f := m.method.Func
returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
if errInter := returnValues[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}
func isExportedOrBuiltinType(t reflect.Type) bool {
return ast.IsExported(t.Name()) || t.PkgPath() == ""
}
================================================
FILE: gee-rpc/day3-service/service_test.go
================================================
package geerpc
import (
"fmt"
"reflect"
"testing"
)
type Foo int
type Args struct{ Num1, Num2 int }
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
// it's not a exported Method
func (f Foo) sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func _assert(condition bool, msg string, v ...interface{}) {
if !condition {
panic(fmt.Sprintf("assertion failed: "+msg, v...))
}
}
func TestNewService(t *testing.T) {
var foo Foo
s := newService(&foo)
_assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method))
mType := s.method["Sum"]
_assert(mType != nil, "wrong Method, Sum shouldn't nil")
}
func TestMethodType_Call(t *testing.T) {
var foo Foo
s := newService(&foo)
mType := s.method["Sum"]
argv := mType.newArgv()
replyv := mType.newReplyv()
argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3}))
err := s.call(mType, argv, replyv)
_assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum")
}
================================================
FILE: gee-rpc/day4-timeout/client.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"context"
"encoding/json"
"errors"
"fmt"
"geerpc/codec"
"io"
"log"
"net"
"sync"
"time"
)
// Call represents an active RPC.
type Call struct {
Seq uint64
ServiceMethod string // format "."
Args interface{} // arguments to the function
Reply interface{} // reply from the function
Error error // if error occurs, it will be set
Done chan *Call // Strobes when call is complete.
}
func (call *Call) done() {
call.Done <- call
}
// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
cc codec.Codec
opt *Option
sending sync.Mutex // protect following
header codec.Header
mu sync.Mutex // protect following
seq uint64
pending map[uint64]*Call
closing bool // user has called Close
shutdown bool // server has told us to stop
}
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shut down")
// Close the connection
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}
// IsAvailable return true if the client does work
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
call.Seq = client.seq
client.pending[call.Seq] = call
client.seq++
return call.Seq, nil
}
func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending, seq)
return call
}
func (client *Client) terminateCalls(err error) {
client.sending.Lock()
defer client.sending.Unlock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
}
}
func (client *Client) send(call *Call) {
// make sure that the client will send a complete request
client.sending.Lock()
defer client.sending.Unlock()
// register this call.
seq, err := client.registerCall(call)
if err != nil {
call.Error = err
call.done()
return
}
// prepare request header
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
// encode and send the request
if err := client.cc.Write(&client.header, call.Args); err != nil {
call := client.removeCall(seq)
// call may be nil, it usually means that Write partially failed,
// client has received the response and handled
if call != nil {
call.Error = err
call.done()
}
}
}
func (client *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = client.cc.ReadHeader(&h); err != nil {
break
}
call := client.removeCall(h.Seq)
switch {
case call == nil:
// it usually means that Write partially failed
// and call was already removed.
err = client.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
default:
err = client.cc.ReadBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
}
// error occurs, so terminateCalls pending calls
client.terminateCalls(err)
}
// Go invokes the function asynchronously.
// It returns the Call structure representing the invocation.
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
client.send(call)
return call
}
// Call invokes the named function, waits for it to complete,
// and returns its error status.
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return errors.New("rpc client: call failed: " + ctx.Err().Error())
case call := <-call.Done:
return call.Error
}
}
func parseOptions(opts ...*Option) (*Option, error) {
// if opts is nil or pass nil as parameter
if len(opts) == 0 || opts[0] == nil {
return DefaultOption, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = DefaultOption.MagicNumber
if opt.CodecType == "" {
opt.CodecType = DefaultOption.CodecType
}
return opt, nil
}
func NewClient(conn net.Conn, opt *Option) (client *Client, err error) {
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
err = fmt.Errorf("invalid codec type %s", opt.CodecType)
log.Println("rpc client: codec error:", err)
return
}
// send options with server
if err = json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
return
}
return newClientCodec(f(conn), opt), nil
}
func newClientCodec(cc codec.Codec, opt *Option) *Client {
client := &Client{
seq: 1, // seq starts with 1, 0 means invalid call
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
return client
}
type clientResult struct {
client *Client
err error
}
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)
func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if err != nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(NewClient, network, address, opts...)
}
================================================
FILE: gee-rpc/day4-timeout/client_test.go
================================================
package geerpc
import (
"context"
"net"
"strings"
"testing"
"time"
)
type Bar int
func (b Bar) Timeout(argv int, reply *int) error {
time.Sleep(time.Second * 2)
return nil
}
func startServer(addr chan string) {
var b Bar
_ = Register(&b)
// pick a free port
l, _ := net.Listen("tcp", ":0")
addr <- l.Addr().String()
Accept(l)
}
func TestClient_dialTimeout(t *testing.T) {
t.Parallel()
l, _ := net.Listen("tcp", ":0")
f := func(conn net.Conn, opt *Option) (client *Client, err error) {
_ = conn.Close()
time.Sleep(time.Second * 2)
return nil, nil
}
t.Run("timeout", func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error")
})
t.Run("0", func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0})
_assert(err == nil, "0 means no limit")
})
}
func TestClient_Call(t *testing.T) {
t.Parallel()
addrCh := make(chan string)
go startServer(addrCh)
addr := <-addrCh
time.Sleep(time.Second)
t.Run("client timeout", func(t *testing.T) {
client, _ := Dial("tcp", addr)
ctx, _ := context.WithTimeout(context.Background(), time.Second)
var reply int
err := client.Call(ctx, "Bar.Timeout", 1, &reply)
_assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error")
})
t.Run("server handle timeout", func(t *testing.T) {
client, _ := Dial("tcp", addr, &Option{
HandleTimeout: time.Second,
})
var reply int
err := client.Call(context.Background(), "Bar.Timeout", 1, &reply)
_assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error")
})
}
================================================
FILE: gee-rpc/day4-timeout/codec/codec.go
================================================
package codec
import (
"io"
)
type Header struct {
ServiceMethod string // format "Service.Method"
Seq uint64 // sequence number chosen by client
Error string
}
type Codec interface {
io.Closer
ReadHeader(*Header) error
ReadBody(interface{}) error
Write(*Header, interface{}) error
}
type NewCodecFunc func(io.ReadWriteCloser) Codec
type Type string
const (
GobType Type = "application/gob"
JsonType Type = "application/json" // not implemented
)
var NewCodecFuncMap map[Type]NewCodecFunc
func init() {
NewCodecFuncMap = make(map[Type]NewCodecFunc)
NewCodecFuncMap[GobType] = NewGobCodec
}
================================================
FILE: gee-rpc/day4-timeout/codec/gob.go
================================================
package codec
import (
"bufio"
"encoding/gob"
"io"
"log"
)
type GobCodec struct {
conn io.ReadWriteCloser
buf *bufio.Writer
dec *gob.Decoder
enc *gob.Encoder
}
var _ Codec = (*GobCodec)(nil)
func NewGobCodec(conn io.ReadWriteCloser) Codec {
buf := bufio.NewWriter(conn)
return &GobCodec{
conn: conn,
buf: buf,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
}
}
func (c *GobCodec) ReadHeader(h *Header) error {
return c.dec.Decode(h)
}
func (c *GobCodec) ReadBody(body interface{}) error {
return c.dec.Decode(body)
}
func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
defer func() {
_ = c.buf.Flush()
if err != nil {
_ = c.Close()
}
}()
if err = c.enc.Encode(h); err != nil {
log.Println("rpc: gob error encoding header:", err)
return
}
if err = c.enc.Encode(body); err != nil {
log.Println("rpc: gob error encoding body:", err)
return
}
return
}
func (c *GobCodec) Close() error {
return c.conn.Close()
}
================================================
FILE: gee-rpc/day4-timeout/go.mod
================================================
module geerpc
go 1.13
================================================
FILE: gee-rpc/day4-timeout/main/main.go
================================================
package main
import (
"context"
"geerpc"
"log"
"net"
"sync"
"time"
)
type Foo int
type Args struct{ Num1, Num2 int }
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func startServer(addr chan string) {
var foo Foo
if err := geerpc.Register(&foo); err != nil {
log.Fatal("register error:", err)
}
// pick a free port
l, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatal("network error:", err)
}
log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String()
geerpc.Accept(l)
}
func main() {
log.SetFlags(0)
addr := make(chan string)
go startServer(addr)
client, _ := geerpc.Dial("tcp", <-addr)
defer func() { _ = client.Close() }()
time.Sleep(time.Second)
// send request & receive response
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
args := &Args{Num1: i, Num2: i * i}
var reply int
if err := client.Call(context.Background(), "Foo.Sum", args, &reply); err != nil {
log.Fatal("call Foo.Sum error:", err)
}
log.Printf("%d + %d = %d", args.Num1, args.Num2, reply)
}(i)
}
wg.Wait()
}
================================================
FILE: gee-rpc/day4-timeout/server.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"encoding/json"
"errors"
"fmt"
"geerpc/codec"
"io"
"log"
"net"
"reflect"
"strings"
"sync"
"time"
)
const MagicNumber = 0x3bef5c
type Option struct {
MagicNumber int // MagicNumber marks this's a geerpc request
CodecType codec.Type // client may choose different Codec to encode body
ConnectTimeout time.Duration // 0 means no limit
HandleTimeout time.Duration
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10,
}
// Server represents an RPC Server.
type Server struct {
serviceMap sync.Map
}
// NewServer returns a new Server.
func NewServer() *Server {
return &Server{}
}
// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer()
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: options error: ", err)
return
}
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
return
}
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
return
}
server.serveCodec(f(conn), &opt)
}
// invalidRequest is a placeholder for response argv when error occurs
var invalidRequest = struct{}{}
func (server *Server) serveCodec(cc codec.Codec, opt *Option) {
sending := new(sync.Mutex) // make sure to send a complete response
wg := new(sync.WaitGroup) // wait until all request are handled
for {
req, err := server.readRequest(cc)
if err != nil {
if req == nil {
break // it's not possible to recover, so close the connection
}
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout)
}
wg.Wait()
_ = cc.Close()
}
// request stores all information of a call
type request struct {
h *codec.Header // header of request
argv, replyv reflect.Value // argv and replyv of request
mtype *methodType
svc *service
}
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error:", err)
}
return nil, err
}
return &h, nil
}
func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod)
return
}
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc server: can't find service " + serviceName)
return
}
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc server: can't find method " + methodName)
}
return
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()
// make sure that argvi is a pointer, ReadBody need a pointer as parameter
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read body err:", err)
return req, err
}
return req, nil
}
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response error:", err)
}
}
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
called := make(chan struct{})
sent := make(chan struct{})
go func() {
err := req.svc.call(req.mtype, req.argv, req.replyv)
called <- struct{}{}
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
sent <- struct{}{}
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}()
if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
server.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error:", err)
return
}
go server.ServeConn(conn)
}
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
// Register publishes in the server the set of methods of the
// receiver value that satisfy the following conditions:
// - exported method of exported type
// - two arguments, both of exported type
// - the second argument is a pointer
// - one return value, of type error
func (server *Server) Register(rcvr interface{}) error {
s := newService(rcvr)
if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup {
return errors.New("rpc: service already defined: " + s.name)
}
return nil
}
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
================================================
FILE: gee-rpc/day4-timeout/service.go
================================================
package geerpc
import (
"go/ast"
"log"
"reflect"
"sync/atomic"
)
type methodType struct {
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
numCalls uint64
}
func (m *methodType) NumCalls() uint64 {
return atomic.LoadUint64(&m.numCalls)
}
func (m *methodType) newArgv() reflect.Value {
var argv reflect.Value
// arg may be a pointer type, or a value type
if m.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(m.ArgType.Elem())
} else {
argv = reflect.New(m.ArgType).Elem()
}
return argv
}
func (m *methodType) newReplyv() reflect.Value {
// reply must be a pointer type
replyv := reflect.New(m.ReplyType.Elem())
switch m.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
}
return replyv
}
type service struct {
name string
typ reflect.Type
rcvr reflect.Value
method map[string]*methodType
}
func newService(rcvr interface{}) *service {
s := new(service)
s.rcvr = reflect.ValueOf(rcvr)
s.name = reflect.Indirect(s.rcvr).Type().Name()
s.typ = reflect.TypeOf(rcvr)
if !ast.IsExported(s.name) {
log.Fatalf("rpc server: %s is not a valid service name", s.name)
}
s.registerMethods()
return s
}
func (s *service) registerMethods() {
s.method = make(map[string]*methodType)
for i := 0; i < s.typ.NumMethod(); i++ {
method := s.typ.Method(i)
mType := method.Type
if mType.NumIn() != 3 || mType.NumOut() != 1 {
continue
}
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
continue
}
argType, replyType := mType.In(1), mType.In(2)
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
continue
}
s.method[method.Name] = &methodType{
method: method,
ArgType: argType,
ReplyType: replyType,
}
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
}
}
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
atomic.AddUint64(&m.numCalls, 1)
f := m.method.Func
returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
if errInter := returnValues[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}
func isExportedOrBuiltinType(t reflect.Type) bool {
return ast.IsExported(t.Name()) || t.PkgPath() == ""
}
================================================
FILE: gee-rpc/day4-timeout/service_test.go
================================================
package geerpc
import (
"fmt"
"reflect"
"testing"
)
type Foo int
type Args struct{ Num1, Num2 int }
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
// it's not a exported Method
func (f Foo) sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func _assert(condition bool, msg string, v ...interface{}) {
if !condition {
panic(fmt.Sprintf("assertion failed: "+msg, v...))
}
}
func TestNewService(t *testing.T) {
var foo Foo
s := newService(&foo)
_assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method))
mType := s.method["Sum"]
_assert(mType != nil, "wrong Method, Sum shouldn't nil")
}
func TestMethodType_Call(t *testing.T) {
var foo Foo
s := newService(&foo)
mType := s.method["Sum"]
argv := mType.newArgv()
replyv := mType.newReplyv()
argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3}))
err := s.call(mType, argv, replyv)
_assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum")
}
================================================
FILE: gee-rpc/day5-http-debug/client.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geerpc/codec"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
)
// Call represents an active RPC.
type Call struct {
Seq uint64
ServiceMethod string // format "."
Args interface{} // arguments to the function
Reply interface{} // reply from the function
Error error // if error occurs, it will be set
Done chan *Call // Strobes when call is complete.
}
func (call *Call) done() {
call.Done <- call
}
// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
cc codec.Codec
opt *Option
sending sync.Mutex // protect following
header codec.Header
mu sync.Mutex // protect following
seq uint64
pending map[uint64]*Call
closing bool // user has called Close
shutdown bool // server has told us to stop
}
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shut down")
// Close the connection
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}
// IsAvailable return true if the client does work
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
call.Seq = client.seq
client.pending[call.Seq] = call
client.seq++
return call.Seq, nil
}
func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending, seq)
return call
}
func (client *Client) terminateCalls(err error) {
client.sending.Lock()
defer client.sending.Unlock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
}
}
func (client *Client) send(call *Call) {
// make sure that the client will send a complete request
client.sending.Lock()
defer client.sending.Unlock()
// register this call.
seq, err := client.registerCall(call)
if err != nil {
call.Error = err
call.done()
return
}
// prepare request header
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
// encode and send the request
if err := client.cc.Write(&client.header, call.Args); err != nil {
call := client.removeCall(seq)
// call may be nil, it usually means that Write partially failed,
// client has received the response and handled
if call != nil {
call.Error = err
call.done()
}
}
}
func (client *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = client.cc.ReadHeader(&h); err != nil {
break
}
call := client.removeCall(h.Seq)
switch {
case call == nil:
// it usually means that Write partially failed
// and call was already removed.
err = client.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
default:
err = client.cc.ReadBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
}
// error occurs, so terminateCalls pending calls
client.terminateCalls(err)
}
// Go invokes the function asynchronously.
// It returns the Call structure representing the invocation.
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
client.send(call)
return call
}
// Call invokes the named function, waits for it to complete,
// and returns its error status.
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return errors.New("rpc client: call failed: " + ctx.Err().Error())
case call := <-call.Done:
return call.Error
}
}
func parseOptions(opts ...*Option) (*Option, error) {
// if opts is nil or pass nil as parameter
if len(opts) == 0 || opts[0] == nil {
return DefaultOption, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = DefaultOption.MagicNumber
if opt.CodecType == "" {
opt.CodecType = DefaultOption.CodecType
}
return opt, nil
}
func NewClient(conn net.Conn, opt *Option) (*Client, error) {
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", opt.CodecType)
log.Println("rpc client: codec error:", err)
return nil, err
}
// send options with server
if err := json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn), opt), nil
}
func newClientCodec(cc codec.Codec, opt *Option) *Client {
client := &Client{
seq: 1, // seq starts with 1, 0 means invalid call
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
return client
}
type clientResult struct {
client *Client
err error
}
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)
func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if err != nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(NewClient, network, address, opts...)
}
// NewHTTPClient new a Client instance via HTTP as transport protocol
func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) {
_, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath))
// Require successful HTTP response
// before switching to RPC protocol.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err == nil && resp.Status == connected {
return NewClient(conn, opt)
}
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
return nil, err
}
// DialHTTP connects to an HTTP RPC server at the specified network address
// listening on the default HTTP RPC path.
func DialHTTP(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(NewHTTPClient, network, address, opts...)
}
// XDial calls different functions to connect to a RPC server
// according the first parameter rpcAddr.
// rpcAddr is a general format (protocol@addr) to represent a rpc server
// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock
func XDial(rpcAddr string, opts ...*Option) (*Client, error) {
parts := strings.Split(rpcAddr, "@")
if len(parts) != 2 {
return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr)
}
protocol, addr := parts[0], parts[1]
switch protocol {
case "http":
return DialHTTP("tcp", addr, opts...)
default:
// tcp, unix or other transport protocol
return Dial(protocol, addr, opts...)
}
}
================================================
FILE: gee-rpc/day5-http-debug/client_test.go
================================================
package geerpc
import (
"context"
"net"
"os"
"runtime"
"strings"
"testing"
"time"
)
type Bar int
func (b Bar) Timeout(argv int, reply *int) error {
time.Sleep(time.Second * 2)
return nil
}
func startServer(addr chan string) {
var b Bar
_ = Register(&b)
// pick a free port
l, _ := net.Listen("tcp", ":0")
addr <- l.Addr().String()
Accept(l)
}
func TestClient_dialTimeout(t *testing.T) {
t.Parallel()
l, _ := net.Listen("tcp", ":0")
f := func(conn net.Conn, opt *Option) (client *Client, err error) {
_ = conn.Close()
time.Sleep(time.Second * 2)
return nil, nil
}
t.Run("timeout", func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error")
})
t.Run("0", func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0})
_assert(err == nil, "0 means no limit")
})
}
func TestClient_Call(t *testing.T) {
t.Parallel()
addrCh := make(chan string)
go startServer(addrCh)
addr := <-addrCh
time.Sleep(time.Second)
t.Run("client timeout", func(t *testing.T) {
client, _ := Dial("tcp", addr)
ctx, _ := context.WithTimeout(context.Background(), time.Second)
var reply int
err := client.Call(ctx, "Bar.Timeout", 1, &reply)
_assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error")
})
t.Run("server handle timeout", func(t *testing.T) {
client, _ := Dial("tcp", addr, &Option{
HandleTimeout: time.Second,
})
var reply int
err := client.Call(context.Background(), "Bar.Timeout", 1, &reply)
_assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error")
})
}
func TestXDial(t *testing.T) {
if runtime.GOOS == "linux" {
ch := make(chan struct{})
addr := "/tmp/geerpc.sock"
go func() {
_ = os.Remove(addr)
l, err := net.Listen("unix", addr)
if err != nil {
t.Fatal("failed to listen unix socket")
}
ch <- struct{}{}
Accept(l)
}()
<-ch
_, err := XDial("unix@" + addr)
_assert(err == nil, "failed to connect unix socket")
}
}
================================================
FILE: gee-rpc/day5-http-debug/codec/codec.go
================================================
package codec
import (
"io"
)
type Header struct {
ServiceMethod string // format "Service.Method"
Seq uint64 // sequence number chosen by client
Error string
}
type Codec interface {
io.Closer
ReadHeader(*Header) error
ReadBody(interface{}) error
Write(*Header, interface{}) error
}
type NewCodecFunc func(io.ReadWriteCloser) Codec
type Type string
const (
GobType Type = "application/gob"
JsonType Type = "application/json" // not implemented
)
var NewCodecFuncMap map[Type]NewCodecFunc
func init() {
NewCodecFuncMap = make(map[Type]NewCodecFunc)
NewCodecFuncMap[GobType] = NewGobCodec
}
================================================
FILE: gee-rpc/day5-http-debug/codec/gob.go
================================================
package codec
import (
"bufio"
"encoding/gob"
"io"
"log"
)
type GobCodec struct {
conn io.ReadWriteCloser
buf *bufio.Writer
dec *gob.Decoder
enc *gob.Encoder
}
var _ Codec = (*GobCodec)(nil)
func NewGobCodec(conn io.ReadWriteCloser) Codec {
buf := bufio.NewWriter(conn)
return &GobCodec{
conn: conn,
buf: buf,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
}
}
func (c *GobCodec) ReadHeader(h *Header) error {
return c.dec.Decode(h)
}
func (c *GobCodec) ReadBody(body interface{}) error {
return c.dec.Decode(body)
}
func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
defer func() {
_ = c.buf.Flush()
if err != nil {
_ = c.Close()
}
}()
if err = c.enc.Encode(h); err != nil {
log.Println("rpc: gob error encoding header:", err)
return
}
if err = c.enc.Encode(body); err != nil {
log.Println("rpc: gob error encoding body:", err)
return
}
return
}
func (c *GobCodec) Close() error {
return c.conn.Close()
}
================================================
FILE: gee-rpc/day5-http-debug/debug.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package geerpc
import (
"fmt"
"html/template"
"net/http"
)
const debugText = `
GeeRPC Services
{{range .}}
Service {{.Name}}