Repository: facebookgo/dvara Branch: master Commit: 95e05462c790 Files: 19 Total size: 88.3 KB Directory structure: gitextract_py9bc1vq/ ├── .travis.yml ├── cmd/ │ └── dvara/ │ ├── logger.go │ └── main.go ├── common_test.go ├── doc.go ├── license ├── patents ├── protocol.go ├── protocol_test.go ├── proxy.go ├── proxy_test.go ├── readme.md ├── replica_set.go ├── replica_set_test.go ├── response_rewriter.go ├── response_rewriter_test.go ├── rs_state.go ├── rs_state_test.go └── state.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .travis.yml ================================================ language: go env: GO_RUN_LONG_TEST=1 go: - 1.5 install: - go get -t ./... ================================================ FILE: cmd/dvara/logger.go ================================================ package main import "log" // stdLogger provides a logger backed by the standard library logger. This is a // placeholder until we can open source our logger. type stdLogger struct{} func (l *stdLogger) Error(args ...interface{}) { log.Print(args...) } func (l *stdLogger) Errorf(format string, args ...interface{}) { log.Printf(format, args...) } func (l *stdLogger) Warn(args ...interface{}) { log.Print(args...) } func (l *stdLogger) Warnf(format string, args ...interface{}) { log.Printf(format, args...) } func (l *stdLogger) Info(args ...interface{}) { log.Print(args...) } func (l *stdLogger) Infof(format string, args ...interface{}) { log.Printf(format, args...) } func (l *stdLogger) Debug(args ...interface{}) { log.Print(args...) } func (l *stdLogger) Debugf(format string, args ...interface{}) { log.Printf(format, args...) } ================================================ FILE: cmd/dvara/main.go ================================================ package main import ( "flag" "fmt" "os" "os/signal" "syscall" "time" "github.com/facebookgo/dvara" "github.com/facebookgo/inject" "github.com/facebookgo/startstop" "github.com/facebookgo/stats" ) func main() { if err := Main(); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } func Main() error { messageTimeout := flag.Duration("message_timeout", 2*time.Minute, "timeout for one message to be proxied") clientIdleTimeout := flag.Duration("client_idle_timeout", 60*time.Minute, "idle timeout for client connections") serverIdleTimeout := flag.Duration("server_idle_timeout", 1*time.Hour, "idle timeout for server connections") serverClosePoolSize := flag.Uint("server_close_pool_size", 100, "number of goroutines that will handle closing server connections") getLastErrorTimeout := flag.Duration("get_last_error_timeout", time.Minute, "timeout for getLastError pinning") maxPerClientConnections := flag.Uint("max_per_client_connections", 100, "maximum number of connections per client") maxConnections := flag.Uint("max_connections", 100, "maximum number of connections per mongo") portStart := flag.Int("port_start", 6000, "start of port range") portEnd := flag.Int("port_end", 6010, "end of port range") addrs := flag.String("addrs", "localhost:27017", "comma separated list of mongo addresses") flag.Parse() replicaSet := dvara.ReplicaSet{ Addrs: *addrs, PortStart: *portStart, PortEnd: *portEnd, MessageTimeout: *messageTimeout, ClientIdleTimeout: *clientIdleTimeout, ServerIdleTimeout: *serverIdleTimeout, ServerClosePoolSize: *serverClosePoolSize, GetLastErrorTimeout: *getLastErrorTimeout, MaxConnections: *maxConnections, MaxPerClientConnections: *maxPerClientConnections, } var statsClient stats.HookClient var log stdLogger var graph inject.Graph err := graph.Provide( &inject.Object{Value: &log}, &inject.Object{Value: &replicaSet}, &inject.Object{Value: &statsClient}, ) if err != nil { return err } if err := graph.Populate(); err != nil { return err } objects := graph.Objects() if err := startstop.Start(objects, &log); err != nil { return err } defer startstop.Stop(objects, &log) ch := make(chan os.Signal, 2) signal.Notify(ch, syscall.SIGTERM, syscall.SIGINT) <-ch signal.Stop(ch) return nil } ================================================ FILE: common_test.go ================================================ package dvara import ( "os" "testing" "time" "gopkg.in/mgo.v2" "github.com/facebookgo/ensure" "github.com/facebookgo/inject" "github.com/facebookgo/mgotest" "github.com/facebookgo/startstop" "github.com/facebookgo/stats" ) var ( disableSlowTests = os.Getenv("GO_RUN_LONG_TEST") == "" veryVerbose = os.Getenv("VERY_VERBOSE") == "1" ) type tLogger struct { TB testing.TB } func (l *tLogger) Error(args ...interface{}) { if veryVerbose { l.TB.Log(args...) } } func (l *tLogger) Errorf(format string, args ...interface{}) { if veryVerbose { l.TB.Logf(format, args...) } } func (l *tLogger) Warn(args ...interface{}) { if veryVerbose { l.TB.Log(args...) } } func (l *tLogger) Warnf(format string, args ...interface{}) { if veryVerbose { l.TB.Logf(format, args...) } } func (l *tLogger) Info(args ...interface{}) { if veryVerbose { l.TB.Log(args...) } } func (l *tLogger) Infof(format string, args ...interface{}) { if veryVerbose { l.TB.Logf(format, args...) } } func (l *tLogger) Debug(args ...interface{}) { if veryVerbose { l.TB.Log(args...) } } func (l *tLogger) Debugf(format string, args ...interface{}) { if veryVerbose { l.TB.Logf(format, args...) } } type stopper interface { Stop() } type Harness struct { T testing.TB Stopper stopper // This is either mgotest.Server or mgotest.ReplicaSet ReplicaSet *ReplicaSet Graph *inject.Graph Log *tLogger } func newHarnessInternal(url string, s stopper, t testing.TB) *Harness { replicaSet := ReplicaSet{ Addrs: url, PortStart: 2000, PortEnd: 3000, MaxConnections: 5, MinIdleConnections: 5, ServerIdleTimeout: 5 * time.Minute, ServerClosePoolSize: 5, ClientIdleTimeout: 5 * time.Minute, MaxPerClientConnections: 250, GetLastErrorTimeout: 5 * time.Minute, MessageTimeout: time.Minute, } log := tLogger{TB: t} var graph inject.Graph err := graph.Provide( &inject.Object{Value: &log}, &inject.Object{Value: &replicaSet}, &inject.Object{Value: &stats.HookClient{}}, ) ensure.Nil(t, err) ensure.Nil(t, graph.Populate()) objects := graph.Objects() ensure.Nil(t, startstop.Start(objects, &log)) return &Harness{ T: t, Stopper: s, ReplicaSet: &replicaSet, Graph: &graph, Log: &log, } } type SingleHarness struct { *Harness MgoServer *mgotest.Server } func NewSingleHarness(t testing.TB) *SingleHarness { mgoserver := mgotest.NewStartedServer(t) return &SingleHarness{ Harness: newHarnessInternal(mgoserver.URL(), mgoserver, t), MgoServer: mgoserver, } } type ReplicaSetHarness struct { *Harness MgoReplicaSet *mgotest.ReplicaSet } func NewReplicaSetHarness(n uint, t testing.TB) *ReplicaSetHarness { if disableSlowTests { t.Skip("disabled because it's slow") } mgoRS := mgotest.NewReplicaSet(n, t) return &ReplicaSetHarness{ Harness: newHarnessInternal(mgoRS.Addrs()[n-1], mgoRS, t), MgoReplicaSet: mgoRS, } } func (h *Harness) Stop() { defer h.Stopper.Stop() ensure.Nil(h.T, startstop.Stop(h.Graph.Objects(), h.Log)) } func (h *Harness) ProxySession() *mgo.Session { return h.Dial(h.ReplicaSet.ProxyMembers()[0]) } func (h *Harness) RealSession() *mgo.Session { return h.Dial(h.ReplicaSet.lastState.Addrs()[0]) } func (h *Harness) Dial(u string) *mgo.Session { session, err := mgo.Dial(u) ensure.Nil(h.T, err, u) session.SetSafe(&mgo.Safe{FSync: true, W: 1}) session.SetSyncTimeout(time.Minute) session.SetSocketTimeout(time.Minute) return session } ================================================ FILE: doc.go ================================================ // Package dvara provides a library to enable setting up a proxy server // for mongo. package dvara ================================================ FILE: license ================================================ BSD License For dvara software Copyright (c) 2015, Facebook, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: patents ================================================ Additional Grant of Patent Rights Version 2 "Software" means the dvara software distributed by Facebook, Inc. Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable (subject to the termination provision below) license under any Necessary Claims, to make, have made, use, sell, offer to sell, import, and otherwise transfer the Software. For avoidance of doubt, no license is granted under Facebook’s rights in any patent claims that are infringed by (i) modifications to the Software made by you or any third party or (ii) the Software in combination with any software or other technology. The license granted hereunder will terminate, automatically and without notice, if you (or any of your subsidiaries, corporate affiliates or agents) initiate directly or indirectly, or take a direct financial interest in, any Patent Assertion: (i) against Facebook or any of its subsidiaries or corporate affiliates, (ii) against any party if such Patent Assertion arises in whole or in part from any software, technology, product or service of Facebook or any of its subsidiaries or corporate affiliates, or (iii) against any party relating to the Software. Notwithstanding the foregoing, if Facebook or any of its subsidiaries or corporate affiliates files a lawsuit alleging patent infringement against you in the first instance, and you respond by filing a patent infringement counterclaim in that lawsuit against that party that is unrelated to the Software, the license granted hereunder will not terminate under section (i) of this paragraph due to such counterclaim. A "Necessary Claim" is a claim of a patent owned by Facebook that is necessarily infringed by the Software standing alone. A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, or contributory infringement or inducement to infringe any patent, including a cross-claim or counterclaim. ================================================ FILE: protocol.go ================================================ package dvara import ( "errors" "fmt" "io" ) var ( errWrite = errors.New("incorrect number of bytes written") ) // Look at http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/ for the protocol. // OpCode allow identifying the type of operation: // // http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/#request-opcodes type OpCode int32 // String returns a human readable representation of the OpCode. func (c OpCode) String() string { switch c { default: return "UNKNOWN" case OpReply: return "REPLY" case OpMessage: return "MESSAGE" case OpUpdate: return "UPDATE" case OpInsert: return "INSERT" case Reserved: return "RESERVED" case OpQuery: return "QUERY" case OpGetMore: return "GET_MORE" case OpDelete: return "DELETE" case OpKillCursors: return "KILL_CURSORS" } } // IsMutation tells us if the operation will mutate data. These operations can // be followed up by a getLastErr operation. func (c OpCode) IsMutation() bool { return c == OpInsert || c == OpUpdate || c == OpDelete } // HasResponse tells us if the operation will have a response from the server. func (c OpCode) HasResponse() bool { return c == OpQuery || c == OpGetMore } // The full set of known request op codes: // http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/#request-opcodes const ( OpReply = OpCode(1) OpMessage = OpCode(1000) OpUpdate = OpCode(2001) OpInsert = OpCode(2002) Reserved = OpCode(2003) OpQuery = OpCode(2004) OpGetMore = OpCode(2005) OpDelete = OpCode(2006) OpKillCursors = OpCode(2007) ) // messageHeader is the mongo MessageHeader type messageHeader struct { // MessageLength is the total message size, including this header MessageLength int32 // RequestID is the identifier for this miessage RequestID int32 // ResponseTo is the RequestID of the message being responded to. used in DB responses ResponseTo int32 // OpCode is the request type, see consts above. OpCode OpCode } // ToWire converts the messageHeader to the wire protocol func (m messageHeader) ToWire() []byte { var d [headerLen]byte b := d[:] setInt32(b, 0, m.MessageLength) setInt32(b, 4, m.RequestID) setInt32(b, 8, m.ResponseTo) setInt32(b, 12, int32(m.OpCode)) return b } // FromWire reads the wirebytes into this object func (m *messageHeader) FromWire(b []byte) { m.MessageLength = getInt32(b, 0) m.RequestID = getInt32(b, 4) m.ResponseTo = getInt32(b, 8) m.OpCode = OpCode(getInt32(b, 12)) } func (m *messageHeader) WriteTo(w io.Writer) error { b := m.ToWire() n, err := w.Write(b) if err != nil { return err } if n != len(b) { return errWrite } return nil } // String returns a string representation of the message header. Useful for debugging. func (m *messageHeader) String() string { return fmt.Sprintf( "opCode:%s (%d) msgLen:%d reqID:%d respID:%d", m.OpCode, m.OpCode, m.MessageLength, m.RequestID, m.ResponseTo, ) } func readHeader(r io.Reader) (*messageHeader, error) { var d [headerLen]byte b := d[:] if _, err := io.ReadFull(r, b); err != nil { return nil, err } h := messageHeader{} h.FromWire(b) return &h, nil } // copyMessage copies reads & writes an entire message. func copyMessage(w io.Writer, r io.Reader) error { h, err := readHeader(r) if err != nil { return err } if err := h.WriteTo(w); err != nil { return err } _, err = io.CopyN(w, r, int64(h.MessageLength-headerLen)) return err } // readDocument read an entire BSON document. This document can be used with // bson.Unmarshal. func readDocument(r io.Reader) ([]byte, error) { var sizeRaw [4]byte if _, err := io.ReadFull(r, sizeRaw[:]); err != nil { return nil, err } size := getInt32(sizeRaw[:], 0) doc := make([]byte, size) setInt32(doc, 0, size) if _, err := io.ReadFull(r, doc[4:]); err != nil { return nil, err } return doc, nil } const x00 = byte(0) // readCString reads a null turminated string as defined by BSON from the // reader. Note, the return value includes the trailing null byte. func readCString(r io.Reader) ([]byte, error) { var b []byte var n [1]byte for { if _, err := io.ReadFull(r, n[:]); err != nil { return nil, err } b = append(b, n[0]) if n[0] == x00 { return b, nil } } } // all data in the MongoDB wire protocol is little-endian. // all the read/write functions below are little-endian. func getInt32(b []byte, pos int) int32 { return (int32(b[pos+0])) | (int32(b[pos+1]) << 8) | (int32(b[pos+2]) << 16) | (int32(b[pos+3]) << 24) } func setInt32(b []byte, pos int, i int32) { b[pos] = byte(i) b[pos+1] = byte(i >> 8) b[pos+2] = byte(i >> 16) b[pos+3] = byte(i >> 24) } ================================================ FILE: protocol_test.go ================================================ package dvara import ( "bytes" "errors" "io" "testing" ) type testReader struct { read func([]byte) (int, error) } func (t testReader) Read(b []byte) (int, error) { return t.read(b) } type testWriter struct { write func([]byte) (int, error) } func (t testWriter) Write(b []byte) (int, error) { return t.write(b) } func TestOpStrings(t *testing.T) { t.Parallel() cases := []struct { OpCode OpCode String string }{ {OpCode(0), "UNKNOWN"}, {OpReply, "REPLY"}, {OpMessage, "MESSAGE"}, {OpUpdate, "UPDATE"}, {OpInsert, "INSERT"}, {Reserved, "RESERVED"}, {OpQuery, "QUERY"}, {OpGetMore, "GET_MORE"}, {OpDelete, "DELETE"}, {OpKillCursors, "KILL_CURSORS"}, } for _, c := range cases { if c.OpCode.String() != c.String { t.Fatalf("for code %d expected %s but got %s", c.OpCode, c.String, c.OpCode) } } } func TestMsgHeaderString(t *testing.T) { t.Parallel() m := &messageHeader{ OpCode: OpQuery, MessageLength: 10, RequestID: 42, ResponseTo: 43, } if m.String() != "opCode:QUERY (2004) msgLen:10 reqID:42 respID:43" { t.Fatalf("did not find expected string, instead found: %s", m) } } func TestCopyEmptyMessage(t *testing.T) { t.Parallel() msg := messageHeader{} msgBytes := msg.ToWire() r := bytes.NewReader(msgBytes) var w bytes.Buffer if err := copyMessage(&w, r); err != nil { t.Fatal(err) } if !bytes.Equal(msgBytes, w.Bytes()) { t.Fatalf("did not get expected bytes %v got %v", msgBytes, w.Bytes()) } } func TestCopyMessageFromReadError(t *testing.T) { t.Parallel() expectedErr := errors.New("foo") r := testReader{ read: func(b []byte) (int, error) { return 0, expectedErr }, } var w bytes.Buffer if err := copyMessage(&w, r); err != expectedErr { t.Fatalf("did not get expected error, instead got: %s", err) } } func TestCopyMessageFromWriteError(t *testing.T) { t.Parallel() msg := messageHeader{} r := bytes.NewReader(msg.ToWire()) expectedErr := errors.New("foo") w := testWriter{ write: func(b []byte) (int, error) { return 0, expectedErr }, } if err := copyMessage(w, r); err != expectedErr { t.Fatalf("did not get expected error, instead got: %s", err) } } func TestCopyMessageFromWriteLengthError(t *testing.T) { t.Parallel() msg := messageHeader{} r := bytes.NewReader(msg.ToWire()) w := testWriter{ write: func(b []byte) (int, error) { return 0, nil }, } if err := copyMessage(w, r); err != errWrite { t.Fatalf("did not get expected error, instead got: %s", err) } } func TestReadDocumentEmpty(t *testing.T) { t.Parallel() doc, err := readDocument(bytes.NewReader([]byte{})) if err != io.EOF { t.Fatal("did not find expected error") } if len(doc) != 0 { t.Fatal("was expecting an empty document") } } func TestReadDocumentPartial(t *testing.T) { t.Parallel() first := true r := testReader{ read: func(b []byte) (int, error) { if first { first = false setInt32(b, 0, 5) return 4, nil } return 0, io.EOF }, } doc, err := readDocument(r) if err != io.EOF { t.Fatalf("did not find expected error, instead got %s %v", err, doc) } if len(doc) != 0 { t.Fatal("was expecting an empty document") } } func TestReadCString(t *testing.T) { t.Parallel() cases := []struct { Data []byte Expected []byte Error error }{ {nil, nil, io.EOF}, {[]byte{0}, []byte{0}, nil}, {[]byte{1, 2, 3, 0}, []byte{1, 2, 3, 0}, nil}, {[]byte{1, 0, 3}, []byte{1, 0}, nil}, } for _, c := range cases { cstring, err := readCString(bytes.NewReader(c.Data)) if err != c.Error { t.Fatalf("did not find expected error, instead got %s %v", err, cstring) } if !bytes.Equal(c.Expected, cstring) { t.Fatalf("did not find expected %v instead got %v", c.Expected, cstring) } } } ================================================ FILE: proxy.go ================================================ package dvara import ( "errors" "fmt" "io" "net" "os" "strings" "sync" "time" "github.com/facebookgo/rpool" "github.com/facebookgo/stats" ) const headerLen = 16 var ( errZeroMaxConnections = errors.New("dvara: MaxConnections cannot be 0") errZeroMaxPerClientConnections = errors.New("dvara: MaxPerClientConnections cannot be 0") errNormalClose = errors.New("dvara: normal close") errClientReadTimeout = errors.New("dvara: client read timeout") timeInPast = time.Now() ) // Proxy sends stuff from clients to mongo servers. type Proxy struct { Log Logger ReplicaSet *ReplicaSet ClientListener net.Listener // Listener for incoming client connections ProxyAddr string // Address for incoming client connections MongoAddr string // Address for destination Mongo server wg sync.WaitGroup closed chan struct{} serverPool rpool.Pool stats stats.Client maxPerClientConnections *maxPerClientConnections } // String representation for debugging. func (p *Proxy) String() string { return fmt.Sprintf("proxy %s => mongo %s", p.ProxyAddr, p.MongoAddr) } // Start the proxy. func (p *Proxy) Start() error { if p.ReplicaSet.MaxConnections == 0 { return errZeroMaxConnections } if p.ReplicaSet.MaxPerClientConnections == 0 { return errZeroMaxPerClientConnections } p.closed = make(chan struct{}) p.maxPerClientConnections = newMaxPerClientConnections(p.ReplicaSet.MaxPerClientConnections) p.serverPool = rpool.Pool{ New: p.newServerConn, CloseErrorHandler: p.serverCloseErrorHandler, Max: p.ReplicaSet.MaxConnections, MinIdle: p.ReplicaSet.MinIdleConnections, IdleTimeout: p.ReplicaSet.ServerIdleTimeout, ClosePoolSize: p.ReplicaSet.ServerClosePoolSize, } // plug stats if we can if p.ReplicaSet.Stats != nil { // Drop the default port suffix to make them pretty in production. dbName := strings.TrimSuffix(p.MongoAddr, ":27017") // We want 2 sets of keys, one specific to the proxy, and another shared // with others. p.serverPool.Stats = stats.PrefixClient( []string{ "mongoproxy.server.pool.", fmt.Sprintf("mongoproxy.%s.server.pool.", dbName), }, p.ReplicaSet.Stats, ) p.stats = stats.PrefixClient( []string{ "mongoproxy.", fmt.Sprintf("mongoproxy.%s.", dbName), }, p.ReplicaSet.Stats, ) } go p.clientAcceptLoop() return nil } // Stop the proxy. func (p *Proxy) Stop() error { return p.stop(false) } func (p *Proxy) stop(hard bool) error { if err := p.ClientListener.Close(); err != nil { return err } close(p.closed) if !hard { p.wg.Wait() } p.serverPool.Close() return nil } func (p *Proxy) checkRSChanged() bool { addrs := p.ReplicaSet.lastState.Addrs() r, err := p.ReplicaSet.ReplicaSetStateCreator.FromAddrs(addrs, p.ReplicaSet.Name) if err != nil { p.Log.Errorf("all nodes possibly down?: %s", err) return true } if err := r.AssertEqual(p.ReplicaSet.lastState); err != nil { p.Log.Error(err) go p.ReplicaSet.Restart() return true } return false } // Open up a new connection to the server. Retry 7 times, doubling the sleep // each time. This means we'll a total of 12.75 seconds with the last wait // being 6.4 seconds. func (p *Proxy) newServerConn() (io.Closer, error) { retrySleep := 50 * time.Millisecond for retryCount := 7; retryCount > 0; retryCount-- { c, err := net.Dial("tcp", p.MongoAddr) if err == nil { return c, nil } p.Log.Error(err) // abort if rs changed if p.checkRSChanged() { return nil, errNormalClose } time.Sleep(retrySleep) retrySleep = retrySleep * 2 } return nil, fmt.Errorf("could not connect to %s", p.MongoAddr) } // getServerConn gets a server connection from the pool. func (p *Proxy) getServerConn() (net.Conn, error) { c, err := p.serverPool.Acquire() if err != nil { return nil, err } return c.(net.Conn), nil } func (p *Proxy) serverCloseErrorHandler(err error) { p.Log.Error(err) } // proxyMessage proxies a message, possibly it's response, and possibly a // follow up call. func (p *Proxy) proxyMessage( h *messageHeader, client net.Conn, server net.Conn, lastError *LastError, ) error { p.Log.Debugf("proxying message %s from %s for %s", h, client.RemoteAddr(), p) deadline := time.Now().Add(p.ReplicaSet.MessageTimeout) server.SetDeadline(deadline) client.SetDeadline(deadline) // OpQuery may need to be transformed and need special handling in order to // make the proxy transparent. if h.OpCode == OpQuery { stats.BumpSum(p.stats, "message.with.response", 1) return p.ReplicaSet.ProxyQuery.Proxy(h, client, server, lastError) } // Anything besides a getlasterror call (which requires an OpQuery) resets // the lastError. if lastError.Exists() { p.Log.Debug("reset getLastError cache") lastError.Reset() } // For other Ops we proxy the header & raw body over. if err := h.WriteTo(server); err != nil { p.Log.Error(err) return err } if _, err := io.CopyN(server, client, int64(h.MessageLength-headerLen)); err != nil { p.Log.Error(err) return err } // For Ops with responses we proxy the raw response message over. if h.OpCode.HasResponse() { stats.BumpSum(p.stats, "message.with.response", 1) if err := copyMessage(client, server); err != nil { p.Log.Error(err) return err } } return nil } // clientAcceptLoop accepts new clients and creates a clientServeLoop for each // new client that connects to the proxy. func (p *Proxy) clientAcceptLoop() { for { p.wg.Add(1) c, err := p.ClientListener.Accept() if err != nil { p.wg.Done() if strings.Contains(err.Error(), "use of closed network connection") { break } p.Log.Error(err) continue } go p.clientServeLoop(c) } } // clientServeLoop loops on a single client connected to the proxy and // dispatches its requests. func (p *Proxy) clientServeLoop(c net.Conn) { remoteIP := c.RemoteAddr().(*net.TCPAddr).IP.String() // enforce per-client max connection limit if p.maxPerClientConnections.inc(remoteIP) { c.Close() stats.BumpSum(p.stats, "client.rejected.max.connections", 1) p.Log.Errorf("rejecting client connection due to max connections limit: %s", remoteIP) return } // turn on TCP keep-alive and set it to the recommended period of 2 minutes // http://docs.mongodb.org/manual/faq/diagnostics/#faq-keepalive if conn, ok := c.(*net.TCPConn); ok { conn.SetKeepAlivePeriod(2 * time.Minute) conn.SetKeepAlive(true) } c = teeIf(fmt.Sprintf("client %s <=> %s", c.RemoteAddr(), p), c) p.Log.Infof("client %s connected to %s", c.RemoteAddr(), p) stats.BumpSum(p.stats, "client.connected", 1) defer func() { p.Log.Infof("client %s disconnected from %s", c.RemoteAddr(), p) p.wg.Done() if err := c.Close(); err != nil { p.Log.Error(err) } p.maxPerClientConnections.dec(remoteIP) }() var lastError LastError for { h, err := p.idleClientReadHeader(c) if err != nil { if err != errNormalClose { p.Log.Error(err) } return } mpt := stats.BumpTime(p.stats, "message.proxy.time") serverConn, err := p.getServerConn() if err != nil { if err != errNormalClose { p.Log.Error(err) } return } scht := stats.BumpTime(p.stats, "server.conn.held.time") for { err := p.proxyMessage(h, c, serverConn, &lastError) if err != nil { p.serverPool.Discard(serverConn) p.Log.Error(err) stats.BumpSum(p.stats, "message.proxy.error", 1) if ne, ok := err.(net.Error); ok && ne.Timeout() { stats.BumpSum(p.stats, "message.proxy.timeout", 1) } if err == errRSChanged { go p.ReplicaSet.Restart() } return } // One message was proxied, stop it's timer. mpt.End() if !h.OpCode.IsMutation() { break } // If the operation we just performed was a mutation, we always make the // follow up request on the same server because it's possibly a getLastErr // call which expects this behavior. stats.BumpSum(p.stats, "message.with.mutation", 1) h, err = p.gleClientReadHeader(c) if err != nil { // Client did not make _any_ query within the GetLastErrorTimeout. // Return the server to the pool and wait go back to outer loop. if err == errClientReadTimeout { break } // Prevent noise of normal client disconnects, but log if anything else. if err != errNormalClose { p.Log.Error(err) } // We need to return our server to the pool (it's still good as far // as we know). p.serverPool.Release(serverConn) return } // Successfully read message when waiting for the getLastError call. mpt = stats.BumpTime(p.stats, "message.proxy.time") } p.serverPool.Release(serverConn) scht.End() stats.BumpSum(p.stats, "message.proxy.success", 1) } } // We wait for upto ClientIdleTimeout in MessageTimeout increments and keep // checking if we're waiting to be closed. This ensures that at worse we // wait for MessageTimeout when closing even when we're idling. func (p *Proxy) idleClientReadHeader(c net.Conn) (*messageHeader, error) { h, err := p.clientReadHeader(c, p.ReplicaSet.ClientIdleTimeout) if err == errClientReadTimeout { stats.BumpSum(p.stats, "client.idle.timeout", 1) } return h, err } func (p *Proxy) gleClientReadHeader(c net.Conn) (*messageHeader, error) { h, err := p.clientReadHeader(c, p.ReplicaSet.GetLastErrorTimeout) if err == errClientReadTimeout { stats.BumpSum(p.stats, "client.gle.timeout", 1) } return h, err } func (p *Proxy) clientReadHeader(c net.Conn, timeout time.Duration) (*messageHeader, error) { t := stats.BumpTime(p.stats, "client.read.header.time") type headerError struct { header *messageHeader error error } resChan := make(chan headerError) c.SetReadDeadline(time.Now().Add(timeout)) go func() { h, err := readHeader(c) resChan <- headerError{header: h, error: err} }() closed := false var response headerError select { case response = <-resChan: // all good case <-p.closed: closed = true c.SetReadDeadline(timeInPast) response = <-resChan } // Successfully read a header. if response.error == nil { t.End() return response.header, nil } // Client side disconnected. if response.error == io.EOF { stats.BumpSum(p.stats, "client.clean.disconnect", 1) return nil, errNormalClose } // We hit our ReadDeadline. if ne, ok := response.error.(net.Error); ok && ne.Timeout() { if closed { stats.BumpSum(p.stats, "client.clean.disconnect", 1) return nil, errNormalClose } return nil, errClientReadTimeout } // Some other unknown error. stats.BumpSum(p.stats, "client.error.disconnect", 1) p.Log.Error(response.error) return nil, response.error } var teeIfEnable = os.Getenv("MONGOPROXY_TEE") == "1" type teeConn struct { context string net.Conn } func (t teeConn) Read(b []byte) (int, error) { n, err := t.Conn.Read(b) if n > 0 { fmt.Fprintf(os.Stdout, "READ %s: %s %v\n", t.context, b[0:n], b[0:n]) } return n, err } func (t teeConn) Write(b []byte) (int, error) { n, err := t.Conn.Write(b) if n > 0 { fmt.Fprintf(os.Stdout, "WRIT %s: %s %v\n", t.context, b[0:n], b[0:n]) } return n, err } func teeIf(context string, c net.Conn) net.Conn { if teeIfEnable { return teeConn{ context: context, Conn: c, } } return c } type maxPerClientConnections struct { max uint counts map[string]uint mutex sync.Mutex } func newMaxPerClientConnections(max uint) *maxPerClientConnections { return &maxPerClientConnections{ max: max, counts: make(map[string]uint), } } func (m *maxPerClientConnections) inc(remoteIP string) bool { m.mutex.Lock() defer m.mutex.Unlock() current := m.counts[remoteIP] if current >= m.max { return true } m.counts[remoteIP] = current + 1 return false } func (m *maxPerClientConnections) dec(remoteIP string) { m.mutex.Lock() defer m.mutex.Unlock() current := m.counts[remoteIP] // delete rather than having entries with 0 connections if current == 1 { delete(m.counts, remoteIP) } else { m.counts[remoteIP] = current - 1 } } ================================================ FILE: proxy_test.go ================================================ package dvara import ( "fmt" "strings" "testing" "github.com/facebookgo/ensure" "github.com/facebookgo/inject" "github.com/facebookgo/mgotest" "github.com/facebookgo/startstop" "github.com/facebookgo/stats" "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" ) func TestParallelInsertWithUniqueIndex(t *testing.T) { t.Parallel() if disableSlowTests { t.Skip("TestParallelInsertWithUniqueIndex disabled because it's slow") } h := NewSingleHarness(t) defer h.Stop() limit := 20000 c := make(chan int, limit) for i := 0; i < 3; i++ { go inserter(h.ProxySession(), c, limit) } set := make(map[int]bool) for k := range c { if set[k] { t.Fatal("Double write on same value") } set[k] = true if len(set) == limit { break } } } func inserter(s *mgo.Session, channel chan int, limit int) { defer s.Close() c := s.DB("test").C("test") c.EnsureIndex(mgo.Index{Key: []string{"phoneNum"}, Unique: true}) for i := 1; i <= limit; i++ { if err := c.Insert(bson.M{"phoneNum": i}); err == nil { channel <- i } } } func TestSimpleCRUD(t *testing.T) { t.Parallel() p := NewSingleHarness(t) defer p.Stop() session := p.ProxySession() defer session.Close() collection := session.DB("test").C("coll1") data := map[string]interface{}{ "_id": 1, "name": "abc", } err := collection.Insert(data) if err != nil { t.Fatal("insertion error", err) } n, err := collection.Count() if err != nil { t.Fatal(err) } if n != 1 { t.Fatalf("expecting 1 got %d", n) } result := make(map[string]interface{}) collection.Find(bson.M{"_id": 1}).One(&result) if result["name"] != "abc" { t.Fatal("expecting name abc got", result) } err = collection.DropCollection() if err != nil { t.Fatal(err) } } // inserting data with same id field twice should fail func TestIDConstraint(t *testing.T) { t.Parallel() p := NewSingleHarness(t) defer p.Stop() session := p.ProxySession() defer session.Close() collection := session.DB("test").C("coll1") data := map[string]interface{}{ "_id": 1, "name": "abc", } err := collection.Insert(data) if err != nil { t.Fatal("insertion error", err) } err = collection.Insert(data) if err == nil { t.Fatal("insertion failed on same id without write concern") } } // inserting data voilating index clause on a separate connection should fail func TestEnsureIndex(t *testing.T) { t.Parallel() p := NewSingleHarness(t) defer p.Stop() session := p.ProxySession() collection := session.DB("test").C("coll1") index := mgo.Index{ Key: []string{"lastname", "firstname"}, Unique: true, DropDups: true, Background: true, // See notes. Sparse: true, } err := collection.EnsureIndex(index) ensure.Nil(t, err) err = collection.Insert( map[string]string{ "firstname": "harvey", "lastname": "dent", }, ) if err != nil { t.Fatal("insertion error", err) } session.Close() session = p.ProxySession() defer session.Close() collection = session.DB("test").C("coll1") err = collection.Insert( map[string]string{ "firstname": "harvey", "lastname": "dent", }, ) ensure.NotNil(t, err) } // inserting same data after dropping an index should work func TestDropIndex(t *testing.T) { t.Parallel() p := NewSingleHarness(t) defer p.Stop() session := p.ProxySession() collection := session.DB("test").C("coll1") index := mgo.Index{ Key: []string{"lastname", "firstname"}, Unique: true, DropDups: true, Background: true, // See notes. Sparse: true, } err := collection.EnsureIndex(index) if err != nil { t.Fatal("ensure index call failed") } err = collection.Insert( map[string]string{ "firstname": "harvey", "lastname": "dent", }, ) if err != nil { t.Fatal("insertion error", err) } collection.DropIndex("lastname", "firstname") session.Close() session = p.ProxySession() defer session.Close() collection = session.DB("test").C("coll1") err = collection.Insert( map[string]string{ "firstname": "harvey", "lastname": "dent", }, ) if err != nil { t.Fatal("drop index did not work") } } func TestRemoval(t *testing.T) { t.Parallel() p := NewSingleHarness(t) defer p.Stop() session := p.ProxySession() defer session.Close() collection := session.DB("test").C("coll1") if err := collection.Insert(bson.M{"S": "hello", "I": 24}); err != nil { t.Fatal(err) } if err := collection.Remove(bson.M{"S": "hello", "I": 24}); err != nil { t.Fatal(err) } var res []interface{} collection.Find(bson.M{"S": "hello", "I": 24}).All(&res) if res != nil { t.Fatal("found object after delete", res) } if err := collection.Remove(bson.M{"S": "hello", "I": 24}); err == nil { t.Fatal("removing nonexistant document should error") } } func TestUpdate(t *testing.T) { t.Parallel() p := NewSingleHarness(t) defer p.Stop() session := p.ProxySession() defer session.Close() collection := session.DB("test").C("coll1") if err := collection.Insert(bson.M{"_id": "1234", "name": "Alfred"}); err != nil { t.Fatal(err) } var result map[string]interface{} collection.Find(nil).One(&result) if result["name"] != "Alfred" { t.Fatal("insert failed") } if err := collection.Update(bson.M{"_id": "1234"}, bson.M{"name": "Jeeves"}); err != nil { t.Fatal("update failed with", err) } collection.Find(nil).One(&result) if result["name"] != "Jeeves" { t.Fatal("update failed") } if err := collection.Update(bson.M{"_id": "00000"}, bson.M{"name": "Jeeves"}); err == nil { t.Fatal("update failed") } } func TestStopChattyClient(t *testing.T) { t.Parallel() p := NewSingleHarness(t) session := p.ProxySession() defer session.Close() fin := make(chan struct{}) go func() { collection := session.DB("test").C("coll1") i := 0 for { select { default: collection.Insert(bson.M{"value": i}) i++ case <-fin: return } } }() close(fin) p.Stop() } func TestStopIdleClient(t *testing.T) { t.Parallel() p := NewSingleHarness(t) session := p.ProxySession() defer session.Close() if err := session.DB("test").C("col").Insert(bson.M{"v": 1}); err != nil { t.Fatal(err) } p.Stop() } func TestZeroMaxConnections(t *testing.T) { t.Parallel() p := &Proxy{ReplicaSet: &ReplicaSet{}} err := p.Start() if err != errZeroMaxConnections { t.Fatal("did not get expected error") } } func TestNoAddrsGiven(t *testing.T) { t.Parallel() replicaSet := ReplicaSet{MaxConnections: 1} log := tLogger{TB: t} var graph inject.Graph err := graph.Provide( &inject.Object{Value: &log}, &inject.Object{Value: &replicaSet}, &inject.Object{Value: &stats.HookClient{}}, ) ensure.Nil(t, err) ensure.Nil(t, graph.Populate()) objects := graph.Objects() err = startstop.Start(objects, &log) if err != errNoAddrsGiven { t.Fatalf("did not get expected error, got: %s", err) } } func TestSingleNodeWhenExpectingRS(t *testing.T) { t.Parallel() mgoserver := mgotest.NewStartedServer(t) defer mgoserver.Stop() replicaSet := ReplicaSet{ Addrs: fmt.Sprintf("127.0.0.1:%d,127.0.0.1:%d", mgoserver.Port, mgoserver.Port+1), MaxConnections: 1, } log := tLogger{TB: t} var graph inject.Graph err := graph.Provide( &inject.Object{Value: &log}, &inject.Object{Value: &replicaSet}, &inject.Object{Value: &stats.HookClient{}}, ) ensure.Nil(t, err) ensure.Nil(t, graph.Populate()) objects := graph.Objects() err = startstop.Start(objects, &log) if err == nil || !strings.Contains(err.Error(), "was expecting it to be in a replica set") { t.Fatalf("did not get expected error, got: %s", err) } } func TestStopListenerCloseError(t *testing.T) { t.Parallel() p := NewSingleHarness(t) p.Stop() err := p.ReplicaSet.Stop() if err == nil || !strings.Contains(err.Error(), "use of closed network connection") { t.Fatalf("did not get expected error, instead got: %s", err) } } func TestMongoGoingAwayAndReturning(t *testing.T) { t.Parallel() p := NewSingleHarness(t) session := p.ProxySession() defer session.Close() collection := session.DB("test").C("coll1") if err := collection.Insert(bson.M{"value": 1}); err != nil { t.Fatal(err) } p.MgoServer.Stop() p.MgoServer.Start() // For now we can only gurantee that eventually things will work again. In an // ideal world the very first client connection after mongo returns should // work, and we shouldn't need a loop here. for { collection = session.Copy().DB("test").C("coll1") if err := collection.Insert(bson.M{"value": 3}); err == nil { break } } p.Stop() } func benchmarkInsertRead(b *testing.B, session *mgo.Session) { defer session.Close() col := session.DB("test").C("col") col.EnsureIndex(mgo.Index{Key: []string{"answer"}, Unique: true}) insertDocs := bson.D{bson.DocElem{Name: "answer"}} inserted := bson.M{} b.ResetTimer() for i := 0; i < b.N; i++ { insertDocs[0].Value = i if err := col.Insert(insertDocs); err != nil { b.Fatal(err) } if err := col.Find(insertDocs).One(inserted); err != nil { b.Fatal(err) } if _, ok := inserted["_id"]; !ok { b.Fatalf("no _id found: %+v", inserted) } } } func BenchmarkInsertReadProxy(b *testing.B) { p := NewSingleHarness(b) benchmarkInsertRead(b, p.ProxySession()) } func BenchmarkInsertReadDirect(b *testing.B) { p := NewSingleHarness(b) benchmarkInsertRead(b, p.RealSession()) } ================================================ FILE: readme.md ================================================ dvara [![Build Status](https://secure.travis-ci.org/facebookgo/dvara.png)](http://travis-ci.org/facebookgo/dvara) ===== **NOTE**: dvara is no longer in use and we are not accepting new pull requests for it. Fork away and use it if it works for you! --- dvara provides a connection pooling proxy for [MongoDB](http://www.mongodb.org/). For more information look at the associated blog post: http://blog.parse.com/2014/06/23/dvara/. To build from source you'll need [Go](http://golang.org/). With it you can install it using: go get github.com/facebookgo/dvara/cmd/dvara Library documentation: https://godoc.org/github.com/facebookgo/dvara ================================================ FILE: replica_set.go ================================================ package dvara import ( "errors" "flag" "fmt" "net" "os" "strings" "sync" "time" "github.com/facebookgo/stackerr" "github.com/facebookgo/stats" ) var hardRestart = flag.Bool( "hard_restart", true, "if true will drop clients on restart", ) // Logger allows for simple text logging. type Logger interface { Error(args ...interface{}) Errorf(format string, args ...interface{}) Warn(args ...interface{}) Warnf(format string, args ...interface{}) Info(args ...interface{}) Infof(format string, args ...interface{}) Debug(args ...interface{}) Debugf(format string, args ...interface{}) } var errNoAddrsGiven = errors.New("dvara: no seed addresses given for ReplicaSet") // ReplicaSet manages the real => proxy address mapping. // NewReplicaSet returns the ReplicaSet given the list of seed servers. It is // required for the seed servers to be a strict subset of the actual members if // they are reachable. That is, if two of the addresses are members of // different replica sets, it will be considered an error. type ReplicaSet struct { Log Logger `inject:""` ReplicaSetStateCreator *ReplicaSetStateCreator `inject:""` ProxyQuery *ProxyQuery `inject:""` // Stats if provided will be used to record interesting stats. Stats stats.Client `inject:""` // Comma separated list of mongo addresses. This is the list of "seed" // servers, and one of two conditions must be met for each entry here -- it's // either alive and part of the same replica set as all others listed, or is // not reachable. Addrs string // PortStart and PortEnd define the port range within which proxies will be // allocated. PortStart int PortEnd int // Maximum number of connections that will be established to each mongo node. MaxConnections uint // MinIdleConnections is the number of idle server connections we'll keep // around. MinIdleConnections uint // ServerIdleTimeout is the duration after which a server connection will be // considered idle. ServerIdleTimeout time.Duration // ServerClosePoolSize is the number of goroutines that will handle closing // server connections. ServerClosePoolSize uint // ClientIdleTimeout is how long until we'll consider a client connection // idle and disconnect and release it's resources. ClientIdleTimeout time.Duration // MaxPerClientConnections is how many client connections are allowed from a // single client. MaxPerClientConnections uint // GetLastErrorTimeout is how long we'll hold on to an acquired server // connection expecting a possibly getLastError call. GetLastErrorTimeout time.Duration // MessageTimeout is used to determine the timeout for a single message to be // proxied. MessageTimeout time.Duration // Name is the name of the replica set to connect to. Nodes that are not part // of this replica set will be ignored. If this is empty, the first replica set // will be used Name string proxyToReal map[string]string realToProxy map[string]string ignoredReal map[string]ReplicaState proxies map[string]*Proxy restarter *sync.Once lastState *ReplicaSetState } // Start starts proxies to support this ReplicaSet. func (r *ReplicaSet) Start() error { r.proxyToReal = make(map[string]string) r.realToProxy = make(map[string]string) r.ignoredReal = make(map[string]ReplicaState) r.proxies = make(map[string]*Proxy) if r.Addrs == "" { return errNoAddrsGiven } rawAddrs := strings.Split(r.Addrs, ",") var err error r.lastState, err = r.ReplicaSetStateCreator.FromAddrs(rawAddrs, r.Name) if err != nil { return err } healthyAddrs := r.lastState.Addrs() // Ensure we have at least one health address. if len(healthyAddrs) == 0 { return stackerr.Newf("no healthy primaries or secondaries: %s", r.Addrs) } // Add discovered nodes to seed address list. Over time if the original seed // nodes have gone away and new nodes have joined this ensures that we'll // still be able to connect. r.Addrs = strings.Join(uniq(append(rawAddrs, healthyAddrs...)), ",") r.restarter = new(sync.Once) for _, addr := range healthyAddrs { listener, err := r.newListener() if err != nil { return err } p := &Proxy{ Log: r.Log, ReplicaSet: r, ClientListener: listener, ProxyAddr: r.proxyAddr(listener), MongoAddr: addr, } if err := r.add(p); err != nil { return err } } // add the ignored hosts, unless lastRS is nil (single node mode) if r.lastState.lastRS != nil { for _, member := range r.lastState.lastRS.Members { if _, ok := r.realToProxy[member.Name]; !ok { r.ignoredReal[member.Name] = member.State } } } var wg sync.WaitGroup wg.Add(len(r.proxies)) errch := make(chan error, len(r.proxies)) for _, p := range r.proxies { go func(p *Proxy) { defer wg.Done() if err := p.Start(); err != nil { r.Log.Error(err) errch <- stackerr.Wrap(err) } }(p) } wg.Wait() select { default: return nil case err := <-errch: return err } } // Stop stops all the associated proxies for this ReplicaSet. func (r *ReplicaSet) Stop() error { return r.stop(false) } func (r *ReplicaSet) stop(hard bool) error { var wg sync.WaitGroup wg.Add(len(r.proxies)) errch := make(chan error, len(r.proxies)) for _, p := range r.proxies { go func(p *Proxy) { defer wg.Done() if err := p.stop(hard); err != nil { r.Log.Error(err) errch <- stackerr.Wrap(err) } }(p) } wg.Wait() select { default: return nil case err := <-errch: return err } } // Restart stops all the proxies and restarts them. This is used when we detect // an RS config change, like when an election happens. func (r *ReplicaSet) Restart() { r.restarter.Do(func() { r.Log.Info("restart triggered") if err := r.stop(*hardRestart); err != nil { // We log and ignore this hoping for a successful start anyways. r.Log.Errorf("stop failed for restart: %s", err) } else { r.Log.Info("successfully stopped for restart") } if err := r.Start(); err != nil { // We panic here because we can't repair from here and are pretty much // fucked. panic(fmt.Errorf("start failed for restart: %s", err)) } r.Log.Info("successfully restarted") }) } func (r *ReplicaSet) proxyAddr(l net.Listener) string { _, port, err := net.SplitHostPort(l.Addr().String()) if err != nil { panic(err) } return fmt.Sprintf("%s:%s", r.proxyHostname(), port) } func (r *ReplicaSet) proxyHostname() string { const home = "127.0.0.1" hostname, err := os.Hostname() if err != nil { r.Log.Error(err) return home } // The follow logic ensures that the hostname resolves to a local address. // If it doesn't we don't use it since it probably wont work anyways. hostnameAddrs, err := net.LookupHost(hostname) if err != nil { r.Log.Error(err) return home } interfaceAddrs, err := net.InterfaceAddrs() if err != nil { r.Log.Error(err) return home } for _, ia := range interfaceAddrs { sa := ia.String() for _, ha := range hostnameAddrs { // check for an exact match or a match ignoring the suffix bits if sa == ha || strings.HasPrefix(sa, ha+"/") { return hostname } } } r.Log.Warnf("hostname %s doesn't resolve to the current host", hostname) return home } func (r *ReplicaSet) newListener() (net.Listener, error) { for i := r.PortStart; i <= r.PortEnd; i++ { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", i)) if err == nil { return listener, nil } } return nil, fmt.Errorf( "could not find a free port in range %d-%d", r.PortStart, r.PortEnd, ) } // add a proxy/mongo mapping. func (r *ReplicaSet) add(p *Proxy) error { if _, ok := r.proxyToReal[p.ProxyAddr]; ok { return fmt.Errorf("proxy %s already used in ReplicaSet", p.ProxyAddr) } if _, ok := r.realToProxy[p.MongoAddr]; ok { return fmt.Errorf("mongo %s already exists in ReplicaSet", p.MongoAddr) } r.Log.Infof("added %s", p) r.proxyToReal[p.ProxyAddr] = p.MongoAddr r.realToProxy[p.MongoAddr] = p.ProxyAddr r.proxies[p.ProxyAddr] = p return nil } // Proxy returns the corresponding proxy address for the given real mongo // address. func (r *ReplicaSet) Proxy(h string) (string, error) { p, ok := r.realToProxy[h] if !ok { if s, ok := r.ignoredReal[h]; ok { return "", &ProxyMapperError{ RealHost: h, State: s, } } return "", fmt.Errorf("mongo %s is not in ReplicaSet", h) } return p, nil } // ProxyMembers returns the list of proxy members in this ReplicaSet. func (r *ReplicaSet) ProxyMembers() []string { members := make([]string, 0, len(r.proxyToReal)) for r := range r.proxyToReal { members = append(members, r) } return members } // SameRS checks if the given replSetGetStatusResponse is the same as the last // state. func (r *ReplicaSet) SameRS(o *replSetGetStatusResponse) bool { return r.lastState.SameRS(o) } // SameIM checks if the given isMasterResponse is the same as the last state. func (r *ReplicaSet) SameIM(o *isMasterResponse) bool { return r.lastState.SameIM(o) } // ProxyMapperError occurs when a known host is being ignored and does not have // a corresponding proxy address. type ProxyMapperError struct { RealHost string State ReplicaState } func (p *ProxyMapperError) Error() string { return fmt.Sprintf("error mapping host %s in state %s", p.RealHost, p.State) } // uniq takes a slice of strings and returns a new slice with duplicates // removed. func uniq(set []string) []string { m := make(map[string]struct{}, len(set)) for _, s := range set { m[s] = struct{}{} } news := make([]string, 0, len(m)) for s := range m { news = append(news, s) } return news } ================================================ FILE: replica_set_test.go ================================================ package dvara import ( "fmt" "testing" "github.com/facebookgo/subset" "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" ) func TestReplicaSetMembers(t *testing.T) { t.Parallel() h := NewReplicaSetHarness(3, t) defer h.Stop() proxyMembers := h.ReplicaSet.ProxyMembers() session := h.ProxySession() defer session.Close() status, err := replSetGetStatus(session) if err != nil { t.Fatal(err) } outerProxyResponseCheckLoop: for _, m := range status.Members { for _, p := range proxyMembers { if m.Name == p { continue outerProxyResponseCheckLoop } } t.Fatalf("Unexpected member: %s", m.Name) } } func TestStopNodeInReplica(t *testing.T) { t.Parallel() h := NewReplicaSetHarness(2, t) defer h.Stop() const dbName = "test" const colName = "foo" const keyName = "answer" d := bson.M{"answer": "42"} s := h.ProxySession() defer s.Close() s.SetSafe(&mgo.Safe{W: 2, WMode: "majority"}) if err := s.DB(dbName).C(colName).Insert(d); err != nil { t.Fatal(err) } h.MgoReplicaSet.Servers[0].Stop() s.SetMode(mgo.Monotonic, true) var actual bson.M if err := s.DB(dbName).C(colName).Find(d).One(&actual); err != nil { t.Fatal(err) } subset.Assert(t, d, actual) } func TestProxyNotInReplicaSet(t *testing.T) { t.Parallel() h := NewSingleHarness(t) defer h.Stop() addr := "127.0.0.1:666" expected := fmt.Sprintf("mongo %s is not in ReplicaSet", addr) _, err := h.ReplicaSet.Proxy(addr) if err == nil || err.Error() != expected { t.Fatalf("did not get expected error, got: %s", err) } } func TestAddSameProxyToReplicaSet(t *testing.T) { t.Parallel() r := &ReplicaSet{ Log: &tLogger{TB: t}, proxyToReal: make(map[string]string), realToProxy: make(map[string]string), proxies: make(map[string]*Proxy), } p := &Proxy{ ProxyAddr: "1", MongoAddr: "2", } if err := r.add(p); err != nil { t.Fatal(err) } expected := fmt.Sprintf("proxy %s already used in ReplicaSet", p.ProxyAddr) err := r.add(p) if err == nil || err.Error() != expected { t.Fatalf("did not get expected error, got: %s", err) } } func TestAddSameMongoToReplicaSet(t *testing.T) { t.Parallel() r := &ReplicaSet{ Log: &tLogger{TB: t}, proxyToReal: make(map[string]string), realToProxy: make(map[string]string), proxies: make(map[string]*Proxy), } p := &Proxy{ ProxyAddr: "1", MongoAddr: "2", } if err := r.add(p); err != nil { t.Fatal(err) } p = &Proxy{ ProxyAddr: "3", MongoAddr: p.MongoAddr, } expected := fmt.Sprintf("mongo %s already exists in ReplicaSet", p.MongoAddr) err := r.add(p) if err == nil || err.Error() != expected { t.Fatalf("did not get expected error, got: %s", err) } } func TestNewListenerZeroZeroRandomPort(t *testing.T) { t.Parallel() r := &ReplicaSet{} l, err := r.newListener() if err != nil { t.Fatal(err) } l.Close() } func TestNewListenerError(t *testing.T) { t.Parallel() r := &ReplicaSet{PortStart: 1, PortEnd: 1} _, err := r.newListener() expected := "could not find a free port in range 1-1" if err == nil || err.Error() != expected { t.Fatalf("did not get expected error, got: %s", err) } } ================================================ FILE: response_rewriter.go ================================================ package dvara import ( "bytes" "errors" "flag" "fmt" "io" "io/ioutil" "strings" "github.com/davecgh/go-spew/spew" "gopkg.in/mgo.v2/bson" ) var ( proxyAllQueries = flag.Bool( "dvara.proxy-all", false, "if true all queries will be proxied and logger", ) adminCollectionName = []byte("admin.$cmd\000") cmdCollectionSuffix = []byte(".$cmd\000") ) // ProxyQuery proxies an OpQuery and a corresponding response. type ProxyQuery struct { Log Logger `inject:""` GetLastErrorRewriter *GetLastErrorRewriter `inject:""` IsMasterResponseRewriter *IsMasterResponseRewriter `inject:""` ReplSetGetStatusResponseRewriter *ReplSetGetStatusResponseRewriter `inject:""` } // Proxy proxies an OpQuery and a corresponding response. func (p *ProxyQuery) Proxy( h *messageHeader, client io.ReadWriter, server io.ReadWriter, lastError *LastError, ) error { // https://github.com/mongodb/mongo/search?q=lastError.disableForCommand // Shows the logic we need to be in sync with. Unfortunately it isn't a // simple check to determine this, and may change underneath us at the mongo // layer. resetLastError := true parts := [][]byte{h.ToWire()} var flags [4]byte if _, err := io.ReadFull(client, flags[:]); err != nil { p.Log.Error(err) return err } parts = append(parts, flags[:]) fullCollectionName, err := readCString(client) if err != nil { p.Log.Error(err) return err } parts = append(parts, fullCollectionName) var rewriter responseRewriter if *proxyAllQueries || bytes.HasSuffix(fullCollectionName, cmdCollectionSuffix) { var twoInt32 [8]byte if _, err := io.ReadFull(client, twoInt32[:]); err != nil { p.Log.Error(err) return err } parts = append(parts, twoInt32[:]) queryDoc, err := readDocument(client) if err != nil { p.Log.Error(err) return err } parts = append(parts, queryDoc) var q bson.D if err := bson.Unmarshal(queryDoc, &q); err != nil { p.Log.Error(err) return err } p.Log.Debugf( "buffered OpQuery for %s: %s", fullCollectionName[:len(fullCollectionName)-1], spew.Sdump(q), ) if hasKey(q, "getLastError") { return p.GetLastErrorRewriter.Rewrite( h, parts, client, server, lastError, ) } if hasKey(q, "isMaster") { rewriter = p.IsMasterResponseRewriter } if bytes.Equal(adminCollectionName, fullCollectionName) && hasKey(q, "replSetGetStatus") { rewriter = p.ReplSetGetStatusResponseRewriter } if rewriter != nil { // If forShell is specified, we don't want to reset the last error. See // comment above around resetLastError for details. resetLastError = hasKey(q, "forShell") } } if resetLastError && lastError.Exists() { p.Log.Debug("reset getLastError cache") lastError.Reset() } var written int for _, b := range parts { n, err := server.Write(b) if err != nil { p.Log.Error(err) return err } written += n } pending := int64(h.MessageLength) - int64(written) if _, err := io.CopyN(server, client, pending); err != nil { p.Log.Error(err) return err } if rewriter != nil { if err := rewriter.Rewrite(client, server); err != nil { return err } return nil } if err := copyMessage(client, server); err != nil { p.Log.Error(err) return err } return nil } // LastError holds the last known error. type LastError struct { header *messageHeader rest bytes.Buffer } // Exists returns true if this instance contains a cached error. func (l *LastError) Exists() bool { return l.header != nil } // Reset resets the stored error clearing it. func (l *LastError) Reset() { l.header = nil l.rest.Reset() } // GetLastErrorRewriter handles getLastError requests and proxies, caches or // sends cached responses as necessary. type GetLastErrorRewriter struct { Log Logger `inject:""` } // Rewrite handles getLastError requests. func (r *GetLastErrorRewriter) Rewrite( h *messageHeader, parts [][]byte, client io.ReadWriter, server io.ReadWriter, lastError *LastError, ) error { if !lastError.Exists() { // We're going to be performing a real getLastError query and caching the // response. var written int for _, b := range parts { n, err := server.Write(b) if err != nil { r.Log.Error(err) return err } written += n } pending := int64(h.MessageLength) - int64(written) if _, err := io.CopyN(server, client, pending); err != nil { r.Log.Error(err) return err } var err error if lastError.header, err = readHeader(server); err != nil { r.Log.Error(err) return err } pending = int64(lastError.header.MessageLength - headerLen) if _, err = io.CopyN(&lastError.rest, server, pending); err != nil { r.Log.Error(err) return err } r.Log.Debugf("caching new getLastError response: %s", lastError.rest.Bytes()) } else { // We need to discard the pending bytes from the client from the query // before we send it our cached response. var written int for _, b := range parts { written += len(b) } pending := int64(h.MessageLength) - int64(written) if _, err := io.CopyN(ioutil.Discard, client, pending); err != nil { r.Log.Error(err) return err } // Modify and send the cached response for this request. lastError.header.ResponseTo = h.RequestID r.Log.Debugf("using cached getLastError response: %s", lastError.rest.Bytes()) } if err := lastError.header.WriteTo(client); err != nil { r.Log.Error(err) return err } if _, err := client.Write(lastError.rest.Bytes()); err != nil { r.Log.Error(err) return err } return nil } var errRSChanged = errors.New("dvara: replset config changed") // ProxyMapper maps real mongo addresses to their corresponding proxy // addresses. type ProxyMapper interface { Proxy(h string) (string, error) } // ReplicaStateCompare provides the last ReplicaSetState and allows for // checking if it has changed as we rewrite/proxy the isMaster & // replSetGetStatus queries. type ReplicaStateCompare interface { SameRS(o *replSetGetStatusResponse) bool SameIM(o *isMasterResponse) bool } type responseRewriter interface { Rewrite(client io.Writer, server io.Reader) error } type replyPrefix [20]byte var emptyPrefix replyPrefix // ReplyRW provides common helpers for rewriting replies from the server. type ReplyRW struct { Log Logger `inject:""` } // ReadOne reads a 1 document response, from the server, unmarshals it into v // and returns the various parts. func (r *ReplyRW) ReadOne(server io.Reader, v interface{}) (*messageHeader, replyPrefix, int32, error) { h, err := readHeader(server) if err != nil { r.Log.Error(err) return nil, emptyPrefix, 0, err } if h.OpCode != OpReply { err := fmt.Errorf("readOneReplyDoc: expected op %s, got %s", OpReply, h.OpCode) return nil, emptyPrefix, 0, err } var prefix replyPrefix if _, err := io.ReadFull(server, prefix[:]); err != nil { r.Log.Error(err) return nil, emptyPrefix, 0, err } numDocs := getInt32(prefix[:], 16) if numDocs != 1 { err := fmt.Errorf("readOneReplyDoc: can only handle 1 result document, got: %d", numDocs) return nil, emptyPrefix, 0, err } rawDoc, err := readDocument(server) if err != nil { r.Log.Error(err) return nil, emptyPrefix, 0, err } if err := bson.Unmarshal(rawDoc, v); err != nil { r.Log.Error(err) return nil, emptyPrefix, 0, err } return h, prefix, int32(len(rawDoc)), nil } // WriteOne writes a rewritten response to the client. func (r *ReplyRW) WriteOne(client io.Writer, h *messageHeader, prefix replyPrefix, oldDocLen int32, v interface{}) error { newDoc, err := bson.Marshal(v) if err != nil { return err } h.MessageLength = h.MessageLength - oldDocLen + int32(len(newDoc)) parts := [][]byte{h.ToWire(), prefix[:], newDoc} for _, p := range parts { if _, err := client.Write(p); err != nil { return err } } return nil } type isMasterResponse struct { Hosts []string `bson:"hosts,omitempty"` Primary string `bson:"primary,omitempty"` Me string `bson:"me,omitempty"` Extra bson.M `bson:",inline"` } // IsMasterResponseRewriter rewrites the response for the "isMaster" query. type IsMasterResponseRewriter struct { Log Logger `inject:""` ProxyMapper ProxyMapper `inject:""` ReplyRW *ReplyRW `inject:""` ReplicaStateCompare ReplicaStateCompare `inject:""` } // Rewrite rewrites the response for the "isMaster" query. func (r *IsMasterResponseRewriter) Rewrite(client io.Writer, server io.Reader) error { var err error var q isMasterResponse h, prefix, docLen, err := r.ReplyRW.ReadOne(server, &q) if err != nil { return err } if !r.ReplicaStateCompare.SameIM(&q) { return errRSChanged } var newHosts []string for _, h := range q.Hosts { newH, err := r.ProxyMapper.Proxy(h) if err != nil { if pme, ok := err.(*ProxyMapperError); ok { if pme.State != ReplicaStateArbiter { r.Log.Errorf("dropping member %s in state %s", h, pme.State) } continue } // unknown err return err } newHosts = append(newHosts, newH) } q.Hosts = newHosts if q.Primary != "" { // failure in mapping the primary is fatal if q.Primary, err = r.ProxyMapper.Proxy(q.Primary); err != nil { return err } } if q.Me != "" { // failure in mapping me is fatal if q.Me, err = r.ProxyMapper.Proxy(q.Me); err != nil { return err } } return r.ReplyRW.WriteOne(client, h, prefix, docLen, q) } type statusMember struct { Name string `bson:"name"` State ReplicaState `bson:"stateStr,omitempty"` Self bool `bson:"self,omitempty"` Extra bson.M `bson:",inline"` } type replSetGetStatusResponse struct { Name string `bson:"set,omitempty"` Members []statusMember `bson:"members"` Extra map[string]interface{} `bson:",inline"` } // ReplSetGetStatusResponseRewriter rewrites the "replSetGetStatus" response. type ReplSetGetStatusResponseRewriter struct { Log Logger `inject:""` ProxyMapper ProxyMapper `inject:""` ReplyRW *ReplyRW `inject:""` ReplicaStateCompare ReplicaStateCompare `inject:""` } // Rewrite rewrites the "replSetGetStatus" response. func (r *ReplSetGetStatusResponseRewriter) Rewrite(client io.Writer, server io.Reader) error { var err error var q replSetGetStatusResponse h, prefix, docLen, err := r.ReplyRW.ReadOne(server, &q) if err != nil { return err } if !r.ReplicaStateCompare.SameRS(&q) { return errRSChanged } var newMembers []statusMember for _, m := range q.Members { newH, err := r.ProxyMapper.Proxy(m.Name) if err != nil { if pme, ok := err.(*ProxyMapperError); ok { if pme.State != ReplicaStateArbiter { r.Log.Errorf("dropping member %s in state %s", h, pme.State) } continue } // unknown err return err } m.Name = newH newMembers = append(newMembers, m) } q.Members = newMembers return r.ReplyRW.WriteOne(client, h, prefix, docLen, q) } // case insensitive check for the specified key name in the top level. func hasKey(d bson.D, k string) bool { for _, v := range d { if strings.EqualFold(v.Name, k) { return true } } return false } ================================================ FILE: response_rewriter_test.go ================================================ package dvara import ( "bytes" "errors" "io" "reflect" "strings" "testing" "github.com/davecgh/go-spew/spew" "github.com/facebookgo/ensure" "github.com/facebookgo/inject" "github.com/facebookgo/startstop" "gopkg.in/mgo.v2/bson" ) var errInvalidBSON = errors.New("invalid BSON") type invalidBSON int func (i invalidBSON) GetBSON() (interface{}, error) { return nil, errInvalidBSON } var errProxyNotFound = errors.New("proxy not found") type fakeProxyMapper struct { m map[string]string } func (t fakeProxyMapper) Proxy(h string) (string, error) { if t.m != nil { if r, ok := t.m[h]; ok { return r, nil } } return "", errProxyNotFound } type fakeReplicaStateCompare struct{ sameRS, sameIM bool } func (f fakeReplicaStateCompare) SameRS(o *replSetGetStatusResponse) bool { return f.sameRS } func (f fakeReplicaStateCompare) SameIM(o *isMasterResponse) bool { return f.sameIM } func fakeReader(h messageHeader, rest []byte) io.Reader { return bytes.NewReader(append(h.ToWire(), rest...)) } func fakeSingleDocReply(v interface{}) io.Reader { b, err := bson.Marshal(v) if err != nil { panic(err) } b = append( []byte{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, }, b..., ) h := messageHeader{ OpCode: OpReply, MessageLength: int32(headerLen + len(b)), } return fakeReader(h, b) } type fakeReadWriter struct { io.Reader io.Writer } func TestResponseRWReadOne(t *testing.T) { t.Parallel() cases := []struct { Name string Server io.Reader Error string }{ { Name: "no header", Server: bytes.NewReader(nil), Error: "EOF", }, { Name: "non reply op", Server: bytes.NewReader((messageHeader{OpCode: OpDelete}).ToWire()), Error: "expected op REPLY, got DELETE", }, { Name: "EOF before flags", Server: bytes.NewReader((messageHeader{OpCode: OpReply}).ToWire()), Error: "EOF", }, { Name: "more than 1 document", Server: fakeReader( messageHeader{OpCode: OpReply}, []byte{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, }, ), Error: "can only handle 1 result document, got: 2", }, { Name: "EOF before document", Server: fakeReader( messageHeader{OpCode: OpReply}, []byte{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, }, ), Error: "EOF", }, { Name: "corrupted document", Server: fakeReader( messageHeader{OpCode: OpReply}, []byte{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 5, 0, 0, 0, 1, }, ), Error: "Document is corrupted", }, } for _, c := range cases { r := &ReplyRW{Log: &tLogger{TB: t}} m := bson.M{} _, _, _, err := r.ReadOne(c.Server, m) if err == nil { t.Errorf("was expecting an error for case %s", c.Name) } if !strings.Contains(err.Error(), c.Error) { t.Errorf("did not get expected error for case %s instead got %s", c.Name, err) } } } func TestResponseRWWriteOne(t *testing.T) { errWrite := errors.New("write error") t.Parallel() cases := []struct { Name string Client io.Writer Header messageHeader Prefix replyPrefix DocLen int32 Value interface{} Error string }{ { Name: "invalid bson", Value: invalidBSON(0), Error: errInvalidBSON.Error(), }, { Name: "write error", Value: map[string]string{}, Client: testWriter{ write: func(b []byte) (int, error) { return 0, errWrite }, }, Error: errWrite.Error(), }, } for _, c := range cases { r := &ReplyRW{Log: &tLogger{TB: t}} err := r.WriteOne(c.Client, &c.Header, c.Prefix, c.DocLen, c.Value) if err == nil { t.Errorf("was expecting an error for case %s", c.Name) } if !strings.Contains(err.Error(), c.Error) { t.Errorf("did not get expected error for case %s instead got %s", c.Name, err) } } } func TestIsMasterResponseRewriterFailures(t *testing.T) { t.Parallel() cases := []struct { Name string Client io.Writer Server io.Reader ProxyMapper ProxyMapper ReplicaStateCompare ReplicaStateCompare Error string }{ { Name: "no header", Server: bytes.NewReader(nil), Error: "EOF", }, { Name: "unknown host in 'hosts'", Server: fakeSingleDocReply( map[string]interface{}{ "hosts": []string{"foo"}, }, ), Error: errProxyNotFound.Error(), ProxyMapper: fakeProxyMapper{}, ReplicaStateCompare: fakeReplicaStateCompare{sameIM: true, sameRS: true}, }, { Name: "unknown host in 'primary'", Server: fakeSingleDocReply( map[string]interface{}{ "primary": "foo", }, ), Error: errProxyNotFound.Error(), ProxyMapper: fakeProxyMapper{}, ReplicaStateCompare: fakeReplicaStateCompare{sameIM: true, sameRS: true}, }, { Name: "unknown host in 'me'", Server: fakeSingleDocReply( map[string]interface{}{ "me": "foo", }, ), Error: errProxyNotFound.Error(), ProxyMapper: fakeProxyMapper{}, ReplicaStateCompare: fakeReplicaStateCompare{sameIM: true, sameRS: true}, }, { Name: "different im", Server: fakeSingleDocReply(map[string]interface{}{}), Error: errRSChanged.Error(), ProxyMapper: nil, ReplicaStateCompare: fakeReplicaStateCompare{sameIM: false, sameRS: true}, }, } for _, c := range cases { r := &IsMasterResponseRewriter{ Log: &tLogger{TB: t}, ProxyMapper: c.ProxyMapper, ReplicaStateCompare: c.ReplicaStateCompare, ReplyRW: &ReplyRW{ Log: &tLogger{TB: t}, }, } err := r.Rewrite(c.Client, c.Server) if err == nil { t.Errorf("was expecting an error for case %s", c.Name) } if !strings.Contains(err.Error(), c.Error) { t.Errorf("did not get expected error for case %s instead got %s", c.Name, err) } } } func TestIsMasterResponseRewriterSuccess(t *testing.T) { proxyMapper := fakeProxyMapper{ m: map[string]string{ "a": "1", "b": "2", "c": "3", }, } in := bson.M{ "hosts": []interface{}{"a", "b", "c"}, "me": "a", "primary": "b", "foo": "bar", } out := bson.M{ "hosts": []interface{}{"1", "2", "3"}, "me": "1", "primary": "2", "foo": "bar", } r := &IsMasterResponseRewriter{ Log: &tLogger{TB: t}, ProxyMapper: proxyMapper, ReplicaStateCompare: fakeReplicaStateCompare{sameIM: true, sameRS: true}, ReplyRW: &ReplyRW{ Log: &tLogger{TB: t}, }, } var client bytes.Buffer if err := r.Rewrite(&client, fakeSingleDocReply(in)); err != nil { t.Fatal(err) } actualOut := bson.M{} doc := client.Bytes()[headerLen+len(emptyPrefix):] if err := bson.Unmarshal(doc, &actualOut); err != nil { t.Fatal(err) } if !reflect.DeepEqual(out, actualOut) { spew.Dump(out) spew.Dump(actualOut) t.Fatal("did not get expected output") } } func TestReplSetGetStatusResponseRewriterFailures(t *testing.T) { t.Parallel() cases := []struct { Name string Client io.Writer Server io.Reader ProxyMapper ProxyMapper ReplicaStateCompare ReplicaStateCompare Error string }{ { Name: "no header", Server: bytes.NewReader(nil), Error: "EOF", }, { Name: "unknown member name", Server: fakeSingleDocReply( map[string]interface{}{ "members": []map[string]interface{}{ { "name": "foo", }, }, }, ), Error: errProxyNotFound.Error(), ProxyMapper: fakeProxyMapper{}, ReplicaStateCompare: fakeReplicaStateCompare{sameIM: true, sameRS: true}, }, { Name: "diffferent rs", Server: fakeSingleDocReply(map[string]interface{}{}), Error: errRSChanged.Error(), ProxyMapper: nil, ReplicaStateCompare: fakeReplicaStateCompare{sameIM: true, sameRS: false}, }, } for _, c := range cases { r := &ReplSetGetStatusResponseRewriter{ Log: &tLogger{TB: t}, ProxyMapper: c.ProxyMapper, ReplicaStateCompare: c.ReplicaStateCompare, ReplyRW: &ReplyRW{ Log: &tLogger{TB: t}, }, } err := r.Rewrite(c.Client, c.Server) if err == nil { t.Errorf("was expecting an error for case %s", c.Name) } if !strings.Contains(err.Error(), c.Error) { t.Errorf("did not get expected error for case %s instead got %s", c.Name, err) } } } func TestReplSetGetStatusResponseRewriterSuccess(t *testing.T) { proxyMapper := fakeProxyMapper{ m: map[string]string{ "a": "1", "b": "2", "c": "3", }, } in := bson.M{ "members": []interface{}{ bson.M{ "name": "a", "stateStr": "PRIMARY", }, bson.M{ "name": "b", }, bson.M{ "name": "c", "stateStr": "ARBITER", }, }, } out := bson.M{ "members": []interface{}{ bson.M{ "name": "1", "stateStr": "PRIMARY", }, bson.M{ "name": "2", }, bson.M{ "name": "3", "stateStr": "ARBITER", }, }, } r := &ReplSetGetStatusResponseRewriter{ Log: &tLogger{TB: t}, ProxyMapper: proxyMapper, ReplicaStateCompare: fakeReplicaStateCompare{sameIM: true, sameRS: true}, ReplyRW: &ReplyRW{ Log: &tLogger{TB: t}, }, } var client bytes.Buffer if err := r.Rewrite(&client, fakeSingleDocReply(in)); err != nil { t.Fatal(err) } actualOut := bson.M{} doc := client.Bytes()[headerLen+len(emptyPrefix):] if err := bson.Unmarshal(doc, &actualOut); err != nil { t.Fatal(err) } if !reflect.DeepEqual(out, actualOut) { spew.Dump(out) spew.Dump(actualOut) t.Fatal("did not get expected output") } } func TestProxyQuery(t *testing.T) { t.Parallel() var p ProxyQuery log := tLogger{TB: t} var graph inject.Graph err := graph.Provide( &inject.Object{Value: &fakeProxyMapper{}}, &inject.Object{Value: &fakeReplicaStateCompare{}}, &inject.Object{Value: &log}, &inject.Object{Value: &p}, ) ensure.Nil(t, err) ensure.Nil(t, graph.Populate()) objects := graph.Objects() ensure.Nil(t, startstop.Start(objects, &log)) defer startstop.Stop(objects, &log) cases := []struct { Name string Header *messageHeader Client io.ReadWriter Error string }{ { Name: "EOF while reading flags from client", Header: &messageHeader{}, Client: new(bytes.Buffer), Error: "EOF", }, { Name: "EOF while reading collection name", Header: &messageHeader{}, Client: fakeReadWriter{ Reader: bytes.NewReader( []byte{0, 0, 0, 0}, // flags int32 before collection name ), }, Error: "EOF", }, { Name: "EOF while reading skip/return", Header: &messageHeader{}, Client: fakeReadWriter{ Reader: bytes.NewReader( append( []byte{0, 0, 0, 0}, // flags int32 before collection name adminCollectionName..., ), ), }, Error: "EOF", }, { Name: "EOF while reading query document", Header: &messageHeader{}, Client: fakeReadWriter{ Reader: io.MultiReader( bytes.NewReader([]byte{0, 0, 0, 0}), // flags int32 before collection name bytes.NewReader(adminCollectionName), bytes.NewReader( []byte{ 0, 0, 0, 0, // numberToSkip int32 0, 0, 0, 0, // numberToReturn int32 1, // partial bson document length header }), ), }, Error: "EOF", }, { Name: "error while unmarshaling query document", Header: &messageHeader{}, Client: fakeReadWriter{ Reader: io.MultiReader( bytes.NewReader([]byte{0, 0, 0, 0}), // flags int32 before collection name bytes.NewReader(adminCollectionName), bytes.NewReader( []byte{ 0, 0, 0, 0, // numberToSkip int32 0, 0, 0, 0, // numberToReturn int32 5, 0, 0, 0, // bson document length header 1, // bson document }), ), }, Error: "Document is corrupted", }, } for _, c := range cases { err := p.Proxy(c.Header, c.Client, nil, nil) if err == nil || !strings.Contains(err.Error(), c.Error) { t.Fatalf("did not find expected error for %s, instead found %s", c.Name, err) } } } ================================================ FILE: rs_state.go ================================================ package dvara import ( "fmt" "sort" "time" "github.com/davecgh/go-spew/spew" "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" ) const errNotReplSet = "not running with --replSet" // ReplicaSetState is a snapshot of the RS configuration at some point in time. type ReplicaSetState struct { lastRS *replSetGetStatusResponse lastIM *isMasterResponse singleAddr string // this is only set when we're not running against a RS } // NewReplicaSetState creates a new ReplicaSetState using the given address. func NewReplicaSetState(addr string) (*ReplicaSetState, error) { info := &mgo.DialInfo{ Addrs: []string{addr}, Direct: true, Timeout: 5 * time.Second, } session, err := mgo.DialWithInfo(info) if err != nil { return nil, err } session.SetMode(mgo.Monotonic, true) session.SetSyncTimeout(5 * time.Second) session.SetSocketTimeout(5 * time.Second) defer session.Close() var r ReplicaSetState if r.lastRS, err = replSetGetStatus(session); err != nil { // This error indicates we're in Single Node Mode. That's okay. if err.Error() != errNotReplSet { return nil, err } r.singleAddr = addr } if r.lastIM, err = isMaster(session); err != nil { return nil, err } if r.lastRS != nil && len(r.lastRS.Members) == 1 { n := r.lastRS.Members[0] if n.State != "PRIMARY" || n.State != "SECONDARY" { return nil, fmt.Errorf("single node RS in bad state: %s", spew.Sdump(r)) } } // nodes starting up are invalid if r.lastRS != nil { for _, member := range r.lastRS.Members { if member.Self && member.State == "STARTUP" { return nil, fmt.Errorf("node is busy starting up: %s", member.Name) } } } return &r, nil } // AssertEqual checks if the given ReplicaSetState equals this one. It returns // a rich error message including the entire state for easier debugging. func (r *ReplicaSetState) AssertEqual(o *ReplicaSetState) error { if r.Equal(o) { return nil } return fmt.Errorf( "conflicting ReplicaSetState:\n%s\nVS\n%s", spew.Sdump(r), spew.Sdump(o), ) } // Equal returns true if the given ReplicaSetState is the same as this one. func (r *ReplicaSetState) Equal(o *ReplicaSetState) bool { return r.SameIM(o.lastIM) && r.SameRS(o.lastRS) } // SameRS checks if the given replSetGetStatusResponse is the same as the one // we have. func (r *ReplicaSetState) SameRS(o *replSetGetStatusResponse) bool { return sameRSMembers(r.lastRS, o) } // SameIM checks if the given isMasterResponse is the same as the one we have. func (r *ReplicaSetState) SameIM(o *isMasterResponse) bool { return sameIMMembers(r.lastIM, o) } // Addrs returns the addresses of members in primary or secondary state. func (r *ReplicaSetState) Addrs() []string { if r.singleAddr != "" { return []string{r.singleAddr} } var members []string for _, m := range r.lastRS.Members { if m.State == ReplicaStatePrimary || m.State == ReplicaStateSecondary { members = append(members, m.Name) } } return members } // ReplicaSetStateCreator allows for creating a ReplicaSetState from a given // set of seed addresses. type ReplicaSetStateCreator struct { Log Logger `inject:""` } // FromAddrs creates a ReplicaSetState from the given set of see addresses. It // requires the addresses to be part of the same Replica Set. func (c *ReplicaSetStateCreator) FromAddrs(addrs []string, replicaSetName string) (*ReplicaSetState, error) { var r *ReplicaSetState for _, addr := range addrs { ar, err := NewReplicaSetState(addr) if err != nil { c.Log.Errorf("ignoring failure against address %s: %s", addr, err) continue } if replicaSetName != "" { if ar.lastRS == nil { c.Log.Errorf( "ignoring standalone node %q not in expected replset: %q", addr, replicaSetName, ) continue } if ar.lastRS.Name != replicaSetName { c.Log.Errorf( "ignoring node %q not in expected replset: %q vs %q", addr, ar.lastRS.Name, replicaSetName, ) continue } } // First successful address. if r == nil { r = ar continue } // Ensure same as already established ReplicaSetState. if err := r.AssertEqual(ar); err != nil { return nil, err } } if r == nil { return nil, fmt.Errorf("could not connect to any provided addresses: %v", addrs) } // Check if we're expecting an RS but got a single node. if r.singleAddr != "" && len(addrs) != 1 { return nil, fmt.Errorf( "node %s is not in a replica set but was expecting it to be in a"+ " replica set with members %v", r.singleAddr, addrs, ) } return r, nil } var ( replSetGetStatusQuery = bson.D{ bson.DocElem{Name: "replSetGetStatus", Value: 1}, } isMasterQuery = bson.D{ bson.DocElem{Name: "isMaster", Value: 1}, } ) func replSetGetStatus(s *mgo.Session) (*replSetGetStatusResponse, error) { var res replSetGetStatusResponse if err := s.Run(replSetGetStatusQuery, &res); err != nil { return nil, err } return &res, nil } func isMaster(s *mgo.Session) (*isMasterResponse, error) { var res isMasterResponse if err := s.Run(isMasterQuery, &res); err != nil { return nil, fmt.Errorf("error in isMaster: %s", err) } return &res, nil } func sameRSMembers(a *replSetGetStatusResponse, b *replSetGetStatusResponse) bool { if (a == nil || len(a.Members) == 0) && (b == nil || len(b.Members) == 0) { return true } if a == nil || b == nil { return false } l := len(a.Members) if l != len(b.Members) { return false } aMembers := make([]string, 0, l) bMembers := make([]string, 0, l) for i := 0; i < l; i++ { aM := a.Members[i] aMembers = append(aMembers, fmt.Sprintf("%s:%s", aM.Name, aM.State)) bM := b.Members[i] bMembers = append(bMembers, fmt.Sprintf("%s:%s", bM.Name, bM.State)) } sort.Strings(aMembers) sort.Strings(bMembers) for i := 0; i < l; i++ { if aMembers[i] != bMembers[i] { return false } } return true } var emptyIsMasterResponse = isMasterResponse{} func sameIMMembers(a *isMasterResponse, b *isMasterResponse) bool { if a == nil && b == nil { return true } if a == nil { a = &emptyIsMasterResponse } if b == nil { b = &emptyIsMasterResponse } l := len(a.Hosts) if l != len(b.Hosts) { return false } aHosts := make([]string, 0, l+1) bHosts := make([]string, 0, l+1) for i := 0; i < l; i++ { aHosts = append(aHosts, a.Hosts[i]) bHosts = append(bHosts, b.Hosts[i]) } sort.Strings(aHosts) sort.Strings(bHosts) aHosts = append(aHosts, a.Primary) bHosts = append(bHosts, b.Primary) for i := range aHosts { if aHosts[i] != bHosts[i] { return false } } return true } ================================================ FILE: rs_state_test.go ================================================ package dvara import ( "testing" "github.com/facebookgo/mgotest" ) func TestSameRSMembers(t *testing.T) { t.Parallel() cases := []struct { Name string A *replSetGetStatusResponse B *replSetGetStatusResponse }{ { Name: "the same", A: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, }, }, B: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, }, }, }, { Name: "out of order", A: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, {Name: "c", State: "d"}, }, }, B: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "c", State: "d"}, {Name: "a", State: "b"}, }, }, }, { Name: "both nil", }, { Name: "A nil B empty", B: &replSetGetStatusResponse{}, }, { Name: "A empty B nil", A: &replSetGetStatusResponse{}, }, } for _, c := range cases { if !sameRSMembers(c.A, c.B) { t.Fatalf("failed %s", c.Name) } } } func TestNotSameRSMembers(t *testing.T) { t.Parallel() cases := []struct { Name string A *replSetGetStatusResponse B *replSetGetStatusResponse }{ { Name: "different name", A: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, }, }, B: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "b", State: "b"}, }, }, }, { Name: "different state", A: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, }, }, B: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "c"}, }, }, }, { Name: "subset A", A: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, {Name: "b", State: "c"}, }, }, B: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, }, }, }, { Name: "subset B", A: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, }, }, B: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, {Name: "b", State: "c"}, }, }, }, { Name: "nil A", B: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "b", State: "b"}, }, }, }, { Name: "nil B", A: &replSetGetStatusResponse{ Members: []statusMember{ {Name: "a", State: "b"}, }, }, }, } for _, c := range cases { if sameRSMembers(c.A, c.B) { t.Fatalf("failed %s", c.Name) } } } func TestSameIMMembers(t *testing.T) { t.Parallel() cases := []struct { Name string A *isMasterResponse B *isMasterResponse }{ { Name: "the same", A: &isMasterResponse{ Hosts: []string{"a", "b"}, }, B: &isMasterResponse{ Hosts: []string{"a", "b"}, }, }, { Name: "out of order", A: &isMasterResponse{ Hosts: []string{"a", "b"}, }, B: &isMasterResponse{ Hosts: []string{"b", "a"}, }, }, { Name: "both nil", }, { Name: "A nil B empty", B: &isMasterResponse{}, }, { Name: "A empty B nil", A: &isMasterResponse{}, }, } for _, c := range cases { if !sameIMMembers(c.A, c.B) { t.Fatalf("failed %s", c.Name) } } } func TestNotSameIMMembers(t *testing.T) { t.Parallel() cases := []struct { Name string A *isMasterResponse B *isMasterResponse }{ { Name: "different name", A: &isMasterResponse{ Hosts: []string{"a"}, }, B: &isMasterResponse{ Hosts: []string{"b"}, }, }, { Name: "subset A", A: &isMasterResponse{ Hosts: []string{"a", "b"}, }, B: &isMasterResponse{ Hosts: []string{"a"}, }, }, { Name: "subset B", A: &isMasterResponse{ Hosts: []string{"a"}, }, B: &isMasterResponse{ Hosts: []string{"a", "b"}, }, }, { Name: "nil A", B: &isMasterResponse{ Hosts: []string{"a"}, }, }, { Name: "nil B", A: &isMasterResponse{ Hosts: []string{"b"}, }, }, } for _, c := range cases { if sameIMMembers(c.A, c.B) { t.Fatalf("failed %s", c.Name) } } } func TestSingleNodeNewReplicaSetState(t *testing.T) { t.Parallel() mgo := mgotest.NewStartedServer(t) defer mgo.Stop() rs, err := NewReplicaSetState(mgo.URL()) if err != nil { t.Fatal(err) } if rs.singleAddr != mgo.URL() { t.Fatalf("expected %s got %s", mgo.URL(), rs.singleAddr) } } func TestNewReplicaSetStateFailure(t *testing.T) { t.Parallel() mgo := mgotest.NewStartedServer(t) mgo.Stop() _, err := NewReplicaSetState(mgo.URL()) const expected = "no reachable servers" if err == nil || err.Error() != expected { t.Fatalf("unexpected error: %s", err) } } func TestSingleNodeNewReplicaSetStateAddrs(t *testing.T) { t.Parallel() mgo := mgotest.NewStartedServer(t) defer mgo.Stop() rs, err := NewReplicaSetState(mgo.URL()) if err != nil { t.Fatal(err) } addrs := rs.Addrs() if len(addrs) != 1 || addrs[0] != mgo.URL() { t.Fatalf("unexpected addrs %v", addrs) } } func TestIgnoreMismatchingReplicaSets(t *testing.T) { if disableSlowTests { t.Skip("disabled because it's slow") } t.Parallel() creator := ReplicaSetStateCreator{ Log: &tLogger{TB: t}, } replicaSet := mgotest.NewReplicaSet(2, t) singleMongo := mgotest.NewStartedServer(t) defer func() { replicaSet.Stop() singleMongo.Stop() }() urls := replicaSet.Addrs() urls = append(urls, singleMongo.URL()) state, err := creator.FromAddrs(urls, "rs") if err != nil { t.Fatalf("unexpected error: %s", err) } if state.lastRS.Name != "rs" { t.Fatalf("unexpected replicaset: %s", state.lastRS.Name) } _, err = creator.FromAddrs(urls, "") if err == nil { t.Fatalf("missing expected error: %s", err) } } ================================================ FILE: state.go ================================================ package dvara // ReplicaState is the state of a node in the replica. type ReplicaState string const ( // ReplicaStatePrimary indicates the node is a primary. ReplicaStatePrimary = ReplicaState("PRIMARY") // ReplicaStateSecondary indicates the node is a secondary. ReplicaStateSecondary = ReplicaState("SECONDARY") // ReplicaStateArbiter indicates the node is an arbiter. ReplicaStateArbiter = ReplicaState("ARBITER") )