Repository: lithdew/reliable Branch: master Commit: 0afae4e9ed8a Files: 19 Total size: 57.8 KB Directory structure: gitextract_f9yxkokx/ ├── .gitignore ├── LICENSE ├── README.md ├── conn.go ├── conn_test.go ├── endpoint.go ├── endpoint_test.go ├── error.go ├── examples/ │ ├── basic/ │ │ └── main.go │ └── benchmark/ │ └── main.go ├── fuzz.go ├── go.mod ├── go.sum ├── options.go ├── packet.go ├── packet_test.go ├── pool.go ├── protocol.go └── protocol_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea/ corpus/ crashers/ suppressions/ reliable-fuzz.zip ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 Kenta Iwasaki Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # reliable [![MIT License](https://img.shields.io/apm/l/atomic-design-ui.svg?)](LICENSE) [![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/github.com/lithdew/reliable) [![Discord Chat](https://img.shields.io/discord/697002823123992617)](https://discord.gg/HZEbkeQ) **reliable** is a reliability layer for UDP connections in Go. With only 9 bytes of packet overhead at most, what **reliable** does for your UDP-based application is: 1. handle acknowledgement over the recipient of packets you sent, 2. handle sending acknowledgements when too many are being buffered up, 3. handle resending sent packets whose recipient hasn't been acknowledged after some timeout, and 4. handle stopping/buffering up packets to be sent when the recipients read buffer is suspected to be full. ** This project is still a WIP! Scrub through the source code, write some unit tests, help out with documentation, or open up a Github issue if you would like to help out or have any questions! ## Protocol ### Packet Header **reliable** uses the same packet header layout described in [`networkprotocol/reliable.io`](https://github.com/networkprotocol/reliable.io). All packets start with a single byte (8 bits) representing 8 different flags. Packets are sequential, and are numbered using an unsigned 16-bit integer included in the packet header unless the packet is marked to be unreliable. Packet acknowledgements (ACKs) are redundantly included in every sent packet using a total of 5 bytes: two bytes representing an unsigned 16-bit packet sequence number (ack), and three bytes representing a 32-bit bitfield (ackBits). The packet header layout, much like [`networkprotocol/reliable.io`](https://github.com/networkprotocol/reliable.io), is delta-encoded and RLE-encoded to reduce the size overhead per packet. ### Packet ACKs Given a packet we have just received from our peer, for each set bit (i) in the bitfield (ackBits), we mark a packet we have sent to be acknowledged if its sequence number is (ack - i). In the case of peer A sending packets to B, with B not sending any packets at all to A, B will send an empty packet for every 32 packets received from A so that A will be aware that B has acknowledged its packets. More explicitly, a counter (lui) is maintained representing the last consecutive packet sequence number that we have received whose acknowledgement we have told to our peer about. For example, if (lui) is 0, and we have sent acknowledgements for packets whose sequence numbers are 2, 3, 4, and 6, and we have then acknowledged packet sequence number 1, then lui would be 4. Upon updating (lui), if the next 32 consecutive sequence numbers are sequence numbers of packets we have previously received, we will increment (lui) by 32 and send a single empty packet containing the following packet acknowledgements: (ack=lui+31, ackBits=[lui,lui+31]). ### Packet Buffering Two fixed-sized sequence buffers are maintained for packets that we have sent (wq), and packets that we have received (rq). The size fixed for these buffers must evenly divide into the max value of an unsigned 16-bit integer (65536). The data structure is described in [this blog post by Glenn Fiedler](https://gafferongames.com/post/reliable_ordered_messages/). We keep track of a counter (oui), representing the last consecutive sequence number of a packet we have sent that was acknowledged by our peer. For example, if we have sent packets whose sequence numbers are in the range [0, 256], and we have received acknowledgements for packets (0, 1, 2, 3, 4, 8, 9, 10, 11, 12), then (oui) would be 4. Let cap(q) be the fixed size or capacity of sequence buffer q. While sending packets, we intermittently stop and buffer the sending of packets if we believe sending more packets would overflow the read buffer of our recipient. More explicitly, if the next packet we sent is assigned a packet number greater than (oui + cap(rq)), we stop all sends until (oui) has incremented through the recipient of a packet from our peer. ### Retransmitting Lost Packets The logic for retransmitting stale, unacknowledged sent packets and maintaining acknowledgements was taken with credit to [this blog post by Glenn Fiedler](https://gafferongames.com/post/reliable_ordered_messages/). Packets are suspected to be lost if they are not acknowledged by their recipient after 100ms. Once a packet is suspected to be lost, it is resent. As of right now, packets are resent for a maximum of 10 times. It might be wise to not allow packets to be resent a capped number of times, and to leave it up to the developer. However, that is open for discussion which I am happy to have over on my Discord server or through a Github issue. ## Rationale On my quest for finding a feasible solution against TCP head-of-line blocking, I looked through a _lot_ of reliable UDP libraries in Go, with the majority primarily suited for either file transfer or gaming: 1. A direct port of the reference C implementation of [networkprotocol/reliable.io](https://github.com/networkprotocol/reliable.io): [jakecoffman/rely](https://github.com/jakecoffman/rely) 2. A realtime multiplayer network gaming protocol: [obsilp/rmnp](https://github.com/obsilp/rmnp) 3. A reliable, production-grade ARQ protocol: [xtaci/kcp-go](https://github.com/xtaci/kcp-go) 4. An encrypted session-based streaming protocol: [ooclab/es](https://github.com/ooclab/es/tree/master/proto/udp) 5. A game networking protocol: [arl/udpnet](https://github.com/arl/udpnet) 6. A direct port of uTP (Micro Transport Protocol): [warjiang/utp](https://github.com/warjiang/utp/tree/master/utp) 7. A protocol for sending arbitrarily large, chunked amounts of data: [go-guoyk/sptp](https://github.com/go-guoyk/sptp) 8. A small-scale fast transmission protocol: [spance/suft](https://github.com/spance/suft/) 9. A direct port of QUIC: [lucas-clemente/quic-go](https://github.com/lucas-clemente/quic-go) Going through all of them, I felt that they did just a little too much for me. For my work and side projects, I have been working heavily on decentralized p2p networking protocols. The nature of these protocols is that they suffer heavily from TCP head-of-line blocking operating in high-latency/high packet loss environments. In many cases, a lot of the features provided by these libraries were either not needed, or honestly felt like they would best be handled and thought through by the developer using these libraries. For example: 1. handshaking/session management 2. packet fragmentation/reassembly 3. packet encryption/decryption So, I began working on a modular approach and decided to abstract away the reliability portion of protocols I have built into a separate library. I feel that this approach is best versus the popular alternatives like QUIC or SCTP that may, depending on your circumstances, do just a bit too much for you. After all, getting _just_ the reliability bits of a UDP-based protocol correct and well-tested is hard enough. ## Todo 1. Estimate the round-trip time (RTT) and adjust the system's packet re-transmission delay based off of it. 2. Encapsulate away protocol logic and `net.PacketConn`-related bits for a finer abstraction. 3. Keep a cache of the string representations of passed-in `net.UDPAddr`. 4. Reduce locking in as many code hot paths as possible. 5. Networking statistics (packet loss, RTT, etc.). 6. More unit tests. ## Usage **reliable** uses Go modules. To include it in your project, run the following command: ``` $ go get github.com/lithdew/reliable ``` Should you just be looking to quickly get a project or demo up and running, use `Endpoint`. If you require more flexibility, consider directly working with `Conn`. Note that some sort of keep-alive mechanism or heartbeat system needs to be bootstrapped on top, otherwise packets may indefinitely be resent as they will have failed to be acknowledged. ## Options 1. The read buffer size may be configured using `WithReadBufferSize`. The default read buffer size is 256. 2. The write buffer size may be configured using `WithWriteBufferSize`. The default write buffer size is 256. 3. The minimum period of time before we retransmit an packet that has yet to be acknowledged may be configured using `WithResendTimeout`. The default resend timeout is 100 milliseconds. 4. A packet handler which is to be called back when a packet is received may be configured using `WithEndpointPacketHandler` or `WithProtocolPacketHandler`. By default, a nil handler is provided which ignores all incoming packets. 5. An error handler which is called when errors occur on a connection that may be configured using `WithEndpointErrorHandler` or `WithProtocolErrorHandler`. By default, a nil handler is provided which ignores all errors. 6. A byte buffer pool may be passed in using `WithBufferPool`. By default, a new byte buffer pool is instantiated. ## Benchmarks A benchmark was done using [`cmd/benchmark`](examples/benchmark) from Japan to a DigitalOcean 2GB / 60 GB Disk / NYC3 server. The benchmark task was to spam 1400 byte packets from Japan to New York. Given a ping latency of roughly 220 milliseconds, the throughput was roughly 1.2 MiB/sec. Unit test benchmarks have also been performed, as shown below. ``` $ cat /proc/cpuinfo | grep 'model name' | uniq model name : Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz $ go test -bench=. -benchtime=10s goos: linux goarch: amd64 pkg: github.com/lithdew/reliable BenchmarkEndpointWriteReliablePacket-8 2053717 5941 ns/op 183 B/op 9 allocs/op BenchmarkEndpointWriteUnreliablePacket-8 2472392 4866 ns/op 176 B/op 8 allocs/op BenchmarkMarshalPacketHeader-8 749060137 15.7 ns/op 0 B/op 0 allocs/op BenchmarkUnmarshalPacketHeader-8 835547473 14.6 ns/op 0 B/op 0 allocs/op ``` ## Example You may run the example below by executing the following command: ``` $ go run github.com/lithdew/reliable/examples/basic ``` This example demonstrates: 1. how to quickly construct two UDP endpoints listening on ports 44444 and 55555, and 2. how to have the UDP endpoint at port 44444 spam 1400-byte packets to the UDP endpoint at port 55555 as fast as possible. ```go package main import ( "bytes" "errors" "github.com/davecgh/go-spew/spew" "github.com/lithdew/reliable" "io" "log" "net" "os" "os/signal" "sync" "sync/atomic" "time" ) var ( PacketData = bytes.Repeat([]byte("x"), 1400) NumPackets = uint64(0) ) func check(err error) { if err != nil && !errors.Is(err, io.EOF) { log.Panic(err) } } func listen(addr string) net.PacketConn { conn, err := net.ListenPacket("udp", addr) check(err) return conn } func handler(buf []byte, _ net.Addr) { if bytes.Equal(buf, PacketData) || len(buf) == 0 { return } spew.Dump(buf) os.Exit(1) } func main() { exit := make(chan struct{}) var wg sync.WaitGroup wg.Add(2) ca := listen("127.0.0.1:44444") cb := listen("127.0.0.1:55555") a := reliable.NewEndpoint(ca, reliable.WithEndpointPacketHandler(handler)) b := reliable.NewEndpoint(cb, reliable.WithEndpointPacketHandler(handler)) defer func() { check(ca.SetDeadline(time.Now().Add(1 * time.Millisecond))) check(cb.SetDeadline(time.Now().Add(1 * time.Millisecond))) close(exit) check(a.Close()) check(b.Close()) check(ca.Close()) check(cb.Close()) wg.Wait() }() go a.Listen() go b.Listen() // The two goroutines below have endpoint A spam endpoint B, and print out how // many packets of data are being sent per second. go func() { defer wg.Done() for { select { case <-exit: return default: } check(a.WriteReliablePacket(PacketData, b.Addr())) atomic.AddUint64(&NumPackets, 1) } }() go func() { defer wg.Done() ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-exit: return case <-ticker.C: numPackets := atomic.SwapUint64(&NumPackets, 0) numBytes := float64(numPackets) * 1400.0 / 1024.0 / 1024.0 log.Printf( "Sent %d packet(s) comprised of %.2f MiB worth of data.", numPackets, numBytes, ) } } }() ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt) <-ch } ``` ================================================ FILE: conn.go ================================================ package reliable import ( "fmt" "io" "net" ) type transmitFunc func(buf []byte) (bool, error) type Conn struct { addr net.Addr conn net.PacketConn protocol *Protocol } func NewConn(addr net.Addr, conn net.PacketConn, opts ...ProtocolOption) *Conn { p := NewProtocol(opts...) return &Conn{addr: addr, conn: conn, protocol: p} } func (c *Conn) WriteReliablePacket(buf []byte) error { buf, err := c.protocol.WritePacket(true, buf) if err != nil { return err } _, err = c.transmit(buf) return err } func (c *Conn) WriteUnreliablePacket(buf []byte) error { buf, err := c.protocol.WritePacket(false, buf) if err != nil { return err } _, err = c.transmit(buf) return err } func (c *Conn) Read(header PacketHeader, buf []byte) error { buf = c.protocol.ReadPacket(header, buf) if len(buf) != 0 { _, err := c.transmit(buf) return err } return nil } func (c *Conn) Close() { c.protocol.Close() } func (c *Conn) Run() { c.protocol.Run(c.transmit) } func (c *Conn) transmit(buf []byte) (EOF bool, err error) { n, err := c.conn.WriteTo(buf, c.addr) if err == nil && n != len(buf) { err = io.ErrShortWrite } EOF = isEOF(err) if err != nil && !EOF { err = fmt.Errorf("failed to transmit packet: %w", err) return } return } ================================================ FILE: conn_test.go ================================================ package reliable import ( "bytes" "github.com/stretchr/testify/require" "go.uber.org/goleak" "math" "net" "sync/atomic" "testing" "time" ) func TestConnWriteReliablePacket(t *testing.T) { defer goleak.VerifyNone(t) data := bytes.Repeat([]byte("x"), 1400) actual := uint64(0) expected := uint64(65536) a, _ := net.ListenPacket("udp", "127.0.0.1:0") b, _ := net.ListenPacket("udp", "127.0.0.1:0") handler := func(buf []byte, _ uint16) { atomic.AddUint64(&actual, 1) require.EqualValues(t, data, buf) } ca := NewConn(a.LocalAddr(), a, WithProtocolPacketHandler(handler)) cb := NewConn(b.LocalAddr(), b, WithProtocolPacketHandler(handler)) go readLoop(t, a, ca) go readLoop(t, b, cb) defer func() { require.NoError(t, a.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, b.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, a.Close()) require.NoError(t, b.Close()) ca.Close() cb.Close() require.EqualValues(t, expected, atomic.LoadUint64(&actual)) }() for i := uint64(0); i < expected; i++ { require.NoError(t, ca.WriteReliablePacket(data)) } } func readLoop(t *testing.T, pc net.PacketConn, c *Conn) { var ( n int err error ) buf := make([]byte, math.MaxUint16+1) for { n, _, err = pc.ReadFrom(buf) if err != nil { break } header, buf, err := UnmarshalPacketHeader(buf[:n]) require.NoError(t, err) if err == nil { err = c.Read(header, buf) require.NoError(t, err) } } } ================================================ FILE: endpoint.go ================================================ package reliable import ( "io" "math" "net" "sync" "sync/atomic" "time" ) type EndpointPacketHandler func(buf []byte, addr net.Addr) type EndpointErrorHandler func(err error, addr net.Addr) type Endpoint struct { writeBufferSize uint16 // write buffer size that must be a divisor of 65536 readBufferSize uint16 // read buffer size that must be a divisor of 65536 updatePeriod time.Duration // how often time-dependant parts of the protocol get checked resendTimeout time.Duration // how long we wait until unacked packets should be resent mu sync.Mutex wg sync.WaitGroup pool *Pool ph EndpointPacketHandler eh EndpointErrorHandler addr net.Addr conn net.PacketConn conns map[string]*Conn closing uint32 } func NewEndpoint(conn net.PacketConn, opts ...EndpointOption) *Endpoint { e := &Endpoint{conn: conn, addr: conn.LocalAddr(), conns: make(map[string]*Conn)} for _, opt := range opts { opt.applyEndpoint(e) } if e.writeBufferSize == 0 { e.writeBufferSize = DefaultWriteBufferSize } if e.readBufferSize == 0 { e.readBufferSize = DefaultReadBufferSize } if e.resendTimeout == 0 { e.resendTimeout = DefaultResendTimeout } if e.updatePeriod == 0 { e.updatePeriod = DefaultUpdatePeriod } if e.pool == nil { e.pool = new(Pool) } return e } func (e *Endpoint) getConn(addr net.Addr) *Conn { id := addr.String() e.mu.Lock() defer e.mu.Unlock() conn := e.conns[id] if conn == nil { if atomic.LoadUint32(&e.closing) == 1 { return nil } conn = NewConn( addr, e.conn, WithWriteBufferSize(e.writeBufferSize), WithReadBufferSize(e.readBufferSize), WithUpdatePeriod(e.updatePeriod), WithResendTimeout(e.resendTimeout), WithBufferPool(e.pool), ) e.wg.Add(1) go func() { defer e.wg.Done() conn.Run() }() e.conns[id] = conn } return conn } func (e *Endpoint) clearConn(addr net.Addr) { id := addr.String() e.mu.Lock() conn := e.conns[id] delete(e.conns, id) e.mu.Unlock() conn.Close() } func (e *Endpoint) clearConns() { e.mu.Lock() conns := make([]*Conn, 0, len(e.conns)) for id, conn := range e.conns { conns = append(conns, conn) delete(e.conns, id) } e.mu.Unlock() for _, conn := range conns { conn.Close() } } func (e *Endpoint) Addr() net.Addr { return e.addr } func (e *Endpoint) WriteReliablePacket(buf []byte, addr net.Addr) error { conn := e.getConn(addr) if conn == nil { return io.EOF } return conn.WriteReliablePacket(buf) } func (e *Endpoint) WriteUnreliablePacket(buf []byte, addr net.Addr) error { conn := e.getConn(addr) if conn == nil { return io.EOF } return conn.WriteUnreliablePacket(buf) } func (e *Endpoint) Listen() { e.mu.Lock() e.wg.Add(1) e.mu.Unlock() defer e.wg.Done() var ( n int addr net.Addr err error ) buf := make([]byte, math.MaxUint16+1) for { n, addr, err = e.conn.ReadFrom(buf) if err != nil { break } conn := e.getConn(addr) if conn == nil { break } header, buf, err := UnmarshalPacketHeader(buf[:n]) if err != nil { e.clearConn(addr) if e.eh != nil { e.eh(err, e.addr) } continue } err = conn.Read(header, buf) if err != nil { e.clearConn(addr) if e.eh != nil { e.eh(err, e.addr) } continue } if e.ph != nil { e.ph(buf, e.addr) } } e.clearConns() } func (e *Endpoint) Close() error { atomic.StoreUint32(&e.closing, 1) e.wg.Wait() return nil } ================================================ FILE: endpoint_test.go ================================================ package reliable import ( "bytes" "github.com/stretchr/testify/require" "go.uber.org/goleak" "net" "sort" "strconv" "sync" "sync/atomic" "testing" "time" ) func newPacketConn(t testing.TB, addr string) net.PacketConn { t.Helper() conn, err := net.ListenPacket("udp", addr) require.NoError(t, err) return conn } func BenchmarkEndpointWriteReliablePacket(b *testing.B) { ca := newPacketConn(b, "127.0.0.1:0") cb := newPacketConn(b, "127.0.0.1:0") ea := NewEndpoint(ca) eb := NewEndpoint(cb) go ea.Listen() go eb.Listen() defer func() { require.NoError(b, ca.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(b, cb.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(b, ea.Close()) require.NoError(b, eb.Close()) require.NoError(b, ca.Close()) require.NoError(b, cb.Close()) }() data := bytes.Repeat([]byte("x"), 1400) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { if err := ea.WriteReliablePacket(data, eb.Addr()); err != nil && !isEOF(err) { b.Fatal(err) } } } func BenchmarkEndpointWriteUnreliablePacket(b *testing.B) { ca := newPacketConn(b, "127.0.0.1:0") cb := newPacketConn(b, "127.0.0.1:0") ea := NewEndpoint(ca) eb := NewEndpoint(cb) go ea.Listen() go eb.Listen() defer func() { require.NoError(b, ca.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(b, cb.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(b, ea.Close()) require.NoError(b, eb.Close()) require.NoError(b, ca.Close()) require.NoError(b, cb.Close()) }() data := bytes.Repeat([]byte("x"), 1400) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { if err := ea.WriteUnreliablePacket(data, eb.Addr()); err != nil && !isEOF(err) { b.Fatal(err) } } } func TestEndpointWriteReliablePacket(t *testing.T) { defer goleak.VerifyNone(t) var mu sync.Mutex values := make(map[string]struct{}) actual := uint64(0) expected := uint64(65536) handler := func(buf []byte, _ net.Addr) { if len(buf) == 0 { return } atomic.AddUint64(&actual, 1) mu.Lock() _, exists := values[string(buf)] delete(values, string(buf)) mu.Unlock() require.True(t, exists) } ca := newPacketConn(t, "127.0.0.1:0") cb := newPacketConn(t, "127.0.0.1:0") a := NewEndpoint(ca, WithEndpointPacketHandler(handler)) b := NewEndpoint(cb, WithEndpointPacketHandler(handler)) go a.Listen() go b.Listen() defer func() { require.NoError(t, ca.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, cb.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, a.Close()) require.NoError(t, b.Close()) require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) require.EqualValues(t, expected, atomic.LoadUint64(&actual)) }() for i := uint64(0); i < expected; i++ { data := strconv.AppendUint(nil, i, 10) mu.Lock() values[string(data)] = struct{}{} mu.Unlock() require.NoError(t, a.WriteReliablePacket(data, b.Addr())) } } func TestEndpointWriteReliablePacketEndToEnd(t *testing.T) { defer goleak.VerifyNone(t) actual := uint64(0) expected := uint64(512) handler := func(buf []byte, _ net.Addr) { if len(buf) == 0 { return } atomic.AddUint64(&actual, 1) } ca := newPacketConn(t, "127.0.0.1:0") cb := newPacketConn(t, "127.0.0.1:0") a := NewEndpoint(ca, WithEndpointPacketHandler(handler)) b := NewEndpoint(cb, WithEndpointPacketHandler(handler)) go a.Listen() go b.Listen() defer func() { require.NoError(t, ca.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, cb.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, a.Close()) require.NoError(t, b.Close()) require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) require.EqualValues(t, expected*2, atomic.LoadUint64(&actual)) }() for i := uint64(0); i < expected; i++ { data := strconv.AppendUint(nil, i, 10) require.NoError(t, a.WriteReliablePacket(data, b.Addr())) require.NoError(t, b.WriteReliablePacket(data, a.Addr())) } } // Check whether race condition happen // Simulate write and read heavy condition by sending packet concurrently func TestRaceConditions(t *testing.T) { defer goleak.VerifyNone(t) var expected int = 1000 tr := newTestRaceConditions(expected) handler := func(buf []byte, _ net.Addr) { if len(buf) == 0 { return } tr.append(buf) } ca := newPacketConn(t, "127.0.0.1:0") cb := newPacketConn(t, "127.0.0.1:0") cc := newPacketConn(t, "127.0.0.1:0") cd := newPacketConn(t, "127.0.0.1:0") ce := newPacketConn(t, "127.0.0.1:0") a := NewEndpoint(ca, WithEndpointPacketHandler(handler)) b := NewEndpoint(cb, WithEndpointPacketHandler(handler)) c := NewEndpoint(cc, WithEndpointPacketHandler(handler)) d := NewEndpoint(cd, WithEndpointPacketHandler(handler)) e := NewEndpoint(ce, WithEndpointPacketHandler(handler)) go a.Listen() go b.Listen() go c.Listen() go d.Listen() go e.Listen() defer func() { tr.wait() // Note: Guarantee that all messages are deliverd time.Sleep(100 * time.Millisecond) require.NoError(t, ca.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, cb.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, cc.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, cd.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, ce.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, a.Close()) require.NoError(t, b.Close()) require.NoError(t, c.Close()) require.NoError(t, d.Close()) require.NoError(t, e.Close()) require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) require.NoError(t, cc.Close()) require.NoError(t, cd.Close()) require.NoError(t, ce.Close()) require.EqualValues(t, tr.expected, uniqSort(tr.actual)) }() tr.wg.Add(1) sB := tr.expected[0 : len(tr.expected)/4] go func() { defer tr.done() for i := 0; i < len(sB); i++ { data := []byte(strconv.Itoa(sB[i])) err := a.WriteReliablePacket(data, b.Addr()) if err != nil { require.True(t, isEOF(err)) } } }() tr.wg.Add(1) sC := tr.expected[len(tr.expected)/4 : len(tr.expected)*2/4] go func() { defer tr.done() for i := 0; i < len(sC); i++ { data := []byte(strconv.Itoa(sC[i])) err := a.WriteReliablePacket(data, c.Addr()) if err != nil { require.True(t, isEOF(err)) } } }() tr.wg.Add(1) sD := tr.expected[len(tr.expected)*2/4 : len(tr.expected)*3/4] go func() { defer tr.done() for i := 0; i < len(sD); i++ { data := []byte(strconv.Itoa(sD[i])) err := a.WriteReliablePacket(data, d.Addr()) if err != nil { require.True(t, isEOF(err)) } } }() tr.wg.Add(1) sE := tr.expected[len(tr.expected)*3/4:] go func() { defer tr.done() for i := 0; i < len(sE); i++ { data := []byte(strconv.Itoa(sE[i])) err := a.WriteReliablePacket(data, e.Addr()) if err != nil { require.True(t, isEOF(err)) } } }() } // Note: This struct is test for TestRaceConditions // The purpose for this struct is to prevent race condition of WaitGroup type testRaceConditions struct { mu sync.Mutex wg sync.WaitGroup expected []int actual []int } func newTestRaceConditions(cap int) *testRaceConditions { return &testRaceConditions{expected: genNumSlice(cap)} } func (t *testRaceConditions) done() { t.mu.Lock() defer t.mu.Unlock() t.wg.Done() } func (t *testRaceConditions) wait() { t.wg.Wait() } func (t *testRaceConditions) append(buf []byte) { t.mu.Lock() defer t.mu.Unlock() num, _ := strconv.Atoi(string(buf)) t.actual = append(t.actual, num) } func genNumSlice(len int) (s []int) { for i := 0; i < len; i++ { s = append(s, i) } return } func uniqSort(s []int) (result []int) { sort.Ints(s) var pre int for i := 0; i < len(s); i++ { if i == 0 || s[i] != pre { result = append(result, s[i]) } pre = s[i] } return } ================================================ FILE: error.go ================================================ package reliable import ( "errors" "io" "net" ) func isEOF(err error) bool { if errors.Is(err, io.EOF) { return true } var netErr *net.OpError if errors.As(err, &netErr) { if netErr.Err.Error() == "use of closed network connection" { return true } if netErr.Timeout() { return true } } return false } ================================================ FILE: examples/basic/main.go ================================================ package main import ( "bytes" "errors" "github.com/davecgh/go-spew/spew" "github.com/lithdew/reliable" "io" "log" "net" "os" "os/signal" "sync" "sync/atomic" "time" ) var ( PacketData = bytes.Repeat([]byte("x"), 1400) NumPackets = uint64(0) ) func check(err error) { if err != nil && !errors.Is(err, io.EOF) { log.Panic(err) } } func listen(addr string) net.PacketConn { conn, err := net.ListenPacket("udp", addr) check(err) return conn } func handler(buf []byte, _ net.Addr) { if bytes.Equal(buf, PacketData) || len(buf) == 0 { return } spew.Dump(buf) os.Exit(1) } func main() { exit := make(chan struct{}) var wg sync.WaitGroup wg.Add(2) ca := listen("127.0.0.1:44444") cb := listen("127.0.0.1:55555") a := reliable.NewEndpoint(ca, reliable.WithEndpointPacketHandler(handler)) b := reliable.NewEndpoint(cb, reliable.WithEndpointPacketHandler(handler)) defer func() { check(ca.SetDeadline(time.Now().Add(1 * time.Millisecond))) check(cb.SetDeadline(time.Now().Add(1 * time.Millisecond))) close(exit) check(a.Close()) check(b.Close()) check(ca.Close()) check(cb.Close()) wg.Wait() }() go a.Listen() go b.Listen() // The two goroutines below have endpoint A spam endpoint B, and print out how // many packets of data are being sent per second. go func() { defer wg.Done() for { select { case <-exit: return default: } check(a.WriteReliablePacket(PacketData, b.Addr())) atomic.AddUint64(&NumPackets, 1) } }() go func() { defer wg.Done() ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-exit: return case <-ticker.C: numPackets := atomic.SwapUint64(&NumPackets, 0) numBytes := float64(numPackets) * 1400.0 / 1024.0 / 1024.0 log.Printf( "Sent %d packet(s) comprised of %.2f MiB worth of data.", numPackets, numBytes, ) } } }() ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt) <-ch } ================================================ FILE: examples/benchmark/main.go ================================================ package main import ( "bytes" "errors" "flag" "github.com/lithdew/reliable" "io" "log" "net" "sync/atomic" "time" ) var ( listener bool ) func check(err error) { if err != nil && !errors.Is(err, io.EOF) { log.Panic(err) } } func listen(addr string) net.PacketConn { conn, err := net.ListenPacket("udp", addr) check(err) log.Printf("%s: Listening for peers.", conn.LocalAddr()) return conn } func main() { flag.BoolVar(&listener, "l", false, "either listen or dial") flag.Parse() host := flag.Arg(0) if !listener || host == "" { host = ":0" } conn := listen(host) counter := uint64(0) handler := func(buf []byte, addr net.Addr) { if len(buf) == 0 { return } //log.Printf("%s->%s: (seq=%d) (size=%d)", addr.String(), conn.LocalAddr().String(), seq, len(buf)) atomic.AddUint64(&counter, 1) } endpoint := reliable.NewEndpoint(conn, reliable.WithEndpointPacketHandler(handler)) go endpoint.Listen() defer func() { check(endpoint.Close()) check(conn.Close()) }() if listener { for range time.Tick(1 * time.Second) { numPackets := atomic.SwapUint64(&counter, 0) numBytes := float64(numPackets) * 1400.0 / 1024.0 / 1024.0 log.Printf("%s: Received %d packets (%.2f MiB).", conn.LocalAddr(), numPackets, numBytes) } } addr, err := net.ResolveUDPAddr("udp", flag.Arg(0)) check(err) data := bytes.Repeat([]byte("x"), 1400) go func() { for range time.Tick(1 * time.Second) { numPackets := atomic.SwapUint64(&counter, 0) numBytes := float64(numPackets) * 1400.0 / 1024.0 / 1024.0 log.Printf("%s: Sent %d packets (%.2f MiB).", conn.LocalAddr(), numPackets, numBytes) } }() for { check(endpoint.WriteReliablePacket(data, addr)) atomic.AddUint64(&counter, 1) } } ================================================ FILE: fuzz.go ================================================ // +build gofuzz package reliable import ( "bytes" "errors" "net" "time" ) func Fuzz(data []byte) int { ca, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { return -1 } cb, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { return -1 } chErr := make(chan error) handler := func(buf []byte, _ net.Addr) { if len(buf) == 0 || bytes.Equal(buf, data) { return } chErr <- errors.New("data miss match") } ea := NewEndpoint(ca, reliable.WithEndpointPacketHandler(handler)) eb := NewEndpoint(cb, reliable.WithEndpointPacketHandler(handler)) go ea.Listen() go eb.Listen() for i := 0; i < 65536; i++ { select { case <-chErr: return 0 default: if err := ea.WriteReliablePacket(data, eb.Addr()); err != nil && !isEOF(err) { return 0 } } } if err := ca.SetDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { return 0 } if err := cb.SetDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { return 0 } if err := ea.Close(); err != nil { return 0 } if err := eb.Close(); err != nil { return 0 } if err := ca.Close(); err != nil { return 0 } if err := cb.Close(); err != nil { return 0 } return 1 } ================================================ FILE: go.mod ================================================ module github.com/lithdew/reliable go 1.14 require ( github.com/davecgh/go-spew v1.1.1 github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813 // indirect github.com/lithdew/bytesutil v0.0.0-20200409052507-d98389230a59 github.com/lithdew/seq v0.0.0-20200504083424-74d5d8117a05 github.com/stretchr/testify v1.5.1 github.com/valyala/bytebufferpool v1.0.0 go.uber.org/goleak v1.0.0 golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect golang.org/x/tools v0.0.0-20200501005904-d351ea090f9b // indirect ) ================================================ FILE: go.sum ================================================ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813 h1:NgO45/5mBLRVfiXerEFzH6ikcZ7DNRPS639xFg3ENzU= github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lithdew/bytesutil v0.0.0-20200409052507-d98389230a59 h1:CQpoOQecHxhvgOU/ijue/yWuShZYDtNpI9bsD4Dkzrk= github.com/lithdew/bytesutil v0.0.0-20200409052507-d98389230a59/go.mod h1:89JlULMIJ/+YWzAp5aHXgAD2d02S2mY+a+PMgXDtoNs= github.com/lithdew/seq v0.0.0-20200504083424-74d5d8117a05 h1:j1UtG8NYCupA5xUwQ/vrTf/zjuNlZ0D1n7UtM8LhS58= github.com/lithdew/seq v0.0.0-20200504083424-74d5d8117a05/go.mod h1:4vVgbfmYc+ZIh0dy99HRrM6knnAtQXNI8MOx+1pUYso= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.uber.org/goleak v1.0.0 h1:qsup4IcBdlmsnGfqyLl4Ntn3C2XCCuKAE7DwHpScyUo= go.uber.org/goleak v1.0.0/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11 h1:Yq9t9jnGoR+dBuitxdo9l6Q7xh/zOyNnYUtDKaQ3x0E= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200501005904-d351ea090f9b h1:2hSR2MyOaYEy6yJYg/CpErymr/m7xJEJpm9kfT7ZMg4= golang.org/x/tools v0.0.0-20200501005904-d351ea090f9b/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= ================================================ FILE: options.go ================================================ package reliable import "time" const ( DefaultWriteBufferSize uint16 = 256 DefaultReadBufferSize uint16 = 256 DefaultUpdatePeriod = 100 * time.Millisecond DefaultResendTimeout = 100 * time.Millisecond ) type ProtocolOption interface { applyProtocol(p *Protocol) } type EndpointOption interface { applyEndpoint(e *Endpoint) } type Option interface { ProtocolOption EndpointOption } type withBufferPool struct{ pool *Pool } func (o withBufferPool) applyProtocol(p *Protocol) { p.pool = o.pool } func (o withBufferPool) applyEndpoint(e *Endpoint) { e.pool = o.pool } func WithBufferPool(pool *Pool) Option { return withBufferPool{pool: pool} } type withWriteBufferSize struct{ writeBufferSize uint16 } func (o withWriteBufferSize) applyProtocol(p *Protocol) { p.writeBufferSize = o.writeBufferSize } func (o withWriteBufferSize) applyEndpoint(e *Endpoint) { e.writeBufferSize = o.writeBufferSize } func WithWriteBufferSize(writeBufferSize uint16) Option { if 65536%uint32(writeBufferSize) != 0 { panic("write buffer size must be smaller than 65536 and a power of two") } return withWriteBufferSize{writeBufferSize: writeBufferSize} } type withReadBufferSize struct{ readBufferSize uint16 } func (o withReadBufferSize) applyProtocol(p *Protocol) { p.readBufferSize = o.readBufferSize } func (o withReadBufferSize) applyEndpoint(e *Endpoint) { e.readBufferSize = o.readBufferSize } func WithReadBufferSize(readBufferSize uint16) Option { if 65536%uint32(readBufferSize) != 0 { panic("read buffer size must be smaller than 65536 and a power of two") } return withReadBufferSize{readBufferSize: readBufferSize} } type withProtocolPacketHandler struct{ ph ProtocolPacketHandler } type withEndpointPacketHandler struct{ ph EndpointPacketHandler } func (o withProtocolPacketHandler) applyProtocol(p *Protocol) { p.ph = o.ph } func (o withEndpointPacketHandler) applyEndpoint(e *Endpoint) { e.ph = o.ph } func WithProtocolPacketHandler(ph ProtocolPacketHandler) ProtocolOption { return withProtocolPacketHandler{ph: ph} } func WithEndpointPacketHandler(ph EndpointPacketHandler) EndpointOption { return withEndpointPacketHandler{ph: ph} } type withProtocolErrorHandler struct{ eh ProtocolErrorHandler } type withEndpointErrorHandler struct{ eh EndpointErrorHandler } func (o withProtocolErrorHandler) applyProtocol(p *Protocol) { p.eh = o.eh } func (o withEndpointErrorHandler) applyEndpoint(e *Endpoint) { e.eh = o.eh } func WithProtocolErrorHandler(eh ProtocolErrorHandler) ProtocolOption { return withProtocolErrorHandler{eh: eh} } func WithEndpointErrorHandler(eh EndpointErrorHandler) EndpointOption { return withEndpointErrorHandler{eh: eh} } type withUpdatePeriod struct{ updatePeriod time.Duration } func (o withUpdatePeriod) applyProtocol(p *Protocol) { p.updatePeriod = o.updatePeriod } func (o withUpdatePeriod) applyEndpoint(e *Endpoint) { e.updatePeriod = o.updatePeriod } func WithUpdatePeriod(updatePeriod time.Duration) Option { if updatePeriod == 0 { panic("update period of zero is not supported yet") } return withUpdatePeriod{updatePeriod: updatePeriod} } type withResendTimeout struct{ resendTimeout time.Duration } func (o withResendTimeout) applyProtocol(p *Protocol) { p.resendTimeout = o.resendTimeout } func (o withResendTimeout) applyEndpoint(e *Endpoint) { e.resendTimeout = o.resendTimeout } func WithResendTimeout(resendTimeout time.Duration) Option { if resendTimeout == 0 { panic("ack timeout of zero is not supported yet") } return withResendTimeout{resendTimeout: resendTimeout} } ================================================ FILE: packet.go ================================================ package reliable import ( "github.com/lithdew/bytesutil" "github.com/valyala/bytebufferpool" "io" "math/bits" "time" ) const ACKBitsetSize = 32 type ( Buffer = bytebufferpool.ByteBuffer Pool = bytebufferpool.Pool ) type writtenPacket struct { buf *Buffer // pooled contents of this packet acked bool // whether or not this packet was acked written time.Time // last time the packet was written resent byte // total number of times this packet was resent } func (p writtenPacket) shouldResend(now time.Time, resendTimeout time.Duration) bool { return !p.acked && p.resent < 10 && now.Sub(p.written) >= resendTimeout } type PacketHeaderFlag uint8 const ( FlagFragment PacketHeaderFlag = 1 << iota FlagA FlagB FlagC FlagD FlagACKEncoded FlagEmpty FlagUnordered ) func (p PacketHeaderFlag) Toggle(flag PacketHeaderFlag) PacketHeaderFlag { return p | flag } func (p PacketHeaderFlag) Toggled(flag PacketHeaderFlag) bool { return p&flag != 0 } func (p PacketHeaderFlag) AppendTo(dst []byte) []byte { return append(dst, byte(p)) } type PacketHeader struct { Sequence uint16 ACK uint16 ACKBits uint32 Unordered bool Empty bool } func (p PacketHeader) AppendTo(dst []byte) []byte { // Mark a flag byte to RLE-encode the ACK bitset. flag := PacketHeaderFlag(0) if p.ACKBits&0x000000FF != 0x000000FF { flag = flag.Toggle(FlagA) } if p.ACKBits&0x0000FF00 != 0x0000FF00 { flag = flag.Toggle(FlagB) } if p.ACKBits&0x00FF0000 != 0x00FF0000 { flag = flag.Toggle(FlagC) } if p.ACKBits&0xFF000000 != 0xFF000000 { flag = flag.Toggle(FlagD) } if p.Empty { flag = flag.Toggle(FlagEmpty) } if p.Unordered { flag = flag.Toggle(FlagUnordered) } diff := int(p.Sequence) - int(p.ACK) if diff < 0 { diff += 65536 } if diff <= 255 { flag = flag.Toggle(FlagACKEncoded) } // If the difference between the sequence number and the latest ACK'd sequence number can be represented by a // single byte, then represent it as a single byte and set the 5th bit of flag. // Marshal the flag and sequence number and latest ACK'd sequence number. dst = flag.AppendTo(dst) if p.Unordered { dst = bytesutil.AppendUint16BE(dst, p.ACK) } else { dst = bytesutil.AppendUint16BE(dst, p.Sequence) if diff <= 255 { dst = append(dst, uint8(diff)) } else { dst = bytesutil.AppendUint16BE(dst, p.ACK) } } // Marshal ACK bitset. if p.ACKBits&0x000000FF != 0x000000FF { dst = append(dst, uint8(p.ACKBits&0x000000FF)) } if p.ACKBits&0x0000FF00 != 0x0000FF00 { dst = append(dst, uint8((p.ACKBits&0x0000FF00)>>8)) } if p.ACKBits&0x00FF0000 != 0x00FF0000 { dst = append(dst, uint8((p.ACKBits&0x00FF0000)>>16)) } if p.ACKBits&0xFF000000 != 0xFF000000 { dst = append(dst, uint8((p.ACKBits&0xFF000000)>>24)) } return dst } func UnmarshalPacketHeader(buf []byte) (header PacketHeader, leftover []byte, err error) { flag := PacketHeaderFlag(0) // Read first 3 bytes (header, flag). if len(buf) < 3 { return header, buf, io.ErrUnexpectedEOF } flag, buf = PacketHeaderFlag(buf[0]), buf[1:] if flag.Toggled(FlagFragment) { return header, buf, io.ErrUnexpectedEOF } header.Empty = flag.Toggled(FlagEmpty) header.Unordered = flag.Toggled(FlagUnordered) if header.Unordered { if len(buf) < 2 { return header, buf, io.ErrUnexpectedEOF } header.ACK, buf = bytesutil.Uint16BE(buf[:2]), buf[2:] } else { header.Sequence, buf = bytesutil.Uint16BE(buf[:2]), buf[2:] // Read and decode the latest ACK'ed sequence number (either 1 or 2 bytes) using the RLE flag marker. if flag.Toggled(FlagACKEncoded) { if len(buf) < 1 { return header, buf, io.ErrUnexpectedEOF } header.ACK, buf = header.Sequence-uint16(buf[0]), buf[1:] } else { if len(buf) < 2 { return header, buf, io.ErrUnexpectedEOF } header.ACK, buf = bytesutil.Uint16BE(buf[:2]), buf[2:] } } if len(buf) < bits.OnesCount8(uint8(flag&(FlagA|FlagB|FlagC|FlagD))) { return header, buf, io.ErrUnexpectedEOF } // Read and decode ACK bitset using the RLE flag marker. header.ACKBits = 0xFFFFFFFF if flag.Toggled(FlagA) { header.ACKBits &= 0xFFFFFF00 header.ACKBits |= uint32(buf[0]) buf = buf[1:] } if flag.Toggled(FlagB) { header.ACKBits &= 0xFFFF00FF header.ACKBits |= uint32(buf[0]) << 8 buf = buf[1:] } if flag.Toggled(FlagC) { header.ACKBits &= 0xFF00FFFF header.ACKBits |= uint32(buf[0]) << 16 buf = buf[1:] } if flag.Toggled(FlagD) { header.ACKBits &= 0x00FFFFFF header.ACKBits |= uint32(buf[0]) << 24 buf = buf[1:] } return header, buf, nil } ================================================ FILE: packet_test.go ================================================ package reliable import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/bytebufferpool" "math" "testing" "testing/quick" ) func TestEncodeDecodePacketHeader(t *testing.T) { buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) f := func(seq, ack uint16, ackBits uint32) bool { header := PacketHeader{Sequence: seq, ACK: ack, ACKBits: ackBits} recovered, leftover, err := UnmarshalPacketHeader(header.AppendTo(buf.B[:0])) return assert.NoError(t, err) && assert.Len(t, leftover, 0) && assert.EqualValues(t, header, recovered) } require.NoError(t, quick.Check(f, &quick.Config{MaxCount: 1000})) } func BenchmarkMarshalPacketHeader(b *testing.B) { header := PacketHeader{Sequence: math.MaxUint16, ACK: math.MaxUint16, ACKBits: math.MaxUint32} buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { buf.B = header.AppendTo(buf.B[:0]) } } func BenchmarkUnmarshalPacketHeader(b *testing.B) { header := PacketHeader{Sequence: math.MaxUint16, ACK: math.MaxUint16, ACKBits: math.MaxUint32} buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) buf.B = header.AppendTo(buf.B) var ( recovered PacketHeader leftover []byte err error ) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { recovered, leftover, err = UnmarshalPacketHeader(buf.B) if err != nil { b.Fatalf("failed to unmarshal packet header: %s", err) } if leftover := len(leftover); leftover != 0 { b.Fatalf("got %d byte(s) leftover", leftover) } if recovered.Sequence != header.Sequence || recovered.ACK != header.ACK || recovered.ACKBits != header.ACKBits { b.Fatalf("got %#v, expected %#v", recovered, header) } } _ = recovered } ================================================ FILE: pool.go ================================================ package reliable import "math" var ( emptyBufferIndexCache [math.MaxUint16]uint32 ) func init() { emptyBufferIndexCache[0] = math.MaxUint32 for i := 1; i < math.MaxUint16; i *= 2 { copy(emptyBufferIndexCache[i:], emptyBufferIndexCache[:i]) } } func emptyBufferIndices(indices []uint32) { copy(indices[:], emptyBufferIndexCache[:len(indices)]) } ================================================ FILE: protocol.go ================================================ package reliable import ( "fmt" "github.com/lithdew/seq" "io" "sync" "time" ) type ProtocolPacketHandler func(buf []byte, seq uint16) type ProtocolErrorHandler func(err error) type Protocol struct { writeBufferSize uint16 // write buffer size that must be a divisor of 65536 readBufferSize uint16 // read buffer size that must be a divisor of 65536 updatePeriod time.Duration // how often time-dependant parts of the protocol get checked resendTimeout time.Duration // how long we wait until unacked packets should be resent pool *Pool ph ProtocolPacketHandler eh ProtocolErrorHandler mu sync.Mutex // mutex over everything die bool // is this conn closed? exit chan struct{} // signal channel to close the conn lui uint16 // last sent packet index that hasn't been sent via an ack yet oui uint16 // oldest sent packet index that hasn't been acked yet ouc sync.Cond // stop writes if the next write given oui may flood our peers read buffer ls time.Time // last time data was sent to our peer wi uint16 // write index ri uint16 // read index wq []uint32 // write queue rq []uint32 // read queue wqe []writtenPacket // write queue entries } func NewProtocol(opts ...ProtocolOption) *Protocol { p := &Protocol{exit: make(chan struct{})} for _, opt := range opts { opt.applyProtocol(p) } if p.writeBufferSize == 0 { p.writeBufferSize = DefaultWriteBufferSize } if p.readBufferSize == 0 { p.readBufferSize = DefaultReadBufferSize } if p.resendTimeout == 0 { p.resendTimeout = DefaultResendTimeout } if p.updatePeriod == 0 { p.updatePeriod = DefaultUpdatePeriod } if p.pool == nil { p.pool = new(Pool) } p.wq = make([]uint32, p.writeBufferSize) p.rq = make([]uint32, p.readBufferSize) emptyBufferIndices(p.wq) emptyBufferIndices(p.rq) p.wqe = make([]writtenPacket, p.writeBufferSize) p.ouc.L = &p.mu return p } func (p *Protocol) WritePacket(reliable bool, buf []byte) ([]byte, error) { p.mu.Lock() defer p.mu.Unlock() var ( idx uint16 ack uint16 ackBits uint32 ok = true ) if reliable { idx, ack, ackBits, ok = p.waitForNextWriteDetails() } else { ack, ackBits = p.nextAckDetails() } if !ok { return nil, io.EOF } p.trackAcked(ack) // log.Printf("%v: send (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", &p, idx, ack, ackBits, len(buf), reliable) return p.write(PacketHeader{Sequence: idx, ACK: ack, ACKBits: ackBits, Unordered: !reliable}, buf), nil } func (p *Protocol) waitUntilReaderAvailable() { for !p.die && seq.GT(p.wi+1, p.oui+uint16(len(p.rq))) { p.ouc.Wait() } } func (p *Protocol) waitForNextWriteDetails() (idx uint16, ack uint16, ackBits uint32, ok bool) { p.waitUntilReaderAvailable() idx, ok = p.nextWriteIndex(), !p.die ack, ackBits = p.nextAckDetails() return idx, ack, ackBits, ok } func (p *Protocol) nextWriteIndex() (idx uint16) { idx, p.wi = p.wi, p.wi+1 return idx } func (p *Protocol) nextAckDetails() (ack uint16, ackBits uint32) { ack = p.ri - 1 ackBits = p.prepareAckBits(ack) return ack, ackBits } func (p *Protocol) prepareAckBits(ack uint16) (ackBits uint32) { for i, m := uint16(0), uint32(1); i < ACKBitsetSize; i, m = i+1, m<<1 { if p.rq[(ack-i)%uint16(len(p.rq))] != uint32(ack-i) { continue } ackBits |= m } return ackBits } func (p *Protocol) write(header PacketHeader, buf []byte) []byte { b := p.pool.Get() b.B = header.AppendTo(b.B) b.B = append(b.B, buf...) if header.Unordered { defer p.pool.Put(b) } if !header.Unordered { p.trackWrite(header.Sequence, b) } return b.B } func (p *Protocol) trackWrite(idx uint16, buf *Buffer) { if seq.GT(idx+1, p.wi) { p.clearWrites(p.wi, idx) p.wi = idx + 1 } i := idx % uint16(len(p.wq)) p.wq[i] = uint32(idx) if p.wqe[i].buf != nil { p.pool.Put(p.wqe[i].buf) } p.wqe[i].buf = buf p.wqe[i].acked = false p.wqe[i].written = time.Now() p.wqe[i].resent = 0 } func (p *Protocol) clearWrites(start, end uint16) { count, size := end-start+1, uint16(len(p.wq)) if count >= size { emptyBufferIndices(p.wq) return } first := p.wq[start%size:] length := uint16(len(first)) if count <= length { emptyBufferIndices(first[:count]) return } second := p.wq[:count-length] emptyBufferIndices(first) emptyBufferIndices(second) } func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) []byte { p.mu.Lock() defer p.mu.Unlock() p.readAckBits(header.ACK, header.ACKBits) if !header.Unordered && !p.trackRead(header.Sequence) { return nil } p.trackUnacked() if header.Empty { return nil } if p.ph != nil { p.ph(buf, header.Sequence) } // log.Printf("%v: recv (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", &p, header.Sequence, header.ACK, header.ACKBits, len(buf), !header.Unordered) return p.writeAcksIfNecessary() } func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { lui := p.lui for i := uint16(0); i < ACKBitsetSize; i++ { if p.rq[(lui+i)%uint16(len(p.rq))] != uint32(lui+i) { return header, needed } } lui += ACKBitsetSize p.lui = lui p.ls = time.Now() p.waitUntilReaderAvailable() header.Sequence, header.ACK = p.nextWriteIndex(), lui-1 header.ACKBits = p.prepareAckBits(header.ACK) header.Empty = true needed = !p.die return header, needed } func (p *Protocol) writeAcksIfNecessary() []byte { for { header, needed := p.createAckIfNecessary() if !needed { return nil } // log.Printf("%v: ack (seq=%05d) (ack=%05d) (ack_bits=%032b)", &p, header.Sequence, header.ACK, header.ACKBits) return p.write(header, nil) } } func (p *Protocol) readAckBits(ack uint16, ackBits uint32) { for idx := uint16(0); idx < ACKBitsetSize; idx, ackBits = idx+1, ackBits>>1 { if ackBits&1 == 0 { continue } i := (ack - idx) % uint16(len(p.wq)) if p.wq[i] != uint32(ack-idx) || p.wqe[i].acked { continue } if p.wqe[i].buf != nil { p.pool.Put(p.wqe[i].buf) } p.wqe[i].buf = nil p.wqe[i].acked = true } } func (p *Protocol) trackRead(idx uint16) bool { i := idx % uint16(len(p.rq)) if p.rq[i] == uint32(idx) { // duplicate packet return false } if seq.GT(idx+1, p.ri) { p.clearReads(p.ri, idx) p.ri = idx + 1 } p.rq[i] = uint32(idx) return true } func (p *Protocol) clearReads(start, end uint16) { count, size := end-start+1, uint16(len(p.rq)) if count >= size { emptyBufferIndices(p.rq) return } first := p.rq[start%size:] length := uint16(len(first)) if count <= length { emptyBufferIndices(first[:count]) return } second := p.rq[:count-length] emptyBufferIndices(first) emptyBufferIndices(second) } func (p *Protocol) trackAcked(ack uint16) { lui := p.lui for lui <= ack { if p.rq[lui%uint16(len(p.rq))] != uint32(lui) { break } lui++ } p.lui = lui p.ls = time.Now() } func (p *Protocol) trackUnacked() { oui := p.oui for { i := oui % uint16(len(p.wq)) if p.wq[i] != uint32(oui) || !p.wqe[i].acked { break } oui++ } p.oui = oui p.ouc.Broadcast() } func (p *Protocol) close() bool { if p.die { return false } close(p.exit) p.die = true p.ouc.Broadcast() return true } func (p *Protocol) Close() { p.mu.Lock() defer p.mu.Unlock() if !p.close() { return } } func (p *Protocol) Run(transmit transmitFunc) { ticker := time.NewTicker(p.updatePeriod) defer ticker.Stop() for { select { case <-p.exit: return case <-ticker.C: if err := p.retransmitUnackedPackets(transmit); err != nil && p.eh != nil { p.eh(err) } } } } func (p *Protocol) retransmitUnackedPackets(transmit transmitFunc) error { p.mu.Lock() defer p.mu.Unlock() for idx := uint16(0); idx < uint16(len(p.wq)); idx++ { i := (p.oui + idx) % uint16(len(p.wq)) if p.wq[i] != uint32(p.oui+idx) || !p.wqe[i].shouldResend(time.Now(), p.resendTimeout) { continue } // log.Printf("%v: resend (seq=%d)", &p, p.oui+idx) if isEOF, err := transmit(p.wqe[i].buf.B); err != nil { return fmt.Errorf("failed to retransmit unacked packet: %w", err) } else if isEOF { break } p.wqe[i].written = time.Now() p.wqe[i].resent++ } return nil } ================================================ FILE: protocol_test.go ================================================ package reliable import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" "sync" "testing" ) func testConnWaitForWriteDetails(inc uint16) func(t testing.TB) { return func(t testing.TB) { defer goleak.VerifyNone(t) p := NewProtocol() p.wi = uint16(len(p.rq)) var wg sync.WaitGroup wg.Add(8) ch := make(chan uint16, 8) for i := 0; i < 8; i++ { go func() { defer wg.Done() p.ouc.L.Lock() idx, _, _, _ := p.waitForNextWriteDetails() p.ouc.L.Unlock() ch <- idx }() } for i := 0; i < 8; i++ { p.ouc.L.Lock() p.oui += inc p.ouc.Broadcast() p.ouc.L.Unlock() } wg.Wait() expected := make(map[uint16]struct{}, 8) close(ch) for idx := range ch { expected[idx] = struct{}{} } for i := 0; i < 8; i++ { actual := uint16(len(p.wq) + i) require.Contains(t, expected, actual) delete(expected, actual) } } } func TestConnWaitForWriteDetails(t *testing.T) { testConnWaitForWriteDetails(1)(t) testConnWaitForWriteDetails(2)(t) testConnWaitForWriteDetails(4)(t) }