Repository: Workiva/go-datastructures Branch: master Commit: 89d15facb2e3 Files: 180 Total size: 812.0 KB Directory structure: gitextract_a5x2r3bc/ ├── .github/ │ ├── CODEOWNERS │ └── workflows/ │ └── tests.yaml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── augmentedtree/ │ ├── atree.go │ ├── atree_test.go │ ├── interface.go │ ├── intervals.go │ ├── intervals_test.go │ ├── mock_test.go │ └── multidimensional_test.go ├── aviary.yaml ├── batcher/ │ ├── batcher.go │ └── batcher_test.go ├── bitarray/ │ ├── and.go │ ├── and_test.go │ ├── bitarray.go │ ├── bitarray_test.go │ ├── bitmap.go │ ├── bitmap_test.go │ ├── block.go │ ├── block_test.go │ ├── encoding.go │ ├── encoding_test.go │ ├── error.go │ ├── interface.go │ ├── iterator.go │ ├── nand.go │ ├── nand_test.go │ ├── or.go │ ├── or_test.go │ ├── sparse_bitarray.go │ ├── sparse_bitarray_test.go │ └── util.go ├── btree/ │ ├── _link/ │ │ ├── interface.go │ │ ├── key.go │ │ ├── mock_test.go │ │ ├── node.go │ │ ├── node_test.go │ │ ├── tree.go │ │ └── tree_test.go │ ├── immutable/ │ │ ├── add.go │ │ ├── cacher.go │ │ ├── config.go │ │ ├── delete.go │ │ ├── error.go │ │ ├── interface.go │ │ ├── item.go │ │ ├── node.go │ │ ├── node_gen.go │ │ ├── path.go │ │ ├── query.go │ │ ├── rt.go │ │ ├── rt_gen.go │ │ └── rt_test.go │ ├── palm/ │ │ ├── action.go │ │ ├── interface.go │ │ ├── key.go │ │ ├── mock_test.go │ │ ├── node.go │ │ ├── tree.go │ │ └── tree_test.go │ └── plus/ │ ├── btree.go │ ├── btree_test.go │ ├── interface.go │ ├── iterator.go │ ├── mock_test.go │ ├── node.go │ └── node_test.go ├── cache/ │ ├── cache.go │ └── cache_test.go ├── common/ │ └── interface.go ├── datastructures.go ├── documentation.md ├── fibheap/ │ ├── Test Generator/ │ │ ├── EnqDecrKey.py │ │ ├── EnqDelete.py │ │ ├── EnqDeqMin.py │ │ ├── Merge.py │ │ └── README.md │ ├── benchmarks.txt │ ├── fibheap.go │ ├── fibheap_examples_test.go │ ├── fibheap_single_example_test.go │ └── fibheap_test.go ├── futures/ │ ├── futures.go │ ├── futures_test.go │ ├── selectable.go │ └── selectable_test.go ├── go.mod ├── go.sum ├── graph/ │ ├── simple.go │ └── simple_test.go ├── hashmap/ │ └── fastinteger/ │ ├── hash.go │ ├── hash_test.go │ ├── hashmap.go │ └── hashmap_test.go ├── list/ │ ├── persistent.go │ └── persistent_test.go ├── mock/ │ ├── batcher.go │ └── rangetree.go ├── numerics/ │ ├── hilbert/ │ │ ├── hilbert.go │ │ └── hilbert_test.go │ └── optimization/ │ ├── global.go │ ├── nelder_mead.go │ └── nelder_mead_test.go ├── queue/ │ ├── error.go │ ├── mock_test.go │ ├── priority_queue.go │ ├── priority_queue_test.go │ ├── queue.go │ ├── queue_test.go │ ├── ring.go │ └── ring_test.go ├── rangetree/ │ ├── entries.go │ ├── entries_test.go │ ├── error.go │ ├── immutable.go │ ├── immutable_test.go │ ├── interface.go │ ├── mock_test.go │ ├── node.go │ ├── ordered.go │ ├── ordered_test.go │ ├── orderedtree.go │ ├── orderedtree_test.go │ └── skiplist/ │ ├── mock_test.go │ ├── skiplist.go │ └── skiplist_test.go ├── rtree/ │ ├── hilbert/ │ │ ├── action.go │ │ ├── cpu.prof │ │ ├── hilbert.go │ │ ├── mock_test.go │ │ ├── node.go │ │ ├── rectangle.go │ │ ├── tree.go │ │ └── tree_test.go │ └── interface.go ├── set/ │ ├── dict.go │ └── dict_test.go ├── slice/ │ ├── int64.go │ ├── int64_test.go │ └── skip/ │ ├── interface.go │ ├── iterator.go │ ├── iterator_test.go │ ├── mock_test.go │ ├── node.go │ ├── skip.go │ └── skip_test.go ├── sort/ │ ├── interface.go │ ├── sort.go │ ├── sort_test.go │ ├── symmerge.go │ └── symmerge_test.go ├── threadsafe/ │ └── err/ │ ├── error.go │ └── error_test.go ├── tree/ │ └── avl/ │ ├── avl.go │ ├── avl_test.go │ ├── interface.go │ ├── mock_test.go │ └── node.go └── trie/ ├── ctrie/ │ ├── ctrie.go │ └── ctrie_test.go ├── dtrie/ │ ├── dtrie.go │ ├── dtrie_test.go │ ├── node.go │ └── util.go ├── xfast/ │ ├── iterator.go │ ├── iterator_test.go │ ├── mock_test.go │ ├── xfast.go │ └── xfast_test.go └── yfast/ ├── entries.go ├── entries_test.go ├── interface.go ├── iterator.go ├── mock_test.go ├── yfast.go └── yfast_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODEOWNERS ================================================ @Workiva/skreams ================================================ FILE: .github/workflows/tests.yaml ================================================ name: "Tests" on: pull_request: push: branches: - 'master' tags: - '*' permissions: pull-requests: write contents: write id-token: write jobs: Tests: runs-on: ubuntu-latest strategy: matrix: go: [ '1.15', 'stable' ] name: Tests on Go ${{ matrix.go }} steps: - name: Checkout Repo uses: actions/checkout@v4 with: path: go/src/github.com/Workiva/go-datastructures - name: Setup Go uses: actions/setup-go@v5.0.0 with: go-version: ${{ matrix.go }} # go install does not work because it needs credentials - name: install go2xunit run: | git clone https://github.com/tebeka/go2xunit.git cd go2xunit git checkout v1.4.10 go install cd .. - name: Run Tests timeout-minutes: 10 run: | cd go/src/github.com/Workiva/go-datastructures go test ./... | tee ${{github.workspace}}/go-test.txt - name: XML output run: | mkdir artifacts go2xunit -input ./go-test.txt -output ./artifacts/tests_go_version-${{ matrix.go }}.xml - name: Upload Test Results uses: actions/upload-artifact@v4 with: name: go-datastructures test go ${{ matrix.go }} path: ./artifacts/tests_go_version-${{ matrix.go }}.xml retention-days: 7 - uses: anchore/sbom-action@v0 with: path: ./ format: cyclonedx-json artifact-name: ${{ matrix.go }}-sbom.spdx ================================================ FILE: .gitignore ================================================ *.out *.test .idea ================================================ FILE: Dockerfile ================================================ FROM golang:1.16-alpine3.13 AS build-go ARG GIT_SSH_KEY ARG KNOWN_HOSTS_CONTENT WORKDIR /go/src/github.com/Workiva/go-datastructures/ ADD . /go/src/github.com/Workiva/go-datastructures/ ARG GOPATH=/go/ ENV PATH $GOPATH/bin:$PATH RUN echo "Starting the script section" && \ go mod vendor && \ echo "script section completed" ARG BUILD_ARTIFACTS_DEPENDENCIES=/go/src/github.com/Workiva/go-datastructures/go.mod FROM scratch ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ go-datastructures ================= Go-datastructures is a collection of useful, performant, and threadsafe Go datastructures. ### NOTE: only tested with Go 1.3+. #### Augmented Tree Interval tree for collision in n-dimensional ranges. Implemented via a red-black augmented tree. Extra dimensions are handled in simultaneous inserts/queries to save space although this may result in suboptimal time complexity. Intersection determined using bit arrays. In a single dimension, inserts, deletes, and queries should be in O(log n) time. #### Bitarray Bitarray used to detect existence without having to resort to hashing with hashmaps. Requires entities have a uint64 unique identifier. Two implementations exist, regular and sparse. Sparse saves a great deal of space but insertions are O(log n). There are some useful functions on the BitArray interface to detect intersection between two bitarrays. This package also includes bitmaps of length 32 and 64 that provide increased speed and O(1) for all operations by storing the bitmaps in unsigned integers rather than arrays. #### Futures A helpful tool to send a "broadcast" message to listeners. Channels have the issue that once one listener takes a message from a channel the other listeners aren't notified. There were many cases when I wanted to notify many listeners of a single event and this package helps. #### Queue Package contains both a normal and priority queue. Both implementations never block on send and grow as much as necessary. Both also only return errors if you attempt to push to a disposed queue and will not panic like sending a message on a closed channel. The priority queue also allows you to place items in priority order inside the queue. If you give a useful hint to the regular queue, it is actually faster than a channel. The priority queue is somewhat slow currently and targeted for an update to a Fibonacci heap. Also included in the queue package is a MPMC threadsafe ring buffer. This is a block full/empty queue, but will return a blocked thread if the queue is disposed while a thread is blocked. This can be used to synchronize goroutines and ensure goroutines quit so objects can be GC'd. Threadsafety is achieved using only CAS operations making this queue quite fast. Benchmarks can be found in that package. #### Fibonacci Heap A standard Fibonacci heap providing the usual operations. Can be useful in executing Dijkstra or Prim's algorithms in the theoretically minimal time. Also useful as a general-purpose priority queue. The special thing about Fibonacci heaps versus other heap variants is the cheap decrease-key operation. This heap has a constant complexity for find minimum, insert and merge of two heaps, an amortized constant complexity for decrease key and O(log(n)) complexity for a deletion or dequeue minimum. In practice the constant factors are large, so Fibonacci heaps could be slower than Pairing heaps, depending on usage. Benchmarks - in the project subfolder. The heap has not been designed for thread-safety. #### Range Tree Useful to determine if n-dimensional points fall within an n-dimensional range. Not a typical range tree however, as we are actually using an n-dimensional sorted list of points as this proved to be simpler and faster than attempting a traditional range tree while saving space on any dimension greater than one. Inserts are typical BBST times at O(log n^d) where d is the number of dimensions. #### Set Our Set implementation is very simple, accepts items of type `interface{}` and includes only a few methods. If your application requires a richer Set implementation over lists of type `sort.Interface`, see [xtgo/set](https://github.com/xtgo/set) and [goware/set](https://github.com/goware/set). #### Threadsafe A package that is meant to contain some commonly used items but in a threadsafe way. Example: there's a threadsafe error in there as I commonly found myself wanting to set an error in many threads at the same time (yes, I know, but channels are slow). #### AVL Tree This is an example of a branch copy immutable AVL BBST. Any operation on a node makes a copy of that node's branch. Because of this, this tree is inherently threadsafe although the writes will likely still need to be serialized. This structure is good if your use case is a large number of reads and infrequent writes as reads will be highly available but writes somewhat slow due to the copying. This structure serves as a basis for a large number of functional data structures. #### X-Fast Trie An interesting design that treats integers as words and uses a trie structure to reduce time complexities by matching prefixes. This structure is really fast for finding values or making predecessor/successor types of queries, but also results in greater than linear space consumption. The exact time complexities can be found in that package. #### Y-Fast Trie An extension of the X-Fast trie in which an X-Fast trie is combined with some other ordered data structure to reduce space consumption and improve CRUD types of operations. These secondary structures are often BSTs, but our implementation uses a simple ordered list as I believe this improves cache locality. We also use fixed size buckets to aid in parallelization of operations. Exact time complexities are in that package. #### Fast integer hashmap A datastructure used for checking existence but without knowing the bounds of your data. If you have a limited small bounds, the bitarray package might be a better choice. This implementation uses a fairly simple hashing algorithm combined with linear probing and a flat datastructure to provide optimal performance up to a few million integers (faster than the native Golang implementation). Beyond that, the native implementation is faster (I believe they are using a large -ary B-tree). In the future, this will be implemented with a B-tree for scale. #### Skiplist An ordered structure that provides amortized logarithmic operations but without the complication of rotations that are required by BSTs. In testing, however, the performance of the skip list is often far worse than the guaranteed log n time of a BBST. Tall nodes tend to "cast shadows", especially when large bitsizes are required as the optimum maximum height for a node is often based on this. More detailed performance characteristics are provided in that package. #### Sort The sort package implements a multithreaded bucket sort that can be up to 3x faster than the native Golang sort package. These buckets are then merged using a symmetrical merge, similar to the stable sort in the Golang package. However, our algorithm is modified so that two sorted lists can be merged by using symmetrical decomposition. #### Numerics Early work on some nonlinear optimization problems. The initial implementation allows a simple use case with either linear or nonlinear constraints. You can find min/max or target an optimal value. The package currently employs a probabilistic global restart system in an attempt to avoid local critical points. More details can be found in that package. #### B+ Tree Initial implementation of a B+ tree. Delete method still needs added as well as some performance optimization. Specific performance characteristics can be found in that package. Despite the theoretical superiority of BSTs, the B-tree often has better all around performance due to cache locality. The current implementation is mutable, but the immutable AVL tree can be used to build an immutable version. Unfortunately, to make the B-tree generic we require an interface and the most expensive operation in CPU profiling is the interface method which in turn calls into runtime.assertI2T. We need generics. #### Immutable B Tree A btree based on two principles, immutability and concurrency. Somewhat slow for single value lookups and puts, it is very fast for bulk operations. A persister can be injected to make this index persistent. #### Ctrie A concurrent, lock-free hash array mapped trie with efficient non-blocking snapshots. For lookups, Ctries have comparable performance to concurrent skip lists and concurrent hashmaps. One key advantage of Ctries is they are dynamically allocated. Memory consumption is always proportional to the number of keys in the Ctrie, while hashmaps typically have to grow and shrink. Lookups, inserts, and removes are O(logn). One interesting advantage Ctries have over traditional concurrent data structures is support for lock-free, linearizable, constant-time snapshots. Most concurrent data structures do not support snapshots, instead opting for locks or requiring a quiescent state. This allows Ctries to have O(1) iterator creation and clear operations and O(logn) size retrieval. #### Dtrie A persistent hash trie that dynamically expands or shrinks to provide efficient memory allocation. Being persistent, the Dtrie is immutable and any modification yields a new version of the Dtrie rather than changing the original. Bitmapped nodes allow for O(log32(n)) get, remove, and update operations. Insertions are O(n) and iteration is O(1). #### Persistent List A persistent, immutable linked list. All write operations yield a new, updated structure which preserve and reuse previous versions. This uses a very functional, cons-style of list manipulation. Insert, get, remove, and size operations are O(n) as you would expect. #### Simple Graph A mutable, non-persistent undirected graph where parallel edges and self-loops are not permitted. Operations to add an edge as well as retrieve the total number of vertices/edges are O(1) while the operation to retrieve the vertices adjacent to a target is O(n). For more details see [wikipedia](https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)#Simple_graph) ### Installation 1. Install Go 1.3 or higher. 2. Run `go get github.com/Workiva/go-datastructures/...` ### Updating When new code is merged to master, you can use go get -u github.com/Workiva/go-datastructures/... To retrieve the latest version of go-datastructures. ### Testing To run all the unit tests use these commands: cd $GOPATH/src/github.com/Workiva/go-datastructures go get -t -u ./... go test ./... Once you've done this once, you can simply use this command to run all unit tests: go test ./... ### Contributing Requirements to commit here: - Branch off master, PR back to master. - `gofmt`'d code. - Compliance with [these guidelines](https://code.google.com/p/go-wiki/wiki/CodeReviewComments) - Unit test coverage - [Good commit messages](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html) ### Maintainers - Alexander Campbell <[alexander.campbell@workiva.com](mailto:alexander.campbell@workiva.com)> - Dustin Hiatt <[dustin.hiatt@workiva.com](mailto:dustin.hiatt@workiva.com)> - Ryan Jackson <[ryan.jackson@workiva.com](mailto:ryan.jackson@workiva.com)> ================================================ FILE: augmentedtree/atree.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package augmentedtree func intervalOverlaps(n *node, low, high int64, interval Interval, maxDimension uint64) bool { if !overlaps(n.interval.HighAtDimension(1), high, n.interval.LowAtDimension(1), low) { return false } if interval == nil { return true } for i := uint64(2); i <= maxDimension; i++ { if !n.interval.OverlapsAtDimension(interval, i) { return false } } return true } func overlaps(high, otherHigh, low, otherLow int64) bool { return high >= otherLow && low <= otherHigh } // compare returns an int indicating which direction the node // should go. func compare(nodeLow, ivLow int64, nodeID, ivID uint64) int { if ivLow > nodeLow { return 1 } if ivLow < nodeLow { return 0 } return intFromBool(ivID > nodeID) } type node struct { interval Interval max, min int64 // max value held by children children [2]*node // array to hold left/right red bool // indicates if this node is red id uint64 // we store the id locally to reduce the number of calls to the method on the interface } func (n *node) query(low, high int64, interval Interval, maxDimension uint64, fn func(node *node)) { if n.children[0] != nil && overlaps(n.children[0].max, high, n.children[0].min, low) { n.children[0].query(low, high, interval, maxDimension, fn) } if intervalOverlaps(n, low, high, interval, maxDimension) { fn(n) } if n.children[1] != nil && overlaps(n.children[1].max, high, n.children[1].min, low) { n.children[1].query(low, high, interval, maxDimension, fn) } } func (n *node) adjustRanges() { for i := 0; i <= 1; i++ { if n.children[i] != nil { n.children[i].adjustRanges() } } n.adjustRange() } func (n *node) adjustRange() { setMin(n) setMax(n) } func newDummy() node { return node{ children: [2]*node{}, } } func newNode(interval Interval, min, max int64, dimension uint64) *node { itn := &node{ interval: interval, min: min, max: max, red: true, children: [2]*node{}, } if interval != nil { itn.id = interval.ID() } return itn } type tree struct { root *node maxDimension, number uint64 dummy node } func (t *tree) Traverse(fn func(id Interval)) { nodes := []*node{t.root} for len(nodes) != 0 { c := nodes[len(nodes)-1] nodes = nodes[:len(nodes)-1] if c != nil { fn(c.interval) if c.children[0] != nil { nodes = append(nodes, c.children[0]) } if c.children[1] != nil { nodes = append(nodes, c.children[1]) } } } } func (tree *tree) resetDummy() { tree.dummy.children[0], tree.dummy.children[1] = nil, nil tree.dummy.red = false } // Len returns the number of items in this tree. func (tree *tree) Len() uint64 { return tree.number } // add will add the provided interval to the tree. func (tree *tree) add(iv Interval) { if tree.root == nil { tree.root = newNode( iv, iv.LowAtDimension(1), iv.HighAtDimension(1), 1, ) tree.root.red = false tree.number++ return } tree.resetDummy() var ( dummy = tree.dummy parent, grandParent *node node = tree.root dir, last int otherLast = 1 id = iv.ID() max = iv.HighAtDimension(1) ivLow = iv.LowAtDimension(1) helper = &dummy ) // set this AFTER clearing dummy helper.children[1] = tree.root for { if node == nil { node = newNode(iv, ivLow, max, 1) parent.children[dir] = node tree.number++ } else if isRed(node.children[0]) && isRed(node.children[1]) { node.red = true node.children[0].red = false node.children[1].red = false } if max > node.max { node.max = max } if ivLow < node.min { node.min = ivLow } if isRed(parent) && isRed(node) { localDir := intFromBool(helper.children[1] == grandParent) if node == parent.children[last] { helper.children[localDir] = rotate(grandParent, otherLast) } else { helper.children[localDir] = doubleRotate(grandParent, otherLast) } } if node.id == id { break } last = dir otherLast = takeOpposite(last) dir = compare(node.interval.LowAtDimension(1), ivLow, node.id, id) if grandParent != nil { helper = grandParent } grandParent, parent, node = parent, node, node.children[dir] } tree.root = dummy.children[1] tree.root.red = false } // Add will add the provided intervals to this tree. func (tree *tree) Add(intervals ...Interval) { for _, iv := range intervals { tree.add(iv) } } // delete will remove the provided interval from the tree. func (tree *tree) delete(iv Interval) { if tree.root == nil { return } tree.resetDummy() var ( dummy = tree.dummy found, parent, grandParent *node last, otherDir, otherLast int // keeping track of last direction id = iv.ID() dir = 1 node = &dummy ivLow = iv.LowAtDimension(1) ) node.children[1] = tree.root for node.children[dir] != nil { last = dir otherLast = takeOpposite(last) grandParent, parent, node = parent, node, node.children[dir] dir = compare(node.interval.LowAtDimension(1), ivLow, node.id, id) otherDir = takeOpposite(dir) if node.id == id { found = node } if !isRed(node) && !isRed(node.children[dir]) { if isRed(node.children[otherDir]) { parent.children[last] = rotate(node, dir) parent = parent.children[last] } else if !isRed(node.children[otherDir]) { t := parent.children[otherLast] if t != nil { if !isRed(t.children[otherLast]) && !isRed(t.children[last]) { parent.red = false node.red = true t.red = true } else { localDir := intFromBool(grandParent.children[1] == parent) if isRed(t.children[last]) { grandParent.children[localDir] = doubleRotate( parent, last, ) } else if isRed(t.children[otherLast]) { grandParent.children[localDir] = rotate( parent, last, ) } node.red = true grandParent.children[localDir].red = true grandParent.children[localDir].children[0].red = false grandParent.children[localDir].children[1].red = false } } } } } if found != nil { tree.number-- found.interval, found.max, found.min, found.id = node.interval, node.max, node.min, node.id parentDir := intFromBool(parent.children[1] == node) childDir := intFromBool(node.children[0] == nil) parent.children[parentDir] = node.children[childDir] } tree.root = dummy.children[1] if tree.root != nil { tree.root.red = false } } // Delete will remove the provided intervals from this tree. func (tree *tree) Delete(intervals ...Interval) { for _, iv := range intervals { tree.delete(iv) } if tree.root != nil { tree.root.adjustRanges() } } // Query will return a list of intervals that intersect the provided // interval. The provided interval's ID method is ignored so the // provided ID is irrelevant. func (tree *tree) Query(interval Interval) Intervals { if tree.root == nil { return nil } var ( Intervals = intervalsPool.Get().(Intervals) ivLow = interval.LowAtDimension(1) ivHigh = interval.HighAtDimension(1) ) tree.root.query(ivLow, ivHigh, interval, tree.maxDimension, func(node *node) { Intervals = append(Intervals, node.interval) }) return Intervals } func isRed(node *node) bool { return node != nil && node.red } func setMax(parent *node) { parent.max = parent.interval.HighAtDimension(1) if parent.children[0] != nil && parent.children[0].max > parent.max { parent.max = parent.children[0].max } if parent.children[1] != nil && parent.children[1].max > parent.max { parent.max = parent.children[1].max } } func setMin(parent *node) { parent.min = parent.interval.LowAtDimension(1) if parent.children[0] != nil && parent.children[0].min < parent.min { parent.min = parent.children[0].min } if parent.children[1] != nil && parent.children[1].min < parent.min { parent.min = parent.children[1].min } if parent.interval.LowAtDimension(1) < parent.min { parent.min = parent.interval.LowAtDimension(1) } } func rotate(parent *node, dir int) *node { otherDir := takeOpposite(dir) child := parent.children[otherDir] parent.children[otherDir] = child.children[dir] child.children[dir] = parent parent.red = true child.red = false child.max = parent.max setMax(child) setMax(parent) setMin(child) setMin(parent) return child } func doubleRotate(parent *node, dir int) *node { otherDir := takeOpposite(dir) parent.children[otherDir] = rotate(parent.children[otherDir], otherDir) return rotate(parent, dir) } func intFromBool(value bool) int { if value { return 1 } return 0 } func takeOpposite(value int) int { return 1 - value } func newTree(maxDimension uint64) *tree { return &tree{ maxDimension: maxDimension, dummy: newDummy(), } } // New constructs and returns a new interval tree with the max // dimensions provided. func New(dimensions uint64) Tree { return newTree(dimensions) } ================================================ FILE: augmentedtree/atree_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package augmentedtree import ( "testing" "github.com/stretchr/testify/assert" ) func min(one, two int64) int64 { if one == -1 { return two } if two == -1 { return one } if one > two { return two } return one } func max(one, two int64) int64 { if one == -1 { return two } if two == -1 { return one } if one > two { return one } return two } func checkRedBlack(tb testing.TB, node *node, dimension int) (int64, int64, int64) { lh, rh := 0, 0 if node == nil { return 1, -1, -1 } if isRed(node) { if isRed(node.children[0]) || isRed(node.children[1]) { tb.Errorf(`Node is red and has red children: %+v`, node) } } fn := func(min, max int64) { if min != -1 && min < node.min { tb.Errorf(`Min not set correctly: %+v, node: %+v`, min, node) } if max != -1 && max > node.max { tb.Errorf(`Max not set correctly: %+v, node: %+v`, max, node) } } left, minL, maxL := checkRedBlack(tb, node.children[0], dimension) fn(minL, maxL) right, minR, maxR := checkRedBlack(tb, node.children[1], dimension) fn(minR, maxR) min := min(minL, minR) if min == -1 && node.min != node.interval.LowAtDimension(1) { tb.Errorf(`Min not set correctly, node: %+v`, node) } else if min != -1 && node.children[0] != nil && node.children[0].min != node.min { tb.Errorf(`Min not set correctly: node: %+v, child: %+v`, node, node.children[0]) } else if min != -1 && node.children[0] == nil && node.min != node.interval.LowAtDimension(1) { tb.Errorf(`Min not set correctly: %+v`, node) } max := max(maxL, maxR) if max == -1 && node.max != node.interval.HighAtDimension(1) { tb.Errorf(`Max not set correctly, node: %+v`, node) } else if max > node.interval.HighAtDimension(1) && max != node.max { tb.Errorf(`Max not set correctly, max: %+v, node: %+v`, max, node) } if left != 0 && right != 0 && lh != rh { tb.Errorf(`Black violation: left: %d, right: %d`, left, right) } if left != 0 && right != 0 { if isRed(node) { return left, node.min, node.max } return left + 1, node.min, node.max } return 0, node.min, node.max } func constructSingleDimensionTestTree(number int) (*tree, Intervals) { tree := newTree(1) ivs := make(Intervals, 0, number) for i := 0; i < number; i++ { iv := constructSingleDimensionInterval(int64(i), int64(i)+10, uint64(i)) ivs = append(ivs, iv) } tree.Add(ivs...) return tree, ivs } func TestSimpleAddNilRoot(t *testing.T) { it := newTree(1) iv := constructSingleDimensionInterval(5, 10, 0) it.Add(iv) expected := newNode(iv, 5, 10, 1) expected.red = false assert.Equal(t, expected, it.root) assert.Equal(t, uint64(1), it.Len()) checkRedBlack(t, it.root, 1) } func TestSimpleAddRootLeft(t *testing.T) { it := newTree(1) iv := constructSingleDimensionInterval(5, 10, 0) it.Add(iv) expectedRoot := newNode(iv, 4, 11, 1) expectedRoot.red = false iv = constructSingleDimensionInterval(4, 11, 1) it.Add(iv) expectedChild := newNode(iv, 4, 11, 1) expectedRoot.children[0] = expectedChild assert.Equal(t, expectedRoot, it.root) assert.Equal(t, uint64(2), it.Len()) checkRedBlack(t, it.root, 1) } func TestSimpleAddRootRight(t *testing.T) { it := newTree(1) iv := constructSingleDimensionInterval(5, 10, 0) it.Add(iv) expectedRoot := newNode(iv, 5, 11, 1) expectedRoot.red = false iv = constructSingleDimensionInterval(7, 11, 1) it.Add(iv) expectedChild := newNode(iv, 7, 11, 1) expectedRoot.children[1] = expectedChild assert.Equal(t, expectedRoot, it.root) assert.Equal(t, uint64(2), it.Len()) checkRedBlack(t, it.root, 1) } func TestAddRootLeftAndRight(t *testing.T) { it := newTree(1) iv := constructSingleDimensionInterval(5, 10, 0) it.Add(iv) expectedRoot := newNode(iv, 4, 12, 1) expectedRoot.red = false iv = constructSingleDimensionInterval(4, 11, 1) it.Add(iv) expectedLeft := newNode(iv, 4, 11, 1) expectedRoot.children[0] = expectedLeft iv = constructSingleDimensionInterval(7, 12, 1) it.Add(iv) expectedRight := newNode(iv, 7, 12, 1) expectedRoot.children[1] = expectedRight assert.Equal(t, expectedRoot, it.root) assert.Equal(t, uint64(3), it.Len()) checkRedBlack(t, it.root, 1) } func TestAddRebalanceInOrder(t *testing.T) { it := newTree(1) for i := int64(0); i < 10; i++ { iv := constructSingleDimensionInterval(i, i+1, uint64(i)) it.add(iv) } checkRedBlack(t, it.root, 1) result := it.Query(constructSingleDimensionInterval(0, 10, 0)) assert.Len(t, result, 10) assert.Equal(t, uint64(10), it.Len()) } func TestAddRebalanceOutOfOrder(t *testing.T) { it := newTree(1) for i := int64(9); i >= 0; i-- { iv := constructSingleDimensionInterval(i, i+1, uint64(i)) it.add(iv) } checkRedBlack(t, it.root, 1) result := it.Query(constructSingleDimensionInterval(0, 10, 0)) assert.Len(t, result, 10) assert.Equal(t, uint64(10), it.Len()) } func TestAddRebalanceRandomOrder(t *testing.T) { it := newTree(1) starts := []int64{0, 4, 2, 1, 3} for _, start := range starts { iv := constructSingleDimensionInterval(start, start+1, uint64(start)) it.add(iv) } checkRedBlack(t, it.root, 1) result := it.Query(constructSingleDimensionInterval(0, 10, 0)) assert.Len(t, result, 5) assert.Equal(t, uint64(5), it.Len()) } func TestAddLargeNumberOfItems(t *testing.T) { numItems := int64(1000) it := newTree(1) for i := int64(0); i < numItems; i++ { iv := constructSingleDimensionInterval(i, i+1, uint64(i)) it.add(iv) } checkRedBlack(t, it.root, 1) result := it.Query(constructSingleDimensionInterval(0, numItems, 0)) assert.Len(t, result, int(numItems)) assert.Equal(t, uint64(numItems), it.Len()) } func BenchmarkAddItems(b *testing.B) { numItems := int64(1000) intervals := make(Intervals, 0, numItems) for i := int64(0); i < numItems; i++ { iv := constructSingleDimensionInterval(i, i+1, uint64(i)) intervals = append(intervals, iv) } b.ResetTimer() for i := 0; i < b.N; i++ { it := newTree(1) it.Add(intervals...) } } func BenchmarkQueryItems(b *testing.B) { numItems := int64(1000) intervals := make(Intervals, 0, numItems) for i := int64(0); i < numItems; i++ { iv := constructSingleDimensionInterval(i, i+1, uint64(i)) intervals = append(intervals, iv) } it := newTree(1) it.Add(intervals...) b.ResetTimer() for i := 0; i < b.N; i++ { it.Query(constructSingleDimensionInterval(0, numItems, 0)) } } func constructSingleDimensionQueryTestTree() ( *tree, Interval, Interval, Interval) { it := newTree(1) iv1 := constructSingleDimensionInterval(6, 10, 0) it.Add(iv1) iv2 := constructSingleDimensionInterval(4, 5, 1) it.Add(iv2) iv3 := constructSingleDimensionInterval(7, 12, 2) it.Add(iv3) return it, iv1, iv2, iv3 } func TestSimpleQuery(t *testing.T) { it, iv1, iv2, _ := constructSingleDimensionQueryTestTree() result := it.Query(constructSingleDimensionInterval(3, 6, 0)) expected := Intervals{iv2, iv1} assert.Equal(t, expected, result) } func TestRightQuery(t *testing.T) { it, iv1, _, iv3 := constructSingleDimensionQueryTestTree() result := it.Query(constructSingleDimensionInterval(6, 8, 0)) expected := Intervals{iv1, iv3} assert.Equal(t, expected, result) } func TestLeftQuery(t *testing.T) { it, _, iv2, _ := constructSingleDimensionQueryTestTree() result := it.Query(constructSingleDimensionInterval(3, 5, 0)) expected := Intervals{iv2} assert.Equal(t, expected, result) } func TestMatchingQuery(t *testing.T) { it, _, iv2, _ := constructSingleDimensionQueryTestTree() result := it.Query(constructSingleDimensionInterval(4, 5, 0)) expected := Intervals{iv2} assert.Equal(t, expected, result) } func TestNoMatchLeft(t *testing.T) { it, _, _, _ := constructSingleDimensionQueryTestTree() result := it.Query(constructSingleDimensionInterval(1, 3, 0)) expected := Intervals{} assert.Equal(t, expected, result) } func TestNoMatchRight(t *testing.T) { it, _, _, _ := constructSingleDimensionQueryTestTree() result := it.Query(constructSingleDimensionInterval(13, 13, 0)) expected := Intervals{} assert.Equal(t, expected, result) } func TestAllQuery(t *testing.T) { it, iv1, iv2, iv3 := constructSingleDimensionQueryTestTree() result := it.Query(constructSingleDimensionInterval(1, 14, 0)) expected := Intervals{iv2, iv1, iv3} assert.Equal(t, expected, result) } func TestQueryDuplicate(t *testing.T) { it, _, iv2, _ := constructSingleDimensionQueryTestTree() iv4 := constructSingleDimensionInterval(4, 5, 3) it.Add(iv4) result := it.Query(constructSingleDimensionInterval(4, 5, 0)) expected := Intervals{iv2, iv4} assert.Equal(t, expected, result) } func TestRootDelete(t *testing.T) { it := newTree(1) iv := constructSingleDimensionInterval(1, 5, 1) it.add(iv) it.Delete(iv) checkRedBlack(t, it.root, 1) result := it.Query(constructSingleDimensionInterval(1, 10, 0)) assert.Len(t, result, 0) assert.Equal(t, uint64(0), it.Len()) } func TestDeleteLeft(t *testing.T) { it, iv1, iv2, iv3 := constructSingleDimensionQueryTestTree() it.Delete(iv2) expected := Intervals{iv1, iv3} result := it.Query(constructSingleDimensionInterval(0, 10, 0)) checkRedBlack(t, it.root, 1) assert.Equal(t, expected, result) assert.Equal(t, uint64(2), it.Len()) } func TestDeleteRight(t *testing.T) { it, iv1, iv2, iv3 := constructSingleDimensionQueryTestTree() it.Delete(iv3) expected := Intervals{iv2, iv1} result := it.Query(constructSingleDimensionInterval(0, 10, 0)) checkRedBlack(t, it.root, 1) assert.Equal(t, expected, result) assert.Equal(t, uint64(2), it.Len()) } func TestDeleteCenter(t *testing.T) { it, iv1, iv2, iv3 := constructSingleDimensionQueryTestTree() it.Delete(iv1) expected := Intervals{iv2, iv3} result := it.Query(constructSingleDimensionInterval(0, 10, 0)) checkRedBlack(t, it.root, 1) assert.Equal(t, expected, result) assert.Equal(t, uint64(2), it.Len()) } func TestDeleteRebalanceInOrder(t *testing.T) { it := newTree(1) var toDelete *mockInterval for i := int64(0); i < 10; i++ { iv := constructSingleDimensionInterval(i, i+1, uint64(i)) it.add(iv) if i == 5 { toDelete = iv } } it.Delete(toDelete) checkRedBlack(t, it.root, 1) result := it.Query(constructSingleDimensionInterval(0, 10, 0)) assert.Len(t, result, 9) assert.Equal(t, uint64(9), it.Len()) } func TestDeleteRebalanceOutOfOrder(t *testing.T) { it := newTree(1) var toDelete *mockInterval for i := int64(9); i >= 0; i-- { iv := constructSingleDimensionInterval(i, i+1, uint64(i)) it.add(iv) if i == 5 { toDelete = iv } } it.Delete(toDelete) checkRedBlack(t, it.root, 1) result := it.Query(constructSingleDimensionInterval(0, 10, 0)) assert.Len(t, result, 9) assert.Equal(t, uint64(9), it.Len()) } func TestDeleteRebalanceRandomOrder(t *testing.T) { it := newTree(1) starts := []int64{0, 4, 2, 1, 3} var toDelete *mockInterval for _, start := range starts { iv := constructSingleDimensionInterval(start, start+1, uint64(start)) it.add(iv) if start == 1 { toDelete = iv } } it.Delete(toDelete) checkRedBlack(t, it.root, 1) result := it.Query(constructSingleDimensionInterval(0, 10, 0)) assert.Len(t, result, 4) assert.Equal(t, uint64(4), it.Len()) } func TestDeleteEmptyTree(t *testing.T) { it := newTree(1) it.Delete(constructSingleDimensionInterval(0, 1, 1)) assert.Equal(t, uint64(0), it.Len()) } func BenchmarkDeleteItems(b *testing.B) { numItems := int64(1000) intervals := make(Intervals, 0, numItems) for i := int64(0); i < numItems; i++ { iv := constructSingleDimensionInterval(i, i+1, uint64(i)) intervals = append(intervals, iv) } trees := make([]*tree, 0, b.N) for i := 0; i < b.N; i++ { it := newTree(1) it.Add(intervals...) trees = append(trees, it) } b.ResetTimer() for i := 0; i < b.N; i++ { trees[i].Delete(intervals...) } } func TestAddDuplicateRanges(t *testing.T) { it := newTree(1) iv1 := constructSingleDimensionInterval(0, 10, 1) iv2 := constructSingleDimensionInterval(0, 10, 2) iv3 := constructSingleDimensionInterval(0, 10, 3) it.Add(iv1, iv2, iv3) it.Delete(iv1, iv2, iv3) assert.Equal(t, uint64(0), it.Len()) } func TestAddDeleteDuplicatesRebalanceInOrder(t *testing.T) { it := newTree(1) intervals := make(Intervals, 0, 10) for i := 0; i < 10; i++ { iv := constructSingleDimensionInterval(0, 10, uint64(i)) intervals = append(intervals, iv) } it.Add(intervals...) it.Delete(intervals...) assert.Equal(t, uint64(0), it.Len()) } func TestAddDeleteDuplicatesRebalanceReverseOrder(t *testing.T) { it := newTree(1) intervals := make(Intervals, 0, 10) for i := 9; i >= 0; i-- { iv := constructSingleDimensionInterval(0, 10, uint64(i)) intervals = append(intervals, iv) } it.Add(intervals...) it.Delete(intervals...) assert.Equal(t, uint64(0), it.Len()) } func TestAddDeleteDuplicatesRebalanceRandomOrder(t *testing.T) { it := newTree(1) starts := []int{0, 4, 2, 1, 3} intervals := make(Intervals, 0, 5) for _, start := range starts { iv := constructSingleDimensionInterval(0, 10, uint64(start)) intervals = append(intervals, iv) } it.Add(intervals...) it.Delete(intervals...) assert.Equal(t, uint64(0), it.Len()) } func TestInsertDuplicateIntervalsToRoot(t *testing.T) { tree := newTree(1) iv1 := constructSingleDimensionInterval(0, 10, 1) iv2 := constructSingleDimensionInterval(0, 10, 1) iv3 := constructSingleDimensionInterval(0, 10, 1) tree.Add(iv1, iv2, iv3) checkRedBlack(t, tree.root, 1) } func TestInsertDuplicateIntervalChildren(t *testing.T) { tree, _ := constructSingleDimensionTestTree(20) iv1 := constructSingleDimensionInterval(0, 10, 21) iv2 := constructSingleDimensionInterval(0, 10, 21) tree.Add(iv1, iv2) checkRedBlack(t, tree.root, 1) result := tree.Query(constructSingleDimensionInterval(0, 10, 0)) assert.Contains(t, result, iv1) } func TestTraverse(t *testing.T) { tree := newTree(1) tree.Traverse(func(i Interval) { assert.Fail(t, `traverse should not be called for empty tree`) }) top := 30 for i := 0; i <= top; i++ { tree.Add(constructSingleDimensionInterval(int64(i*10), int64((i+1)*10), uint64(i))) } found := map[uint64]bool{} tree.Traverse(func(id Interval) { found[id.ID()] = true }) for i := 0; i <= top; i++ { if found, _ := found[uint64(i)]; !found { t.Errorf("could not find expected interval %d", i) } } } ================================================ FILE: augmentedtree/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package augmentedtree is designed to be useful when checking for intersection of ranges in n-dimensions. For instance, if you imagine an xy plane then the augmented tree is for telling you if plane defined by the points (0, 0) and (10, 10). The augmented tree can tell you if that plane overlaps with a plane defined by (-5, -5) and (5, 5) (true in this case). You can also check intersections against a point by constructing a range of encompassed solely if a single point. The current tree is a simple top-down red-black binary search tree. TODO: Add a bottom-up implementation to assist with duplicate range handling. */ package augmentedtree // Interval is the interface that must be implemented by any // item added to the interval tree. This interface is similar to the // interval found in the rangetree package and it should be possible // for the same struct to implement both interfaces. Note that ranges // here are inclusive. It is also expected that the provided interval // is immutable and that the returned values never change. Doing so // results in undefined behavior. type Interval interface { // LowAtDimension returns an integer representing the lower bound // at the requested dimension. LowAtDimension(uint64) int64 // HighAtDimension returns an integer representing the higher bound // at the requested dimension. HighAtDimension(uint64) int64 // OverlapsAtDimension should return a bool indicating if the provided // interval overlaps this interval at the dimension requested. OverlapsAtDimension(Interval, uint64) bool // ID should be a unique ID representing this interval. This // is used to identify which interval to delete from the tree if // there are duplicates. ID() uint64 } // Tree defines the object that is returned from the // tree constructor. We use a Tree interface here because // the returned tree could be a single dimension or many // dimensions. type Tree interface { // Add will add the provided intervals to the tree. Add(intervals ...Interval) // Len returns the number of intervals in the tree. Len() uint64 // Delete will remove the provided intervals from the tree. Delete(intervals ...Interval) // Query will return a list of intervals that intersect the provided // interval. The provided interval's ID method is ignored so the // provided ID is irrelevant. Query(interval Interval) Intervals // Traverse will traverse tree and give alls intervals // found in an undefined order Traverse(func(Interval)) } ================================================ FILE: augmentedtree/intervals.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package augmentedtree import "sync" var intervalsPool = sync.Pool{ New: func() interface{} { return make(Intervals, 0, 10) }, } // Intervals represents a list of Intervals. type Intervals []Interval // Dispose will free any consumed resources and allow this list to be // re-allocated. func (ivs *Intervals) Dispose() { for i := 0; i < len(*ivs); i++ { (*ivs)[i] = nil } *ivs = (*ivs)[:0] intervalsPool.Put(*ivs) } ================================================ FILE: augmentedtree/intervals_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package augmentedtree import ( "testing" "github.com/stretchr/testify/assert" ) func TestIntervalsDispose(t *testing.T) { intervals := intervalsPool.Get().(Intervals) intervals = append(intervals, constructSingleDimensionInterval(0, 1, 0)) intervals.Dispose() assert.Len(t, intervals, 0) } ================================================ FILE: augmentedtree/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package augmentedtree import "fmt" type dimension struct { low, high int64 } type mockInterval struct { dimensions []*dimension id uint64 } func (mi *mockInterval) checkDimension(dimension uint64) { if dimension > uint64(len(mi.dimensions)) { panic(fmt.Sprintf(`Dimension: %d out of range.`, dimension)) } } func (mi *mockInterval) LowAtDimension(dimension uint64) int64 { return mi.dimensions[dimension-1].low } func (mi *mockInterval) HighAtDimension(dimension uint64) int64 { return mi.dimensions[dimension-1].high } func (mi *mockInterval) OverlapsAtDimension(iv Interval, dimension uint64) bool { return mi.HighAtDimension(dimension) > iv.LowAtDimension(dimension) && mi.LowAtDimension(dimension) < iv.HighAtDimension(dimension) } func (mi mockInterval) ID() uint64 { return mi.id } func constructSingleDimensionInterval(low, high int64, id uint64) *mockInterval { return &mockInterval{[]*dimension{&dimension{low: low, high: high}}, id} } func constructMultiDimensionInterval(id uint64, dimensions ...*dimension) *mockInterval { return &mockInterval{dimensions: dimensions, id: id} } ================================================ FILE: augmentedtree/multidimensional_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package augmentedtree import ( "testing" "github.com/stretchr/testify/assert" ) func constructMultiDimensionQueryTestTree() ( *tree, Interval, Interval, Interval) { it := newTree(2) iv1 := constructMultiDimensionInterval( 0, &dimension{low: 5, high: 10}, &dimension{low: 5, high: 10}, ) it.Add(iv1) iv2 := constructMultiDimensionInterval( 1, &dimension{low: 4, high: 5}, &dimension{low: 4, high: 5}, ) it.Add(iv2) iv3 := constructMultiDimensionInterval( 2, &dimension{low: 7, high: 12}, &dimension{low: 7, high: 12}, ) it.Add(iv3) return it, iv1, iv2, iv3 } func TestRootAddMultipleDimensions(t *testing.T) { it := newTree(2) iv := constructMultiDimensionInterval( 1, &dimension{low: 0, high: 5}, &dimension{low: 1, high: 6}, ) it.Add(iv) checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 10}, &dimension{0, 10}, ), ) assert.Equal(t, Intervals{iv}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{100, 200}, &dimension{100, 200}, ), ) assert.Len(t, result, 0) } func TestMultipleAddMultipleDimensions(t *testing.T) { it, iv1, iv2, iv3 := constructMultiDimensionQueryTestTree() checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 100}, &dimension{0, 100}, ), ) assert.Equal(t, Intervals{iv2, iv1, iv3}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{3, 5}, &dimension{3, 5}, ), ) assert.Equal(t, Intervals{iv2}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{5, 8}, &dimension{5, 8}, ), ) assert.Equal(t, Intervals{iv1, iv3}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{11, 15}, &dimension{11, 15}, ), ) assert.Equal(t, Intervals{iv3}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{15, 20}, &dimension{15, 20}, ), ) assert.Len(t, result, 0) } func TestAddRebalanceInOrderMultiDimensions(t *testing.T) { it := newTree(2) for i := int64(0); i < 10; i++ { iv := constructMultiDimensionInterval( uint64(i), &dimension{i, i + 1}, &dimension{i, i + 1}, ) it.Add(iv) } checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 10}, &dimension{0, 10}, ), ) assert.Len(t, result, 10) assert.Equal(t, uint64(10), it.Len()) } func TestAddRebalanceReverseOrderMultiDimensions(t *testing.T) { it := newTree(2) for i := int64(9); i >= 0; i-- { iv := constructMultiDimensionInterval( uint64(i), &dimension{i, i + 1}, &dimension{i, i + 1}, ) it.Add(iv) } checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 10}, &dimension{0, 10}, ), ) assert.Len(t, result, 10) assert.Equal(t, uint64(10), it.Len()) } func TestAddRebalanceRandomOrderMultiDimensions(t *testing.T) { it := newTree(2) starts := []int64{0, 4, 2, 1, 3} for i, start := range starts { iv := constructMultiDimensionInterval( uint64(i), &dimension{start, start + 1}, &dimension{start, start + 1}, ) it.Add(iv) } checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 10}, &dimension{0, 10}, ), ) assert.Len(t, result, 5) assert.Equal(t, uint64(5), it.Len()) } func TestAddLargeNumbersMultiDimensions(t *testing.T) { numItems := int64(1000) it := newTree(2) for i := int64(0); i < numItems; i++ { iv := constructMultiDimensionInterval( uint64(i), &dimension{i, i + 1}, &dimension{i, i + 1}, ) it.Add(iv) } checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, numItems}, &dimension{0, numItems}, ), ) assert.Len(t, result, int(numItems)) assert.Equal(t, uint64(numItems), it.Len()) } func BenchmarkAddItemsMultiDimensions(b *testing.B) { numItems := int64(b.N) intervals := make(Intervals, 0, numItems) for i := int64(0); i < numItems; i++ { iv := constructMultiDimensionInterval( uint64(i), &dimension{i, i + 1}, &dimension{i, i + 1}, ) intervals = append(intervals, iv) } it := newTree(2) b.ResetTimer() for i := 0; i < b.N; i++ { it.Add(intervals[int64(i)%numItems]) } } func BenchmarkQueryItemsMultiDimensions(b *testing.B) { numItems := int64(1000) intervals := make(Intervals, 0, numItems) for i := int64(0); i < numItems; i++ { iv := constructMultiDimensionInterval( uint64(i), &dimension{i, i + 1}, &dimension{i, i + 1}, ) intervals = append(intervals, iv) } it := newTree(2) it.Add(intervals...) b.ResetTimer() for i := 0; i < b.N; i++ { it.Query( constructMultiDimensionInterval( 0, &dimension{0, numItems}, &dimension{0, numItems}, ), ) } } func TestRootDeleteMultiDimensions(t *testing.T) { it := newTree(2) iv := constructMultiDimensionInterval( 0, &dimension{low: 5, high: 10}, &dimension{low: 5, high: 10}, ) it.Add(iv) it.Delete(iv) checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 100}, &dimension{0, 100}, ), ) assert.Len(t, result, 0) assert.Equal(t, uint64(0), it.Len()) } func TestDeleteMultiDimensions(t *testing.T) { it, iv1, iv2, iv3 := constructMultiDimensionQueryTestTree() checkRedBlack(t, it.root, 1) it.Delete(iv1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 100}, &dimension{0, 100}, ), ) assert.Equal(t, Intervals{iv2, iv3}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{3, 5}, &dimension{3, 5}, ), ) assert.Equal(t, Intervals{iv2}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{5, 8}, &dimension{5, 8}, ), ) assert.Equal(t, Intervals{iv3}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{11, 15}, &dimension{11, 15}, ), ) assert.Equal(t, Intervals{iv3}, result) result = it.Query( constructMultiDimensionInterval( 0, &dimension{15, 20}, &dimension{15, 20}, ), ) assert.Len(t, result, 0) } func TestDeleteRebalanceInOrderMultiDimensions(t *testing.T) { it := newTree(2) var toDelete *mockInterval for i := int64(0); i < 10; i++ { iv := constructMultiDimensionInterval( uint64(i), &dimension{i, i + 1}, &dimension{i, i + 1}, ) it.Add(iv) if i == 5 { toDelete = iv } } it.Delete(toDelete) checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 10}, &dimension{0, 10}, ), ) assert.Len(t, result, 9) assert.Equal(t, uint64(9), it.Len()) } func TestDeleteRebalanceReverseOrderMultiDimensions(t *testing.T) { it := newTree(2) var toDelete *mockInterval for i := int64(9); i >= 0; i-- { iv := constructMultiDimensionInterval( uint64(i), &dimension{i, i + 1}, &dimension{i, i + 1}, ) it.Add(iv) if i == 5 { toDelete = iv } } it.Delete(toDelete) checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 10}, &dimension{0, 10}, ), ) assert.Len(t, result, 9) assert.Equal(t, uint64(9), it.Len()) } func TestDeleteRebalanceRandomOrderMultiDimensions(t *testing.T) { it := newTree(2) starts := []int64{0, 4, 2, 1, 3} var toDelete *mockInterval for i, start := range starts { iv := constructMultiDimensionInterval( uint64(i), &dimension{start, start + 1}, &dimension{start, start + 1}, ) it.Add(iv) if start == 1 { toDelete = iv } } it.Delete(toDelete) checkRedBlack(t, it.root, 1) result := it.Query( constructMultiDimensionInterval( 0, &dimension{0, 10}, &dimension{0, 10}, ), ) assert.Len(t, result, 4) assert.Equal(t, uint64(4), it.Len()) } func TestDeleteEmptyTreeMultiDimensions(t *testing.T) { it := newTree(2) it.Delete( constructMultiDimensionInterval( 0, &dimension{0, 10}, &dimension{0, 10}, ), ) assert.Equal(t, uint64(0), it.Len()) } func BenchmarkDeleteItemsMultiDimensions(b *testing.B) { numItems := int64(1000) intervals := make(Intervals, 0, numItems) for i := int64(0); i < numItems; i++ { iv := constructMultiDimensionInterval( uint64(i), &dimension{i, i + 1}, &dimension{i, i + 1}, ) intervals = append(intervals, iv) } trees := make([]*tree, 0, b.N) for i := 0; i < b.N; i++ { it := newTree(2) it.Add(intervals...) trees = append(trees, it) } b.ResetTimer() for i := 0; i < b.N; i++ { trees[i].Delete(intervals...) } } func TestAddDeleteDuplicatesRebalanceInOrderMultiDimensions(t *testing.T) { it := newTree(2) intervals := make(Intervals, 0, 10) for i := 0; i < 10; i++ { iv := constructMultiDimensionInterval( uint64(i), &dimension{0, 10}, &dimension{0, 10}, ) intervals = append(intervals, iv) } it.Add(intervals...) it.Delete(intervals...) assert.Equal(t, uint64(0), it.Len()) } func TestAddDeleteDuplicatesRebalanceReverseOrderMultiDimensions(t *testing.T) { it := newTree(2) intervals := make(Intervals, 0, 10) for i := 9; i >= 0; i-- { iv := constructMultiDimensionInterval( uint64(i), &dimension{0, 10}, &dimension{0, 10}, ) intervals = append(intervals, iv) } it.Add(intervals...) it.Delete(intervals...) assert.Equal(t, uint64(0), it.Len()) } func TestAddDeleteDuplicatesRebalanceRandomOrderMultiDimensions(t *testing.T) { it := newTree(2) intervals := make(Intervals, 0, 5) starts := []int{0, 4, 2, 1, 3} for _, start := range starts { iv := constructMultiDimensionInterval( uint64(start), &dimension{0, 10}, &dimension{0, 10}, ) intervals = append(intervals, iv) } it.Add(intervals...) it.Delete(intervals...) assert.Equal(t, uint64(0), it.Len()) } ================================================ FILE: aviary.yaml ================================================ version: 1 exclude: - tests?/ raven_monitored_classes: null raven_monitored_files: null raven_monitored_functions: null raven_monitored_keywords: null ================================================ FILE: batcher/batcher.go ================================================ /* Copyright 2015 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package batcher import ( "errors" "time" ) // I honestly can't believe I'm doing this, but go's sync package doesn't // have a TryLock function. // Could probably do this with atomics type mutex struct { // This is really more of a semaphore design, but eh // Full -> locked, empty -> unlocked lock chan struct{} } func newMutex() *mutex { return &mutex{lock: make(chan struct{}, 1)} } func (m *mutex) Lock() { m.lock <- struct{}{} } func (m *mutex) Unlock() { <-m.lock } func (m *mutex) TryLock() bool { select { case m.lock <- struct{}{}: return true default: return false } } // Batcher provides an API for accumulating items into a batch for processing. type Batcher interface { // Put adds items to the batcher. Put(interface{}) error // Get retrieves a batch from the batcher. This call will block until // one of the conditions for a "complete" batch is reached. Get() ([]interface{}, error) // Flush forcibly completes the batch currently being built Flush() error // Dispose will dispose of the batcher. Any calls to Put or Flush // will return ErrDisposed, calls to Get will return an error iff // there are no more ready batches. Dispose() // IsDisposed will determine if the batcher is disposed IsDisposed() bool } // ErrDisposed is the error returned for a disposed Batcher var ErrDisposed = errors.New("batcher: disposed") // CalculateBytes evaluates the number of bytes in an item added to a Batcher. type CalculateBytes func(interface{}) uint type basicBatcher struct { maxTime time.Duration maxItems uint maxBytes uint calculateBytes CalculateBytes disposed bool items []interface{} batchChan chan []interface{} availableBytes uint lock *mutex } // New creates a new Batcher using the provided arguments. // Batch readiness can be determined in three ways: // - Maximum number of bytes per batch // - Maximum number of items per batch // - Maximum amount of time waiting for a batch // Values of zero for one of these fields indicate they should not be // taken into account when evaluating the readiness of a batch. // This provides an ordering guarantee for any given thread such that if a // thread places two items in the batcher, Get will guarantee the first // item is returned before the second, whether before the second in the same // batch, or in an earlier batch. func New(maxTime time.Duration, maxItems, maxBytes, queueLen uint, calculate CalculateBytes) (Batcher, error) { if maxBytes > 0 && calculate == nil { return nil, errors.New("batcher: must provide CalculateBytes function") } return &basicBatcher{ maxTime: maxTime, maxItems: maxItems, maxBytes: maxBytes, calculateBytes: calculate, items: make([]interface{}, 0, maxItems), batchChan: make(chan []interface{}, queueLen), lock: newMutex(), }, nil } // Put adds items to the batcher. func (b *basicBatcher) Put(item interface{}) error { b.lock.Lock() if b.disposed { b.lock.Unlock() return ErrDisposed } b.items = append(b.items, item) if b.calculateBytes != nil { b.availableBytes += b.calculateBytes(item) } if b.ready() { // To guarantee ordering this MUST be in the lock, otherwise multiple // flush calls could be blocked at the same time, in which case // there's no guarantee each batch is placed into the channel in // the proper order b.flush() } b.lock.Unlock() return nil } // Get retrieves a batch from the batcher. This call will block until // one of the conditions for a "complete" batch is reached. func (b *basicBatcher) Get() ([]interface{}, error) { // Don't check disposed yet so any items remaining in the queue // will be returned properly. var timeout <-chan time.Time if b.maxTime > 0 { timeout = time.After(b.maxTime) } select { case items, ok := <-b.batchChan: // If there's something on the batch channel, we definitely want that. if !ok { return nil, ErrDisposed } return items, nil case <-timeout: // It's possible something was added to the channel after something // was received on the timeout channel, in which case that must // be returned first to satisfy our ordering guarantees. // We can't just grab the lock here in case the batch channel is full, // in which case a Put or Flush will be blocked and holding // onto the lock. In that case, there should be something on the // batch channel for { if b.lock.TryLock() { // We have a lock, try to read from channel first in case // something snuck in select { case items, ok := <-b.batchChan: b.lock.Unlock() if !ok { return nil, ErrDisposed } return items, nil default: } // If that is unsuccessful, nothing was added to the channel, // and the temp buffer can't have changed because of the lock, // so grab that items := b.items b.items = make([]interface{}, 0, b.maxItems) b.availableBytes = 0 b.lock.Unlock() return items, nil } else { // If we didn't get a lock, there are two cases: // 1) The batch chan is full. // 2) A Put or Flush temporarily has the lock. // In either case, trying to read something off the batch chan, // and going back to trying to get a lock if unsuccessful // works. select { case items, ok := <-b.batchChan: if !ok { return nil, ErrDisposed } return items, nil default: } } } } } // Flush forcibly completes the batch currently being built func (b *basicBatcher) Flush() error { // This is the same pattern as a Put b.lock.Lock() if b.disposed { b.lock.Unlock() return ErrDisposed } b.flush() b.lock.Unlock() return nil } // Dispose will dispose of the batcher. Any calls to Put or Flush // will return ErrDisposed, calls to Get will return an error iff // there are no more ready batches. Any items not flushed and retrieved // by a Get may or may not be retrievable after calling this. func (b *basicBatcher) Dispose() { for { if b.lock.TryLock() { // We've got a lock if b.disposed { b.lock.Unlock() return } b.disposed = true b.items = nil b.drainBatchChan() close(b.batchChan) b.lock.Unlock() } else { // Two cases here: // 1) Something is blocked and holding onto the lock // 2) Something temporarily has a lock // For case 1, we have to clear at least some space so the blocked // Put/Flush can release the lock. For case 2, nothing bad // will happen here b.drainBatchChan() } } } // IsDisposed will determine if the batcher is disposed func (b *basicBatcher) IsDisposed() bool { b.lock.Lock() disposed := b.disposed b.lock.Unlock() return disposed } // flush adds the batch currently being built to the queue of completed batches. // flush is not threadsafe, so should be synchronized externally. func (b *basicBatcher) flush() { b.batchChan <- b.items b.items = make([]interface{}, 0, b.maxItems) b.availableBytes = 0 } func (b *basicBatcher) ready() bool { if b.maxItems != 0 && uint(len(b.items)) >= b.maxItems { return true } if b.maxBytes != 0 && b.availableBytes >= b.maxBytes { return true } return false } func (b *basicBatcher) drainBatchChan() { for { select { case <-b.batchChan: default: return } } } ================================================ FILE: batcher/batcher_test.go ================================================ /* Copyright 2015 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package batcher import ( "sync" "testing" "time" "github.com/stretchr/testify/assert" ) func TestNoCalculateBytes(t *testing.T) { _, err := New(0, 0, 100, 5, nil) assert.Error(t, err) } func TestMaxItems(t *testing.T) { assert := assert.New(t) b, err := New(0, 100, 100000, 10, func(str interface{}) uint { return uint(len(str.(string))) }) assert.Nil(err) for i := 0; i < 1000; i++ { assert.Nil(b.Put("foo bar baz")) } batch, err := b.Get() assert.Len(batch, 100) assert.Nil(err) } func TestMaxBytes(t *testing.T) { assert := assert.New(t) b, err := New(0, 10000, 100, 10, func(str interface{}) uint { return uint(len(str.(string))) }) assert.Nil(err) go func() { for i := 0; i < 1000; i++ { b.Put("a") } }() batch, err := b.Get() assert.Len(batch, 100) assert.Nil(err) } func TestMaxTime(t *testing.T) { assert := assert.New(t) b, err := New(time.Millisecond*200, 100000, 100000, 10, func(str interface{}) uint { return uint(len(str.(string))) }, ) assert.Nil(err) go func() { for i := 0; i < 10000; i++ { b.Put("a") time.Sleep(time.Millisecond) } }() before := time.Now() batch, err := b.Get() // This delta is normally 1-3 ms but running tests in CI with -race causes // this to run much slower. For now, just bump up the threshold. assert.InDelta(200, time.Since(before).Seconds()*1000, 100) assert.True(len(batch) > 0) assert.Nil(err) } func TestFlush(t *testing.T) { assert := assert.New(t) b, err := New(0, 10, 10, 10, func(str interface{}) uint { return uint(len(str.(string))) }) assert.Nil(err) b.Put("a") wait := make(chan bool) go func() { batch, err := b.Get() assert.Equal([]interface{}{"a"}, batch) assert.Nil(err) wait <- true }() b.Flush() <-wait } func TestMultiConsumer(t *testing.T) { assert := assert.New(t) b, err := New(0, 100, 100000, 10, func(str interface{}) uint { return uint(len(str.(string))) }) assert.Nil(err) var wg sync.WaitGroup wg.Add(5) for i := 0; i < 5; i++ { go func() { batch, err := b.Get() assert.Len(batch, 100) assert.Nil(err) wg.Done() }() } go func() { for i := 0; i < 500; i++ { b.Put("a") } }() wg.Wait() } func TestDispose(t *testing.T) { assert := assert.New(t) b, err := New(1, 2, 100000, 2, func(str interface{}) uint { return uint(len(str.(string))) }) assert.Nil(err) b.Put("a") b.Put("b") b.Put("c") batch1, err := b.Get() assert.Equal([]interface{}{"a", "b"}, batch1) assert.Nil(err) batch2, err := b.Get() assert.Equal([]interface{}{"c"}, batch2) assert.Nil(err) b.Put("d") b.Put("e") b.Put("f") b.Dispose() _, err = b.Get() assert.Equal(ErrDisposed, err) assert.Equal(ErrDisposed, b.Put("j")) assert.Equal(ErrDisposed, b.Flush()) } func TestIsDisposed(t *testing.T) { assert := assert.New(t) b, err := New(0, 10, 10, 10, func(str interface{}) uint { return uint(len(str.(string))) }) assert.Nil(err) assert.False(b.IsDisposed()) b.Dispose() assert.True(b.IsDisposed()) } ================================================ FILE: bitarray/and.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray func andSparseWithSparseBitArray(sba, other *sparseBitArray) BitArray { max := maxInt64(int64(len(sba.indices)), int64(len(other.indices))) indices := make(uintSlice, 0, max) blocks := make(blocks, 0, max) selfIndex := 0 otherIndex := 0 var resultBlock block // move through the array and compare the blocks if they happen to // intersect for { if selfIndex == len(sba.indices) || otherIndex == len(other.indices) { // One of the arrays has been exhausted. We don't need // to compare anything else for a bitwise and; the // operation is complete. break } selfValue := sba.indices[selfIndex] otherValue := other.indices[otherIndex] switch { case otherValue < selfValue: // The `sba` bitarray has a block with a index position // greater than us. We want to compare with that block // if possible, so move our `other` index closer to that // block's index. otherIndex++ case otherValue > selfValue: // This is the exact logical inverse of the above case. selfIndex++ default: // Here, our indices match for both `sba` and `other`. // Time to do the bitwise AND operation and add a block // to our result list if the block has values in it. resultBlock = sba.blocks[selfIndex].and(other.blocks[otherIndex]) if resultBlock > 0 { indices = append(indices, selfValue) blocks = append(blocks, resultBlock) } selfIndex++ otherIndex++ } } return &sparseBitArray{ indices: indices, blocks: blocks, } } func andSparseWithDenseBitArray(sba *sparseBitArray, other *bitArray) BitArray { if other.IsEmpty() { return newSparseBitArray() } // Use a duplicate of the sparse array to store the results of the // bitwise and. More memory-efficient than allocating a new dense bit // array. // // NOTE: this could be faster if we didn't copy the values as well // (since they are overwritten), but I don't want this method to know // too much about the internals of sparseBitArray. The performance hit // should be minor anyway. ba := sba.copy() // Run through the sparse array and attempt comparisons wherever // possible against the dense bit array. for selfIndex, selfValue := range ba.indices { if selfValue >= uint64(len(other.blocks)) { // The dense bit array has been exhausted. This is the // annoying case because we have to trim the sparse // array to the size of the dense array. ba.blocks = ba.blocks[:selfIndex-1] ba.indices = ba.indices[:selfIndex-1] // once this is done, there are no more comparisons. // We're ready to return break } ba.blocks[selfIndex] = ba.blocks[selfIndex].and( other.blocks[selfValue]) } // Ensure any zero'd blocks in the resulting sparse // array are deleted for i := 0; i < len(ba.blocks); i++ { if ba.blocks[i] == 0 { ba.blocks.deleteAtIndex(int64(i)) ba.indices.deleteAtIndex(int64(i)) i-- } } return ba } func andDenseWithDenseBitArray(dba, other *bitArray) BitArray { min := minUint64(uint64(len(dba.blocks)), uint64(len(other.blocks))) ba := newBitArray(min * s) for i := uint64(0); i < min; i++ { ba.blocks[i] = dba.blocks[i].and(other.blocks[i]) } ba.setLowest() ba.setHighest() return ba } ================================================ FILE: bitarray/and_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray import ( "testing" "github.com/stretchr/testify/assert" ) // checkBit is a helper method for these unit tests func checkBit(t *testing.T, ba BitArray, position uint64, expected bool) { ok, err := ba.GetBit(position) if assert.NoError(t, err) { if expected { assert.True(t, ok, "Bitarray at position %d should be set", position) } else { assert.False(t, ok, "Bitarray at position %d should be unset", position) } } } func TestAndSparseWithSparseBitArray(t *testing.T) { sba := newSparseBitArray() other := newSparseBitArray() // bits for which only one of the arrays is set sba.SetBit(3) sba.SetBit(280) other.SetBit(9) other.SetBit(100) sba.SetBit(1000) other.SetBit(1001) // bits for which both arrays are set sba.SetBit(1) other.SetBit(1) sba.SetBit(2680) other.SetBit(2680) sba.SetBit(30) other.SetBit(30) ba := andSparseWithSparseBitArray(sba, other) // Bits in both checkBit(t, ba, 1, true) checkBit(t, ba, 30, true) checkBit(t, ba, 2680, true) // Bits in sba but not other checkBit(t, ba, 3, false) checkBit(t, ba, 280, false) checkBit(t, ba, 1000, false) // Bits in other but not sba checkBit(t, ba, 9, false) checkBit(t, ba, 100, false) checkBit(t, ba, 2, false) nums := ba.ToNums() assert.Equal(t, []uint64{1, 30, 2680}, nums) } func TestAndSparseWithDenseBitArray(t *testing.T) { sba := newSparseBitArray() other := newBitArray(300) other.SetBit(1) sba.SetBit(1) other.SetBit(150) sba.SetBit(150) sba.SetBit(155) other.SetBit(156) sba.SetBit(300) other.SetBit(300) ba := andSparseWithDenseBitArray(sba, other) // Bits in both checkBit(t, ba, 1, true) checkBit(t, ba, 150, true) checkBit(t, ba, 300, true) // Bits in sba but not other checkBit(t, ba, 155, false) // Bits in other but not sba checkBit(t, ba, 156, false) } // Make sure that the sparse array is trimmed correctly if compared against a // smaller dense bit array. func TestAndSparseWithSmallerDenseBitArray(t *testing.T) { sba := newSparseBitArray() other := newBitArray(512) other.SetBit(1) sba.SetBit(1) other.SetBit(150) sba.SetBit(150) sba.SetBit(155) sba.SetBit(500) other.SetBit(128) sba.SetBit(1500) sba.SetBit(1200) ba := andSparseWithDenseBitArray(sba, other) // Bits in both checkBit(t, ba, 1, true) checkBit(t, ba, 150, true) // Bits in sba but not other checkBit(t, ba, 155, false) checkBit(t, ba, 500, false) checkBit(t, ba, 1200, false) checkBit(t, ba, 1500, false) // Bits in other but not sba checkBit(t, ba, 128, false) } func TestAndDenseWithDenseBitArray(t *testing.T) { dba := newBitArray(1000) other := newBitArray(2000) dba.SetBit(1) other.SetBit(18) dba.SetBit(222) other.SetBit(222) other.SetBit(1501) ba := andDenseWithDenseBitArray(dba, other) checkBit(t, ba, 0, false) checkBit(t, ba, 1, false) checkBit(t, ba, 3, false) checkBit(t, ba, 18, false) checkBit(t, ba, 222, true) // check that the ba is the minimum of the size of `dba` and `other` // (dense bitarrays return an error on an out-of-bounds access) _, err := ba.GetBit(1500) assert.Equal(t, OutOfRangeError(1500), err) _, err = ba.GetBit(1501) assert.Equal(t, OutOfRangeError(1501), err) } func TestAndSparseWithEmptySparse(t *testing.T) { sba := newSparseBitArray() other := newSparseBitArray() sba.SetBit(5) ba := andSparseWithSparseBitArray(sba, other) checkBit(t, ba, 0, false) checkBit(t, ba, 5, false) checkBit(t, ba, 100, false) } func TestAndSparseWithEmptyDense(t *testing.T) { sba := newSparseBitArray() other := newBitArray(1000) sba.SetBit(5) ba := andSparseWithDenseBitArray(sba, other) checkBit(t, ba, 5, false) sba.Reset() other.SetBit(5) ba = andSparseWithDenseBitArray(sba, other) checkBit(t, ba, 5, false) } func TestAndDenseWithEmptyDense(t *testing.T) { dba := newBitArray(1000) other := newBitArray(1000) dba.SetBit(5) ba := andDenseWithDenseBitArray(dba, other) checkBit(t, ba, 5, false) dba.Reset() other.SetBit(5) ba = andDenseWithDenseBitArray(dba, other) checkBit(t, ba, 5, false) } ================================================ FILE: bitarray/bitarray.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package bitarray implements a bit array. Useful for tracking bool type values in a space efficient way. This is *NOT* a threadsafe package. */ package bitarray import "math/bits" // bitArray is a struct that maintains state of a bit array. type bitArray struct { blocks []block lowest uint64 highest uint64 anyset bool } func getIndexAndRemainder(k uint64) (uint64, uint64) { return k / s, k % s } func (ba *bitArray) setLowest() { for i := uint64(0); i < uint64(len(ba.blocks)); i++ { if ba.blocks[i] == 0 { continue } pos := ba.blocks[i].findRightPosition() ba.lowest = (i * s) + pos ba.anyset = true return } ba.anyset = false ba.lowest = 0 ba.highest = 0 } func (ba *bitArray) setHighest() { for i := len(ba.blocks) - 1; i >= 0; i-- { if ba.blocks[i] == 0 { continue } pos := ba.blocks[i].findLeftPosition() ba.highest = (uint64(i) * s) + pos ba.anyset = true return } ba.anyset = false ba.highest = 0 ba.lowest = 0 } // capacity returns the total capacity of the bit array. func (ba *bitArray) Capacity() uint64 { return uint64(len(ba.blocks)) * s } // ToNums converts this bitarray to a list of numbers contained within it. func (ba *bitArray) ToNums() []uint64 { nums := make([]uint64, 0, ba.highest-ba.lowest/4) for i, block := range ba.blocks { block.toNums(uint64(i)*s, &nums) } return nums } // SetBit sets a bit at the given index to true. func (ba *bitArray) SetBit(k uint64) error { if k >= ba.Capacity() { return OutOfRangeError(k) } if !ba.anyset { ba.lowest = k ba.highest = k ba.anyset = true } else { if k < ba.lowest { ba.lowest = k } else if k > ba.highest { ba.highest = k } } i, pos := getIndexAndRemainder(k) ba.blocks[i] = ba.blocks[i].insert(pos) return nil } // GetBit returns a bool indicating if the value at the given // index has been set. func (ba *bitArray) GetBit(k uint64) (bool, error) { if k >= ba.Capacity() { return false, OutOfRangeError(k) } i, pos := getIndexAndRemainder(k) result := ba.blocks[i]&block(1<>= fromOffset } for block != 0 { trailing := bits.TrailingZeros64(uint64(block)) if isFirstBlock { results[resultSize] = uint64(trailing) + (blockIndex << 6) + fromOffset } else { results[resultSize] = uint64(trailing) + (blockIndex << 6) } resultSize++ if resultSize == cap(results) { return results[:resultSize] } // Clear the bit we just added to the result, which is the last bit set in the block. Ex.: // block 01001100 // ^block 10110011 // (^block) + 1 10110100 // block & (^block) + 1 00000100 // block ^ mask 01001000 mask := block & ((^block) + 1) block = block ^ mask } } return results[:resultSize] } // ClearBit will unset a bit at the given index if it is set. func (ba *bitArray) ClearBit(k uint64) error { if k >= ba.Capacity() { return OutOfRangeError(k) } if !ba.anyset { // nothing is set, might as well bail return nil } i, pos := getIndexAndRemainder(k) ba.blocks[i] &^= block(1 << pos) if k == ba.highest { ba.setHighest() } else if k == ba.lowest { ba.setLowest() } return nil } // Count returns the number of set bits in this array. func (ba *bitArray) Count() int { count := 0 for _, block := range ba.blocks { count += bits.OnesCount64(uint64(block)) } return count } // Or will bitwise or two bit arrays and return a new bit array // representing the result. func (ba *bitArray) Or(other BitArray) BitArray { if dba, ok := other.(*bitArray); ok { return orDenseWithDenseBitArray(ba, dba) } return orSparseWithDenseBitArray(other.(*sparseBitArray), ba) } // And will bitwise and two bit arrays and return a new bit array // representing the result. func (ba *bitArray) And(other BitArray) BitArray { if dba, ok := other.(*bitArray); ok { return andDenseWithDenseBitArray(ba, dba) } return andSparseWithDenseBitArray(other.(*sparseBitArray), ba) } // Nand will return the result of doing a bitwise and not of the bit array // with the other bit array on each block. func (ba *bitArray) Nand(other BitArray) BitArray { if dba, ok := other.(*bitArray); ok { return nandDenseWithDenseBitArray(ba, dba) } return nandDenseWithSparseBitArray(ba, other.(*sparseBitArray)) } // Reset clears out the bit array. func (ba *bitArray) Reset() { for i := uint64(0); i < uint64(len(ba.blocks)); i++ { ba.blocks[i] &= block(0) } ba.anyset = false } // Equals returns a bool indicating if these two bit arrays are equal. func (ba *bitArray) Equals(other BitArray) bool { if other.Capacity() == 0 && ba.highest > 0 { return false } if other.Capacity() == 0 && !ba.anyset { return true } var selfIndex uint64 for iter := other.Blocks(); iter.Next(); { toIndex, otherBlock := iter.Value() if toIndex > selfIndex { for i := selfIndex; i < toIndex; i++ { if ba.blocks[i] > 0 { return false } } } selfIndex = toIndex if !ba.blocks[selfIndex].equals(otherBlock) { return false } selfIndex++ } lastIndex, _ := getIndexAndRemainder(ba.highest) if lastIndex >= selfIndex { return false } return true } // Intersects returns a bool indicating if the supplied bitarray intersects // this bitarray. This will check for intersection up to the length of the supplied // bitarray. If the supplied bitarray is longer than this bitarray, this // function returns false. func (ba *bitArray) Intersects(other BitArray) bool { if other.Capacity() > ba.Capacity() { return false } if sba, ok := other.(*sparseBitArray); ok { return ba.intersectsSparseBitArray(sba) } return ba.intersectsDenseBitArray(other.(*bitArray)) } // Blocks will return an iterator over this bit array. func (ba *bitArray) Blocks() Iterator { return newBitArrayIterator(ba) } func (ba *bitArray) IsEmpty() bool { return !ba.anyset } // complement flips all bits in this array. func (ba *bitArray) complement() { for i := uint64(0); i < uint64(len(ba.blocks)); i++ { ba.blocks[i] = ^ba.blocks[i] } ba.setLowest() if ba.anyset { ba.setHighest() } } func (ba *bitArray) intersectsSparseBitArray(other *sparseBitArray) bool { for i, index := range other.indices { if !ba.blocks[index].intersects(other.blocks[i]) { return false } } return true } func (ba *bitArray) intersectsDenseBitArray(other *bitArray) bool { for i, block := range other.blocks { if !ba.blocks[i].intersects(block) { return false } } return true } func (ba *bitArray) copy() BitArray { blocks := make(blocks, len(ba.blocks)) copy(blocks, ba.blocks) return &bitArray{ blocks: blocks, lowest: ba.lowest, highest: ba.highest, anyset: ba.anyset, } } // newBitArray returns a new dense BitArray at the specified size. This is a // separate private constructor so unit tests don't have to constantly cast the // BitArray interface to the concrete type. func newBitArray(size uint64, args ...bool) *bitArray { i, r := getIndexAndRemainder(size) if r > 0 { i++ } ba := &bitArray{ blocks: make([]block, i), anyset: false, } if len(args) > 0 && args[0] == true { for i := uint64(0); i < uint64(len(ba.blocks)); i++ { ba.blocks[i] = maximumBlock } ba.lowest = 0 ba.highest = i*s - 1 ba.anyset = true } return ba } // NewBitArray returns a new BitArray at the specified size. The // optional arg denotes whether this bitarray should be set to the // bitwise complement of the empty array, ie. sets all bits. func NewBitArray(size uint64, args ...bool) BitArray { return newBitArray(size, args...) } ================================================ FILE: bitarray/bitarray_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestBitOperations(t *testing.T) { ba := newBitArray(10) err := ba.SetBit(5) if err != nil { t.Fatal(err) } result, err := ba.GetBit(5) if err != nil { t.Fatal(err) } if !result { t.Errorf(`Expected true at position: %d`, 5) } result, err = ba.GetBit(3) if err != nil { t.Fatal(err) } if result { t.Errorf(`Expected false at position %d`, 3) } err = ba.ClearBit(5) if err != nil { t.Fatal(err) } result, err = ba.GetBit(5) if err != nil { t.Fatal(err) } if result { t.Errorf(`Expected false at position: %d`, 5) } ba = newBitArray(24) err = ba.SetBit(16) if err != nil { t.Fatal(err) } result, err = ba.GetBit(16) if err != nil { t.Fatal(err) } if !result { t.Errorf(`Expected true at position: %d`, 16) } } func TestDuplicateOperation(t *testing.T) { ba := newBitArray(10) err := ba.SetBit(5) if err != nil { t.Fatal(err) } err = ba.SetBit(5) if err != nil { t.Fatal(err) } result, err := ba.GetBit(5) if err != nil { t.Fatal(err) } if !result { t.Errorf(`Expected true at position: %d`, 5) } err = ba.ClearBit(5) if err != nil { t.Fatal(err) } err = ba.ClearBit(5) if err != nil { t.Fatal(err) } result, err = ba.GetBit(5) if err != nil { t.Fatal(err) } if result { t.Errorf(`Expected false at position: %d`, 5) } } func TestOutOfBounds(t *testing.T) { ba := newBitArray(4) err := ba.SetBit(s + 1) if _, ok := err.(OutOfRangeError); !ok { t.Errorf(`Expected out of range error.`) } _, err = ba.GetBit(s + 1) if _, ok := err.(OutOfRangeError); !ok { t.Errorf(`Expected out of range error.`) } } func TestIsEmpty(t *testing.T) { ba := newBitArray(10) assert.True(t, ba.IsEmpty()) ba.SetBit(5) assert.False(t, ba.IsEmpty()) } func TestCount(t *testing.T) { ba := newBitArray(500) assert.Equal(t, 0, ba.Count()) require.NoError(t, ba.SetBit(0)) assert.Equal(t, 1, ba.Count()) require.NoError(t, ba.SetBit(40)) require.NoError(t, ba.SetBit(64)) require.NoError(t, ba.SetBit(100)) require.NoError(t, ba.SetBit(200)) require.NoError(t, ba.SetBit(469)) require.NoError(t, ba.SetBit(500)) assert.Equal(t, 7, ba.Count()) require.NoError(t, ba.ClearBit(200)) assert.Equal(t, 6, ba.Count()) ba.Reset() assert.Equal(t, 0, ba.Count()) } func TestClear(t *testing.T) { ba := newBitArray(10) err := ba.SetBit(5) if err != nil { t.Fatal(err) } err = ba.SetBit(9) if err != nil { t.Fatal(err) } ba.Reset() assert.False(t, ba.anyset) result, err := ba.GetBit(5) if err != nil { t.Fatal(err) } if result { t.Errorf(`BA not reset.`) } result, err = ba.GetBit(9) if err != nil { t.Fatal(err) } if result { t.Errorf(`BA not reset.`) } } func BenchmarkGetBit(b *testing.B) { numItems := uint64(168000) ba := newBitArray(numItems) for i := uint64(0); i < numItems; i++ { ba.SetBit(i) } b.ResetTimer() for i := 0; i < b.N; i++ { for j := uint64(0); j < numItems; j++ { ba.GetBit(j) } } } func TestGetSetBits(t *testing.T) { ba := newBitArray(1000) buf := make([]uint64, 0, 5) require.NoError(t, ba.SetBit(1)) require.NoError(t, ba.SetBit(4)) require.NoError(t, ba.SetBit(8)) require.NoError(t, ba.SetBit(63)) require.NoError(t, ba.SetBit(64)) require.NoError(t, ba.SetBit(200)) require.NoError(t, ba.SetBit(1000)) assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil)) assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{})) assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf)) assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf)) assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf)) assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf)) require.NoError(t, ba.ClearBit(4)) require.NoError(t, ba.ClearBit(64)) assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf)) assert.Empty(t, ba.GetSetBits(1001, buf)) ba.Reset() assert.Empty(t, ba.GetSetBits(0, buf)) } func BenchmarkGetSetBits(b *testing.B) { numItems := uint64(168000) ba := newBitArray(numItems) for i := uint64(0); i < numItems; i++ { if i%13 == 0 || i%5 == 0 { require.NoError(b, ba.SetBit(i)) } } buf := make([]uint64, 0, ba.Capacity()) b.ResetTimer() for i := 0; i < b.N; i++ { ba.GetSetBits(0, buf) } } func TestEquality(t *testing.T) { ba := newBitArray(s + 1) other := newBitArray(s + 1) if !ba.Equals(other) { t.Errorf(`Expected equality.`) } ba.SetBit(s + 1) other.SetBit(s + 1) if !ba.Equals(other) { t.Errorf(`Expected equality.`) } other.SetBit(0) if ba.Equals(other) { t.Errorf(`Expected inequality.`) } } func BenchmarkEquality(b *testing.B) { ba := newBitArray(160000) other := newBitArray(ba.Capacity()) b.ResetTimer() for i := 0; i < b.N; i++ { ba.Equals(other) } } func TestIntersects(t *testing.T) { ba := newBitArray(10) other := newBitArray(ba.Capacity()) ba.SetBit(1) ba.SetBit(2) other.SetBit(1) if !ba.Intersects(other) { t.Errorf(`Is intersecting.`) } other.SetBit(5) if ba.Intersects(other) { t.Errorf(`Is not intersecting.`) } other = newBitArray(ba.Capacity() + 1) other.SetBit(1) if ba.Intersects(other) { t.Errorf(`Is not intersecting.`) } } func BenchmarkIntersects(b *testing.B) { ba := newBitArray(162432) other := newBitArray(ba.Capacity()) ba.SetBit(159999) other.SetBit(159999) b.ResetTimer() for i := 0; i < b.N; i++ { ba.Intersects(other) } } func TestComplement(t *testing.T) { ba := newBitArray(10) ba.SetBit(5) ba.complement() if ok, _ := ba.GetBit(5); ok { t.Errorf(`Expected clear.`) } if ok, _ := ba.GetBit(4); !ok { t.Errorf(`Expected set.`) } } func BenchmarkComplement(b *testing.B) { ba := newBitArray(160000) b.ResetTimer() for i := 0; i < b.N; i++ { ba.complement() } } func TestSetHighestLowest(t *testing.T) { ba := newBitArray(10) assert.False(t, ba.anyset) assert.Equal(t, uint64(0), ba.lowest) assert.Equal(t, uint64(0), ba.highest) ba.SetBit(5) assert.True(t, ba.anyset) assert.Equal(t, uint64(5), ba.lowest) assert.Equal(t, uint64(5), ba.highest) ba.SetBit(8) assert.Equal(t, uint64(5), ba.lowest) assert.Equal(t, uint64(8), ba.highest) } func TestGetBitAtCapacity(t *testing.T) { ba := newBitArray(s * 2) _, err := ba.GetBit(s * 2) assert.Error(t, err) } func TestSetBitAtCapacity(t *testing.T) { ba := newBitArray(s * 2) err := ba.SetBit(s * 2) assert.Error(t, err) } func TestClearBitAtCapacity(t *testing.T) { ba := newBitArray(s * 2) err := ba.ClearBit(s * 2) assert.Error(t, err) } func TestClearHighestLowest(t *testing.T) { ba := newBitArray(10) ba.SetBit(5) ba.ClearBit(5) assert.False(t, ba.anyset) assert.Equal(t, uint64(0), ba.lowest) assert.Equal(t, uint64(0), ba.highest) ba.SetBit(3) ba.SetBit(5) ba.SetBit(7) ba.ClearBit(7) assert.True(t, ba.anyset) assert.Equal(t, uint64(5), ba.highest) assert.Equal(t, uint64(3), ba.lowest) ba.SetBit(7) ba.ClearBit(3) assert.True(t, ba.anyset) assert.Equal(t, uint64(5), ba.lowest) assert.Equal(t, uint64(7), ba.highest) ba.ClearBit(7) assert.True(t, ba.anyset) assert.Equal(t, uint64(5), ba.lowest) assert.Equal(t, uint64(5), ba.highest) ba.ClearBit(5) assert.False(t, ba.anyset) assert.Equal(t, uint64(0), ba.lowest) assert.Equal(t, uint64(0), ba.highest) } func TestComplementResetsBounds(t *testing.T) { ba := newBitArray(5) ba.complement() assert.True(t, ba.anyset) assert.Equal(t, uint64(0), ba.lowest) assert.Equal(t, uint64(s-1), ba.highest) } func TestBitArrayIntersectsSparse(t *testing.T) { ba := newBitArray(s * 2) cba := newSparseBitArray() assert.True(t, ba.Intersects(cba)) cba.SetBit(5) assert.False(t, ba.Intersects(cba)) ba.SetBit(5) assert.True(t, ba.Intersects(cba)) cba.SetBit(s + 1) assert.False(t, ba.Intersects(cba)) ba.SetBit(s + 1) assert.True(t, ba.Intersects(cba)) } func TestBitArrayEqualsSparse(t *testing.T) { ba := newBitArray(s * 2) cba := newSparseBitArray() assert.True(t, ba.Equals(cba)) ba.SetBit(5) assert.False(t, ba.Equals(cba)) cba.SetBit(5) assert.True(t, ba.Equals(cba)) ba.SetBit(s + 1) assert.False(t, ba.Equals(cba)) cba.SetBit(s + 1) assert.True(t, ba.Equals(cba)) } func TestConstructorSetBitArray(t *testing.T) { ba := newBitArray(8, true) result, err := ba.GetBit(7) assert.Nil(t, err) assert.True(t, result) assert.Equal(t, s-1, ba.highest) assert.Equal(t, uint64(0), ba.lowest) assert.True(t, ba.anyset) } func TestCopyBitArray(t *testing.T) { ba := newBitArray(10) ba.SetBit(5) ba.SetBit(1) result := ba.copy().(*bitArray) assert.Equal(t, ba.anyset, result.anyset) assert.Equal(t, ba.lowest, result.lowest) assert.Equal(t, ba.highest, result.highest) assert.Equal(t, ba.blocks, result.blocks) } func BenchmarkDenseIntersectsCompressed(b *testing.B) { numBits := uint64(162432) ba := newBitArray(numBits) other := newSparseBitArray() for i := uint64(0); i < numBits; i++ { ba.SetBit(i) other.SetBit(i) } b.ResetTimer() for i := 0; i < b.N; i++ { ba.intersectsSparseBitArray(other) } } func TestBitArrayToNums(t *testing.T) { ba := newBitArray(s * 2) ba.SetBit(s - 1) ba.SetBit(s + 1) expected := []uint64{s - 1, s + 1} result := ba.ToNums() assert.Equal(t, expected, result) } func BenchmarkBitArrayToNums(b *testing.B) { numItems := uint64(1000) ba := newBitArray(numItems) for i := uint64(0); i < numItems; i++ { ba.SetBit(i) } b.ResetTimer() for i := 0; i < b.N; i++ { ba.ToNums() } } ================================================ FILE: bitarray/bitmap.go ================================================ /* Copyright (c) 2016, Theodore Butler All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ // Package bitmap contains bitmaps of length 32 and 64 for tracking bool // values without the need for arrays or hashing. package bitarray // Bitmap32 tracks 32 bool values within a uint32 type Bitmap32 uint32 // SetBit returns a Bitmap32 with the bit at the given position set to 1 func (b Bitmap32) SetBit(pos uint) Bitmap32 { return b | (1 << pos) } // ClearBit returns a Bitmap32 with the bit at the given position set to 0 func (b Bitmap32) ClearBit(pos uint) Bitmap32 { return b & ^(1 << pos) } // GetBit returns true if the bit at the given position in the Bitmap32 is 1 func (b Bitmap32) GetBit(pos uint) bool { return (b & (1 << pos)) != 0 } // PopCount returns the amount of bits set to 1 in the Bitmap32 func (b Bitmap32) PopCount() int { // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel b -= (b >> 1) & 0x55555555 b = (b>>2)&0x33333333 + b&0x33333333 b += b >> 4 b &= 0x0f0f0f0f b *= 0x01010101 return int(byte(b >> 24)) } // Bitmap64 tracks 64 bool values within a uint64 type Bitmap64 uint64 // SetBit returns a Bitmap64 with the bit at the given position set to 1 func (b Bitmap64) SetBit(pos uint) Bitmap64 { return b | (1 << pos) } // ClearBit returns a Bitmap64 with the bit at the given position set to 0 func (b Bitmap64) ClearBit(pos uint) Bitmap64 { return b & ^(1 << pos) } // GetBit returns true if the bit at the given position in the Bitmap64 is 1 func (b Bitmap64) GetBit(pos uint) bool { return (b & (1 << pos)) != 0 } // PopCount returns the amount of bits set to 1 in the Bitmap64 func (b Bitmap64) PopCount() int { // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel b -= (b >> 1) & 0x5555555555555555 b = (b>>2)&0x3333333333333333 + b&0x3333333333333333 b += b >> 4 b &= 0x0f0f0f0f0f0f0f0f b *= 0x0101010101010101 return int(byte(b >> 56)) } ================================================ FILE: bitarray/bitmap_test.go ================================================ /* Copyright (c) 2016, Theodore Butler All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package bitarray import ( "testing" "github.com/stretchr/testify/assert" ) func TestBitmap32_PopCount(t *testing.T) { b := []uint32{ uint32(0x55555555), // 0x55555555 = 01010101 01010101 01010101 01010101 uint32(0x33333333), // 0x33333333 = 00110011 00110011 00110011 00110011 uint32(0x0F0F0F0F), // 0x0F0F0F0F = 00001111 00001111 00001111 00001111 uint32(0x00FF00FF), // 0x00FF00FF = 00000000 11111111 00000000 11111111 uint32(0x0000FFFF), // 0x0000FFFF = 00000000 00000000 11111111 11111111 } for _, x := range b { assert.Equal(t, 16, Bitmap32(x).PopCount()) } } func TestBitmap64_PopCount(t *testing.T) { b := []uint64{ uint64(0x5555555555555555), uint64(0x3333333333333333), uint64(0x0F0F0F0F0F0F0F0F), uint64(0x00FF00FF00FF00FF), uint64(0x0000FFFF0000FFFF), } for _, x := range b { assert.Equal(t, 32, Bitmap64(x).PopCount()) } } func TestBitmap32_SetBit(t *testing.T) { m := Bitmap32(0) assert.Equal(t, Bitmap32(0x4), m.SetBit(2)) } func TestBitmap32_ClearBit(t *testing.T) { m := Bitmap32(0x4) assert.Equal(t, Bitmap32(0), m.ClearBit(2)) } func TestBitmap32_zGetBit(t *testing.T) { m := Bitmap32(0x55555555) assert.Equal(t, true, m.GetBit(2)) } func TestBitmap64_SetBit(t *testing.T) { m := Bitmap64(0) assert.Equal(t, Bitmap64(0x4), m.SetBit(2)) } func TestBitmap64_ClearBit(t *testing.T) { m := Bitmap64(0x4) assert.Equal(t, Bitmap64(0), m.ClearBit(2)) } func TestBitmap64_GetBit(t *testing.T) { m := Bitmap64(0x55555555) assert.Equal(t, true, m.GetBit(2)) } func BenchmarkBitmap32_PopCount(b *testing.B) { m := Bitmap32(0x33333333) b.ResetTimer() for i := b.N; i > 0; i-- { m.PopCount() } } func BenchmarkBitmap64_PopCount(b *testing.B) { m := Bitmap64(0x3333333333333333) b.ResetTimer() for i := b.N; i > 0; i-- { m.PopCount() } } ================================================ FILE: bitarray/block.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray import ( "fmt" "unsafe" ) // block defines how we split apart the bit array. This also determines the size // of s. This can be changed to any unsigned integer type: uint8, uint16, // uint32, and so on. type block uint64 // s denotes the size of any element in the block array. // For a block of uint64, s will be equal to 64 // For a block of uint32, s will be equal to 32 // and so on... const s = uint64(unsafe.Sizeof(block(0)) * 8) // maximumBlock represents a block of all 1s and is used in the constructors. const maximumBlock = block(0) | ^block(0) func (b block) toNums(offset uint64, nums *[]uint64) { for i := uint64(0); i < s; i++ { if b&block(1< 0 { *nums = append(*nums, i+offset) } } } func (b block) findLeftPosition() uint64 { for i := s - 1; i < s; i-- { test := block(1 << i) if b&test == test { return i } } return s } func (b block) findRightPosition() uint64 { for i := uint64(0); i < s; i++ { test := block(1 << i) if b&test == test { return i } } return s } func (b block) insert(position uint64) block { return b | block(1< selfValue: // Here, the sba array has blocks that the other array doesn't // have. In this case, we just copy exactly the sba array values indices = append(indices, selfValue) blocks = append(blocks, sba.blocks[selfIndex]) // This is the exact logical inverse of the above case. selfIndex++ default: // Here, our indices match for both `sba` and `other`. // Time to do the bitwise AND operation and add a block // to our result list if the block has values in it. resultBlock = sba.blocks[selfIndex].nand(other.blocks[otherIndex]) if resultBlock > 0 { indices = append(indices, selfValue) blocks = append(blocks, resultBlock) } selfIndex++ otherIndex++ } } return &sparseBitArray{ indices: indices, blocks: blocks, } } func nandSparseWithDenseBitArray(sba *sparseBitArray, other *bitArray) BitArray { // Since nand is non-commutative, the resulting array should be sparse, // and the same length or less than the sparse array indices := make(uintSlice, 0, len(sba.indices)) blocks := make(blocks, 0, len(sba.indices)) var resultBlock block // Loop through the sparse array and match it with the dense array. for selfIndex, selfValue := range sba.indices { if selfValue >= uint64(len(other.blocks)) { // Since the dense array is exhausted, just copy over the data // from the sparse array resultBlock = sba.blocks[selfIndex] indices = append(indices, selfValue) blocks = append(blocks, resultBlock) continue } resultBlock = sba.blocks[selfIndex].nand(other.blocks[selfValue]) if resultBlock > 0 { indices = append(indices, selfValue) blocks = append(blocks, resultBlock) } } return &sparseBitArray{ indices: indices, blocks: blocks, } } func nandDenseWithSparseBitArray(sba *bitArray, other *sparseBitArray) BitArray { // Since nand is non-commutative, the resulting array should be dense, // and the same length or less than the dense array tmp := sba.copy() ret := tmp.(*bitArray) // Loop through the other array and match it with the sba array. for otherIndex, otherValue := range other.indices { if otherValue >= uint64(len(ret.blocks)) { break } ret.blocks[otherValue] = sba.blocks[otherValue].nand(other.blocks[otherIndex]) } ret.setLowest() ret.setHighest() return ret } func nandDenseWithDenseBitArray(dba, other *bitArray) BitArray { min := uint64(len(dba.blocks)) ba := newBitArray(min * s) for i := uint64(0); i < min; i++ { ba.blocks[i] = dba.blocks[i].nand(other.blocks[i]) } ba.setLowest() ba.setHighest() return ba } ================================================ FILE: bitarray/nand_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray import ( "testing" "github.com/stretchr/testify/assert" ) func TestNandSparseWithSparseBitArray(t *testing.T) { sba := newSparseBitArray() other := newSparseBitArray() // bits for which only one of the arrays is set sba.SetBit(3) sba.SetBit(280) other.SetBit(9) other.SetBit(100) sba.SetBit(1000) other.SetBit(1001) // bits for which both arrays are set sba.SetBit(1) other.SetBit(1) sba.SetBit(2680) other.SetBit(2680) sba.SetBit(30) other.SetBit(30) ba := nandSparseWithSparseBitArray(sba, other) // Bits in both checkBit(t, ba, 1, false) checkBit(t, ba, 30, false) checkBit(t, ba, 2680, false) // Bits in sba but not other checkBit(t, ba, 3, true) checkBit(t, ba, 280, true) checkBit(t, ba, 1000, true) // Bits in other but not sba checkBit(t, ba, 9, false) checkBit(t, ba, 100, false) checkBit(t, ba, 2, false) nums := ba.ToNums() assert.Equal(t, []uint64{3, 280, 1000}, nums) } func TestNandSparseWithDenseBitArray(t *testing.T) { sba := newSparseBitArray() other := newBitArray(300) other.SetBit(1) sba.SetBit(1) other.SetBit(150) sba.SetBit(150) sba.SetBit(155) other.SetBit(156) sba.SetBit(300) other.SetBit(300) ba := nandSparseWithDenseBitArray(sba, other) // Bits in both checkBit(t, ba, 1, false) checkBit(t, ba, 150, false) checkBit(t, ba, 300, false) // Bits in sba but not other checkBit(t, ba, 155, true) // Bits in other but not sba checkBit(t, ba, 156, false) nums := ba.ToNums() assert.Equal(t, []uint64{155}, nums) } func TestNandDenseWithSparseBitArray(t *testing.T) { sba := newBitArray(300) other := newSparseBitArray() other.SetBit(1) sba.SetBit(1) other.SetBit(150) sba.SetBit(150) sba.SetBit(155) other.SetBit(156) sba.SetBit(300) other.SetBit(300) ba := nandDenseWithSparseBitArray(sba, other) // Bits in both checkBit(t, ba, 1, false) checkBit(t, ba, 150, false) checkBit(t, ba, 300, false) // Bits in sba but not other checkBit(t, ba, 155, true) // Bits in other but not sba checkBit(t, ba, 156, false) nums := ba.ToNums() assert.Equal(t, []uint64{155}, nums) } func TestNandSparseWithSmallerDenseBitArray(t *testing.T) { sba := newSparseBitArray() other := newBitArray(512) other.SetBit(1) sba.SetBit(1) other.SetBit(150) sba.SetBit(150) sba.SetBit(155) sba.SetBit(500) other.SetBit(128) sba.SetBit(1500) sba.SetBit(1200) ba := nandSparseWithDenseBitArray(sba, other) // Bits in both checkBit(t, ba, 1, false) checkBit(t, ba, 150, false) // Bits in sba but not other checkBit(t, ba, 155, true) checkBit(t, ba, 500, true) checkBit(t, ba, 1200, true) checkBit(t, ba, 1500, true) // Bits in other but not sba checkBit(t, ba, 128, false) nums := ba.ToNums() assert.Equal(t, []uint64{155, 500, 1200, 1500}, nums) } func TestNandDenseWithDenseBitArray(t *testing.T) { dba := newBitArray(1000) other := newBitArray(2000) dba.SetBit(1) other.SetBit(18) dba.SetBit(222) other.SetBit(222) other.SetBit(1501) ba := nandDenseWithDenseBitArray(dba, other) // Bits in both checkBit(t, ba, 222, false) // Bits in dba and not other checkBit(t, ba, 1, true) // Bits in other checkBit(t, ba, 18, false) // Bits in neither checkBit(t, ba, 0, false) checkBit(t, ba, 3, false) // check that the ba is the minimum of the size of `dba` and `other` // (dense bitarrays return an error on an out-of-bounds access) _, err := ba.GetBit(1500) assert.Equal(t, OutOfRangeError(1500), err) _, err = ba.GetBit(1501) assert.Equal(t, OutOfRangeError(1501), err) nums := ba.ToNums() assert.Equal(t, []uint64{1}, nums) } func TestNandSparseWithEmptySparse(t *testing.T) { sba := newSparseBitArray() other := newSparseBitArray() sba.SetBit(5) ba := nandSparseWithSparseBitArray(sba, other) checkBit(t, ba, 0, false) checkBit(t, ba, 5, true) checkBit(t, ba, 100, false) } func TestNandSparseWithEmptyDense(t *testing.T) { sba := newSparseBitArray() other := newBitArray(1000) sba.SetBit(5) ba := nandSparseWithDenseBitArray(sba, other) checkBit(t, ba, 5, true) sba.Reset() other.SetBit(5) ba = nandSparseWithDenseBitArray(sba, other) checkBit(t, ba, 5, false) } func TestNandDenseWithEmptyDense(t *testing.T) { dba := newBitArray(1000) other := newBitArray(1000) dba.SetBit(5) ba := nandDenseWithDenseBitArray(dba, other) checkBit(t, ba, 5, true) dba.Reset() other.SetBit(5) ba = nandDenseWithDenseBitArray(dba, other) checkBit(t, ba, 5, false) } func BenchmarkNandSparseWithSparse(b *testing.B) { numItems := uint64(160000) sba := newSparseBitArray() other := newSparseBitArray() for i := uint64(0); i < numItems; i += s { if i%200 == 0 { sba.SetBit(i) } else if i%300 == 0 { other.SetBit(i) } } b.ResetTimer() for i := 0; i < b.N; i++ { nandSparseWithSparseBitArray(sba, other) } } func BenchmarkNandSparseWithDense(b *testing.B) { numItems := uint64(160000) sba := newSparseBitArray() other := newBitArray(numItems) for i := uint64(0); i < numItems; i += s { if i%2 == 0 { sba.SetBit(i) } else if i%3 == 0 { other.SetBit(i) } } b.ResetTimer() for i := 0; i < b.N; i++ { nandSparseWithDenseBitArray(sba, other) } } func BenchmarkNandDenseWithSparse(b *testing.B) { numItems := uint64(160000) ba := newBitArray(numItems) other := newSparseBitArray() for i := uint64(0); i < numItems; i += s { if i%2 == 0 { ba.SetBit(i) } else if i%3 == 0 { other.SetBit(i) } } b.ResetTimer() for i := 0; i < b.N; i++ { nandDenseWithSparseBitArray(ba, other) } } func BenchmarkNandDenseWithDense(b *testing.B) { numItems := uint64(160000) dba := newBitArray(numItems) other := newBitArray(numItems) for i := uint64(0); i < numItems; i += s { if i%2 == 0 { dba.SetBit(i) } else if i%3 == 0 { other.SetBit(i) } } b.ResetTimer() for i := 0; i < b.N; i++ { nandDenseWithDenseBitArray(dba, other) } } ================================================ FILE: bitarray/or.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray func orSparseWithSparseBitArray(sba *sparseBitArray, other *sparseBitArray) BitArray { if len(other.indices) == 0 { return sba.copy() } if len(sba.indices) == 0 { return other.copy() } max := maxInt64(int64(len(sba.indices)), int64(len(other.indices))) indices := make(uintSlice, 0, max) blocks := make(blocks, 0, max) selfIndex := 0 otherIndex := 0 for { // last comparison was a real or, we are both exhausted now if selfIndex == len(sba.indices) && otherIndex == len(other.indices) { break } else if selfIndex == len(sba.indices) { indices = append(indices, other.indices[otherIndex:]...) blocks = append(blocks, other.blocks[otherIndex:]...) break } else if otherIndex == len(other.indices) { indices = append(indices, sba.indices[selfIndex:]...) blocks = append(blocks, sba.blocks[selfIndex:]...) break } selfValue := sba.indices[selfIndex] otherValue := other.indices[otherIndex] switch diff := int(otherValue) - int(selfValue); { case diff > 0: indices = append(indices, selfValue) blocks = append(blocks, sba.blocks[selfIndex]) selfIndex++ case diff < 0: indices = append(indices, otherValue) blocks = append(blocks, other.blocks[otherIndex]) otherIndex++ default: indices = append(indices, otherValue) blocks = append(blocks, sba.blocks[selfIndex].or(other.blocks[otherIndex])) selfIndex++ otherIndex++ } } return &sparseBitArray{ indices: indices, blocks: blocks, } } func orSparseWithDenseBitArray(sba *sparseBitArray, other *bitArray) BitArray { if other.Capacity() == 0 || !other.anyset { return sba.copy() } if sba.Capacity() == 0 { return other.copy() } max := maxUint64(uint64(sba.Capacity()), uint64(other.Capacity())) ba := newBitArray(max * s) selfIndex := 0 otherIndex := 0 for { if selfIndex == len(sba.indices) && otherIndex == len(other.blocks) { break } else if selfIndex == len(sba.indices) { copy(ba.blocks[otherIndex:], other.blocks[otherIndex:]) break } else if otherIndex == len(other.blocks) { for i, value := range sba.indices[selfIndex:] { ba.blocks[value] = sba.blocks[i+selfIndex] } break } selfValue := sba.indices[selfIndex] if selfValue == uint64(otherIndex) { ba.blocks[otherIndex] = sba.blocks[selfIndex].or(other.blocks[otherIndex]) selfIndex++ otherIndex++ continue } ba.blocks[otherIndex] = other.blocks[otherIndex] otherIndex++ } ba.setHighest() ba.setLowest() return ba } func orDenseWithDenseBitArray(dba *bitArray, other *bitArray) BitArray { if dba.Capacity() == 0 || !dba.anyset { return other.copy() } if other.Capacity() == 0 || !other.anyset { return dba.copy() } max := maxUint64(uint64(len(dba.blocks)), uint64(len(other.blocks))) ba := newBitArray(max * s) for i := uint64(0); i < max; i++ { if i == uint64(len(dba.blocks)) { copy(ba.blocks[i:], other.blocks[i:]) break } if i == uint64(len(other.blocks)) { copy(ba.blocks[i:], dba.blocks[i:]) break } ba.blocks[i] = dba.blocks[i].or(other.blocks[i]) } ba.setLowest() ba.setHighest() return ba } ================================================ FILE: bitarray/or_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray import ( "testing" "github.com/stretchr/testify/assert" ) func TestOrSparseWithSparseBitArray(t *testing.T) { sba := newSparseBitArray() other := newSparseBitArray() ctx := false for i := uint64(0); i < 1000; i += s { if ctx { sba.SetBit(i) } else { other.SetBit(i) } ctx = !ctx } sba.SetBit(s - 1) other.SetBit(s - 1) result := orSparseWithSparseBitArray(sba, other) for i := uint64(0); i < 1000; i += s { ok, err := result.GetBit(i) assert.Nil(t, err) assert.True(t, ok) } ok, err := result.GetBit(s - 1) assert.Nil(t, err) assert.True(t, ok) ok, err = result.GetBit(s - 2) assert.Nil(t, err) assert.False(t, ok) other.SetBit(2000) result = orSparseWithSparseBitArray(sba, other) ok, err = result.GetBit(2000) assert.Nil(t, err) assert.True(t, ok) sba.SetBit(2000) result = orSparseWithSparseBitArray(sba, other) ok, err = result.GetBit(2000) assert.Nil(t, err) assert.True(t, ok) } func BenchmarkOrSparseWithSparse(b *testing.B) { numItems := uint64(160000) sba := newSparseBitArray() other := newSparseBitArray() ctx := false for i := uint64(0); i < numItems; i += s { if ctx { sba.SetBit(i) } else { other.SetBit(i) } ctx = !ctx } b.ResetTimer() for i := 0; i < b.N; i++ { orSparseWithSparseBitArray(sba, other) } } func TestOrSparseWithDenseBitArray(t *testing.T) { sba := newSparseBitArray() other := newBitArray(2000) ctx := false for i := uint64(0); i < 1000; i += s { if ctx { sba.SetBit(i) } else { other.SetBit(i) } ctx = !ctx } other.SetBit(1500) other.SetBit(s - 1) sba.SetBit(s - 1) result := orSparseWithDenseBitArray(sba, other) for i := uint64(0); i < 1000; i += s { ok, err := result.GetBit(i) assert.Nil(t, err) assert.True(t, ok) } ok, err := result.GetBit(1500) assert.Nil(t, err) assert.True(t, ok) ok, err = result.GetBit(s - 1) assert.Nil(t, err) assert.True(t, ok) ok, err = result.GetBit(s - 2) assert.Nil(t, err) assert.False(t, ok) sba.SetBit(2500) result = orSparseWithDenseBitArray(sba, other) ok, err = result.GetBit(2500) assert.Nil(t, err) assert.True(t, ok) } func BenchmarkOrSparseWithDense(b *testing.B) { numItems := uint64(160000) sba := newSparseBitArray() other := newBitArray(numItems) ctx := false for i := uint64(0); i < numItems; i += s { if ctx { sba.SetBit(i) } else { other.SetBit(i) } ctx = !ctx } b.ResetTimer() for i := 0; i < b.N; i++ { orSparseWithDenseBitArray(sba, other) } } func TestOrDenseWithDenseBitArray(t *testing.T) { dba := newBitArray(1000) other := newBitArray(2000) ctx := false for i := uint64(0); i < 1000; i += s { if ctx { dba.SetBit(i) } else { other.SetBit(i) } ctx = !ctx } other.SetBit(1500) other.SetBit(s - 1) dba.SetBit(s - 1) result := orDenseWithDenseBitArray(dba, other) for i := uint64(0); i < 1000; i += s { ok, err := result.GetBit(i) assert.Nil(t, err) assert.True(t, ok) } ok, err := result.GetBit(s - 1) assert.Nil(t, err) assert.True(t, ok) ok, err = result.GetBit(1500) assert.Nil(t, err) assert.True(t, ok) ok, err = result.GetBit(1700) assert.Nil(t, err) assert.False(t, ok) } func BenchmarkOrDenseWithDense(b *testing.B) { numItems := uint64(160000) dba := newBitArray(numItems) other := newBitArray(numItems) ctx := false for i := uint64(0); i < numItems; i += s { if ctx { dba.SetBit(i) } else { other.SetBit(i) } ctx = !ctx } b.ResetTimer() for i := 0; i < b.N; i++ { orDenseWithDenseBitArray(dba, other) } } func TestOrSparseWithEmptySparse(t *testing.T) { sba := newSparseBitArray() other := newSparseBitArray() sba.SetBit(5) result := orSparseWithSparseBitArray(sba, other) assert.Equal(t, sba, result) sba.Reset() other.SetBit(5) result = orSparseWithSparseBitArray(sba, other) assert.Equal(t, other, result) } func TestOrSparseWithEmptyDense(t *testing.T) { sba := newSparseBitArray() other := newBitArray(1000) sba.SetBit(5) result := orSparseWithDenseBitArray(sba, other) assert.Equal(t, sba, result) sba.Reset() other.SetBit(5) result = orSparseWithDenseBitArray(sba, other) assert.Equal(t, other, result) } func TestOrDenseWithEmptyDense(t *testing.T) { dba := newBitArray(1000) other := newBitArray(1000) dba.SetBit(5) result := orDenseWithDenseBitArray(dba, other) assert.Equal(t, dba, result) dba.Reset() other.SetBit(5) result = orDenseWithDenseBitArray(dba, other) assert.Equal(t, other, result) } ================================================ FILE: bitarray/sparse_bitarray.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray import ( "math/bits" "sort" ) // uintSlice is an alias for a slice of ints. Len, Swap, and Less // are exported to fulfill an interface needed for the search // function in the sort library. type uintSlice []uint64 // Len returns the length of the slice. func (u uintSlice) Len() int64 { return int64(len(u)) } // Swap swaps values in this slice at the positions given. func (u uintSlice) Swap(i, j int64) { u[i], u[j] = u[j], u[i] } // Less returns a bool indicating if the value at position i is // less than position j. func (u uintSlice) Less(i, j int64) bool { return u[i] < u[j] } func (u uintSlice) search(x uint64) int64 { return int64(sort.Search(len(u), func(i int) bool { return uint64(u[i]) >= x })) } func (u *uintSlice) insert(x uint64) (int64, bool) { i := u.search(x) if i == int64(len(*u)) { *u = append(*u, x) return i, true } if (*u)[i] == x { return i, false } *u = append(*u, 0) copy((*u)[i+1:], (*u)[i:]) (*u)[i] = x return i, true } func (u *uintSlice) deleteAtIndex(i int64) { copy((*u)[i:], (*u)[i+1:]) (*u)[len(*u)-1] = 0 *u = (*u)[:len(*u)-1] } func (u uintSlice) get(x uint64) int64 { i := u.search(x) if i == int64(len(u)) { return -1 } if u[i] == x { return i } return -1 } type blocks []block func (b *blocks) insert(index int64) { if index == int64(len(*b)) { *b = append(*b, block(0)) return } *b = append(*b, block(0)) copy((*b)[index+1:], (*b)[index:]) (*b)[index] = block(0) } func (b *blocks) deleteAtIndex(i int64) { copy((*b)[i:], (*b)[i+1:]) (*b)[len(*b)-1] = block(0) *b = (*b)[:len(*b)-1] } type sparseBitArray struct { blocks blocks indices uintSlice } // SetBit sets the bit at the given position. func (sba *sparseBitArray) SetBit(k uint64) error { index, position := getIndexAndRemainder(k) i, inserted := sba.indices.insert(index) if inserted { sba.blocks.insert(i) } sba.blocks[i] = sba.blocks[i].insert(position) return nil } // GetBit gets the bit at the given position. func (sba *sparseBitArray) GetBit(k uint64) (bool, error) { index, position := getIndexAndRemainder(k) i := sba.indices.get(index) if i == -1 { return false, nil } return sba.blocks[i].get(position), nil } // GetSetBits gets the position of bits set in the array. func (sba *sparseBitArray) GetSetBits(from uint64, buffer []uint64) []uint64 { fromBlockIndex, fromOffset := getIndexAndRemainder(from) fromBlockLocation := sba.indices.search(fromBlockIndex) if int(fromBlockLocation) == len(sba.indices) { return buffer[:0] } return getSetBitsInBlocks( fromBlockIndex, fromOffset, sba.blocks[fromBlockLocation:], sba.indices[fromBlockLocation:], buffer, ) } // ToNums converts this sparse bitarray to a list of numbers contained // within it. func (sba *sparseBitArray) ToNums() []uint64 { if len(sba.indices) == 0 { return nil } diff := uint64(len(sba.indices)) * s nums := make([]uint64, 0, diff/4) for i, offset := range sba.indices { sba.blocks[i].toNums(offset*s, &nums) } return nums } // ClearBit clears the bit at the given position. func (sba *sparseBitArray) ClearBit(k uint64) error { index, position := getIndexAndRemainder(k) i := sba.indices.get(index) if i == -1 { return nil } sba.blocks[i] = sba.blocks[i].remove(position) if sba.blocks[i] == 0 { sba.blocks.deleteAtIndex(i) sba.indices.deleteAtIndex(i) } return nil } // Reset erases all values from this bitarray. func (sba *sparseBitArray) Reset() { sba.blocks = sba.blocks[:0] sba.indices = sba.indices[:0] } // Blocks returns an iterator to iterator of this bitarray's blocks. func (sba *sparseBitArray) Blocks() Iterator { return newCompressedBitArrayIterator(sba) } // Capacity returns the value of the highest possible *seen* value // in this sparse bitarray. func (sba *sparseBitArray) Capacity() uint64 { if len(sba.indices) == 0 { return 0 } return (sba.indices[len(sba.indices)-1] + 1) * s } // Equals returns a bool indicating if the provided bit array // equals this bitarray. func (sba *sparseBitArray) Equals(other BitArray) bool { if other.Capacity() == 0 && sba.Capacity() > 0 { return false } var selfIndex uint64 for iter := other.Blocks(); iter.Next(); { otherIndex, otherBlock := iter.Value() if len(sba.indices) == 0 { if otherBlock > 0 { return false } continue } if selfIndex >= uint64(len(sba.indices)) { return false } if otherIndex < sba.indices[selfIndex] { if otherBlock > 0 { return false } continue } if otherIndex > sba.indices[selfIndex] { return false } if !sba.blocks[selfIndex].equals(otherBlock) { return false } selfIndex++ } return true } // Count returns the number of set bits in this array. func (sba *sparseBitArray) Count() int { count := 0 for _, block := range sba.blocks { count += bits.OnesCount64(uint64(block)) } return count } // Or will perform a bitwise or operation with the provided bitarray and // return a new result bitarray. func (sba *sparseBitArray) Or(other BitArray) BitArray { if ba, ok := other.(*sparseBitArray); ok { return orSparseWithSparseBitArray(sba, ba) } return orSparseWithDenseBitArray(sba, other.(*bitArray)) } // And will perform a bitwise and operation with the provided bitarray and // return a new result bitarray. func (sba *sparseBitArray) And(other BitArray) BitArray { if ba, ok := other.(*sparseBitArray); ok { return andSparseWithSparseBitArray(sba, ba) } return andSparseWithDenseBitArray(sba, other.(*bitArray)) } // Nand will return the result of doing a bitwise and not of the bit array // with the other bit array on each block. func (sba *sparseBitArray) Nand(other BitArray) BitArray { if ba, ok := other.(*sparseBitArray); ok { return nandSparseWithSparseBitArray(sba, ba) } return nandSparseWithDenseBitArray(sba, other.(*bitArray)) } func (sba *sparseBitArray) IsEmpty() bool { // This works because the and, nand and delete functions only // keep values that have a non-zero block. return len(sba.indices) == 0 } func (sba *sparseBitArray) copy() *sparseBitArray { blocks := make(blocks, len(sba.blocks)) copy(blocks, sba.blocks) indices := make(uintSlice, len(sba.indices)) copy(indices, sba.indices) return &sparseBitArray{ blocks: blocks, indices: indices, } } // Intersects returns a bool indicating if the provided bit array // intersects with this bitarray. func (sba *sparseBitArray) Intersects(other BitArray) bool { if other.Capacity() == 0 { return true } var selfIndex int64 for iter := other.Blocks(); iter.Next(); { otherI, otherBlock := iter.Value() if len(sba.indices) == 0 { if otherBlock > 0 { return false } continue } // here we grab where the block should live in ourselves i := uintSlice(sba.indices[selfIndex:]).search(otherI) // this is a block we don't have, doesn't intersect if i == int64(len(sba.indices)) { return false } if sba.indices[i] != otherI { return false } if !sba.blocks[i].intersects(otherBlock) { return false } selfIndex = i } return true } func (sba *sparseBitArray) IntersectsBetween(other BitArray, start, stop uint64) bool { return true } func newSparseBitArray() *sparseBitArray { return &sparseBitArray{} } // NewSparseBitArray will create a bit array that consumes a great // deal less memory at the expense of longer sets and gets. func NewSparseBitArray() BitArray { return newSparseBitArray() } ================================================ FILE: bitarray/sparse_bitarray_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetCompressedBit(t *testing.T) { ba := newSparseBitArray() result, err := ba.GetBit(5) assert.Nil(t, err) assert.False(t, result) } func BenchmarkGetCompressedBit(b *testing.B) { numItems := 1000 ba := newSparseBitArray() for i := 0; i < numItems; i++ { ba.SetBit(uint64(i)) } b.ResetTimer() for i := 0; i < b.N; i++ { ba.GetBit(s) } } func TestGetSetCompressedBit(t *testing.T) { ba := newSparseBitArray() ba.SetBit(5) result, err := ba.GetBit(5) assert.Nil(t, err) assert.True(t, result) result, err = ba.GetBit(7) assert.Nil(t, err) assert.False(t, result) ba.SetBit(s * 2) result, _ = ba.GetBit(s * 2) assert.True(t, result) result, _ = ba.GetBit(s*2 + 1) assert.False(t, result) } func BenchmarkSetCompressedBit(b *testing.B) { numItems := 1000 ba := newSparseBitArray() b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < numItems; j++ { ba.SetBit(uint64(j)) } } } func TestGetSetCompressedBits(t *testing.T) { ba := newSparseBitArray() buf := make([]uint64, 0, 5) require.NoError(t, ba.SetBit(1)) require.NoError(t, ba.SetBit(4)) require.NoError(t, ba.SetBit(8)) require.NoError(t, ba.SetBit(63)) require.NoError(t, ba.SetBit(64)) require.NoError(t, ba.SetBit(200)) require.NoError(t, ba.SetBit(1000)) assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil)) assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{})) assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf)) assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf)) assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf)) assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf)) require.NoError(t, ba.ClearBit(4)) require.NoError(t, ba.ClearBit(64)) assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf)) assert.Empty(t, ba.GetSetBits(1001, buf)) ba.Reset() assert.Empty(t, ba.GetSetBits(0, buf)) } func BenchmarkGetSetCompressedBits(b *testing.B) { ba := newSparseBitArray() for i := uint64(0); i < 168000; i++ { if i%13 == 0 || i%5 == 0 { require.NoError(b, ba.SetBit(i)) } } buf := make([]uint64, 0, ba.Capacity()) b.ResetTimer() for i := 0; i < b.N; i++ { ba.GetSetBits(0, buf) } } func TestCompressedCount(t *testing.T) { ba := newSparseBitArray() assert.Equal(t, 0, ba.Count()) require.NoError(t, ba.SetBit(0)) assert.Equal(t, 1, ba.Count()) require.NoError(t, ba.SetBit(40)) require.NoError(t, ba.SetBit(64)) require.NoError(t, ba.SetBit(100)) require.NoError(t, ba.SetBit(200)) require.NoError(t, ba.SetBit(469)) require.NoError(t, ba.SetBit(500)) assert.Equal(t, 7, ba.Count()) require.NoError(t, ba.ClearBit(200)) assert.Equal(t, 6, ba.Count()) ba.Reset() assert.Equal(t, 0, ba.Count()) } func TestClearCompressedBit(t *testing.T) { ba := newSparseBitArray() ba.SetBit(5) ba.ClearBit(5) result, err := ba.GetBit(5) assert.Nil(t, err) assert.False(t, result) assert.Len(t, ba.blocks, 0) assert.Len(t, ba.indices, 0) ba.SetBit(s * 2) ba.ClearBit(s * 2) result, _ = ba.GetBit(s * 2) assert.False(t, result) assert.Len(t, ba.indices, 0) assert.Len(t, ba.blocks, 0) } func BenchmarkClearCompressedBit(b *testing.B) { numItems := 1000 ba := newSparseBitArray() for i := 0; i < numItems; i++ { ba.SetBit(uint64(i)) } b.ResetTimer() for i := 0; i < b.N; i++ { ba.ClearBit(uint64(i)) } } func TestClearCompressedBitArray(t *testing.T) { ba := newSparseBitArray() ba.SetBit(5) ba.SetBit(s * 2) result, err := ba.GetBit(5) assert.Nil(t, err) assert.True(t, result) result, _ = ba.GetBit(s * 2) assert.True(t, result) ba.Reset() result, err = ba.GetBit(5) assert.Nil(t, err) assert.False(t, result) result, _ = ba.GetBit(s * 2) assert.False(t, result) } func TestCompressedEquals(t *testing.T) { ba := newSparseBitArray() other := newSparseBitArray() assert.True(t, ba.Equals(other)) ba.SetBit(5) assert.False(t, ba.Equals(other)) other.SetBit(5) assert.True(t, ba.Equals(other)) ba.ClearBit(5) assert.False(t, ba.Equals(other)) } func TestCompressedIntersects(t *testing.T) { ba := newSparseBitArray() other := newSparseBitArray() assert.True(t, ba.Intersects(other)) other.SetBit(5) assert.False(t, ba.Intersects(other)) assert.True(t, other.Intersects(ba)) ba.SetBit(5) assert.True(t, ba.Intersects(other)) assert.True(t, other.Intersects(ba)) other.SetBit(10) assert.False(t, ba.Intersects(other)) assert.True(t, other.Intersects(ba)) } func TestLongCompressedIntersects(t *testing.T) { ba := newSparseBitArray() other := newSparseBitArray() ba.SetBit(5) other.SetBit(5) assert.True(t, ba.Intersects(other)) other.SetBit(s * 2) assert.False(t, ba.Intersects(other)) assert.True(t, other.Intersects(ba)) ba.SetBit(s * 2) assert.True(t, ba.Intersects(other)) assert.True(t, other.Intersects(ba)) other.SetBit(s*2 + 1) assert.False(t, ba.Intersects(other)) assert.True(t, other.Intersects(ba)) } func BenchmarkCompressedIntersects(b *testing.B) { numItems := uint64(1000) ba := newSparseBitArray() other := newSparseBitArray() for i := uint64(0); i < numItems; i++ { ba.SetBit(i) other.SetBit(i) } b.ResetTimer() for i := 0; i < b.N; i++ { ba.Intersects(other) } } func TestSparseIntersectsBitArray(t *testing.T) { cba := newSparseBitArray() ba := newBitArray(s * 2) assert.True(t, cba.Intersects(ba)) ba.SetBit(5) assert.False(t, cba.Intersects(ba)) cba.SetBit(5) assert.True(t, cba.Intersects(ba)) cba.SetBit(10) assert.True(t, cba.Intersects(ba)) ba.SetBit(s + 1) assert.False(t, cba.Intersects(ba)) cba.SetBit(s + 1) assert.True(t, cba.Intersects(ba)) cba.SetBit(s * 3) assert.True(t, cba.Intersects(ba)) } func TestSparseEqualsBitArray(t *testing.T) { cba := newSparseBitArray() ba := newBitArray(s * 2) assert.True(t, cba.Equals(ba)) ba.SetBit(5) assert.False(t, cba.Equals(ba)) cba.SetBit(5) assert.True(t, cba.Equals(ba)) ba.SetBit(s + 1) assert.False(t, cba.Equals(ba)) cba.SetBit(s + 1) assert.True(t, cba.Equals(ba)) } func BenchmarkCompressedEquals(b *testing.B) { numItems := uint64(1000) cba := newSparseBitArray() other := newSparseBitArray() for i := uint64(0); i < numItems; i++ { cba.SetBit(i) other.SetBit(i) } b.ResetTimer() for i := 0; i < b.N; i++ { cba.Equals(other) } } func TestInsertPreviousBlockInSparse(t *testing.T) { sba := newSparseBitArray() sba.SetBit(s * 2) sba.SetBit(s - 1) result, err := sba.GetBit(s - 1) assert.Nil(t, err) assert.True(t, result) } func TestSparseBitArrayToNums(t *testing.T) { sba := newSparseBitArray() sba.SetBit(s - 1) sba.SetBit(s + 1) expected := []uint64{s - 1, s + 1} results := sba.ToNums() assert.Equal(t, expected, results) } func BenchmarkSparseBitArrayToNums(b *testing.B) { numItems := uint64(1000) sba := newSparseBitArray() for i := uint64(0); i < numItems; i++ { sba.SetBit(i) } b.ResetTimer() for i := 0; i < b.N; i++ { sba.ToNums() } } ================================================ FILE: bitarray/util.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package bitarray // maxInt64 returns the highest integer in the provided list of int64s func maxInt64(ints ...int64) int64 { maxInt := ints[0] for i := 1; i < len(ints); i++ { if ints[i] > maxInt { maxInt = ints[i] } } return maxInt } // maxUint64 returns the highest integer in the provided list of uint64s func maxUint64(ints ...uint64) uint64 { maxInt := ints[0] for i := 1; i < len(ints); i++ { if ints[i] > maxInt { maxInt = ints[i] } } return maxInt } // minUint64 returns the lowest integer in the provided list of int32s func minUint64(ints ...uint64) uint64 { minInt := ints[0] for i := 1; i < len(ints); i++ { if ints[i] < minInt { minInt = ints[i] } } return minInt } ================================================ FILE: btree/_link/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package link // Keys is a typed list of Key interfaces. type Keys []Key type Key interface { // Compare should return an int indicating how this key relates // to the provided key. -1 will indicate less than, 0 will indicate // equality, and 1 will indicate greater than. Duplicate keys // are allowed, but duplicate IDs are not. Compare(Key) int } ================================================ FILE: btree/_link/key.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package link import "sort" func (keys Keys) search(key Key) int { return sort.Search(len(keys), func(i int) bool { return keys[i].Compare(key) >= 0 }) } func (keys *Keys) insert(key Key) Key { i := keys.search(key) return keys.insertAt(key, i) } func (keys *Keys) insertAt(key Key, i int) Key { if i == len(*keys) { *keys = append(*keys, key) return nil } if (*keys)[i].Compare(key) == 0 { //overwrite case oldKey := (*keys)[i] (*keys)[i] = key return oldKey } *keys = append(*keys, nil) copy((*keys)[i+1:], (*keys)[i:]) (*keys)[i] = key return nil } func (keys *Keys) split() (Key, Keys, Keys) { i := (len(*keys) / 2) - 1 middle := (*keys)[i] left, right := keys.splitAt(i) return middle, left, right } func (keys *Keys) splitAt(i int) (Keys, Keys) { right := make(Keys, len(*keys)-i-1, cap(*keys)) copy(right, (*keys)[i+1:]) for j := i + 1; j < len(*keys); j++ { (*keys)[j] = nil } *keys = (*keys)[:i+1] return *keys, right } func (keys Keys) last() Key { return keys[len(keys)-1] } func (keys Keys) first() Key { return keys[0] } func (keys Keys) needsSplit() bool { return cap(keys) == len(keys) } func (keys Keys) reverse() Keys { reversed := make(Keys, len(keys)) for i := len(keys) - 1; i >= 0; i-- { reversed[len(keys)-1-i] = keys[i] } return reversed } func chunkKeys(keys Keys, numParts int64) []Keys { parts := make([]Keys, numParts) for i := int64(0); i < numParts; i++ { parts[i] = keys[i*int64(len(keys))/numParts : (i+1)*int64(len(keys))/numParts] } return parts } ================================================ FILE: btree/_link/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package link type mockKey uint64 func (mk mockKey) Compare(other Key) int { otherK := other.(mockKey) if mk < otherK { return -1 } if mk > otherK { return 1 } return 0 } ================================================ FILE: btree/_link/node.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package link import ( "log" "sync" ) func search(parent *node, key Key) Key { parent = getParent(parent, nil, key) parent.lock.RLock() parent = moveRight(parent, key, false) defer parent.lock.RUnlock() i := parent.search(key) if i == len(parent.keys) { return nil } return parent.keys[i] } func getParent(parent *node, stack *nodes, key Key) *node { var n *node for parent != nil && !parent.isLeaf { parent.lock.RLock() parent = moveRight(parent, key, false) // if this happens on the root this should always just return n = parent.searchNode(key) if stack != nil { stack.push(parent) } parent.lock.RUnlock() parent = n } return parent } func insert(tree *blink, parent *node, stack *nodes, key Key) Key { parent = getParent(parent, stack, key) parent.lock.Lock() parent = moveRight(parent, key, true) result := parent.insert(key) if result != nil { // overwrite parent.lock.Unlock() return result } if !parent.needsSplit() { parent.lock.Unlock() return nil } split(tree, parent, stack) return nil } func split(tree *blink, n *node, stack *nodes) { var l, r *node var k Key var parent *node for n.needsSplit() { k, l, r = n.split() parent = stack.pop() if parent == nil { tree.lock.Lock() if tree.root == nil || tree.root == n { parent = newNode(false, make(Keys, 0, tree.ary), make(nodes, 0, tree.ary+1)) parent.maxSeen = r.max() parent.keys.insert(k) parent.nodes.push(l) parent.nodes.push(r) tree.root = parent n.lock.Unlock() tree.lock.Unlock() return } parent = tree.root tree.lock.Unlock() } parent.lock.Lock() parent = moveRight(parent, r.key(), true) i := parent.search(k) parent.keys.insertAt(k, i) parent.nodes[i] = l parent.nodes.insertAt(r, i+1) n.lock.Unlock() n = parent } n.lock.Unlock() } func moveRight(n *node, key Key, getLock bool) *node { var right *node for { if len(n.keys) == 0 || n.right == nil { // this is either the node or the rightmost node return n } if key.Compare(n.max()) < 1 { return n } if getLock { n.right.lock.Lock() right = n.right n.lock.Unlock() } else { n.right.lock.RLock() right = n.right n.lock.RUnlock() } n = right } } type nodes []*node func (ns *nodes) reset() { for i := range *ns { (*ns)[i] = nil } *ns = (*ns)[:0] } func (ns *nodes) push(n *node) { *ns = append(*ns, n) } func (ns *nodes) pop() *node { if len(*ns) == 0 { return nil } n := (*ns)[len(*ns)-1] (*ns)[len(*ns)-1] = nil *ns = (*ns)[:len(*ns)-1] return n } func (ns *nodes) insertAt(n *node, i int) { if i == len(*ns) { *ns = append(*ns, n) return } *ns = append(*ns, nil) copy((*ns)[i+1:], (*ns)[i:]) (*ns)[i] = n } func (ns *nodes) splitAt(i int) (nodes, nodes) { length := len(*ns) - i right := make(nodes, length, cap(*ns)) copy(right, (*ns)[i+1:]) for j := i + 1; j < len(*ns); j++ { (*ns)[j] = nil } *ns = (*ns)[:i+1] return *ns, right } type node struct { keys Keys nodes nodes right *node lock sync.RWMutex isLeaf bool maxSeen Key } func (n *node) key() Key { return n.keys.last() } func (n *node) insert(key Key) Key { if !n.isLeaf { panic(`Can't only insert key in an internal node.`) } overwritten := n.keys.insert(key) return overwritten } func (n *node) insertNode(other *node) { key := other.key() i := n.keys.search(key) n.keys.insertAt(key, i) n.nodes.insertAt(other, i) } func (n *node) needsSplit() bool { return n.keys.needsSplit() } func (n *node) max() Key { if n.isLeaf { return n.keys.last() } return n.maxSeen } func (n *node) splitLeaf() (Key, *node, *node) { i := (len(n.keys) / 2) key := n.keys[i] _, rightKeys := n.keys.splitAt(i) nn := &node{ keys: rightKeys, right: n.right, isLeaf: true, } n.right = nn return key, n, nn } func (n *node) splitInternal() (Key, *node, *node) { i := (len(n.keys) / 2) key := n.keys[i] rightKeys := make(Keys, len(n.keys)-1-i, cap(n.keys)) rightNodes := make(nodes, len(rightKeys)+1, cap(n.nodes)) copy(rightKeys, n.keys[i+1:]) copy(rightNodes, n.nodes[i+1:]) // for garbage collection for j := i + 1; j < len(n.nodes); j++ { if j != len(n.keys) { n.keys[j] = nil } n.nodes[j] = nil } nn := newNode(false, rightKeys, rightNodes) nn.maxSeen = n.max() n.maxSeen = key n.keys = n.keys[:i] n.nodes = n.nodes[:i+1] n.right = nn return key, n, nn } func (n *node) split() (Key, *node, *node) { if n.isLeaf { return n.splitLeaf() } return n.splitInternal() } func (n *node) search(key Key) int { return n.keys.search(key) } func (n *node) searchNode(key Key) *node { i := n.search(key) return n.nodes[i] } func (n *node) print(output *log.Logger) { output.Printf(`NODE: %+v, %p`, n, n) if !n.isLeaf { for _, n := range n.nodes { if n == nil { output.Println(`NIL NODE`) continue } n.print(output) } } } func newNode(isLeaf bool, keys Keys, ns nodes) *node { return &node{ isLeaf: isLeaf, keys: keys, nodes: ns, } } ================================================ FILE: btree/_link/node_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package link import ( "testing" "github.com/stretchr/testify/assert" ) func newTestNode(isLeaf bool, ary int) *node { return &node{ isLeaf: isLeaf, keys: make(Keys, 0, ary), nodes: make(nodes, 0, ary+1), } } func checkTree(t testing.TB, tree *blink) bool { if tree.root == nil { return true } return checkNode(t, tree.root) } func checkNode(t testing.TB, n *node) bool { if len(n.keys) == 0 { assert.Len(t, n.nodes, 0) return false } if n.isLeaf { assert.Len(t, n.nodes, 0) return false } if !assert.Len(t, n.nodes, len(n.keys)+1) { return false } for i := 0; i < len(n.keys); i++ { if !assert.True(t, n.keys[i].Compare(n.nodes[i].key()) >= 0) { t.Logf(`N: %+v %p, n.keys[i]: %+v, n.nodes[i]: %+v`, n, n, n.keys[i], n.nodes[i]) return false } } if !assert.True(t, n.nodes[len(n.nodes)-1].key().Compare(n.keys.last()) > 0) { t.Logf(`m: %+v, %p, n.nodes[len(n.nodes)-1].key(): %+v, n.keys.last(): %+v`, n, n, n.nodes[len(n.nodes)-1].key(), n.keys.last()) return false } for _, child := range n.nodes { if !assert.NotNil(t, child) { return false } if !checkNode(t, child) { return false } } return true } func TestSplitInternalNodeOddAry(t *testing.T) { parent := newTestNode(false, 3) n1 := newTestNode(true, 3) n1.keys.insert(mockKey(1)) n2 := newTestNode(true, 3) n2.keys.insert(mockKey(5)) n3 := newTestNode(true, 3) n3.keys.insert(mockKey(10)) n4 := newTestNode(true, 3) n4.keys.insert(mockKey(15)) parent.nodes = nodes{n1, n2, n3, n4} parent.keys = Keys{mockKey(5), mockKey(10), mockKey(15)} key, l, r := parent.split() assert.Equal(t, mockKey(10), key) assert.Equal(t, Keys{mockKey(5)}, l.keys) assert.Equal(t, Keys{mockKey(15)}, r.keys) assert.Equal(t, nodes{n1, n2}, l.nodes) assert.Equal(t, nodes{n3, n4}, r.nodes) assert.Equal(t, l.right, r) assert.False(t, l.isLeaf) assert.False(t, r.isLeaf) } func TestSplitInternalNodeEvenAry(t *testing.T) { parent := newTestNode(false, 4) n1 := newTestNode(true, 4) n1.keys.insert(mockKey(1)) n2 := newTestNode(true, 4) n2.keys.insert(mockKey(5)) n3 := newTestNode(true, 4) n3.keys.insert(mockKey(10)) n4 := newTestNode(true, 4) n4.keys.insert(mockKey(15)) n5 := newTestNode(true, 4) n5.keys.insert(mockKey(20)) parent.nodes = nodes{n1, n2, n3, n4, n5} parent.keys = Keys{mockKey(5), mockKey(10), mockKey(15), mockKey(20)} key, l, r := parent.split() assert.Equal(t, mockKey(15), key) assert.Equal(t, Keys{mockKey(5), mockKey(10)}, l.keys) assert.Equal(t, Keys{mockKey(20)}, r.keys) assert.Equal(t, nodes{n1, n2, n3}, l.nodes) assert.Equal(t, nodes{n4, n5}, r.nodes) assert.Equal(t, l.right, r) assert.False(t, l.isLeaf) assert.False(t, r.isLeaf) } func TestSplitLeafNodeOddAry(t *testing.T) { parent := newTestNode(true, 3) k1 := mockKey(5) k2 := mockKey(15) k3 := mockKey(20) parent.keys = Keys{k1, k2, k3} key, l, r := parent.split() assert.Equal(t, k2, key) assert.Equal(t, Keys{k1, k2}, l.keys) assert.Equal(t, Keys{k3}, r.keys) assert.True(t, l.isLeaf) assert.True(t, r.isLeaf) assert.Equal(t, r, l.right) } func TestSplitLeafNodeEvenAry(t *testing.T) { parent := newTestNode(true, 4) k1 := mockKey(5) k2 := mockKey(15) k3 := mockKey(20) k4 := mockKey(25) parent.keys = Keys{k1, k2, k3, k4} key, l, r := parent.split() assert.Equal(t, k3, key) assert.Equal(t, Keys{k1, k2, k3}, l.keys) assert.Equal(t, Keys{k4}, r.keys) assert.True(t, l.isLeaf) assert.True(t, r.isLeaf) assert.Equal(t, r, l.right) } ================================================ FILE: btree/_link/tree.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* This is a b-link tree in progress from the following paper: http://www.csd.uoc.gr/~hy460/pdf/p650-lehman.pdf This is still a work in progress and the CRUD methods on the tree need to be parallelized. Until this is complete, there is no constructor method for this package. Time complexities: Space: O(n) Search: O(log n) Insert: O(log n) Delete: O(log n) Current benchmarks with 16 ary: BenchmarkSimpleAdd-8 1000000 1455 ns/op BenchmarkGet-8 2000000 704 ns/op B-link was chosen after examining this paper: http://www.vldb.org/journal/VLDBJ2/P361.pdf */ package link import ( "log" "sync" "sync/atomic" ) // numberOfItemsBeforeMultithread defines the number of items that have // to be called with a method before we multithread. const numberOfItemsBeforeMultithread = 10 type blink struct { root *node lock sync.RWMutex number, ary, numRoutines uint64 } func (blink *blink) insert(key Key, stack *nodes) Key { var parent *node blink.lock.Lock() if blink.root == nil { blink.root = newNode( true, make(Keys, 0, blink.ary), make(nodes, 0, blink.ary+1), ) blink.root.keys = make(Keys, 0, blink.ary) blink.root.isLeaf = true } parent = blink.root blink.lock.Unlock() result := insert(blink, parent, stack, key) if result == nil { atomic.AddUint64(&blink.number, 1) return nil } return result } func (blink *blink) multithreadedInsert(keys Keys) Keys { chunks := chunkKeys(keys, int64(blink.numRoutines)) overwritten := make(Keys, len(keys)) var offset uint64 var wg sync.WaitGroup wg.Add(len(chunks)) for _, chunk := range chunks { go func(chunk Keys, offset uint64) { defer wg.Done() stack := make(nodes, 0, blink.ary) for i := 0; i < len(chunk); i++ { result := blink.insert(chunk[i], &stack) stack.reset() overwritten[offset+uint64(i)] = result } }(chunk, offset) offset += uint64(len(chunk)) } wg.Wait() return overwritten } // Insert will insert the provided keys into the b-tree and return // a list of keys overwritten, if any. Each insert is an O(log n) // operation. func (blink *blink) Insert(keys ...Key) Keys { if len(keys) > numberOfItemsBeforeMultithread { return blink.multithreadedInsert(keys) } overwritten := make(Keys, 0, len(keys)) stack := make(nodes, 0, blink.ary) for _, k := range keys { overwritten = append(overwritten, blink.insert(k, &stack)) stack.reset() } return overwritten } // Len returns the number of items in this b-link tree. func (blink *blink) Len() uint64 { return atomic.LoadUint64(&blink.number) } func (blink *blink) get(key Key) Key { var parent *node blink.lock.RLock() parent = blink.root blink.lock.RUnlock() k := search(parent, key) if k == nil { return nil } if k.Compare(key) == 0 { return k } return nil } // Get will retrieve the keys if they exist in this tree. If not, // a nil is returned in the proper place in the list of keys. Each // lookup is O(log n) time complexity. func (blink *blink) Get(keys ...Key) Keys { found := make(Keys, 0, len(keys)) for _, k := range keys { found = append(found, blink.get(k)) } return found } func (blink *blink) print(output *log.Logger) { output.Println(`PRINTING B-LINK`) if blink.root == nil { return } blink.root.print(output) } func newTree(ary, numRoutines uint64) *blink { return &blink{ary: ary, numRoutines: numRoutines} } ================================================ FILE: btree/_link/tree_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package link import ( "log" "math/rand" "os" "testing" "github.com/stretchr/testify/assert" ) func getConsoleLogger() *log.Logger { return log.New(os.Stderr, "", log.LstdFlags) } func generateRandomKeys(num int) Keys { keys := make(Keys, 0, num) for i := 0; i < num; i++ { keys = append(keys, mockKey(uint64(rand.Uint32()%uint32(100)))) } return keys } func generateKeys(num int) Keys { keys := make(Keys, 0, num) for i := 0; i < num; i++ { keys = append(keys, mockKey(uint64(i))) } return keys } func TestSimpleInsert(t *testing.T) { k1 := mockKey(5) tree := newTree(8, 1) result := tree.Insert(k1) assert.Equal(t, Keys{nil}, result) assert.Equal(t, uint64(1), tree.Len()) if !assert.Equal(t, Keys{k1}, tree.Get(k1)) { tree.print(getConsoleLogger()) } } func TestMultipleInsert(t *testing.T) { k1 := mockKey(10) k2 := mockKey(5) tree := newTree(8, 1) result := tree.Insert(k1, k2) assert.Equal(t, Keys{nil, nil}, result) assert.Equal(t, uint64(2), tree.Len()) assert.Equal(t, Keys{k1, k2}, tree.Get(k1, k2)) checkTree(t, tree) } func TestMultipleInsertCausesSplitOddAryReverseOrder(t *testing.T) { k1, k2, k3 := mockKey(15), mockKey(10), mockKey(5) tree := newTree(3, 1) result := tree.Insert(k1, k2, k3) assert.Equal(t, Keys{nil, nil, nil}, result) assert.Equal(t, uint64(3), tree.Len()) if !assert.Equal(t, Keys{k1, k2, k3}, tree.Get(k1, k2, k3)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitOddAry(t *testing.T) { k1, k2, k3 := mockKey(5), mockKey(10), mockKey(15) tree := newTree(3, 1) result := tree.Insert(k1, k2, k3) assert.Equal(t, Keys{nil, nil, nil}, result) assert.Equal(t, uint64(3), tree.Len()) if !assert.Equal(t, Keys{k1, k2, k3}, tree.Get(k1, k2, k3)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitOddAryRandomOrder(t *testing.T) { k1, k2, k3 := mockKey(10), mockKey(5), mockKey(15) tree := newTree(3, 1) result := tree.Insert(k1, k2, k3) assert.Equal(t, Keys{nil, nil, nil}, result) assert.Equal(t, uint64(3), tree.Len()) if !assert.Equal(t, Keys{k1, k2, k3}, tree.Get(k1, k2, k3)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitEvenAryReverseOrder(t *testing.T) { k1, k2, k3, k4 := mockKey(20), mockKey(15), mockKey(10), mockKey(5) tree := newTree(4, 1) result := tree.Insert(k1, k2, k3, k4) assert.Equal(t, Keys{nil, nil, nil, nil}, result) assert.Equal(t, uint64(4), tree.Len()) if !assert.Equal(t, Keys{k1, k2, k3, k4}, tree.Get(k1, k2, k3, k4)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitEvenAry(t *testing.T) { k1, k2, k3, k4 := mockKey(5), mockKey(10), mockKey(15), mockKey(20) tree := newTree(4, 1) result := tree.Insert(k1, k2, k3, k4) assert.Equal(t, Keys{nil, nil, nil, nil}, result) assert.Equal(t, uint64(4), tree.Len()) if !assert.Equal(t, Keys{k1, k2, k3, k4}, tree.Get(k1, k2, k3, k4)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitEvenAryRandomOrder(t *testing.T) { k1, k2, k3, k4 := mockKey(10), mockKey(15), mockKey(20), mockKey(5) tree := newTree(4, 1) result := tree.Insert(k1, k2, k3, k4) assert.Equal(t, Keys{nil, nil, nil, nil}, result) assert.Equal(t, uint64(4), tree.Len()) if !assert.Equal(t, Keys{k1, k2, k3, k4}, tree.Get(k1, k2, k3, k4)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitEvenAryMultiThreaded(t *testing.T) { keys := generateRandomKeys(16) tree := newTree(16, 8) result := tree.Insert(keys...) assert.Len(t, result, len(keys)) if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesCascadingSplitsOddAry(t *testing.T) { keys := generateRandomKeys(1600) tree := newTree(9, 8) result := tree.Insert(keys...) assert.Len(t, result, len(keys)) // about all we can assert, random may produce duplicates checkTree(t, tree) if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesCascadingSplitsOddAryReverseOrder(t *testing.T) { keys := generateKeys(30000) tree := newTree(17, 8) reversed := keys.reverse() result := tree.Insert(reversed...) assert.Len(t, result, len(keys)) // about all we can assert, random may produce duplicates if !assert.Equal(t, keys, tree.Get(keys...)) { //tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesCascadingSplitsEvenAry(t *testing.T) { keys := generateRandomKeys(200) tree := newTree(12, 8) result := tree.Insert(keys...) assert.Len(t, result, len(keys)) // about all we can assert, random may produce duplicates if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } } func TestOverwriteOddAry(t *testing.T) { keys := generateRandomKeys(15) tree := newTree(3, 8) duplicate := mockKey(uint64(keys[0].(mockKey))) result := tree.Insert(keys...) assert.Len(t, result, len(keys)) oldLength := tree.Len() result = tree.Insert(duplicate) assert.Equal(t, Keys{keys[0]}, result) assert.Equal(t, oldLength, tree.Len()) } func TestOverwriteEvenAry(t *testing.T) { keys := generateRandomKeys(15) tree := newTree(12, 8) duplicate := mockKey(uint64(keys[0].(mockKey))) result := tree.Insert(keys...) assert.Len(t, result, len(keys)) oldLength := tree.Len() result = tree.Insert(duplicate) assert.Equal(t, Keys{keys[0]}, result) assert.Equal(t, oldLength, tree.Len()) } func BenchmarkSimpleAdd(b *testing.B) { numItems := 1000 keys := generateRandomKeys(numItems) tree := newTree(16, 8) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(keys[i%numItems]) } } func BenchmarkGet(b *testing.B) { numItems := 1000 keys := generateRandomKeys(numItems) tree := newTree(16, 4) tree.Insert(keys...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Get(keys[i%numItems]) } } func BenchmarkBulkAdd(b *testing.B) { numItems := b.N keys := generateRandomKeys(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree := newTree(64, 1) tree.Insert(keys...) } } ================================================ FILE: btree/immutable/add.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree import ( "runtime" "sort" "sync" terr "github.com/Workiva/go-datastructures/threadsafe/err" ) func (t *Tr) AddItems(its ...*Item) ([]*Item, error) { if len(its) == 0 { return nil, nil } keys := make(Keys, 0, len(its)) for _, item := range its { keys = append(keys, &Key{Value: item.Value, Payload: item.Payload}) } overwrittens, err := t.add(keys) if err != nil { return nil, err } return overwrittens.toItems(), nil } func (t *Tr) add(keys Keys) (Keys, error) { if t.Root == nil { n := t.createRoot() t.Root = n.ID t.context.addNode(n) } nodes, err := t.determinePaths(keys) if err != nil { return nil, err } var overwrittens Keys var wg sync.WaitGroup wg.Add(len(nodes)) var treeLock sync.Mutex localOverwrittens := make([]Keys, len(nodes)) tree := make(map[string]*path, runtime.NumCPU()) lerr := terr.New() i := 0 for id, bundles := range nodes { go func(i int, id string, bundles []*nodeBundle) { defer wg.Done() if len(bundles) == 0 { return } n, err := t.contextOrCachedNode(ID(id), true) if err != nil { lerr.Set(err) return } if !t.context.nodeExists(n.ID) { n = n.copy() t.context.addNode(n) } overwrittens, err := insertLastDimension(t, n, bundles) if err != nil { lerr.Set(err) return } localOverwrittens[i] = overwrittens path := bundles[0].path treeLock.Lock() tree[string(n.ID)] = path treeLock.Unlock() }(i, id, bundles) i++ } wg.Wait() if lerr.Get() != nil { return nil, lerr.Get() } t.walkupInsert(tree) for _, chunk := range localOverwrittens { overwrittens = append(overwrittens, chunk...) } t.Count += len(keys) - len(overwrittens) return overwrittens, nil } func (t *Tr) determinePaths(keys Keys) (map[string][]*nodeBundle, error) { chunks := splitKeys(keys, runtime.NumCPU()) var wg sync.WaitGroup wg.Add(len(chunks)) chunkPaths := make([]map[interface{}]*nodeBundle, len(chunks)) lerr := terr.New() for i := range chunks { go func(i int) { defer wg.Done() keys := chunks[i] if len(keys) == 0 { return } mp := make(map[interface{}]*nodeBundle, len(keys)) for _, key := range keys { path, err := t.iterativeFind( key.Value, t.Root, ) if err != nil { lerr.Set(err) return } mp[key.Value] = &nodeBundle{path: path, k: key} } chunkPaths[i] = mp }(i) } wg.Wait() if lerr.Get() != nil { return nil, lerr.Get() } nodes := make(map[string][]*nodeBundle, 10) for _, chunk := range chunkPaths { for _, pb := range chunk { nodes[string(pb.path.peek().n.ID)] = append(nodes[string(pb.path.pop().n.ID)], pb) } } return nodes, nil } func insertByMerge(comparator Comparator, n *Node, bundles []*nodeBundle) (Keys, error) { positions := make(map[interface{}]int, len(n.ChildValues)) overwrittens := make(Keys, 0, 10) for i, value := range n.ChildValues { positions[value] = i } for _, bundle := range bundles { if i, ok := positions[bundle.k.Value]; ok { overwrittens = append(overwrittens, n.ChildKeys[i]) n.ChildKeys[i] = bundle.k } else { n.ChildValues = append(n.ChildValues, bundle.k.Value) n.ChildKeys = append(n.ChildKeys, bundle.k) } } nsw := &nodeSortWrapper{ values: n.ChildValues, keys: n.ChildKeys, comparator: comparator, } sort.Sort(nsw) for i := 0; i < len(nsw.values); i++ { if nsw.values[i] != nil { nsw.values = nsw.values[i:] nsw.keys = nsw.keys[i:] break } nsw.keys[i] = nil } n.ChildValues = nsw.values n.ChildKeys = nsw.keys return overwrittens, nil } func insertLastDimension(t *Tr, n *Node, bundles []*nodeBundle) (Keys, error) { if n.IsLeaf && len(bundles) >= n.lenValues()/16 { // Found through empirical testing, it appears that the memmoves are more sensitive when dealing with interface{}'s. return insertByMerge(t.config.Comparator, n, bundles) } overwrittens := make(Keys, 0, len(bundles)) for _, bundle := range bundles { overwritten := n.insert(t.config.Comparator, bundle.k) if overwritten != nil { overwrittens = append(overwrittens, overwritten) } } return overwrittens, nil } func (t *Tr) iterativeSplit(n *Node) Keys { keys := make(Keys, 0, 10) for n.needsSplit(t.config.NodeWidth) { leftValue, leftNode := n.splitAt(t.config.NodeWidth / 2) t.context.addNode(leftNode) keys = append(keys, &Key{UUID: leftNode.ID, Value: leftValue}) } return keys } // walkupInsert walks up nodes during the insertion process and adds // any new keys due to splits. Each layer of the tree can have insertions // performed in parallel as splits are local changes. func (t *Tr) walkupInsert(nodes map[string]*path) error { mapping := make(map[string]*Node, len(nodes)) for len(nodes) > 0 { splitNodes := make(map[string]Keys) newNodes := make(map[string]*path) for id, path := range nodes { node := t.context.getNode(ID(id)) parentPath := path.pop() if parentPath == nil { t.Root = node.ID continue } parent := parentPath.n newNode := mapping[string(parent.ID)] if newNode == nil { if !t.context.nodeExists(parent.ID) { cp := parent.copy() if string(t.Root) == string(parent.ID) { t.Root = cp.ID } t.context.addNode(cp) mapping[string(parent.ID)] = cp parent = cp } else { newNode = t.context.getNode(parent.ID) mapping[string(parent.ID)] = newNode parent = newNode } } else { parent = newNode } i := parentPath.i parent.replaceKeyAt(&Key{UUID: node.ID}, i) splitNodes[string(parent.ID)] = append(splitNodes[string(parent.ID)], t.iterativeSplit(node)...) newNodes[string(parent.ID)] = path } var wg sync.WaitGroup wg.Add(len(splitNodes)) lerr := terr.New() for id, keys := range splitNodes { go func(id ID, keys Keys) { defer wg.Done() n, err := t.contextOrCachedNode(id, true) if err != nil { lerr.Set(err) return } for _, key := range keys { n.insert(t.config.Comparator, key) } }(ID(id), keys) } wg.Wait() if lerr.Get() != nil { return lerr.Get() } nodes = newNodes } n := t.context.getNode(t.Root) for n.needsSplit(t.config.NodeWidth) { root := newNode() t.Root = root.ID t.context.addNode(root) root.appendChild(&Key{UUID: n.ID}) keys := t.iterativeSplit(n) for _, key := range keys { root.insert(t.config.Comparator, key) } n = root } return nil } ================================================ FILE: btree/immutable/cacher.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree import ( "sync" "time" "github.com/Workiva/go-datastructures/futures" ) // cacher provides a convenient construct for retrieving, // storing, and caching nodes; basically wrapper persister with a caching layer. // This ensures that we don't have to constantly // run to the persister to fetch nodes we are using over and over again. // TODO: this should probably evict items from the cache if the cache gets // too full. type cacher struct { lock sync.Mutex cache map[string]*futures.Future persister Persister } func (c *cacher) asyncLoadNode(t *Tr, key ID, completer chan interface{}) { n, err := c.loadNode(t, key) if err != nil { completer <- err return } if n == nil { completer <- ErrNodeNotFound return } completer <- n } // clear deletes all items from the cache. func (c *cacher) clear() { c.lock.Lock() defer c.lock.Unlock() c.cache = make(map[string]*futures.Future, 10) } // deleteFromCache will remove the provided ID from the cache. This // is a threadsafe operation. func (c *cacher) deleteFromCache(id ID) { c.lock.Lock() defer c.lock.Unlock() delete(c.cache, string(id)) } func (c *cacher) loadNode(t *Tr, key ID) (*Node, error) { items, err := c.persister.Load(key) if err != nil { return nil, err } n, err := nodeFromBytes(t, items[0].Payload) if err != nil { return nil, err } return n, nil } // getNode will return a Node matching the provided id. An error is returned // if the cacher could not go to the persister or the node could not be found. // All found nodes are cached so subsequent calls should be faster than // the initial. This blocks until the node is loaded, but is also threadsafe. func (c *cacher) getNode(t *Tr, key ID, useCache bool) (*Node, error) { if !useCache { return c.loadNode(t, key) } c.lock.Lock() future, ok := c.cache[string(key)] if ok { c.lock.Unlock() ifc, err := future.GetResult() if err != nil { return nil, err } return ifc.(*Node), nil } completer := make(chan interface{}, 1) future = futures.New(completer, 30*time.Second) c.cache[string(key)] = future c.lock.Unlock() go c.asyncLoadNode(t, key, completer) ifc, err := future.GetResult() if err != nil { c.deleteFromCache(key) return nil, err } if err, ok := ifc.(error); ok { c.deleteFromCache(key) return nil, err } return ifc.(*Node), nil } // newCacher is the constructor for a cacher that caches nodes for // an indefinite period of time. func newCacher(persister Persister) *cacher { return &cacher{ persister: persister, cache: make(map[string]*futures.Future, 10), } } ================================================ FILE: btree/immutable/config.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree // Config defines all the parameters available to the UB-tree. // Of most important are nodewidth and the persister to be used // during commit phase. type Config struct { // NodeWidth defines the branching factor of the tree. Any node // wider than this value will get split and when the width of a node // falls to less than half this value the node gets merged. This // ensures optimal performance while running to the key value store. NodeWidth int // Perister defines the key value store that the tree can use to // save and load nodes. Persister Persister // Comparator is the function used to determine ordering. Comparator Comparator `msg:"-"` } // DefaultConfig returns a configuration with the persister set. All other // fields are set to smart defaults for persistence. func DefaultConfig(persister Persister, comparator Comparator) Config { return Config{ NodeWidth: 10000, Persister: persister, Comparator: comparator, } } ================================================ FILE: btree/immutable/delete.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree import "bytes" func (t *Tr) DeleteItems(values ...interface{}) ([]*Item, error) { if len(values) == 0 { return nil, nil } keys := make(Keys, 0, len(values)) err := t.Apply(func(item *Item) { keys = append(keys, &Key{Value: item.Value, Payload: item.Payload}) }, values...) if err != nil { return nil, err } // we need to sort the keys to ensure that a multidelete // distributes deletes across a single node correctly keys = keys.sort(t.config.Comparator) err = t.delete(keys) if err != nil { return nil, err } t.Count -= len(keys) return keys.toItems(), nil } func (t *Tr) delete(keys Keys) error { if len(keys) == 0 { return nil } toDelete := make([]*Key, 0, len(keys)) for i := 0; i < len(keys); { key := keys[i] mapping := make(map[string]*Node, 10) path, err := t.iterativeFind(key.Value, t.Root) if err != nil { return err } pb := path.peek() node := pb.n isRoot := bytes.Compare(node.ID, t.Root) == 0 if !t.context.nodeExists(node.ID) { cp := node.copy() t.context.addNode(cp) mapping[string(node.ID)] = cp node = cp } base := node toDelete = append(toDelete, key) for j := i + 1; j <= len(keys); j++ { i = j if j == len(keys) { break } neighbor := keys[j] if t.config.Comparator(neighbor.Value, node.lastValue()) <= 0 { toDelete = append(toDelete, neighbor) } else { break } } if len(toDelete) > len(node.ChildValues)/4 { node.multiDelete(t.config.Comparator, toDelete...) } else { for _, k := range toDelete { node.delete(t.config.Comparator, k) } } toDelete = toDelete[:0] if isRoot { t.Root = node.ID continue } for pb.prev != nil { parentBundle := pb.prev parent := parentBundle.n isRoot := bytes.Compare(parent.ID, t.Root) == 0 if !t.context.nodeExists(parent.ID) { cp := parent.copy() t.context.addNode(cp) mapping[string(parent.ID)] = cp parent = cp } else { mapping[string(parent.ID)] = parent } if isRoot { t.Root = parent.ID } i := pb.prev.i parent.replaceKeyAt(&Key{UUID: node.ID}, i) node = parent pb = pb.prev } path.pop() err = t.walkupDelete(key, base, path, mapping) if err != nil { return err } } n := t.context.getNode(t.Root) if n.lenValues() == 0 { t.Root = nil } return nil } // walkupDelete is similar to walkupInsert but is only done one at a time. // This is because deletes can cause borrowing or merging with neighbors which makes // the changes non-local. // TODO: come up with a good way to parallelize this. func (t *Tr) walkupDelete(key *Key, node *Node, path *path, mapping map[string]*Node) error { needsMerged := t.config.NodeWidth / 2 if needsMerged < 1 { needsMerged = 1 } if node.lenValues() >= needsMerged { return nil } if string(node.ID) == string(t.Root) { if node.lenKeys() == 1 { id := node.keyAt(0) t.Root = id.UUID } return nil } var getSibling = func(parent *Node, i int) (*Node, error) { key := parent.keyAt(i) n, err := t.contextOrCachedNode(key.UUID, true) if err != nil { return nil, err } if !t.context.nodeExists(n.ID) { cp := t.copyNode(n) mapping[string(n.ID)] = cp parent.replaceKeyAt(&Key{UUID: cp.ID}, i) n = cp } return n, nil } parentBundle := path.pop() parent := mapping[string(parentBundle.n.ID)] _, i := parent.searchKey(t.config.Comparator, key.Value) siblingPosition := i if i == parent.lenValues() { siblingPosition-- } else { siblingPosition++ } sibling, err := getSibling(parent, siblingPosition) if err != nil { return err } prepend := false // thing are just easier if we make this swap so we can grok // left to right always assuming node is on the left and sibling // is on the right if siblingPosition < i { node, sibling = sibling, node prepend = true } // first case, can we just borrow? if so, simply shift values from one node // to the other. Once done, replace the parent value with the middle value // shifted and return. if (sibling.lenValues()+node.lenValues())/2 >= needsMerged { if i == parent.lenValues() { i-- } var key *Key var value interface{} for node.lenValues() < needsMerged || sibling.lenValues() < needsMerged { if prepend { correctedValue, key := node.popValue(), node.popKey() if node.IsLeaf { sibling.prependValue(correctedValue) sibling.prependKey(key) parent.replaceValueAt(i, node.lastValue()) } else { parentValue := parent.valueAt(i) sibling.prependKey(key) sibling.prependValue(parentValue) parent.replaceValueAt(i, correctedValue) } } else { value, key = sibling.popFirstValue(), sibling.popFirstKey() correctedValue := value if !node.IsLeaf { correctedValue = parent.valueAt(i) } node.appendValue(correctedValue) node.appendChild(key) parent.replaceValueAt(i, value) } } return nil } // the harder case, we need to merge with sibling, pull a value down // from the parent, and recurse on this function // easier case, merge the nodes and delete value and child from parent if node.IsLeaf { node.append(sibling) if prepend { parent.deleteKeyAt(i) } else { parent.deleteKeyAt(i + 1) } if i == parent.lenValues() { i-- } parent.deleteValueAt(i) return t.walkupDelete(key, parent, path, mapping) } // harder case, need to pull a value down from the parent, insert // value into the left node, append the nodes, and then delete // the value from the parent valueIndex := i if i == parent.lenValues() { valueIndex-- } parentValue := parent.valueAt(valueIndex) node.appendValue(parentValue) node.append(sibling) parent.deleteKeyAt(i) parent.deleteValueAt(valueIndex) parent.replaceKeyAt(&Key{UUID: node.ID}, valueIndex) return t.walkupDelete(key, parent, path, mapping) } ================================================ FILE: btree/immutable/error.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree import "errors" // ErrNodeNotFound is returned when the cacher could not find a node. var ErrNodeNotFound = errors.New(`node not found`) // ErrTreeNotFound is returned when a tree with the provided key could // not be loaded. var ErrTreeNotFound = errors.New(`tree not found`) ================================================ FILE: btree/immutable/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package btree provides a very specific set implementation for k/v lookup. This is based on a B PALM tree as described here: http://irvcvcs01.intel-research.net/publications/palm.pdf This tree is best interacted with in batches. Insertions and deletions are optimized for dealing with large amounts of data. Future work includes: 1) Optimization 2) Range scans Usage: rt := New(config) mutable := rt.AsMutable() ... operations rt, err := mutable.Commit() // saves all mutated nodes .. rt reading/operations Once a mutable has been committed, its further operations are undefined. */ package btree // Tree describes the common functionality of both the read-only and mutable // forms of a btree. type Tree interface { // Apply takes a range and applies the provided function to every value // in that range in order. If a key could not be found, it is // skipped. Apply(fn func(item *Item), keys ...interface{}) error // ID returns the identifier for this tree. ID() ID // Len returns the number of items in the tree. Len() int } // ReadableTree represents the operations that can be performed on a read-only // version of the tree. All reads of the readable tree are threadsafe and // an indefinite number of mutable trees can be created from a single readable // tree with the caveat that no mutable trees reflect any mutations to any other // mutable tree. type ReadableTree interface { Tree // AsMutable returns a mutable version of this tree. The mutable version // has common mutations and you can create as many mutable versions of this // tree as you'd like. However, the returned mutable is not threadsafe. AsMutable() MutableTree } // MutableTree represents a mutable version of the btree. This interface // is not threadsafe. type MutableTree interface { Tree // Commit commits all mutated nodes to persistence and returns a // read-only version of this tree. An error is returned if nodes // could not be committed to persistence. Commit() (ReadableTree, error) // AddItems adds the provided items to the btree. Any existing items // are overwritten. An error is returned if the tree could not be // traversed due to an error in the persistence layer. AddItems(items ...*Item) ([]*Item, error) // DeleteItems removes all provided keys and returns them. // An error is returned if the tree could not be traversed. DeleteItems(keys ...interface{}) ([]*Item, error) } // Comparator is used to determine ordering in the tree. If item1 // is less than item2, a negative number should be returned and // vice versa. If equal, 0 should be returned. type Comparator func(item1, item2 interface{}) int // Payload is very basic and simply contains a key and a payload. type Payload struct { Key []byte Payload []byte } // Perister describes the interface of the different implementations. // Given that we expect that datastrutures are immutable, we never // have the need to delete. type Persister interface { Save(items ...*Payload) error Load(keys ...[]byte) ([]*Payload, error) } ================================================ FILE: btree/immutable/item.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree type Item struct { Value interface{} Payload []byte } type items []*Item func (its items) split(numParts int) []items { parts := make([]items, numParts) for i := int64(0); i < int64(numParts); i++ { parts[i] = its[i*int64(len(its))/int64(numParts) : (i+1)*int64(len(its))/int64(numParts)] } return parts } ================================================ FILE: btree/immutable/node.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ //go:generate msgp -tests=false -io=false package btree import ( "crypto/rand" "io" "sort" ) func newID() []byte { id := make([]byte, 16) _, err := io.ReadFull(rand.Reader, id) if err != nil { // This can't happen unless the system is badly // configured, like /dev/urandom isn't readable. panic("reading random: " + err.Error()) } return id } // ID exists because i'm tired of writing []byte type ID []byte // Key a convenience struct that holds both an id and a value. Internally, // this is how we reference items in nodes but consumers interface with // the tree using row/col/id. type Key struct { UUID ID `msg:"u"` Value interface{} `msg:"v"` Payload []byte `msg:"p"` } // ID returns the unique identifier. func (k Key) ID() []byte { return k.UUID[:16] // to maintain backwards compatibility } func (k Key) ToItem() *Item { return &Item{ Value: k.Value, Payload: k.Payload, } } type Keys []*Key func (k Keys) toItems() items { items := make(items, 0, len(k)) for _, key := range k { items = append(items, key.ToItem()) } return items } func (k Keys) sort(comparator Comparator) Keys { return (&keySortWrapper{comparator, k}).sort() } type keySortWrapper struct { comparator Comparator keys Keys } func (sw *keySortWrapper) Len() int { return len(sw.keys) } func (sw *keySortWrapper) Swap(i, j int) { sw.keys[i], sw.keys[j] = sw.keys[j], sw.keys[i] } func (sw *keySortWrapper) Less(i, j int) bool { return sw.comparator(sw.keys[i].Value, sw.keys[j].Value) < 0 } func (sw *keySortWrapper) sort() Keys { sort.Sort(sw) return sw.keys } func splitKeys(keys Keys, numParts int) []Keys { parts := make([]Keys, numParts) for i := int64(0); i < int64(numParts); i++ { parts[i] = keys[i*int64(len(keys))/int64(numParts) : (i+1)*int64(len(keys))/int64(numParts)] } return parts } // Node represents either a leaf node or an internal node. These // are the value containers. This is exported because code generation // requires it. Only exported fields are required to be persisted. We // use msgpack for optimal performance. type Node struct { // ID is the unique UUID that addresses this singular node. ID ID `msg:"id"` // IsLeaf is a bool indicating if this is a leaf node as opposed // to an internal node. The primary difference between these nodes // is that leaf nodes have an equal number of values and IDs while // internal nodes have n+1 ids. IsLeaf bool `msg:"il"` // ChildValues is only a temporary field that is used to house all // values for serialization purposes. ChildValues []interface{} `msg:"cv"` // ChildKeys is similar to child values but holds the IDs of children. ChildKeys Keys `msg:"ck"` } // copy makes a deep copy of this node. Required before any mutation. func (n *Node) copy() *Node { cpValues := make([]interface{}, len(n.ChildValues)) copy(cpValues, n.ChildValues) cpKeys := make(Keys, len(n.ChildKeys)) copy(cpKeys, n.ChildKeys) return &Node{ ID: newID(), IsLeaf: n.IsLeaf, ChildValues: cpValues, ChildKeys: cpKeys, } } // searchKey returns the key associated with the provided value. If the // provided value is greater than the highest value in this node and this // node is an internal node, this method returns the last ID and an index // equal to lenValues. func (n *Node) searchKey(comparator Comparator, value interface{}) (*Key, int) { i := n.search(comparator, value) if n.IsLeaf && i == len(n.ChildValues) { // not found return nil, i } if n.IsLeaf { // equal number of ids and values return n.ChildKeys[i], i } if i == len(n.ChildValues) { // we need to go to the farthest node to the write return n.ChildKeys[len(n.ChildKeys)-1], i } return n.ChildKeys[i], i } // insert adds the provided key to this node and returns any ID that has // been overwritten. This method should only be called on leaf nodes. func (n *Node) insert(comparator Comparator, key *Key) *Key { var overwrittenKey *Key i := n.search(comparator, key.Value) if i == len(n.ChildValues) { n.ChildValues = append(n.ChildValues, key.Value) } else { if n.ChildValues[i] == key.Value { overwrittenKey = n.ChildKeys[i] n.ChildKeys[i] = key return overwrittenKey } else { n.ChildValues = append(n.ChildValues, 0) copy(n.ChildValues[i+1:], n.ChildValues[i:]) n.ChildValues[i] = key.Value } } if n.IsLeaf && i == len(n.ChildKeys) { n.ChildKeys = append(n.ChildKeys, key) } else { n.ChildKeys = append(n.ChildKeys, nil) copy(n.ChildKeys[i+1:], n.ChildKeys[i:]) n.ChildKeys[i] = key } return overwrittenKey } // delete removes the provided key from the node and returns any key that // was deleted. Returns nil of the key could not be found. func (n *Node) delete(comparator Comparator, key *Key) *Key { i := n.search(comparator, key.Value) if i == len(n.ChildValues) { return nil } n.deleteValueAt(i) n.deleteKeyAt(i) return key } func (n *Node) multiDelete(comparator Comparator, keys ...*Key) { indices := make([]int, 0, len(keys)) for _, k := range keys { i := n.search(comparator, k.Value) if i < len(n.ChildValues) { indices = append(indices, i) } } for _, i := range indices { n.ChildValues[i] = nil n.ChildKeys[i] = nil } if len(indices) == len(n.ChildValues) { n.ChildKeys = n.ChildKeys[:0] n.ChildValues = n.ChildValues[:0] return } // get the indices in the correct order for the next stage // which is removing the nils sort.Ints(indices) // iterate through the list moving all values up to overwrite the // nils and place all nils at the "back" for i, j := range indices { index := j - i // correct for previous copies copy(n.ChildValues[index:], n.ChildValues[index+1:]) copy(n.ChildKeys[index:], n.ChildKeys[index+1:]) } n.ChildValues = n.ChildValues[:len(n.ChildValues)-len(indices)] n.ChildKeys = n.ChildKeys[:len(n.ChildKeys)-len(indices)] } // replaceKeyAt replaces the key at index i with the provided id. This does // not do any bounds checking. func (n *Node) replaceKeyAt(key *Key, i int) { n.ChildKeys[i] = key } // flatten returns a flattened list of values and IDs. Useful for serialization. func (n *Node) flatten() ([]interface{}, Keys) { return n.ChildValues, n.ChildKeys } // iter returns an iterator that will iterate through the provided Morton // numbers as they exist in this node. func (n *Node) iter(comparator Comparator, start, stop interface{}) iterator { pointer := n.search(comparator, start) pointer-- return &sliceIterator{ stop: stop, n: n, pointer: pointer, comparator: comparator, } } func (n *Node) valueAt(i int) interface{} { return n.ChildValues[i] } func (n *Node) keyAt(i int) *Key { return n.ChildKeys[i] } func (n *Node) needsSplit(max int) bool { return n.lenValues() > max } func (n *Node) lastValue() interface{} { return n.ChildValues[len(n.ChildValues)-1] } func (n *Node) firstValue() interface{} { return n.ChildValues[0] } func (n *Node) append(other *Node) { n.ChildValues = append(n.ChildValues, other.ChildValues...) n.ChildKeys = append(n.ChildKeys, other.ChildKeys...) } func (n *Node) replaceValueAt(i int, value interface{}) { n.ChildValues[i] = value } func (n *Node) deleteValueAt(i int) { copy(n.ChildValues[i:], n.ChildValues[i+1:]) n.ChildValues[len(n.ChildValues)-1] = 0 // or the zero value of T n.ChildValues = n.ChildValues[:len(n.ChildValues)-1] } func (n *Node) deleteKeyAt(i int) { copy(n.ChildKeys[i:], n.ChildKeys[i+1:]) n.ChildKeys[len(n.ChildKeys)-1] = nil // or the zero value of T n.ChildKeys = n.ChildKeys[:len(n.ChildKeys)-1] } func (n *Node) splitLeafAt(i int) (interface{}, *Node) { left := newNode() left.IsLeaf = n.IsLeaf left.ID = newID() value := n.ChildValues[i] leftValues := make([]interface{}, i+1) copy(leftValues, n.ChildValues[:i+1]) n.ChildValues = n.ChildValues[i+1:] leftKeys := make(Keys, i+1) copy(leftKeys, n.ChildKeys[:i+1]) for j := 0; j <= i; j++ { n.ChildKeys[j] = nil } n.ChildKeys = n.ChildKeys[i+1:] left.ChildValues = leftValues left.ChildKeys = leftKeys return value, left } // splitInternalAt is a method that generates a new set of children // for an internal node and returns the new set and the value that // separates them. func (n *Node) splitInternalAt(i int) (interface{}, *Node) { left := newNode() left.IsLeaf = n.IsLeaf left.ID = newID() value := n.ChildValues[i] leftValues := make([]interface{}, i) copy(leftValues, n.ChildValues[:i]) n.ChildValues = n.ChildValues[i+1:] leftKeys := make(Keys, i+1) copy(leftKeys, n.ChildKeys[:i+1]) for j := 0; j <= i; j++ { n.ChildKeys[j] = nil } n.ChildKeys = n.ChildKeys[i+1:] left.ChildKeys = leftKeys left.ChildValues = leftValues return value, left } // splitAt breaks this node into two parts and conceptually // returns the left part func (n *Node) splitAt(i int) (interface{}, *Node) { if n.IsLeaf { return n.splitLeafAt(i) } return n.splitInternalAt(i) } func (n *Node) lenKeys() int { return len(n.ChildKeys) } func (n *Node) lenValues() int { return len(n.ChildValues) } func (n *Node) appendChild(key *Key) { n.ChildKeys = append(n.ChildKeys, key) } func (n *Node) appendValue(value interface{}) { n.ChildValues = append(n.ChildValues, value) } func (n *Node) popFirstKey() *Key { key := n.ChildKeys[0] n.deleteKeyAt(0) return key } func (n *Node) popFirstValue() interface{} { value := n.ChildValues[0] n.deleteValueAt(0) return value } func (n *Node) popKey() *Key { key := n.ChildKeys[len(n.ChildKeys)-1] n.deleteKeyAt(len(n.ChildKeys) - 1) return key } func (n *Node) popValue() interface{} { value := n.ChildValues[len(n.ChildValues)-1] n.deleteValueAt(len(n.ChildValues) - 1) return value } func (n *Node) prependKey(key *Key) { n.ChildKeys = append(n.ChildKeys, nil) copy(n.ChildKeys[1:], n.ChildKeys) n.ChildKeys[0] = key } func (n *Node) prependValue(value interface{}) { n.ChildValues = append(n.ChildValues, nil) copy(n.ChildValues[1:], n.ChildValues) n.ChildValues[0] = value } func (n *Node) search(comparator Comparator, value interface{}) int { return sort.Search(len(n.ChildValues), func(i int) bool { return comparator(n.ChildValues[i], value) >= 0 }) } // nodeFromBytes returns a new node struct deserialized from the provided // bytes. An error is returned for any deserialization errors. func nodeFromBytes(t *Tr, data []byte) (*Node, error) { n := &Node{} _, err := n.UnmarshalMsg(data) if err != nil { panic(err) return nil, err } return n, nil } // newNode returns a node with a random id and empty values and children. // IsLeaf is false by default. func newNode() *Node { return &Node{ ID: newID(), } } type sliceIterator struct { stop interface{} n *Node pointer int comparator Comparator } func (s *sliceIterator) next() bool { s.pointer++ if s.n.IsLeaf { return s.pointer < len(s.n.ChildValues) && s.comparator(s.stop, s.n.ChildValues[s.pointer]) >= 0 } else { if s.pointer >= len(s.n.ChildKeys) { return false } if s.pointer == len(s.n.ChildValues) { return true } if s.comparator(s.stop, s.n.ChildValues[s.pointer]) < 0 { return false } } return true } func (s *sliceIterator) value() (*Key, int) { return s.n.ChildKeys[s.pointer], s.pointer } type iterator interface { next() bool value() (*Key, int) } type nodeBundle struct { path *path k *Key } type nodeSortWrapper struct { values []interface{} keys Keys comparator Comparator } func (n *nodeSortWrapper) Len() int { return len(n.values) } func (n *nodeSortWrapper) Swap(i, j int) { n.values[i], n.values[j] = n.values[j], n.values[i] n.keys[i], n.keys[j] = n.keys[j], n.keys[i] } func (n *nodeSortWrapper) Less(i, j int) bool { return n.comparator(n.values[i], n.values[j]) < 0 } func splitValues(values []interface{}, numParts int) [][]interface{} { parts := make([][]interface{}, numParts) for i := int64(0); i < int64(numParts); i++ { parts[i] = values[i*int64(len(values))/int64(numParts) : (i+1)*int64(len(values))/int64(numParts)] } return parts } ================================================ FILE: btree/immutable/node_gen.go ================================================ package btree // NOTE: THIS FILE WAS PRODUCED BY THE // MSGP CODE GENERATION TOOL (github.com/tinylib/msgp) // DO NOT EDIT import ( "github.com/tinylib/msgp/msgp" ) // MarshalMsg implements msgp.Marshaler func (z ID) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) o = msgp.AppendBytes(o, []byte(z)) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *ID) UnmarshalMsg(bts []byte) (o []byte, err error) { { var tmp []byte tmp, bts, err = msgp.ReadBytesBytes(bts, []byte((*z))) (*z) = ID(tmp) } if err != nil { return } o = bts return } func (z ID) Msgsize() (s int) { s = msgp.BytesPrefixSize + len([]byte(z)) return } // MarshalMsg implements msgp.Marshaler func (z *Key) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 3 // string "u" o = append(o, 0x83, 0xa1, 0x75) o = msgp.AppendBytes(o, []byte(z.UUID)) // string "v" o = append(o, 0xa1, 0x76) o, err = msgp.AppendIntf(o, z.Value) if err != nil { return } // string "p" o = append(o, 0xa1, 0x70) o = msgp.AppendBytes(o, z.Payload) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *Key) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var isz uint32 isz, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { return } for isz > 0 { isz-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { return } switch msgp.UnsafeString(field) { case "u": { var tmp []byte tmp, bts, err = msgp.ReadBytesBytes(bts, []byte(z.UUID)) z.UUID = ID(tmp) } if err != nil { return } case "v": z.Value, bts, err = msgp.ReadIntfBytes(bts) if err != nil { return } case "p": z.Payload, bts, err = msgp.ReadBytesBytes(bts, z.Payload) if err != nil { return } default: bts, err = msgp.Skip(bts) if err != nil { return } } } o = bts return } func (z *Key) Msgsize() (s int) { s = 1 + 2 + msgp.BytesPrefixSize + len([]byte(z.UUID)) + 2 + msgp.GuessSize(z.Value) + 2 + msgp.BytesPrefixSize + len(z.Payload) return } // MarshalMsg implements msgp.Marshaler func (z Keys) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) o = msgp.AppendArrayHeader(o, uint32(len(z))) for xvk := range z { if z[xvk] == nil { o = msgp.AppendNil(o) } else { o, err = z[xvk].MarshalMsg(o) if err != nil { return } } } return } // UnmarshalMsg implements msgp.Unmarshaler func (z *Keys) UnmarshalMsg(bts []byte) (o []byte, err error) { var xsz uint32 xsz, bts, err = msgp.ReadArrayHeaderBytes(bts) if err != nil { return } if cap((*z)) >= int(xsz) { (*z) = (*z)[:xsz] } else { (*z) = make(Keys, xsz) } for bzg := range *z { if msgp.IsNil(bts) { bts, err = msgp.ReadNilBytes(bts) if err != nil { return } (*z)[bzg] = nil } else { if (*z)[bzg] == nil { (*z)[bzg] = new(Key) } bts, err = (*z)[bzg].UnmarshalMsg(bts) if err != nil { return } } } o = bts return } func (z Keys) Msgsize() (s int) { s = msgp.ArrayHeaderSize for bai := range z { if z[bai] == nil { s += msgp.NilSize } else { s += z[bai].Msgsize() } } return } // MarshalMsg implements msgp.Marshaler func (z *Node) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 4 // string "id" o = append(o, 0x84, 0xa2, 0x69, 0x64) o = msgp.AppendBytes(o, []byte(z.ID)) // string "il" o = append(o, 0xa2, 0x69, 0x6c) o = msgp.AppendBool(o, z.IsLeaf) // string "cv" o = append(o, 0xa2, 0x63, 0x76) o = msgp.AppendArrayHeader(o, uint32(len(z.ChildValues))) for cmr := range z.ChildValues { o, err = msgp.AppendIntf(o, z.ChildValues[cmr]) if err != nil { return } } // string "ck" o = append(o, 0xa2, 0x63, 0x6b) o = msgp.AppendArrayHeader(o, uint32(len(z.ChildKeys))) for ajw := range z.ChildKeys { if z.ChildKeys[ajw] == nil { o = msgp.AppendNil(o) } else { o, err = z.ChildKeys[ajw].MarshalMsg(o) if err != nil { return } } } return } // UnmarshalMsg implements msgp.Unmarshaler func (z *Node) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var isz uint32 isz, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { return } for isz > 0 { isz-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { return } switch msgp.UnsafeString(field) { case "id": { var tmp []byte tmp, bts, err = msgp.ReadBytesBytes(bts, []byte(z.ID)) z.ID = ID(tmp) } if err != nil { return } case "il": z.IsLeaf, bts, err = msgp.ReadBoolBytes(bts) if err != nil { return } case "cv": var xsz uint32 xsz, bts, err = msgp.ReadArrayHeaderBytes(bts) if err != nil { return } if cap(z.ChildValues) >= int(xsz) { z.ChildValues = z.ChildValues[:xsz] } else { z.ChildValues = make([]interface{}, xsz) } for cmr := range z.ChildValues { z.ChildValues[cmr], bts, err = msgp.ReadIntfBytes(bts) if err != nil { return } } case "ck": var xsz uint32 xsz, bts, err = msgp.ReadArrayHeaderBytes(bts) if err != nil { return } if cap(z.ChildKeys) >= int(xsz) { z.ChildKeys = z.ChildKeys[:xsz] } else { z.ChildKeys = make(Keys, xsz) } for ajw := range z.ChildKeys { if msgp.IsNil(bts) { bts, err = msgp.ReadNilBytes(bts) if err != nil { return } z.ChildKeys[ajw] = nil } else { if z.ChildKeys[ajw] == nil { z.ChildKeys[ajw] = new(Key) } bts, err = z.ChildKeys[ajw].UnmarshalMsg(bts) if err != nil { return } } } default: bts, err = msgp.Skip(bts) if err != nil { return } } } o = bts return } func (z *Node) Msgsize() (s int) { s = 1 + 3 + msgp.BytesPrefixSize + len([]byte(z.ID)) + 3 + msgp.BoolSize + 3 + msgp.ArrayHeaderSize for cmr := range z.ChildValues { s += msgp.GuessSize(z.ChildValues[cmr]) } s += 3 + msgp.ArrayHeaderSize for ajw := range z.ChildKeys { if z.ChildKeys[ajw] == nil { s += msgp.NilSize } else { s += z.ChildKeys[ajw].Msgsize() } } return } ================================================ FILE: btree/immutable/path.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree /* This file contains logic pertaining to keeping track of the path followed to find a particular node while descending the tree. */ type pathBundle struct { // i defines the child index of the n. i int n *Node prev *pathBundle } // path is simply a linked list of pathBundles. We only ever // go in one direction and there's no need to search so a linked list // makes sense. type path struct { head *pathBundle tail *pathBundle } func (p *path) append(pb *pathBundle) { if p.head == nil { p.head = pb p.tail = pb return } pb.prev = p.tail p.tail = pb } // pop removes the last item from the path. Note that it also nils // out the returned pathBundle's prev field. Returns nil if no items // remain. func (p *path) pop() *pathBundle { if pb := p.tail; pb != nil { p.tail = pb.prev pb.prev = nil return pb } return nil } func (p *path) peek() *pathBundle { return p.tail } ================================================ FILE: btree/immutable/query.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree import ( "runtime" "sync" terr "github.com/Workiva/go-datastructures/threadsafe/err" ) func (t *Tr) Apply(fn func(item *Item), keys ...interface{}) error { if t.Root == nil || len(keys) == 0 { return nil } positions := make(map[interface{}]int, len(keys)) for i, key := range keys { positions[key] = i } chunks := splitValues(keys, runtime.NumCPU()) var wg sync.WaitGroup wg.Add(len(chunks)) lerr := terr.New() result := make(Keys, len(keys)) for i := 0; i < len(chunks); i++ { go func(i int) { defer wg.Done() chunk := chunks[i] if len(chunk) == 0 { return } for _, value := range chunk { n, _, err := t.iterativeFindWithoutPath(value, t.Root) if err != nil { lerr.Set(err) return } if n == nil { continue } k, _ := n.searchKey(t.config.Comparator, value) if k != nil && t.config.Comparator(k.Value, value) == 0 { result[positions[value]] = k } } }(i) } wg.Wait() if lerr.Get() != nil { return lerr.Get() } for _, k := range result { if k == nil { continue } item := k.ToItem() fn(item) } return nil } // filter performs an after fetch filtering of the values in the provided node. // Due to the nature of the UB-Tree, we may get results in the node that // aren't in the provided range. The returned list of keys is not necessarily // in the correct row-major order. func (t *Tr) filter(start, stop interface{}, n *Node, fn func(key *Key) bool) bool { for iter := n.iter(t.config.Comparator, start, stop); iter.next(); { id, _ := iter.value() if !fn(id) { return false } } return true } func (t *Tr) iter(start, stop interface{}, fn func(*Key) bool) error { if len(t.Root) == 0 { return nil } cur := start seen := make(map[string]struct{}, 10) for t.config.Comparator(stop, cur) > 0 { n, highestValue, err := t.iterativeFindWithoutPath(cur, t.Root) if err != nil { return err } if n == nil && highestValue == nil { break } else if n != nil { if _, ok := seen[string(n.ID)]; ok { break } if !t.filter(cur, stop, n, fn) { break } } cur = n.lastValue() seen[string(n.ID)] = struct{}{} } return nil } // iterativeFind searches for the node with the provided value. This // is an iterative function and returns an error if there was a problem // with persistence. func (t *Tr) iterativeFind(value interface{}, id ID) (*path, error) { if len(id) == 0 { // can't find a matching node return nil, nil } path := &path{} var n *Node var err error var i int var key *Key for { n, err = t.contextOrCachedNode(id, t.mutable) if err != nil { return nil, err } key, i = n.searchKey(t.config.Comparator, value) pb := &pathBundle{i: i, n: n} path.append(pb) if n.IsLeaf { return path, nil } id = key.ID() } return path, nil } func (t *Tr) iterativeFindWithoutPath(value interface{}, id ID) (*Node, interface{}, error) { var n *Node var err error var i int var key *Key var highestValue interface{} for { n, err = t.contextOrCachedNode(id, t.mutable) if err != nil { return nil, highestValue, err } if n.IsLeaf { if t.config.Comparator(n.lastValue(), value) < 0 { return nil, highestValue, nil } highestValue = n.lastValue() return n, highestValue, nil } key, i = n.searchKey(t.config.Comparator, value) if i < n.lenValues() { highestValue = n.valueAt(i) } id = key.ID() } return n, highestValue, nil } ================================================ FILE: btree/immutable/rt.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ //go:generate msgp -tests=false -io=false package btree import "sync" // context is used to keep track of the nodes in this mutable // that have been created. This is basically any node that had // to be touched to perform mutations. Further mutations will visit // this context first so we don't have to constantly copy if // we don't need to. type context struct { lock sync.RWMutex seenNodes map[string]*Node } func (c *context) nodeExists(id ID) bool { c.lock.RLock() defer c.lock.RUnlock() _, ok := c.seenNodes[string(id)] return ok } func (c *context) addNode(n *Node) { c.lock.Lock() defer c.lock.Unlock() c.seenNodes[string(n.ID)] = n } func (c *context) getNode(id ID) *Node { c.lock.RLock() defer c.lock.RUnlock() return c.seenNodes[string(id)] } func newContext() *context { return &context{ seenNodes: make(map[string]*Node, 10), } } // Tr itself is exported so that the code generated for serialization/deserialization // works on Tr. Exported fields on Tr are those fields that need to be // serialized. type Tr struct { UUID ID `msg:"u"` Count int `msg:"c"` config Config Root ID `msg:"r"` cacher *cacher context *context NodeWidth int `msg:"nw"` mutable bool } func (t *Tr) createRoot() *Node { n := newNode() n.IsLeaf = true return n } // contextOrCachedNode is a convenience function for either fetching // a node from the context or persistence. func (t *Tr) contextOrCachedNode(id ID, cache bool) (*Node, error) { if t.context != nil { n := t.context.getNode(id) if n != nil { return n, nil } } return t.cacher.getNode(t, id, cache) } func (t *Tr) ID() ID { return t.UUID } // toBytes encodes this tree into a byte array. Panics if unable // as this error has to be fixed in code. func (t *Tr) toBytes() []byte { buf, err := t.MarshalMsg(nil) if err != nil { panic(`unable to encode tree`) } return buf } // reset is called on a tree to empty the context and clear the cache. func (t *Tr) reset() { t.cacher.clear() t.context = nil } // commit will gather up all created nodes and serialize them into // items that can be persisted. func (t *Tr) commit() []*Payload { items := make([]*Payload, 0, len(t.context.seenNodes)) for _, n := range t.context.seenNodes { n.ChildValues, n.ChildKeys = n.flatten() buf, err := n.MarshalMsg(nil) if err != nil { panic(`unable to encode node`) } n.ChildValues, n.ChildKeys = nil, nil item := &Payload{n.ID, buf} items = append(items, item) } return items } func (t *Tr) copyNode(n *Node) *Node { if t.context.nodeExists(n.ID) { return n } cp := n.copy() t.context.addNode(cp) return cp } func (t *Tr) Len() int { return t.Count } func (t *Tr) AsMutable() MutableTree { return &Tr{ Count: t.Count, UUID: newID(), Root: t.Root, config: t.config, cacher: t.cacher, context: newContext(), NodeWidth: t.NodeWidth, mutable: true, } } func (t *Tr) Commit() (ReadableTree, error) { t.NodeWidth = t.config.NodeWidth items := make([]*Payload, 0, len(t.context.seenNodes)) items = append(items, t.commit()...) // save self items = append(items, &Payload{t.ID(), t.toBytes()}) err := t.config.Persister.Save(items...) if err != nil { return nil, err } t.reset() t.context = nil return t, nil } func treeFromBytes(p Persister, data []byte, comparator Comparator) (*Tr, error) { t := &Tr{} _, err := t.UnmarshalMsg(data) if err != nil { return nil, err } cfg := DefaultConfig(p, comparator) if t.NodeWidth > 0 { cfg.NodeWidth = t.NodeWidth } t.config = cfg t.cacher = newCacher(cfg.Persister) return t, nil } func newTree(cfg Config) *Tr { return &Tr{ config: cfg, UUID: newID(), cacher: newCacher(cfg.Persister), } } // New creates a new ReadableTree using the provided config. func New(cfg Config) ReadableTree { return newTree(cfg) } // Load returns a ReadableTree from persistence. The provided // config should contain a persister that can be used for this purpose. // An error is returned if the tree could not be found or an error // occurred in the persistence layer. func Load(p Persister, id []byte, comparator Comparator) (ReadableTree, error) { items, err := p.Load(id) if err != nil { return nil, err } if len(items) == 0 || items[0] == nil { return nil, ErrTreeNotFound } rt, err := treeFromBytes(p, items[0].Payload, comparator) if err != nil { return nil, err } return rt, nil } ================================================ FILE: btree/immutable/rt_gen.go ================================================ package btree // NOTE: THIS FILE WAS PRODUCED BY THE // MSGP CODE GENERATION TOOL (github.com/tinylib/msgp) // DO NOT EDIT import ( "github.com/tinylib/msgp/msgp" ) // MarshalMsg implements msgp.Marshaler func (z *Tr) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 4 // string "u" o = append(o, 0x84, 0xa1, 0x75) o, err = z.UUID.MarshalMsg(o) if err != nil { return } // string "c" o = append(o, 0xa1, 0x63) o = msgp.AppendInt(o, z.Count) // string "r" o = append(o, 0xa1, 0x72) o, err = z.Root.MarshalMsg(o) if err != nil { return } // string "nw" o = append(o, 0xa2, 0x6e, 0x77) o = msgp.AppendInt(o, z.NodeWidth) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *Tr) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var isz uint32 isz, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { return } for isz > 0 { isz-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { return } switch msgp.UnsafeString(field) { case "u": bts, err = z.UUID.UnmarshalMsg(bts) if err != nil { return } case "c": z.Count, bts, err = msgp.ReadIntBytes(bts) if err != nil { return } case "r": bts, err = z.Root.UnmarshalMsg(bts) if err != nil { return } case "nw": z.NodeWidth, bts, err = msgp.ReadIntBytes(bts) if err != nil { return } default: bts, err = msgp.Skip(bts) if err != nil { return } } } o = bts return } func (z *Tr) Msgsize() (s int) { s = 1 + 2 + z.UUID.Msgsize() + 2 + msgp.IntSize + 2 + z.Root.Msgsize() + 3 + msgp.IntSize return } ================================================ FILE: btree/immutable/rt_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package btree import ( "log" "math/rand" "sort" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type ephemeral struct { mp map[string]*Payload lock sync.RWMutex } func (e *ephemeral) Save(items ...*Payload) error { e.lock.Lock() defer e.lock.Unlock() if len(items) == 0 { return nil } for _, item := range items { e.mp[string(item.Key)] = item } return nil } func (e *ephemeral) Load(keys ...[]byte) ([]*Payload, error) { e.lock.RLock() defer e.lock.RUnlock() if len(keys) == 0 { return nil, nil } items := make([]*Payload, 0, len(keys)) for _, k := range keys { items = append(items, e.mp[string(k)]) } return items, nil } const ( maxValue = int64(100000) ) func init() { rand.Seed(time.Now().Unix()) } type valueSortWrapper struct { comparator Comparator values []interface{} } func (v *valueSortWrapper) Len() int { return len(v.values) } func (v *valueSortWrapper) Swap(i, j int) { v.values[i], v.values[j] = v.values[j], v.values[i] } func (v *valueSortWrapper) Less(i, j int) bool { return v.comparator(v.values[i], v.values[j]) < 0 } func (v *valueSortWrapper) sort() { sort.Sort(v) } func reverse(items items) items { for i := 0; i < len(items)/2; i++ { items[i], items[len(items)-1-i] = items[len(items)-1-i], items[i] } return items } var comparator = func(item1, item2 interface{}) int { int1, int2 := item1.(int64), item2.(int64) if int1 < int2 { return -1 } if int1 > int2 { return 1 } return 0 } // orderedItems is going to contain our "master" copy of items in // sorted order. Because the operations on a flat list are well // understood, we can use this type to do generative type testing and // confirm the results. type orderedItems []*Item func (o orderedItems) Len() int { return len(o) } func (o orderedItems) Swap(i, j int) { o[i], o[j] = o[j], o[i] } func (o orderedItems) Less(i, j int) bool { return comparator(o[i].Value, o[j].Value) < 0 } func (o orderedItems) equal(item1, item2 *Item) bool { return comparator(item1.Value, item2.Value) == 0 } func (o orderedItems) copy() orderedItems { cp := make(orderedItems, len(o)) copy(cp, o) return cp } func (o orderedItems) search(value interface{}) int { return sort.Search(len(o), func(i int) bool { return comparator(o[i].Value, value) >= 0 }) } func (o orderedItems) add(item *Item) orderedItems { cp := make(orderedItems, len(o)) copy(cp, o) i := cp.search(item.Value) if i < len(o) && o.equal(o[i], item) { cp[i] = item return cp } if i == len(cp) { cp = append(cp, item) return cp } cp = append(cp, nil) copy(cp[i+1:], cp[i:]) cp[i] = item return cp } func (o orderedItems) delete(item *Item) orderedItems { i := o.search(item.Value) if i == len(o) { return o } if !o.equal(o[i], item) { return o } cp := make(orderedItems, len(o)) copy(cp, o) copy(cp[i:], cp[i+1:]) cp[len(cp)-1] = nil // or the zero value of T cp = cp[:len(cp)-1] return cp } func (o orderedItems) toItems() items { cp := make(items, 0, len(o)) for _, item := range o { cp = append(cp, item) } return cp } func (o orderedItems) query(start, stop interface{}) items { items := make(items, 0, len(o)) for i := o.search(start); i < len(o); i++ { if comparator(o[i], stop) > 0 { break } items = append(items, o[i]) } return items } func generateRandomQuery() (interface{}, interface{}) { start := int64(rand.Intn(int(maxValue))) offset := int64(rand.Intn(100)) return start, start + offset } func newItem(value interface{}) *Item { return &Item{ Value: value, Payload: newID(), } } func newEphemeral() Persister { return &ephemeral{ mp: make(map[string]*Payload), } } type delayedPersister struct { Persister } func (d *delayedPersister) Load(keys ...[]byte) ([]*Payload, error) { time.Sleep(5 * time.Millisecond) return d.Persister.Load(keys...) } func newDelayed() Persister { return &delayedPersister{newEphemeral()} } func defaultConfig() Config { return Config{ NodeWidth: 10, // easy number to test with Persister: newEphemeral(), Comparator: comparator, } } func generateRandomItem() *Item { return newItem(int64(rand.Intn(int(maxValue)))) } // generateRandomItems will generate a list of random items with // no duplicates. func generateRandomItems(num int) items { items := make(items, 0, num) mp := make(map[interface{}]struct{}, num) for len(items) < num { c := generateRandomItem() if _, ok := mp[c.Value]; ok { continue } mp[c.Value] = struct{}{} items = append(items, c) } return items } // generateLinearItems is similar to random item generation except that // items are returned in sorted order. func generateLinearItems(num int) items { items := make(items, 0, num) for i := 0; i < num; i++ { c := newItem(int64(i)) items = append(items, c) } return items } func toOrdered(items items) orderedItems { oc := make(orderedItems, 0, len(items)) for _, item := range items { oc = oc.add(item) } return oc } // the following 3 methods are in the _test file as they are only used // in a testing environment. func (t *Tr) toList(values ...interface{}) (items, error) { items := make(items, 0, t.Count) err := t.Apply(func(item *Item) { items = append(items, item) }, values...) return items, err } func (t *Tr) pprint(id ID) { n, _ := t.contextOrCachedNode(id, true) if n == nil { log.Printf(`NODE: %+v`, n) return } log.Printf(`NODE: %+v, LEN(ids): %+v, LEN(values): %+v`, n, n.lenKeys(), n.lenValues()) for i, key := range n.ChildKeys { child, _ := t.contextOrCachedNode(key.ID(), true) if child == nil { continue } log.Printf(`CHILD %d: %+v`, i, child) } for _, key := range n.ChildKeys { child, _ := t.contextOrCachedNode(key.ID(), true) if child == nil { continue } t.pprint(key.ID()) } } func (t *Tr) verify(id ID, tb testing.TB) (interface{}, interface{}) { n, err := t.contextOrCachedNode(id, true) require.NoError(tb, err) cp := n.copy() // copy the values and sort them, ensure node values are sorted cpValues := cp.ChildValues (&valueSortWrapper{comparator: comparator, values: cpValues}).sort() assert.Equal(tb, cpValues, n.ChildValues) if !assert.False(tb, n.needsSplit(t.config.NodeWidth)) { tb.Logf(`NODE NEEDS SPLIT: NODE: %+v`, n) } if string(t.Root) != string(n.ID) { assert.True(tb, n.lenValues() >= t.config.NodeWidth/2) } if n.IsLeaf { assert.Equal(tb, n.lenValues(), n.lenKeys()) // assert lens are equal return n.firstValue(), n.lastValue() // return last value } else { for _, key := range n.ChildKeys { assert.Empty(tb, key.Payload) } } for i, key := range n.ChildKeys { min, max := t.verify(key.ID(), tb) if i == 0 { assert.True(tb, comparator(max, n.valueAt(i)) <= 0) } else if i == n.lenValues() { assert.True(tb, comparator(min, n.lastValue()) > 0) } else { assert.True(tb, comparator(max, n.valueAt(i)) <= 0) assert.True(tb, comparator(min, n.valueAt(i-1)) > 0) } } return n.firstValue(), n.lastValue() } func itemsToValues(items ...*Item) []interface{} { values := make([]interface{}, 0, len(items)) for _, item := range items { values = append(values, item.Value) } return values } func TestNodeSplit(t *testing.T) { number := 100 items := generateLinearItems(number) cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() _, err := mutable.AddItems(items...) require.NoError(t, err) assert.Equal(t, number, mutable.Len()) mutable.(*Tr).verify(mutable.(*Tr).Root, t) result, err := mutable.(*Tr).toList(itemsToValues(items[5:10]...)...) require.NoError(t, err) if !assert.Equal(t, items[5:10], result) { mutable.(*Tr).pprint(mutable.(*Tr).Root) for i, c := range items[5:10] { t.Logf(`EXPECTED: %+v, RESULT: %+v`, c, result[i]) } t.FailNow() } mutable = rt.AsMutable() for _, c := range items { _, err := mutable.AddItems(c) require.NoError(t, err) } result, err = mutable.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) assert.Equal(t, items, result) mutable.(*Tr).verify(mutable.(*Tr).Root, t) rt, err = mutable.Commit() require.NoError(t, err) rt, err = Load(cfg.Persister, rt.ID(), comparator) result, err = mutable.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) assert.Equal(t, items, result) rt.(*Tr).verify(rt.(*Tr).Root, t) } func TestReverseNodeSplit(t *testing.T) { number := 400 items := generateLinearItems(number) reversed := make([]*Item, len(items)) copy(reversed, items) reversed = reverse(reversed) rt := New(defaultConfig()) mutable := rt.AsMutable() _, err := mutable.AddItems(reversed...) require.NoError(t, err) result, err := mutable.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) if !assert.Equal(t, items, result) { for _, c := range result { t.Logf(`RESULT: %+v`, c) } } mutable = rt.AsMutable() for _, c := range reversed { _, err := mutable.AddItems(c) require.NoError(t, err) } result, err = mutable.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) assert.Equal(t, items, result) mutable.(*Tr).verify(mutable.(*Tr).Root, t) } func TestDuplicate(t *testing.T) { item1 := newItem(int64(1)) item2 := newItem(int64(1)) rt := New(defaultConfig()) mutable := rt.AsMutable() _, err := mutable.AddItems(item1) require.NoError(t, err) _, err = mutable.AddItems(item2) require.NoError(t, err) assert.Equal(t, 1, mutable.Len()) result, err := mutable.(*Tr).toList(int64(1)) require.NoError(t, err) assert.Equal(t, items{item2}, result) mutable.(*Tr).verify(mutable.(*Tr).Root, t) } func TestCommit(t *testing.T) { items := generateRandomItems(5) rt := New(defaultConfig()) mutable := rt.AsMutable() _, err := mutable.AddItems(items...) require.Nil(t, err) rt, err = mutable.Commit() require.NoError(t, err) expected := toOrdered(items).toItems() result, err := rt.(*Tr).toList(itemsToValues(expected...)...) require.NoError(t, err) if !assert.Equal(t, expected, result) { require.Equal(t, len(expected), len(result)) for i, c := range expected { if !assert.Equal(t, c, result[i]) { t.Logf(`EXPECTED: %+v, RESULT: %+v`, c, result[i]) } } } rt.(*Tr).verify(rt.(*Tr).Root, t) } func TestRandom(t *testing.T) { items := generateRandomItems(1000) rt := New(defaultConfig()) mutable := rt.AsMutable() _, err := mutable.AddItems(items...) require.Nil(t, err) require.NoError(t, err) expected := toOrdered(items).toItems() result, err := mutable.(*Tr).toList(itemsToValues(expected...)...) if !assert.Equal(t, expected, result) { assert.Equal(t, len(expected), len(result)) for i, c := range expected { assert.Equal(t, c, result[i]) } } mutable.(*Tr).verify(mutable.(*Tr).Root, t) } func TestLoad(t *testing.T) { cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateRandomItems(1000) _, err := mutable.AddItems(items...) require.NoError(t, err) id := mutable.ID() _, err = mutable.Commit() require.NoError(t, err) rt, err = Load(cfg.Persister, id, comparator) require.NoError(t, err) sort.Sort(orderedItems(items)) result, err := rt.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) assert.Equal(t, items, result) rt.(*Tr).verify(rt.(*Tr).Root, t) } func TestDeleteFromRoot(t *testing.T) { number := 5 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateLinearItems(number) mutable.AddItems(items...) mutable.DeleteItems(items[0].Value, items[1].Value, items[2].Value) result, err := mutable.(*Tr).toList(itemsToValues(items...)...) require.Nil(t, err) assert.Equal(t, items[3:], result) assert.Equal(t, 2, mutable.Len()) mutable.(*Tr).verify(mutable.(*Tr).Root, t) } func TestDeleteAllFromRoot(t *testing.T) { num := 5 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateLinearItems(num) mutable.AddItems(items...) mutable.DeleteItems(itemsToValues(items...)...) result, err := mutable.(*Tr).toList(itemsToValues(items...)...) require.Nil(t, err) assert.Empty(t, result) assert.Equal(t, 0, mutable.Len()) } func TestDeleteAfterSplitIncreasing(t *testing.T) { num := 11 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateLinearItems(num) mutable.AddItems(items...) for i := 0; i < num-1; i++ { mutable.DeleteItems(itemsToValues(items[i])...) result, err := mutable.(*Tr).toList(itemsToValues(items...)...) require.Nil(t, err) assert.Equal(t, items[i+1:], result) mutable.(*Tr).verify(mutable.(*Tr).Root, t) } } func TestDeleteMultipleLevelsRandomlyBulk(t *testing.T) { num := 200 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateRandomItems(num) mutable.AddItems(items...) mutable.DeleteItems(itemsToValues(items[:100]...)...) result, _ := mutable.(*Tr).toList(itemsToValues(items...)...) assert.Len(t, result, 100) } func TestDeleteAfterSplitDecreasing(t *testing.T) { num := 11 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateLinearItems(num) mutable.AddItems(items...) for i := num - 1; i >= 0; i-- { mutable.DeleteItems(itemsToValues(items[i])...) result, err := mutable.(*Tr).toList(itemsToValues(items...)...) require.Nil(t, err) assert.Equal(t, items[:i], result) if i > 0 { mutable.(*Tr).verify(mutable.(*Tr).Root, t) } } } func TestDeleteMultipleLevels(t *testing.T) { num := 20 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateRandomItems(num) mutable.AddItems(items...) ordered := toOrdered(items) for i, c := range ordered { _, err := mutable.DeleteItems(c.Value) require.NoError(t, err) result, err := mutable.(*Tr).toList(itemsToValues(ordered...)...) require.NoError(t, err) if !assert.Equal(t, ordered[i+1:].toItems(), result) { log.Printf(`LEN EXPECTED: %+v, RESULT: %+v`, len(ordered[i+1:]), len(result)) mutable.(*Tr).pprint(mutable.(*Tr).Root) assert.Equal(t, len(ordered[i+1:]), len(result)) for i, c := range ordered[i+1:] { log.Printf(`EXPECTED: %+v`, c) if i < len(result) { log.Printf(`RECEIVED: %+v`, result[i]) } } break } if len(ordered[i+1:]) > 0 { mutable.(*Tr).verify(mutable.(*Tr).Root, t) } } assert.Nil(t, mutable.(*Tr).Root) } func TestDeleteMultipleLevelsRandomly(t *testing.T) { num := 200 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateRandomItems(num) mutable.AddItems(items...) ordered := toOrdered(items) for _, c := range items { _, err := mutable.DeleteItems(c.Value) require.NoError(t, err) ordered = ordered.delete(c) result, err := mutable.(*Tr).toList(itemsToValues(ordered...)...) require.NoError(t, err) assert.Equal(t, ordered.toItems(), result) if len(ordered) > 0 { mutable.(*Tr).verify(mutable.(*Tr).Root, t) } } assert.Nil(t, mutable.(*Tr).Root) } func TestDeleteMultipleLevelsWithCommit(t *testing.T) { num := 20 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateRandomItems(num) mutable.AddItems(items...) rt, _ = mutable.Commit() rt, _ = Load(cfg.Persister, rt.ID(), comparator) result, err := rt.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) assert.Equal(t, items, result) mutable = rt.AsMutable() for _, c := range items[:10] { _, err := mutable.DeleteItems(c.Value) require.Nil(t, err) } result, err = mutable.(*Tr).toList(itemsToValues(items[10:]...)...) require.Nil(t, err) assert.Equal(t, items[10:], result) mutable.(*Tr).verify(mutable.(*Tr).Root, t) result, err = rt.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) assert.Equal(t, items, result) rt.(*Tr).verify(rt.(*Tr).Root, t) } func TestCommitAfterDelete(t *testing.T) { num := 15 cfg := defaultConfig() rt := New(cfg) mutable := rt.AsMutable() items := generateRandomItems(num) mutable.AddItems(items...) for _, c := range items[:5] { mutable.DeleteItems(c.Value) mutable.(*Tr).verify(mutable.(*Tr).Root, t) } rt, err := mutable.Commit() require.Nil(t, err) result, err := rt.(*Tr).toList(itemsToValues(items...)...) require.Nil(t, err) assert.Equal(t, items[5:], result) rt.(*Tr).verify(rt.(*Tr).Root, t) } func TestSecondCommitSplitsRoot(t *testing.T) { number := 15 cfg := defaultConfig() rt := New(cfg) items := generateLinearItems(number) mutable := rt.AsMutable() mutable.AddItems(items[:10]...) mutable.(*Tr).verify(mutable.(*Tr).Root, t) rt, _ = mutable.Commit() rt.(*Tr).verify(rt.(*Tr).Root, t) mutable = rt.AsMutable() mutable.AddItems(items[10:]...) mutable.(*Tr).verify(mutable.(*Tr).Root, t) result, err := mutable.(*Tr).toList(itemsToValues(items...)...) require.Nil(t, err) if !assert.Equal(t, items, result) { for i, c := range items { log.Printf(`EXPECTED: %+v, RECEIVED: %+v`, c, result[i]) } } } func TestSecondCommitMultipleSplits(t *testing.T) { num := 50 cfg := defaultConfig() rt := New(cfg) items := generateRandomItems(num) mutable := rt.AsMutable() mutable.AddItems(items[:25]...) mutable.(*Tr).verify(mutable.(*Tr).Root, t) rt, err := mutable.Commit() rt.(*Tr).verify(rt.(*Tr).Root, t) result, err := rt.(*Tr).toList(itemsToValues(items...)...) require.Nil(t, err) assert.Equal(t, items[:25], result) mutable = rt.AsMutable() mutable.AddItems(items[25:]...) mutable.(*Tr).verify(mutable.(*Tr).Root, t) sort.Sort(orderedItems(items)) result, err = mutable.(*Tr).toList(itemsToValues(items...)...) require.Nil(t, err) if !assert.Equal(t, items, result) { mutable.(*Tr).pprint(mutable.(*Tr).Root) } } func TestLargeAdd(t *testing.T) { cfg := defaultConfig() number := cfg.NodeWidth * 5 rt := New(cfg) items := generateLinearItems(number) mutable := rt.AsMutable() _, err := mutable.AddItems(items...) require.NoError(t, err) id := mutable.ID() result, err := mutable.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) assert.Equal(t, items, result) _, err = mutable.Commit() require.NoError(t, err) rt, err = Load(cfg.Persister, id, comparator) require.NoError(t, err) result, err = rt.(*Tr).toList(itemsToValues(items...)...) require.NoError(t, err) assert.Equal(t, items, result) } func TestNodeInfiniteLoop(t *testing.T) { cfg := defaultConfig() rt := New(cfg) items := generateLinearItems(3) mutable := rt.AsMutable() _, err := mutable.AddItems(items...) require.NoError(t, err) result, err := mutable.DeleteItems(items[1].Value, items[2].Value) require.NoError(t, err) assert.Len(t, result, 2) } // all remaining tests are generative in nature to catch things // I can't think of. func TestGenerativeAdds(t *testing.T) { if testing.Short() { t.Skipf(`skipping generative add`) return } number := 100 cfg := defaultConfig() rt := New(cfg) oc := make(orderedItems, 0) for i := 0; i < number; i++ { num := int(rand.Int31n(100)) if num == 0 { num++ } items := generateRandomItems(num) mutated := oc.copy() for _, c := range items { mutated = mutated.add(c) } mutable := rt.AsMutable() _, err := mutable.AddItems(items...) require.Nil(t, err) mutable.(*Tr).verify(mutable.(*Tr).Root, t) rtMutated, err := mutable.Commit() require.Nil(t, err) rtMutated.(*Tr).verify(rtMutated.(*Tr).Root, t) result, err := rtMutated.(*Tr).toList(itemsToValues(mutated.toItems()...)...) require.Nil(t, err) if !assert.Equal(t, mutated.toItems(), result) { rtMutated.(*Tr).pprint(rtMutated.(*Tr).Root) if len(mutated) == len(result) { for i, c := range mutated.toItems() { log.Printf(`EXPECTED: %+v, RECEIVED: %+v`, c, result[i]) } } } assert.Equal(t, len(mutated), rtMutated.Len()) result, err = rt.(*Tr).toList(itemsToValues(oc.toItems()...)...) require.Nil(t, err) assert.Equal(t, oc.toItems(), result) oc = mutated rt = rtMutated } } func TestGenerativeDeletes(t *testing.T) { if testing.Short() { t.Skipf(`skipping generative delete`) return } number := 100 var err error cfg := defaultConfig() rt := New(cfg) oc := toOrdered(generateRandomItems(1000)) mutable := rt.AsMutable() mutable.AddItems(oc.toItems()...) mutable.(*Tr).verify(mutable.(*Tr).Root, t) rt, err = mutable.Commit() require.NoError(t, err) for i := 0; i < number; i++ { mutable = rt.AsMutable() index := rand.Intn(len(oc)) c := oc[index] mutated := oc.delete(c) result, err := rt.(*Tr).toList(itemsToValues(oc.toItems()...)...) require.NoError(t, err) assert.Equal(t, oc.toItems(), result) assert.Equal(t, len(oc), rt.Len()) _, err = mutable.DeleteItems(c.Value) require.NoError(t, err) mutable.(*Tr).verify(mutable.(*Tr).Root, t) result, err = mutable.(*Tr).toList(itemsToValues(mutated.toItems()...)...) require.NoError(t, err) assert.Equal(t, len(mutated), len(result)) require.Equal(t, mutated.toItems(), result) oc = mutated rt, err = mutable.Commit() require.NoError(t, err) } } func TestGenerativeOperations(t *testing.T) { if testing.Short() { t.Skipf(`skipping generative operations`) return } number := 100 cfg := defaultConfig() rt := New(cfg) // seed the tree items := generateRandomItems(1000) oc := toOrdered(items) mutable := rt.AsMutable() mutable.AddItems(items...) result, err := mutable.(*Tr).toList(itemsToValues(oc.toItems()...)...) require.NoError(t, err) require.Equal(t, oc.toItems(), result) rt, err = mutable.Commit() require.NoError(t, err) for i := 0; i < number; i++ { mutable = rt.AsMutable() if rand.Float64() < .5 && len(oc) > 0 { c := oc[rand.Intn(len(oc))] oc = oc.delete(c) _, err = mutable.DeleteItems(c.Value) require.NoError(t, err) mutable.(*Tr).verify(mutable.(*Tr).Root, t) result, err := mutable.(*Tr).toList(itemsToValues(oc.toItems()...)...) require.NoError(t, err) require.Equal(t, oc.toItems(), result) assert.Equal(t, len(oc), mutable.Len()) } else { c := generateRandomItem() oc = oc.add(c) _, err = mutable.AddItems(c) require.NoError(t, err) mutable.(*Tr).verify(mutable.(*Tr).Root, t) result, err = mutable.(*Tr).toList(itemsToValues(oc.toItems()...)...) require.NoError(t, err) require.Equal(t, oc.toItems(), result) assert.Equal(t, len(oc), mutable.Len()) } rt, err = mutable.Commit() require.NoError(t, err) } } func BenchmarkGetitems(b *testing.B) { number := 100 cfg := defaultConfig() cfg.Persister = newDelayed() rt := New(cfg) items := generateRandomItems(number) mutable := rt.AsMutable() _, err := mutable.AddItems(items...) require.NoError(b, err) rt, err = mutable.Commit() require.NoError(b, err) id := rt.ID() b.ResetTimer() for i := 0; i < b.N; i++ { rt, err = Load(cfg.Persister, id, comparator) require.NoError(b, err) _, err = rt.(*Tr).toList(itemsToValues(items...)...) require.NoError(b, err) } } func BenchmarkBulkAdd(b *testing.B) { number := 1000000 items := generateLinearItems(number) b.ResetTimer() for i := 0; i < b.N; i++ { tr := New(defaultConfig()) mutable := tr.AsMutable() mutable.AddItems(items...) } } ================================================ FILE: btree/palm/action.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package palm import ( "runtime" "sync" "sync/atomic" "github.com/Workiva/go-datastructures/common" ) type actions []action type action interface { operation() operation keys() common.Comparators complete() addNode(int64, *node) nodes() []*node } type getAction struct { result common.Comparators completer *sync.WaitGroup } func (ga *getAction) complete() { ga.completer.Done() } func (ga *getAction) operation() operation { return get } func (ga *getAction) keys() common.Comparators { return ga.result } func (ga *getAction) addNode(i int64, n *node) { return // not necessary for gets } func (ga *getAction) nodes() []*node { return nil } func newGetAction(keys common.Comparators) *getAction { result := make(common.Comparators, len(keys)) copy(result, keys) // don't want to mutate passed in keys ga := &getAction{ result: result, completer: new(sync.WaitGroup), } ga.completer.Add(1) return ga } type insertAction struct { result common.Comparators completer *sync.WaitGroup ns []*node } func (ia *insertAction) complete() { ia.completer.Done() } func (ia *insertAction) operation() operation { return add } func (ia *insertAction) keys() common.Comparators { return ia.result } func (ia *insertAction) addNode(i int64, n *node) { ia.ns[i] = n } func (ia *insertAction) nodes() []*node { return ia.ns } func newInsertAction(keys common.Comparators) *insertAction { result := make(common.Comparators, len(keys)) copy(result, keys) ia := &insertAction{ result: result, completer: new(sync.WaitGroup), ns: make([]*node, len(keys)), } ia.completer.Add(1) return ia } type removeAction struct { *insertAction } func (ra *removeAction) operation() operation { return remove } func newRemoveAction(keys common.Comparators) *removeAction { return &removeAction{ newInsertAction(keys), } } type applyAction struct { fn func(common.Comparator) bool start, stop common.Comparator completer *sync.WaitGroup } func (aa *applyAction) operation() operation { return apply } func (aa *applyAction) nodes() []*node { return nil } func (aa *applyAction) addNode(i int64, n *node) {} func (aa *applyAction) keys() common.Comparators { return nil } func (aa *applyAction) complete() { aa.completer.Done() } func newApplyAction(fn func(common.Comparator) bool, start, stop common.Comparator) *applyAction { aa := &applyAction{ fn: fn, start: start, stop: stop, completer: new(sync.WaitGroup), } aa.completer.Add(1) return aa } func minUint64(choices ...uint64) uint64 { min := choices[0] for i := 1; i < len(choices); i++ { if choices[i] < min { min = choices[i] } } return min } type interfaces []interface{} func executeInterfacesInParallel(ifs interfaces, fn func(interface{})) { if len(ifs) == 0 { return } done := int64(-1) numCPU := uint64(runtime.NumCPU()) if numCPU > 1 { numCPU-- } numCPU = minUint64(numCPU, uint64(len(ifs))) var wg sync.WaitGroup wg.Add(int(numCPU)) for i := uint64(0); i < numCPU; i++ { go func() { defer wg.Done() for { i := atomic.AddInt64(&done, 1) if i >= int64(len(ifs)) { return } fn(ifs[i]) } }() } wg.Wait() } func executeInterfacesInSerial(ifs interfaces, fn func(interface{})) { if len(ifs) == 0 { return } for _, ifc := range ifs { fn(ifc) } } ================================================ FILE: btree/palm/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package palm implements parallel architecture-friendly latch-free modifications (PALM). Details can be found here: http://cs.unc.edu/~sewall/palm.pdf The primary purpose of the tree is to efficiently batch operations in such a way that locks are not required. This is most beneficial for in-memory indices. Otherwise, the operations have typical B-tree time complexities. You primarily see the benefits of multithreading in availability and bulk operations. Benchmarks: BenchmarkReadAndWrites-8 3000 483140 ns/op BenchmarkSimultaneousReadsAndWrites-8 300 4418123 ns/op BenchmarkBulkAdd-8 300 5569750 ns/op BenchmarkAdd-8 500000 2478 ns/op BenchmarkBulkAddToExisting-8 100 20552674 ns/op BenchmarkGet-8 2000000 629 ns/op BenchmarkBulkGet-8 5000 223249 ns/op BenchmarkDelete-8 500000 2421 ns/op BenchmarkBulkDelete-8 500 2790461 ns/op BenchmarkFindQuery-8 1000000 1166 ns/op BenchmarkExecuteQuery-8 10000 1290732 ns/op */ package palm import "github.com/Workiva/go-datastructures/common" // BTree is the interface returned from this package's constructor. type BTree interface { // Insert will insert the provided keys into the tree. Insert(...common.Comparator) // Delete will remove the provided keys from the tree. If no // matching key is found, this is a no-op. Delete(...common.Comparator) // Get will return a key matching the associated provided // key if it exists. Get(...common.Comparator) common.Comparators // Len returns the number of items in the tree. Len() uint64 // Query will return a list of Comparators that fall within the // provided start and stop Comparators. Start is inclusive while // stop is exclusive, ie [start, stop). Query(start, stop common.Comparator) common.Comparators // Dispose will clean up any resources used by this tree. This // must be called to prevent a memory leak. Dispose() } ================================================ FILE: btree/palm/key.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package palm import "github.com/Workiva/go-datastructures/common" func reverseKeys(cmps common.Comparators) common.Comparators { reversed := make(common.Comparators, len(cmps)) for i := len(cmps) - 1; i >= 0; i-- { reversed[len(cmps)-1-i] = cmps[i] } return reversed } func chunkKeys(keys common.Comparators, numParts int64) []common.Comparators { parts := make([]common.Comparators, numParts) for i := int64(0); i < numParts; i++ { parts[i] = keys[i*int64(len(keys))/numParts : (i+1)*int64(len(keys))/numParts] } return parts } ================================================ FILE: btree/palm/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package palm import "github.com/Workiva/go-datastructures/common" type mockKey int func (mk mockKey) Compare(other common.Comparator) int { otherKey := other.(mockKey) if mk == otherKey { return 0 } if mk > otherKey { return 1 } return -1 } ================================================ FILE: btree/palm/node.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package palm import ( "log" "sort" "github.com/Workiva/go-datastructures/common" ) func getParent(parent *node, key common.Comparator) *node { var n *node for parent != nil && !parent.isLeaf { n = parent.searchNode(key) parent = n } return parent } type nodes struct { list []*node } func (ns *nodes) push(n *node) { ns.list = append(ns.list, n) } func (ns *nodes) splitAt(i, capacity uint64) (*nodes, *nodes) { i++ right := make([]*node, uint64(len(ns.list))-i, capacity) copy(right, ns.list[i:]) for j := i; j < uint64(len(ns.list)); j++ { ns.list[j] = nil } ns.list = ns.list[:i] return ns, &nodes{list: right} } func (ns *nodes) byPosition(pos uint64) *node { if pos >= uint64(len(ns.list)) { return nil } return ns.list[pos] } func (ns *nodes) insertAt(i uint64, n *node) { ns.list = append(ns.list, nil) copy(ns.list[i+1:], ns.list[i:]) ns.list[i] = n } func (ns *nodes) replaceAt(i uint64, n *node) { ns.list[i] = n } func (ns *nodes) len() uint64 { return uint64(len(ns.list)) } func newNodes(size uint64) *nodes { return &nodes{ list: make([]*node, 0, size), } } type keys struct { list common.Comparators } func (ks *keys) splitAt(i, capacity uint64) (*keys, *keys) { i++ right := make(common.Comparators, uint64(len(ks.list))-i, capacity) copy(right, ks.list[i:]) for j := i; j < uint64(len(ks.list)); j++ { ks.list[j] = nil } ks.list = ks.list[:i] return ks, &keys{list: right} } func (ks *keys) len() uint64 { return uint64(len(ks.list)) } func (ks *keys) byPosition(i uint64) common.Comparator { if i >= uint64(len(ks.list)) { return nil } return ks.list[i] } func (ks *keys) delete(k common.Comparator) common.Comparator { i := ks.search(k) if i >= uint64(len(ks.list)) { return nil } if ks.list[i].Compare(k) != 0 { return nil } old := ks.list[i] copy(ks.list[i:], ks.list[i+1:]) ks.list[len(ks.list)-1] = nil // GC ks.list = ks.list[:len(ks.list)-1] return old } func (ks *keys) search(key common.Comparator) uint64 { i := sort.Search(len(ks.list), func(i int) bool { return ks.list[i].Compare(key) > -1 }) return uint64(i) } func (ks *keys) insert(key common.Comparator) (common.Comparator, uint64) { i := ks.search(key) if i == uint64(len(ks.list)) { ks.list = append(ks.list, key) return nil, i } var old common.Comparator if ks.list[i].Compare(key) == 0 { old = ks.list[i] ks.list[i] = key } else { ks.insertAt(i, key) } return old, i } func (ks *keys) last() common.Comparator { return ks.list[len(ks.list)-1] } func (ks *keys) insertAt(i uint64, k common.Comparator) { ks.list = append(ks.list, nil) copy(ks.list[i+1:], ks.list[i:]) ks.list[i] = k } func (ks *keys) withPosition(k common.Comparator) (common.Comparator, uint64) { i := ks.search(k) if i == uint64(len(ks.list)) { return nil, i } if ks.list[i].Compare(k) == 0 { return ks.list[i], i } return nil, i } func newKeys(size uint64) *keys { return &keys{ list: make(common.Comparators, 0, size), } } type node struct { keys *keys nodes *nodes isLeaf bool parent, right *node } func (n *node) needsSplit(ary uint64) bool { return n.keys.len() >= ary } func (n *node) splitLeaf(i, capacity uint64) (common.Comparator, *node, *node) { key := n.keys.byPosition(i) _, rightKeys := n.keys.splitAt(i, capacity) nn := &node{ keys: rightKeys, nodes: newNodes(uint64(cap(n.nodes.list))), isLeaf: true, right: n.right, } n.right = nn return key, n, nn } func (n *node) splitInternal(i, capacity uint64) (common.Comparator, *node, *node) { key := n.keys.byPosition(i) n.keys.delete(key) _, rightKeys := n.keys.splitAt(i-1, capacity) _, rightNodes := n.nodes.splitAt(i, capacity) nn := newNode(false, rightKeys, rightNodes) for _, n := range rightNodes.list { n.parent = nn } return key, n, nn } func (n *node) split(i, capacity uint64) (common.Comparator, *node, *node) { if n.isLeaf { return n.splitLeaf(i, capacity) } return n.splitInternal(i, capacity) } func (n *node) search(key common.Comparator) uint64 { return n.keys.search(key) } func (n *node) searchNode(key common.Comparator) *node { i := n.search(key) return n.nodes.byPosition(uint64(i)) } func (n *node) key() common.Comparator { return n.keys.last() } func (n *node) print(output *log.Logger) { output.Printf(`NODE: %+v, %p`, n, n) for _, k := range n.keys.list { output.Printf(`KEY: %+v`, k) } if !n.isLeaf { for _, n := range n.nodes.list { if n == nil { output.Println(`NIL NODE`) continue } n.print(output) } } } // Compare is required by the skip.Entry interface but nodes are always // added by position so while this method is required it doesn't // need to return anything useful. func (n *node) Compare(e common.Comparator) int { return 0 } func newNode(isLeaf bool, keys *keys, ns *nodes) *node { return &node{ isLeaf: isLeaf, keys: keys, nodes: ns, } } ================================================ FILE: btree/palm/tree.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package palm import ( "log" "runtime" "sync" "sync/atomic" "github.com/Workiva/go-datastructures/common" "github.com/Workiva/go-datastructures/queue" ) type operation int const ( get operation = iota add remove apply ) const multiThreadAt = 400 // number of keys before we multithread lookups type keyBundle struct { key common.Comparator left, right *node } func (kb *keyBundle) dispose(ptree *ptree) { if ptree.kbRing.Len() == ptree.kbRing.Cap() { return } kb.key, kb.left, kb.right = nil, nil, nil ptree.kbRing.Put(kb) } type ptree struct { root *node _padding0 [8]uint64 number uint64 _padding1 [8]uint64 ary, bufferSize uint64 actions *queue.RingBuffer cache []interface{} buffer0 [8]uint64 disposed uint64 buffer1 [8]uint64 running uint64 _padding2 [8]uint64 kbRing *queue.RingBuffer disposeChannel chan bool mpChannel chan map[*node][]*keyBundle } func (ptree *ptree) checkAndRun(action action) { if ptree.actions.Len() > 0 { if action != nil { ptree.actions.Put(action) } if atomic.CompareAndSwapUint64(&ptree.running, 0, 1) { var a interface{} var err error for ptree.actions.Len() > 0 { a, err = ptree.actions.Get() if err != nil { return } ptree.cache = append(ptree.cache, a) if uint64(len(ptree.cache)) >= ptree.bufferSize { break } } go ptree.operationRunner(ptree.cache, true) } } else if action != nil { if atomic.CompareAndSwapUint64(&ptree.running, 0, 1) { switch action.operation() { case get: ptree.read(action) action.complete() ptree.reset() case add, remove: if len(action.keys()) > multiThreadAt { ptree.operationRunner(interfaces{action}, true) } else { ptree.operationRunner(interfaces{action}, false) } case apply: q := action.(*applyAction) n := getParent(ptree.root, q.start) ptree.apply(n, q) q.complete() ptree.reset() } } else { ptree.actions.Put(action) ptree.checkAndRun(nil) } } } func (ptree *ptree) init(bufferSize, ary uint64) { ptree.bufferSize = bufferSize ptree.ary = ary ptree.cache = make([]interface{}, 0, bufferSize) ptree.root = newNode(true, newKeys(ary), newNodes(ary)) ptree.actions = queue.NewRingBuffer(ptree.bufferSize) ptree.kbRing = queue.NewRingBuffer(1024) for i := uint64(0); i < ptree.kbRing.Cap(); i++ { ptree.kbRing.Put(&keyBundle{}) } ptree.disposeChannel = make(chan bool) ptree.mpChannel = make(chan map[*node][]*keyBundle, 1024) var wg sync.WaitGroup wg.Add(1) go ptree.disposer(&wg) wg.Wait() } func (ptree *ptree) newKeyBundle(key common.Comparator) *keyBundle { if ptree.kbRing.Len() == 0 { return &keyBundle{key: key} } ifc, err := ptree.kbRing.Get() if err != nil { return nil } kb := ifc.(*keyBundle) kb.key = key return kb } func (ptree *ptree) operationRunner(xns interfaces, threaded bool) { writeOperations, deleteOperations, toComplete := ptree.fetchKeys(xns, threaded) ptree.recursiveMutate(writeOperations, deleteOperations, false, threaded) for _, a := range toComplete { a.complete() } ptree.reset() } func (ptree *ptree) read(action action) { for i, k := range action.keys() { n := getParent(ptree.root, k) if n == nil { action.keys()[i] = nil } else { key, _ := n.keys.withPosition(k) if key == nil { action.keys()[i] = nil } else { action.keys()[i] = key } } } } func (ptree *ptree) fetchKeys(xns interfaces, inParallel bool) (map[*node][]*keyBundle, map[*node][]*keyBundle, actions) { if inParallel { ptree.fetchKeysInParallel(xns) } else { ptree.fetchKeysInSerial(xns) } writeOperations := make(map[*node][]*keyBundle) deleteOperations := make(map[*node][]*keyBundle) toComplete := make(actions, 0, len(xns)/2) for _, ifc := range xns { action := ifc.(action) switch action.operation() { case add: for i, n := range action.nodes() { writeOperations[n] = append(writeOperations[n], ptree.newKeyBundle(action.keys()[i])) } toComplete = append(toComplete, action) case remove: for i, n := range action.nodes() { deleteOperations[n] = append(deleteOperations[n], ptree.newKeyBundle(action.keys()[i])) } toComplete = append(toComplete, action) case get, apply: action.complete() } } return writeOperations, deleteOperations, toComplete } func (ptree *ptree) apply(n *node, aa *applyAction) { i := n.search(aa.start) if i == n.keys.len() { // nothing to apply against return } var k common.Comparator for n != nil { for j := i; j < n.keys.len(); j++ { k = n.keys.byPosition(j) if aa.stop.Compare(k) < 1 || !aa.fn(k) { return } } n = n.right i = 0 } } func (ptree *ptree) disposer(wg *sync.WaitGroup) { wg.Done() for { select { case mp := <-ptree.mpChannel: ptree.cleanMap(mp) case <-ptree.disposeChannel: return } } } func (ptree *ptree) fetchKeysInSerial(xns interfaces) { for _, ifc := range xns { action := ifc.(action) for i, key := range action.keys() { n := getParent(ptree.root, key) switch action.operation() { case add, remove: action.addNode(int64(i), n) case get: if n == nil { action.keys()[i] = nil } else { k, _ := n.keys.withPosition(key) if k == nil { action.keys()[i] = nil } else { action.keys()[i] = k } } case apply: q := action.(*applyAction) ptree.apply(n, q) } } } } func (ptree *ptree) reset() { for i := range ptree.cache { ptree.cache[i] = nil } ptree.cache = ptree.cache[:0] atomic.StoreUint64(&ptree.running, 0) ptree.checkAndRun(nil) } func (ptree *ptree) fetchKeysInParallel(xns []interface{}) { var forCache struct { i int64 buffer [8]uint64 // different cache lines js []int64 } for j := 0; j < len(xns); j++ { forCache.js = append(forCache.js, -1) } numCPU := runtime.NumCPU() if numCPU > 1 { numCPU-- } var wg sync.WaitGroup wg.Add(numCPU) for k := 0; k < numCPU; k++ { go func() { for { index := atomic.LoadInt64(&forCache.i) if index >= int64(len(xns)) { break } action := xns[index].(action) j := atomic.AddInt64(&forCache.js[index], 1) if j > int64(len(action.keys())) { // someone else is updating i continue } else if j == int64(len(action.keys())) { atomic.StoreInt64(&forCache.i, index+1) continue } n := getParent(ptree.root, action.keys()[j]) switch action.operation() { case add, remove: action.addNode(j, n) case get: if n == nil { action.keys()[j] = nil } else { k, _ := n.keys.withPosition(action.keys()[j]) if k == nil { action.keys()[j] = nil } else { action.keys()[j] = k } } case apply: q := action.(*applyAction) ptree.apply(n, q) } } wg.Done() }() } wg.Wait() } func (ptree *ptree) splitNode(n, parent *node, nodes *[]*node, keys *common.Comparators) { if !n.needsSplit(ptree.ary) { return } length := n.keys.len() splitAt := ptree.ary - 1 for i := splitAt; i < length; i += splitAt { offset := length - i k, left, right := n.split(offset, ptree.ary) left.right = right *keys = append(*keys, k) *nodes = append(*nodes, left, right) left.parent = parent right.parent = parent } } func (ptree *ptree) applyNode(n *node, adds, deletes []*keyBundle) { for _, kb := range deletes { if n.keys.len() == 0 { break } deleted := n.keys.delete(kb.key) if deleted != nil { atomic.AddUint64(&ptree.number, ^uint64(0)) } } for _, kb := range adds { if n.keys.len() == 0 { oldKey, _ := n.keys.insert(kb.key) if n.isLeaf && oldKey == nil { atomic.AddUint64(&ptree.number, 1) } if kb.left != nil { n.nodes.push(kb.left) n.nodes.push(kb.right) } continue } oldKey, index := n.keys.insert(kb.key) if n.isLeaf && oldKey == nil { atomic.AddUint64(&ptree.number, 1) } if kb.left != nil { n.nodes.replaceAt(index, kb.left) n.nodes.insertAt(index+1, kb.right) } } } func (ptree *ptree) cleanMap(op map[*node][]*keyBundle) { for _, bundles := range op { for _, kb := range bundles { kb.dispose(ptree) } } } func (ptree *ptree) recursiveMutate(adds, deletes map[*node][]*keyBundle, setRoot, inParallel bool) { if len(adds) == 0 && len(deletes) == 0 { return } if setRoot && len(adds) > 1 { panic(`SHOULD ONLY HAVE ONE ROOT`) } ifs := make(interfaces, 0, len(adds)) for n := range adds { if n.parent == nil { setRoot = true } ifs = append(ifs, n) } for n := range deletes { if n.parent == nil { setRoot = true } if _, ok := adds[n]; !ok { ifs = append(ifs, n) } } var dummyRoot *node if setRoot { dummyRoot = &node{ keys: newKeys(ptree.ary), nodes: newNodes(ptree.ary), } } var write sync.Mutex nextLayerWrite := make(map[*node][]*keyBundle) nextLayerDelete := make(map[*node][]*keyBundle) var mutate func(interfaces, func(interface{})) if inParallel { mutate = executeInterfacesInParallel } else { mutate = executeInterfacesInSerial } mutate(ifs, func(ifc interface{}) { n := ifc.(*node) adds := adds[n] deletes := deletes[n] if len(adds) == 0 && len(deletes) == 0 { return } if setRoot { ptree.root = n } parent := n.parent if parent == nil { parent = dummyRoot setRoot = true } ptree.applyNode(n, adds, deletes) if n.needsSplit(ptree.ary) { keys := make(common.Comparators, 0, n.keys.len()) nodes := make([]*node, 0, n.nodes.len()) ptree.splitNode(n, parent, &nodes, &keys) write.Lock() for i, k := range keys { kb := ptree.newKeyBundle(k) kb.left = nodes[i*2] kb.right = nodes[i*2+1] nextLayerWrite[parent] = append(nextLayerWrite[parent], kb) } write.Unlock() } }) ptree.mpChannel <- adds ptree.mpChannel <- deletes ptree.recursiveMutate(nextLayerWrite, nextLayerDelete, setRoot, inParallel) } // Insert will add the provided keys to the tree. func (ptree *ptree) Insert(keys ...common.Comparator) { ia := newInsertAction(keys) ptree.checkAndRun(ia) ia.completer.Wait() } // Delete will remove the provided keys from the tree. If no // matching key is found, this is a no-op. func (ptree *ptree) Delete(keys ...common.Comparator) { ra := newRemoveAction(keys) ptree.checkAndRun(ra) ra.completer.Wait() } // Get will retrieve a list of keys from the provided keys. func (ptree *ptree) Get(keys ...common.Comparator) common.Comparators { ga := newGetAction(keys) ptree.checkAndRun(ga) ga.completer.Wait() return ga.result } // Len returns the number of items in the tree. func (ptree *ptree) Len() uint64 { return atomic.LoadUint64(&ptree.number) } // Query will return a list of Comparators that fall within the // provided start and stop Comparators. Start is inclusive while // stop is exclusive, ie [start, stop). func (ptree *ptree) Query(start, stop common.Comparator) common.Comparators { cmps := make(common.Comparators, 0, 32) aa := newApplyAction(func(cmp common.Comparator) bool { cmps = append(cmps, cmp) return true }, start, stop) ptree.checkAndRun(aa) aa.completer.Wait() return cmps } // Dispose will clean up any resources used by this tree. This // must be called to prevent a memory leak. func (ptree *ptree) Dispose() { if atomic.LoadUint64(&ptree.disposed) == 1 { return } ptree.actions.Dispose() atomic.StoreUint64(&ptree.disposed, 1) close(ptree.disposeChannel) } func (ptree *ptree) print(output *log.Logger) { println(`PRINTING TREE`) if ptree.root == nil { return } ptree.root.print(output) } func newTree(bufferSize, ary uint64) *ptree { ptree := &ptree{} ptree.init(bufferSize, ary) return ptree } // New will allocate, initialize, and return a new B-Tree based // on PALM principles. This type of tree is suited for in-memory // indices in a multi-threaded environment. func New(bufferSize, ary uint64) BTree { return newTree(bufferSize, ary) } ================================================ FILE: btree/palm/tree_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package palm import ( "log" "math/rand" "os" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/Workiva/go-datastructures/common" ) func checkTree(t testing.TB, tree *ptree) bool { return true if tree.root == nil { return true } return checkNode(t, tree.root) } func checkNode(t testing.TB, n *node) bool { if n.keys.len() == 0 { assert.Equal(t, uint64(0), n.nodes.len()) return false } if n.isLeaf { assert.Equal(t, uint64(0), n.nodes.len()) return false } if !assert.Equal(t, n.keys.len()+1, n.nodes.len()) { return false } for i, k := range n.keys.list { nd := n.nodes.list[i] if !assert.NotNil(t, nd) { return false } if !assert.True(t, k.Compare(nd.key()) >= 0) { t.Logf(`N: %+v %p, n.keys[i]: %+v, n.nodes[i]: %+v`, n, n, k, nd) return false } } k := n.keys.last() nd := n.nodes.byPosition(n.nodes.len() - 1) if !assert.True(t, k.Compare(nd.key()) < 0) { t.Logf(`m: %+v, %p, n.nodes[len(n.nodes)-1].key(): %+v, n.keys.last(): %+v`, n, n, nd, k) return false } for _, child := range n.nodes.list { if !assert.NotNil(t, child) { return false } if !checkNode(t, child) { return false } } return true } func getConsoleLogger() *log.Logger { return log.New(os.Stderr, "", log.LstdFlags) } func generateRandomKeys(num int) common.Comparators { keys := make(common.Comparators, 0, num) for i := 0; i < num; i++ { m := rand.Int() keys = append(keys, mockKey(m%50)) } return keys } func generateKeys(num int) common.Comparators { keys := make(common.Comparators, 0, num) for i := 0; i < num; i++ { keys = append(keys, mockKey(i)) } return keys } func TestSimpleInsert(t *testing.T) { tree := newTree(16, 16) defer tree.Dispose() m1 := mockKey(1) tree.Insert(m1) assert.Equal(t, common.Comparators{m1}, tree.Get(m1)) assert.Equal(t, uint64(1), tree.Len()) checkTree(t, tree) } func TestSimpleDelete(t *testing.T) { tree := newTree(8, 8) defer tree.Dispose() m1 := mockKey(1) tree.Insert(m1) tree.Delete(m1) assert.Equal(t, uint64(0), tree.Len()) assert.Equal(t, common.Comparators{nil}, tree.Get(m1)) checkTree(t, tree) } func TestMultipleAdd(t *testing.T) { tree := newTree(16, 16) defer tree.Dispose() m1 := mockKey(1) m2 := mockKey(10) tree.Insert(m1, m2) if !assert.Equal(t, common.Comparators{m1, m2}, tree.Get(m1, m2)) { tree.print(getConsoleLogger()) } assert.Equal(t, uint64(2), tree.Len()) checkTree(t, tree) } func TestMultipleDelete(t *testing.T) { tree := newTree(16, 16) defer tree.Dispose() m1 := mockKey(1) m2 := mockKey(10) tree.Insert(m1, m2) tree.Delete(m1, m2) assert.Equal(t, uint64(0), tree.Len()) assert.Equal(t, common.Comparators{nil, nil}, tree.Get(m1, m2)) checkTree(t, tree) } func TestMultipleInsertCausesSplitOddAryReverseOrder(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() keys := generateKeys(100) reversed := reverseKeys(keys) tree.Insert(reversed...) if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleDeleteOddAryReverseOrder(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() keys := generateKeys(100) reversed := reverseKeys(keys) tree.Insert(reversed...) assert.Equal(t, uint64(100), tree.Len()) tree.Delete(reversed...) assert.Equal(t, uint64(0), tree.Len()) for _, k := range reversed { assert.Equal(t, common.Comparators{nil}, tree.Get(k)) } checkTree(t, tree) } func TestMultipleInsertCausesSplitOddAry(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() keys := generateKeys(100) tree.Insert(keys...) if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitOddAryRandomOrder(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() keys := generateRandomKeys(10) tree.Insert(keys...) if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleBulkInsertOddAry(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() keys1 := generateRandomKeys(100) keys2 := generateRandomKeys(100) tree.Insert(keys1...) if !assert.Equal(t, keys1, tree.Get(keys1...)) { tree.print(getConsoleLogger()) } tree.Insert(keys2...) if !assert.Equal(t, keys2, tree.Get(keys2...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleBulkInsertEvenAry(t *testing.T) { tree := newTree(4, 4) defer tree.Dispose() keys1 := generateRandomKeys(100) keys2 := generateRandomKeys(100) tree.Insert(keys1...) tree.Insert(keys2...) if !assert.Equal(t, keys1, tree.Get(keys1...)) { tree.print(getConsoleLogger()) } if !assert.Equal(t, keys2, tree.Get(keys2...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitEvenAryReverseOrder(t *testing.T) { tree := newTree(4, 4) defer tree.Dispose() keys := generateKeys(100) reversed := reverseKeys(keys) tree.Insert(reversed...) if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitEvenAry(t *testing.T) { tree := newTree(4, 4) defer tree.Dispose() keys := generateKeys(100) tree.Insert(keys...) if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestMultipleInsertCausesSplitEvenAryRandomOrder(t *testing.T) { tree := newTree(4, 4) defer tree.Dispose() keys := generateRandomKeys(100) tree.Insert(keys...) if !assert.Equal(t, keys, tree.Get(keys...)) { tree.print(getConsoleLogger()) } checkTree(t, tree) } func TestInsertOverwrite(t *testing.T) { tree := newTree(4, 4) defer tree.Dispose() keys := generateKeys(10) duplicate := mockKey(0) tree.Insert(keys...) tree.Insert(duplicate) assert.Equal(t, common.Comparators{duplicate}, tree.Get(duplicate)) checkTree(t, tree) } func TestSimultaneousReadsAndWrites(t *testing.T) { numLoops := 3 keys := make([]common.Comparators, 0, numLoops) for i := 0; i < numLoops; i++ { keys = append(keys, generateRandomKeys(10)) } tree := newTree(16, 16) defer tree.Dispose() var wg sync.WaitGroup wg.Add(numLoops) for i := 0; i < numLoops; i++ { go func(i int) { tree.Insert(keys[i]...) tree.Get(keys[i]...) wg.Done() }(i) } wg.Wait() for i := 0; i < numLoops; i++ { assert.Equal(t, keys[i], tree.Get(keys[i]...)) } checkTree(t, tree) } func TestInsertAndDelete(t *testing.T) { tree := newTree(1024, 1024) defer tree.Dispose() keys := generateKeys(100) keys1 := keys[:50] keys2 := keys[50:] tree.Insert(keys1...) assert.Equal(t, uint64(len(keys1)), tree.Len()) var wg sync.WaitGroup wg.Add(2) go func() { tree.Insert(keys2...) wg.Done() }() go func() { tree.Delete(keys1...) wg.Done() }() wg.Wait() assert.Equal(t, uint64(len(keys2)), tree.Len()) assert.Equal(t, keys2, tree.Get(keys2...)) } func TestInsertAndDeletesWithSplits(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() keys := generateKeys(100) keys1 := keys[:50] keys2 := keys[50:] tree.Insert(keys1...) assert.Equal(t, uint64(len(keys1)), tree.Len()) var wg sync.WaitGroup wg.Add(2) go func() { tree.Insert(keys2...) wg.Done() }() go func() { tree.Delete(keys1...) wg.Done() }() wg.Wait() assert.Equal(t, uint64(len(keys2)), tree.Len()) assert.Equal(t, keys2, tree.Get(keys2...)) } func TestSimpleQuery(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() m1 := mockKey(1) tree.Insert(m1) result := tree.Query(mockKey(0), mockKey(5)) assert.Equal(t, common.Comparators{m1}, result) result = tree.Query(mockKey(0), mockKey(1)) assert.Len(t, result, 0) result = tree.Query(mockKey(2), mockKey(10)) assert.Len(t, result, 0) result = tree.Query(mockKey(1), mockKey(10)) assert.Equal(t, common.Comparators{m1}, result) } func TestMultipleQuery(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() m1 := mockKey(1) m2 := mockKey(5) tree.Insert(m1, m2) result := tree.Query(mockKey(0), mockKey(10)) assert.Equal(t, common.Comparators{m1, m2}, result) result = tree.Query(mockKey(1), mockKey(5)) assert.Equal(t, common.Comparators{m1}, result) result = tree.Query(mockKey(6), mockKey(10)) assert.Len(t, result, 0) result = tree.Query(mockKey(5), mockKey(10)) assert.Equal(t, common.Comparators{m2}, result) } func TestCrossNodeQuery(t *testing.T) { tree := newTree(3, 3) defer tree.Dispose() keys := generateKeys(100) tree.Insert(keys...) result := tree.Query(mockKey(0), mockKey(len(keys))) if !assert.Equal(t, keys, result) { tree.print(getConsoleLogger()) } } func BenchmarkReadAndWrites(b *testing.B) { numItems := 1000 keys := make([]common.Comparators, 0, b.N) for i := 0; i < b.N; i++ { keys = append(keys, generateRandomKeys(numItems)) } tree := newTree(8, 8) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(keys[i]...) tree.Get(keys[i]...) } } func BenchmarkSimultaneousReadsAndWrites(b *testing.B) { numItems := 10000 numRoutines := 8 keys := generateRandomKeys(numItems) chunks := chunkKeys(keys, int64(numRoutines)) trees := make([]*ptree, 0, numItems) for i := 0; i < b.N; i++ { trees = append(trees, newTree(8, 8)) } var wg sync.WaitGroup b.ResetTimer() for i := 0; i < b.N; i++ { wg.Add(numRoutines) for j := 0; j < numRoutines; j++ { go func(i, j int) { trees[i].Insert(chunks[j]...) trees[i].Get(chunks[j]...) wg.Done() }(i, j) } wg.Wait() } } func BenchmarkBulkAdd(b *testing.B) { numItems := 10000 keys := generateRandomKeys(numItems) trees := make([]*ptree, 0, b.N) for i := 0; i < b.N; i++ { trees = append(trees, newTree(8, 8)) } b.ResetTimer() for i := 0; i < b.N; i++ { trees[i].Insert(keys...) } } func BenchmarkAdd(b *testing.B) { numItems := b.N keys := generateRandomKeys(numItems) tree := newTree(8, 8) // writes will be amortized over node splits b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(keys[i%numItems]) } } func BenchmarkBulkAddToExisting(b *testing.B) { numItems := 100000 keySet := make([]common.Comparators, 0, b.N) for i := 0; i < b.N; i++ { keySet = append(keySet, generateRandomKeys(numItems)) } tree := newTree(8, 8) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(keySet[i]...) } } func BenchmarkGet(b *testing.B) { numItems := 10000 keys := generateRandomKeys(numItems) tree := newTree(8, 8) tree.Insert(keys...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Get(keys[i%numItems]) } } func BenchmarkBulkGet(b *testing.B) { numItems := b.N keys := generateRandomKeys(numItems) tree := newTree(8, 8) tree.Insert(keys...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Get(keys...) } } func BenchmarkDelete(b *testing.B) { numItems := b.N keys := generateRandomKeys(numItems) tree := newTree(8, 8) tree.Insert(keys...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Delete(keys[i%numItems]) } } func BenchmarkBulkDelete(b *testing.B) { numItems := 10000 keys := generateRandomKeys(numItems) trees := make([]*ptree, 0, b.N) for i := 0; i < b.N; i++ { tree := newTree(8, 8) tree.Insert(keys...) trees = append(trees, tree) } b.ResetTimer() for i := 0; i < b.N; i++ { trees[i].Delete(keys...) } } func BenchmarkFindQuery(b *testing.B) { numItems := b.N keys := generateKeys(numItems) tree := newTree(8, 8) tree.Insert(keys...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Query(mockKey(numItems/2), mockKey(numItems/2+1)) } } func BenchmarkExecuteQuery(b *testing.B) { numItems := b.N keys := generateKeys(numItems) tree := newTree(8, 8) tree.Insert(keys...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Query(mockKey(0), mockKey(numItems)) } } ================================================ FILE: btree/plus/btree.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package btree/plus implements the ubiquitous B+ tree. As of this writing, the tree is not quite finished. The delete-node merge functionality needs to be added. There are also some performance improvements that can be made, with some possible concurrency mechanisms. This is a mutable b-tree so it is not threadsafe. Performance characteristics: Space: O(n) Insert: O(log n) Search: O(log n) BenchmarkIteration-8 10000 109347 ns/op BenchmarkInsert-8 3000000 608 ns/op BenchmarkGet-8 3000000 627 ns/op */ package plus func keySearch(keys keys, key Key) int { low, high := 0, len(keys)-1 var mid int for low <= high { mid = (high + low) / 2 switch keys[mid].Compare(key) { case 1: low = mid + 1 case -1: high = mid - 1 case 0: return mid } } return low } type btree struct { root node nodeSize, number uint64 } func (tree *btree) insert(key Key) { if tree.root == nil { n := newLeafNode(tree.nodeSize) n.insert(tree, key) tree.number = 1 return } result := tree.root.insert(tree, key) if result { tree.number++ } if tree.root.needsSplit(tree.nodeSize) { tree.root = split(tree, nil, tree.root) } } // Insert will insert the provided keys into the btree. This is an // O(m*log n) operation where m is the number of keys to be inserted // and n is the number of items in the tree. func (tree *btree) Insert(keys ...Key) { for _, key := range keys { tree.insert(key) } } // Iter returns an iterator that can be used to traverse the b-tree // starting from the specified key or its successor. func (tree *btree) Iter(key Key) Iterator { if tree.root == nil { return nilIterator() } return tree.root.find(key) } func (tree *btree) get(key Key) Key { iter := tree.root.find(key) if !iter.Next() { return nil } if iter.Value().Compare(key) == 0 { return iter.Value() } return nil } // Get will retrieve any keys matching the provided keys in the tree. // Returns nil in any place of a key that couldn't be found. Each lookup // is an O(log n) operation. func (tree *btree) Get(keys ...Key) Keys { results := make(Keys, 0, len(keys)) for _, k := range keys { results = append(results, tree.get(k)) } return results } // Len returns the number of items in this tree. func (tree *btree) Len() uint64 { return tree.number } func newBTree(nodeSize uint64) *btree { return &btree{ nodeSize: nodeSize, root: newLeafNode(nodeSize), } } ================================================ FILE: btree/plus/btree_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package plus import ( "sync" "testing" "github.com/stretchr/testify/assert" ) func TestSearchKeys(t *testing.T) { keys := keys{newMockKey(1), newMockKey(2), newMockKey(4)} testKey := newMockKey(5) assert.Equal(t, 3, keySearch(keys, testKey)) testKey = newMockKey(2) assert.Equal(t, 1, keySearch(keys, testKey)) testKey = newMockKey(0) assert.Equal(t, 0, keySearch(keys, testKey)) testKey = newMockKey(3) assert.Equal(t, 2, keySearch(keys, testKey)) assert.Equal(t, 0, keySearch(nil, testKey)) } func TestTreeInsert2_3_4(t *testing.T) { tree := newBTree(3) keys := constructMockKeys(4) tree.Insert(keys...) assert.Len(t, tree.root.(*inode).keys, 2) assert.Len(t, tree.root.(*inode).nodes, 3) assert.IsType(t, &inode{}, tree.root) } func TestTreeInsert3_4_5(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) tree.Insert(keys...) assert.Len(t, tree.root.(*inode).keys, 1) assert.Len(t, tree.root.(*inode).nodes, 2) assert.IsType(t, &inode{}, tree.root) } func TestTreeInsertQuery2_3_4(t *testing.T) { tree := newBTree(3) keys := constructMockKeys(4) tree.Insert(keys...) iter := tree.Iter(newMockKey(0)) result := iter.exhaust() assert.Equal(t, keys, result) } func TestTreeInsertQuery3_4_5(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) tree.Insert(keys...) iter := tree.Iter(newMockKey(0)) result := iter.exhaust() assert.Equal(t, keys, result) } func TestTreeInsertReverseOrder2_3_4(t *testing.T) { tree := newBTree(3) keys := constructMockKeys(4) keys.reverse() tree.Insert(keys...) iter := tree.Iter(newMockKey(0)) result := iter.exhaust() keys.reverse() // we want to fetch things in the correct // ascending order assert.Equal(t, keys, result) } func TestTreeInsertReverseOrder3_4_5(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) keys.reverse() tree.Insert(keys...) iter := tree.Iter(newMockKey(0)) result := iter.exhaust() keys.reverse() // we want to fetch things in the correct // ascending order assert.Equal(t, keys, result) } func TestTreeInsert3_4_5_WithEndDuplicate(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) tree.Insert(keys...) duplicate := newMockKey(4) tree.Insert(duplicate) keys[4] = duplicate iter := tree.Iter(newMockKey(0)) result := iter.exhaust() assert.Equal(t, keys, result) } func TestTreeInsert3_4_5_WithMiddleDuplicate(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) tree.Insert(keys...) duplicate := newMockKey(2) tree.Insert(duplicate) keys[2] = duplicate iter := tree.Iter(newMockKey(0)) result := iter.exhaust() assert.Equal(t, keys, result) } func TestTreeInsert3_4_5WithEarlyDuplicate(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) tree.Insert(keys...) duplicate := newMockKey(0) tree.Insert(duplicate) keys[0] = duplicate iter := tree.Iter(newMockKey(0)) result := iter.exhaust() assert.Equal(t, keys, result) } func TestTreeInsert3_4_5WithDuplicateID(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) key := newMockKey(2) tree.Insert(keys...) tree.Insert(key) iter := tree.Iter(newMockKey(0)) result := iter.exhaust() assert.Equal(t, keys, result) } func TestTreeInsert3_4_5MiddleQuery(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) tree.Insert(keys...) iter := tree.Iter(newMockKey(2)) result := iter.exhaust() assert.Equal(t, keys[2:], result) } func TestTreeInsert3_4_5LateQuery(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) tree.Insert(keys...) iter := tree.Iter(newMockKey(4)) result := iter.exhaust() assert.Equal(t, keys[4:], result) } func TestTreeInsert3_4_5AfterQuery(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(5) tree.Insert(keys...) iter := tree.Iter(newMockKey(5)) result := iter.exhaust() assert.Len(t, result, 0) } func TestTreeInternalNodeSplit(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(10) tree.Insert(keys...) iter := tree.Iter(newMockKey(0)) result := iter.exhaust() assert.Equal(t, keys, result) } func TestTreeInternalNodeSplitReverseOrder(t *testing.T) { tree := newBTree(4) keys := constructMockKeys(10) keys.reverse() tree.Insert(keys...) iter := tree.Iter(newMockKey(0)) result := iter.exhaust() keys.reverse() assert.Equal(t, keys, result) } func TestTreeInternalNodeSplitRandomOrder(t *testing.T) { ids := []int{6, 2, 9, 0, 3, 4, 7, 1, 8, 5} keys := make(keys, 0, len(ids)) for _, id := range ids { keys = append(keys, newMockKey(id)) } tree := newBTree(4) tree.Insert(keys...) iter := tree.Iter(newMockKey(0)) result := iter.exhaust() assert.Len(t, result, 10) for i, key := range result { assert.Equal(t, newMockKey(i), key) } } func TestTreeRandomOrderQuery(t *testing.T) { ids := []int{6, 2, 9, 0, 3, 4, 7, 1, 8, 5} keys := make(keys, 0, len(ids)) for _, id := range ids { keys = append(keys, newMockKey(id)) } tree := newBTree(4) tree.Insert(keys...) iter := tree.Iter(newMockKey(4)) result := iter.exhaust() assert.Len(t, result, 6) for i, key := range result { assert.Equal(t, newMockKey(i+4), key) } } func TestTreeGet(t *testing.T) { keys := constructRandomMockKeys(100) tree := newBTree(64) tree.Insert(keys...) assert.Equal(t, uint64(100), tree.Len()) fromTree := tree.Get(keys...) for _, key := range keys { assert.Contains(t, fromTree, key) } } func TestTreeGetNotFound(t *testing.T) { keys := constructMockKeys(5) tree := newBTree(64) tree.Insert(keys...) assert.Equal(t, Keys{nil}, tree.Get(newMockKey(20))) } func TestGetExactMatchesOnly(t *testing.T) { k1 := newMockKey(0) k2 := newMockKey(5) tree := newBTree(64) tree.Insert(k1, k2) assert.Equal(t, Keys{nil}, tree.Get(newMockKey(3))) } func BenchmarkIteration(b *testing.B) { numItems := 1000 ary := uint64(16) keys := constructMockKeys(numItems) tree := newBTree(ary) tree.Insert(keys...) searchKey := newMockKey(0) b.ResetTimer() for i := 0; i < b.N; i++ { iter := tree.Iter(searchKey) iter.exhaust() } } func BenchmarkInsert(b *testing.B) { numItems := b.N ary := uint64(16) keys := constructMockKeys(numItems) tree := newBTree(ary) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(keys[i%numItems]) } } func BenchmarkBulkAdd(b *testing.B) { numItems := 10000 keys := constructRandomMockKeys(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree := newBTree(1024) tree.Insert(keys...) } } func BenchmarkGet(b *testing.B) { numItems := b.N ary := uint64(16) keys := constructMockKeys(numItems) tree := newBTree(ary) tree.Insert(keys...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Get(keys[i%numItems]) } } func BenchmarkBulkAddToExisting(b *testing.B) { numItems := 100000 keySet := make([]keys, 0, b.N) for i := 0; i < b.N; i++ { keySet = append(keySet, constructRandomMockKeys(numItems)) } tree := newBTree(1024) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(keySet[i]...) } } func BenchmarkReadAndWrites(b *testing.B) { numItems := 1000 ks := make([]keys, 0, b.N) for i := 0; i < b.N; i++ { ks = append(ks, constructRandomMockKeys(numItems)) } tree := newBTree(16) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(ks[i]...) tree.Get(ks[i]...) } } func BenchmarkSimultaneousReadsAndWrites(b *testing.B) { numItems := 10000 numRoutines := 8 keys := constructRandomMockKeys(numItems) chunks := chunkKeys(keys, int64(numRoutines)) trees := make([]*btree, 0, numItems) for i := 0; i < b.N; i++ { trees = append(trees, newBTree(8)) } var wg sync.WaitGroup var lock sync.Mutex b.ResetTimer() for i := 0; i < b.N; i++ { wg.Add(numRoutines) for j := 0; j < numRoutines; j++ { go func(i, j int) { lock.Lock() trees[i].Insert(chunks[j]...) trees[i].Get(chunks[j]...) lock.Unlock() wg.Done() }(i, j) } wg.Wait() } } ================================================ FILE: btree/plus/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package plus // Keys is a typed list of Key interfaces. type Keys []Key type Key interface { // Compare should return an int indicating how this key relates // to the provided key. -1 will indicate less than, 0 will indicate // equality, and 1 will indicate greater than. Duplicate keys // are allowed, but duplicate IDs are not. Compare(Key) int } // Iterator will be called with matching keys until either false is // returned or we run out of keys to iterate. type Iterator interface { // Next will move the iterator to the next position and return // a bool indicating if there is a value. Next() bool // Value returns a Key at the associated iterator position. Returns // nil if the iterator is exhausted or has never been nexted. Value() Key // exhaust is an internal helper method to iterate this iterator // until exhausted and returns the resulting list of keys. exhaust() keys } ================================================ FILE: btree/plus/iterator.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package plus const iteratorExhausted = -2 type iterator struct { node *lnode index int } func (iter *iterator) Next() bool { if iter.index == iteratorExhausted { return false } iter.index++ if iter.index >= len(iter.node.keys) { iter.node = iter.node.pointer if iter.node == nil { iter.index = iteratorExhausted return false } iter.index = 0 } return true } func (iter *iterator) Value() Key { if iter.index == iteratorExhausted || iter.index < 0 || iter.index >= len(iter.node.keys) { return nil } return iter.node.keys[iter.index] } // exhaust is a test function that's not exported func (iter *iterator) exhaust() keys { keys := make(keys, 0, 10) for iter := iter; iter.Next(); { keys = append(keys, iter.Value()) } return keys } func nilIterator() *iterator { return &iterator{ index: iteratorExhausted, } } ================================================ FILE: btree/plus/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package plus func chunkKeys(ks keys, numParts int64) []keys { parts := make([]keys, numParts) for i := int64(0); i < numParts; i++ { parts[i] = ks[i*int64(len(ks))/numParts : (i+1)*int64(len(ks))/numParts] } return parts } type mockKey struct { value int } func (mk *mockKey) Compare(other Key) int { key := other.(*mockKey) if key.value == mk.value { return 0 } if key.value > mk.value { return 1 } return -1 } func newMockKey(value int) *mockKey { return &mockKey{value} } ================================================ FILE: btree/plus/node.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package plus func split(tree *btree, parent, child node) node { if !child.needsSplit(tree.nodeSize) { return parent } key, left, right := child.split() if parent == nil { in := newInternalNode(tree.nodeSize) in.keys = append(in.keys, key) in.nodes = append(in.nodes, left) in.nodes = append(in.nodes, right) return in } p := parent.(*inode) i := p.search(key) // we want to ensure if the children are leaves we set // the left node's left sibling to point to left if cr, ok := left.(*lnode); ok { if i > 0 { p.nodes[i-1].(*lnode).pointer = cr } } p.keys.insertAt(i, key) p.nodes[i] = left p.nodes.insertAt(i+1, right) return parent } type node interface { insert(tree *btree, key Key) bool needsSplit(nodeSize uint64) bool // key is the median key while left and right nodes // represent the left and right nodes respectively split() (Key, node, node) search(key Key) int find(key Key) *iterator } type nodes []node func (nodes *nodes) insertAt(i int, node node) { if i == len(*nodes) { *nodes = append(*nodes, node) return } *nodes = append(*nodes, nil) copy((*nodes)[i+1:], (*nodes)[i:]) (*nodes)[i] = node } func (ns nodes) splitAt(i int) (nodes, nodes) { left := make(nodes, i, cap(ns)) right := make(nodes, len(ns)-i, cap(ns)) copy(left, ns[:i]) copy(right, ns[i:]) return left, right } type inode struct { keys keys nodes nodes } func (node *inode) search(key Key) int { return node.keys.search(key) } func (node *inode) find(key Key) *iterator { i := node.search(key) if i == len(node.keys) { return node.nodes[len(node.nodes)-1].find(key) } found := node.keys[i] switch found.Compare(key) { case 0, 1: return node.nodes[i+1].find(key) default: return node.nodes[i].find(key) } } func (n *inode) insert(tree *btree, key Key) bool { i := n.search(key) var child node if i == len(n.keys) { // we want the last child node in this case child = n.nodes[len(n.nodes)-1] } else { match := n.keys[i] switch match.Compare(key) { case 1, 0: child = n.nodes[i+1] default: child = n.nodes[i] } } result := child.insert(tree, key) if !result { // no change of state occurred return result } if child.needsSplit(tree.nodeSize) { split(tree, n, child) } return result } func (n *inode) needsSplit(nodeSize uint64) bool { return uint64(len(n.keys)) >= nodeSize } func (n *inode) split() (Key, node, node) { if len(n.keys) < 3 { return nil, nil, nil } i := len(n.keys) / 2 key := n.keys[i] ourKeys := make(keys, len(n.keys)-i-1, cap(n.keys)) otherKeys := make(keys, i, cap(n.keys)) copy(ourKeys, n.keys[i+1:]) copy(otherKeys, n.keys[:i]) left, right := n.nodes.splitAt(i + 1) otherNode := &inode{ keys: otherKeys, nodes: left, } n.keys = ourKeys n.nodes = right return key, otherNode, n } func newInternalNode(size uint64) *inode { return &inode{ keys: make(keys, 0, size), nodes: make(nodes, 0, size+1), } } type lnode struct { // points to the left leaf node is there is one pointer *lnode keys keys } func (node *lnode) search(key Key) int { return node.keys.search(key) } func (lnode *lnode) insert(tree *btree, key Key) bool { i := keySearch(lnode.keys, key) var inserted bool if i == len(lnode.keys) { // simple append will do lnode.keys = append(lnode.keys, key) inserted = true } else { if lnode.keys[i].Compare(key) == 0 { lnode.keys[i] = key } else { lnode.keys.insertAt(i, key) inserted = true } } if !inserted { return false } return true } func (node *lnode) find(key Key) *iterator { i := node.search(key) if i == len(node.keys) { if node.pointer == nil { return nilIterator() } return &iterator{ node: node.pointer, index: -1, } } iter := &iterator{ node: node, index: i - 1, } return iter } func (node *lnode) split() (Key, node, node) { if len(node.keys) < 2 { return nil, nil, nil } i := len(node.keys) / 2 key := node.keys[i] otherKeys := make(keys, i, cap(node.keys)) ourKeys := make(keys, len(node.keys)-i, cap(node.keys)) // we perform these copies so these slices don't all end up // pointing to the same underlying array which may make // for some very difficult to debug situations later. copy(otherKeys, node.keys[:i]) copy(ourKeys, node.keys[i:]) // this should release the original array for GC node.keys = ourKeys otherNode := &lnode{ keys: otherKeys, pointer: node, } return key, otherNode, node } func (lnode *lnode) needsSplit(nodeSize uint64) bool { return uint64(len(lnode.keys)) >= nodeSize } func newLeafNode(size uint64) *lnode { return &lnode{ keys: make(keys, 0, size), } } type keys []Key func (keys keys) search(key Key) int { return keySearch(keys, key) } func (keys *keys) insertAt(i int, key Key) { if i == len(*keys) { *keys = append(*keys, key) return } *keys = append(*keys, nil) copy((*keys)[i+1:], (*keys)[i:]) (*keys)[i] = key } func (keys keys) reverse() { for i := 0; i < len(keys)/2; i++ { keys[i], keys[len(keys)-i-1] = keys[len(keys)-i-1], keys[i] } } ================================================ FILE: btree/plus/node_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package plus import ( "math/rand" "testing" "github.com/stretchr/testify/assert" ) func constructMockPayloads(num int) keys { keys := make(keys, 0, num) for i := 0; i < num; i++ { keys = append(keys, newMockKey(i)) } return keys } func constructMockKeys(num int) keys { keys := make(keys, 0, num) for i := 0; i < num; i++ { keys = append(keys, newMockKey(i)) } return keys } func constructRandomMockKeys(num int) keys { keys := make(keys, 0, num) for i := 0; i < num; i++ { keys = append(keys, newMockKey(rand.Int())) } return keys } func constructMockNodes(num int) nodes { nodes := make(nodes, 0, num) for i := 0; i < num; i++ { keys := make(keys, 0, num) for j := 0; j < num; j++ { keys = append(keys, newMockKey(j*i+j)) } node := &lnode{ keys: keys, } nodes = append(nodes, node) if i > 0 { nodes[i-1].(*lnode).pointer = node } } return nodes } func constructMockInternalNode(nodes nodes) *inode { if len(nodes) < 2 { return nil } keys := make(keys, 0, len(nodes)-1) for i := 1; i < len(nodes); i++ { keys = append(keys, nodes[i].(*lnode).keys[0]) } in := &inode{ keys: keys, nodes: nodes, } return in } func TestLeafNodeInsert(t *testing.T) { tree := newBTree(3) n := newLeafNode(3) key := newMockKey(3) n.insert(tree, key) assert.Len(t, n.keys, 1) assert.Nil(t, n.pointer) assert.Equal(t, n.keys[0], key) assert.Equal(t, 0, n.keys[0].Compare(key)) } func TestDuplicateLeafNodeInsert(t *testing.T) { tree := newBTree(3) n := newLeafNode(3) k1 := newMockKey(3) k2 := newMockKey(3) assert.True(t, n.insert(tree, k1)) assert.False(t, n.insert(tree, k2)) assert.False(t, n.insert(tree, k1)) } func TestMultipleLeafNodeInsert(t *testing.T) { tree := newBTree(3) n := newLeafNode(3) k1 := newMockKey(3) k2 := newMockKey(4) assert.True(t, n.insert(tree, k1)) n.insert(tree, k2) if !assert.Len(t, n.keys, 2) { return } assert.Nil(t, n.pointer) assert.Equal(t, k1, n.keys[0]) assert.Equal(t, k2, n.keys[1]) } func TestLeafNodeSplitEvenNumber(t *testing.T) { keys := constructMockPayloads(4) node := &lnode{ keys: keys, } key, left, right := node.split() assert.Equal(t, keys[2], key) assert.Equal(t, left.(*lnode).keys, keys[:2]) assert.Equal(t, right.(*lnode).keys, keys[2:]) assert.Equal(t, left.(*lnode).pointer, right) } func TestLeafNodeSplitOddNumber(t *testing.T) { keys := constructMockPayloads(3) node := &lnode{ keys: keys, } key, left, right := node.split() assert.Equal(t, keys[1], key) assert.Equal(t, left.(*lnode).keys, keys[:1]) assert.Equal(t, right.(*lnode).keys, keys[1:]) assert.Equal(t, left.(*lnode).pointer, right) } func TestTwoKeysLeafNodeSplit(t *testing.T) { keys := constructMockPayloads(2) node := &lnode{ keys: keys, } key, left, right := node.split() assert.Equal(t, keys[1], key) assert.Equal(t, left.(*lnode).keys, keys[:1]) assert.Equal(t, right.(*lnode).keys, keys[1:]) assert.Equal(t, left.(*lnode).pointer, right) } func TestLessThanTwoKeysSplit(t *testing.T) { keys := constructMockPayloads(1) node := &lnode{ keys: keys, } key, left, right := node.split() assert.Nil(t, key) assert.Nil(t, left) assert.Nil(t, right) } func TestInternalNodeSplit2_3_4(t *testing.T) { nodes := constructMockNodes(4) in := constructMockInternalNode(nodes) key, left, right := in.split() assert.Equal(t, nodes[3].(*lnode).keys[0], key) assert.Len(t, left.(*inode).keys, 1) assert.Len(t, right.(*inode).keys, 1) assert.Equal(t, nodes[:2], left.(*inode).nodes) assert.Equal(t, nodes[2:], right.(*inode).nodes) } func TestInternalNodeSplit3_4_5(t *testing.T) { nodes := constructMockNodes(5) in := constructMockInternalNode(nodes) key, left, right := in.split() assert.Equal(t, nodes[4].(*lnode).keys[0], key) assert.Len(t, left.(*inode).keys, 2) assert.Len(t, right.(*inode).keys, 1) assert.Equal(t, nodes[:3], left.(*inode).nodes) assert.Equal(t, nodes[3:], right.(*inode).nodes) } func TestInternalNodeLessThan3Keys(t *testing.T) { nodes := constructMockNodes(2) in := constructMockInternalNode(nodes) key, left, right := in.split() assert.Nil(t, key) assert.Nil(t, left) assert.Nil(t, right) } ================================================ FILE: cache/cache.go ================================================ package cache import ( "container/list" "sync" ) // Cache is a bounded-size in-memory cache of sized items with a configurable eviction policy type Cache interface { // Get retrieves items from the cache by key. // If an item for a particular key is not found, its position in the result will be nil. Get(keys ...string) []Item // Put adds an item to the cache. Put(key string, item Item) // Remove clears items with the given keys from the cache Remove(keys ...string) // Size returns the size of all items currently in the cache. Size() uint64 } // Item is an item in a cache type Item interface { // Size returns the item's size, in bytes Size() uint64 } // A tuple tracking a cached item and a reference to its node in the eviction list type cached struct { item Item element *list.Element } // Sets the provided list element on the cached item if it is not nil func (c *cached) setElementIfNotNil(element *list.Element) { if element != nil { c.element = element } } // Private cache implementation type cache struct { sync.Mutex // Lock for synchronizing Get, Put, Remove cap uint64 // Capacity bound size uint64 // Cumulative size items map[string]*cached // Map from keys to cached items keyList *list.List // List of cached items in order of increasing evictability recordAdd func(key string) *list.Element // Function called to indicate that an item with the given key was added recordAccess func(key string) *list.Element // Function called to indicate that an item with the given key was accessed } // CacheOption configures a cache. type CacheOption func(*cache) // Policy is a cache eviction policy for use with the EvictionPolicy CacheOption. type Policy uint8 const ( // LeastRecentlyAdded indicates a least-recently-added eviction policy. LeastRecentlyAdded Policy = iota // LeastRecentlyUsed indicates a least-recently-used eviction policy. LeastRecentlyUsed ) // EvictionPolicy sets the eviction policy to be used to make room for new items. // If not provided, default is LeastRecentlyUsed. func EvictionPolicy(policy Policy) CacheOption { return func(c *cache) { switch policy { case LeastRecentlyAdded: c.recordAccess = c.noop c.recordAdd = c.record case LeastRecentlyUsed: c.recordAccess = c.record c.recordAdd = c.noop } } } // New returns a cache with the requested options configured. // The cache consumes memory bounded by a fixed capacity, // plus tracking overhead linear in the number of items. func New(capacity uint64, options ...CacheOption) Cache { c := &cache{ cap: capacity, keyList: list.New(), items: map[string]*cached{}, } // Default LRU eviction policy EvictionPolicy(LeastRecentlyUsed)(c) for _, option := range options { option(c) } return c } func (c *cache) Get(keys ...string) []Item { c.Lock() defer c.Unlock() items := make([]Item, len(keys)) for i, key := range keys { cached := c.items[key] if cached == nil { items[i] = nil } else { c.recordAccess(key) items[i] = cached.item } } return items } func (c *cache) Put(key string, item Item) { c.Lock() defer c.Unlock() // Remove the item currently with this key (if any) c.remove(key) // Make sure there's room to add this item c.ensureCapacity(item.Size()) // Actually add the new item cached := &cached{item: item} cached.setElementIfNotNil(c.recordAdd(key)) cached.setElementIfNotNil(c.recordAccess(key)) c.items[key] = cached c.size += item.Size() } func (c *cache) Remove(keys ...string) { c.Lock() defer c.Unlock() for _, key := range keys { c.remove(key) } } func (c *cache) Size() uint64 { c.Lock() defer c.Unlock() return c.size } // Given the need to add some number of new bytes to the cache, // evict items according to the eviction policy until there is room. // The caller should hold the cache lock. func (c *cache) ensureCapacity(toAdd uint64) { mustRemove := int64(c.size+toAdd) - int64(c.cap) for mustRemove > 0 { key := c.keyList.Back().Value.(string) mustRemove -= int64(c.items[key].item.Size()) c.remove(key) } } // Remove the item associated with the given key. // The caller should hold the cache lock. func (c *cache) remove(key string) { if cached, ok := c.items[key]; ok { delete(c.items, key) c.size -= cached.item.Size() c.keyList.Remove(cached.element) } } // A no-op function that does nothing for the provided key func (c *cache) noop(string) *list.Element { return nil } // A function to record the given key and mark it as last to be evicted func (c *cache) record(key string) *list.Element { if item, ok := c.items[key]; ok { c.keyList.MoveToFront(item.element) return item.element } return c.keyList.PushFront(key) } ================================================ FILE: cache/cache_test.go ================================================ package cache import ( "container/list" "testing" "github.com/stretchr/testify/assert" ) func TestEvictionPolicy(t *testing.T) { c := &cache{keyList: list.New()} EvictionPolicy(LeastRecentlyUsed)(c) accessed, added := c.recordAccess("foo"), c.recordAdd("foo") assert.NotNil(t, accessed) assert.Nil(t, added) c = &cache{keyList: list.New()} EvictionPolicy(LeastRecentlyAdded)(c) accessed, added = c.recordAccess("foo"), c.recordAdd("foo") assert.Nil(t, accessed) assert.NotNil(t, added) } func TestNew(t *testing.T) { optionApplied := false option := func(*cache) { optionApplied = true } c := New(314159, option).(*cache) assert.Equal(t, uint64(314159), c.cap) assert.Equal(t, uint64(0), c.size) assert.NotNil(t, c.items) assert.NotNil(t, c.keyList) assert.True(t, optionApplied) accessed, added := c.recordAccess("foo"), c.recordAdd("foo") assert.NotNil(t, accessed) assert.Nil(t, added) } type testItem uint64 func (ti testItem) Size() uint64 { return uint64(ti) } func TestPutGetRemoveSize(t *testing.T) { keys := []string{"foo", "bar", "baz"} testCases := []struct { label string cache Cache useCache func(c Cache) expectedSize uint64 expectedItems []Item }{{ label: "Items added, key doesn't exist", cache: New(10000), useCache: func(c Cache) { c.Put("foo", testItem(1)) }, expectedSize: 1, expectedItems: []Item{testItem(1), nil, nil}, }, { label: "Items added, key exists", cache: New(10000), useCache: func(c Cache) { c.Put("foo", testItem(1)) c.Put("foo", testItem(10)) }, expectedSize: 10, expectedItems: []Item{testItem(10), nil, nil}, }, { label: "Items added, LRA eviction", cache: New(2, EvictionPolicy(LeastRecentlyAdded)), useCache: func(c Cache) { c.Put("foo", testItem(1)) c.Put("bar", testItem(1)) c.Get("foo") c.Put("baz", testItem(1)) }, expectedSize: 2, expectedItems: []Item{nil, testItem(1), testItem(1)}, }, { label: "Items added, LRU eviction", cache: New(2, EvictionPolicy(LeastRecentlyUsed)), useCache: func(c Cache) { c.Put("foo", testItem(1)) c.Put("bar", testItem(1)) c.Get("foo") c.Put("baz", testItem(1)) }, expectedSize: 2, expectedItems: []Item{testItem(1), nil, testItem(1)}, }, { label: "Items removed, key doesn't exist", cache: New(10000), useCache: func(c Cache) { c.Put("foo", testItem(1)) c.Remove("baz") }, expectedSize: 1, expectedItems: []Item{testItem(1), nil, nil}, }, { label: "Items removed, key exists", cache: New(10000), useCache: func(c Cache) { c.Put("foo", testItem(1)) c.Remove("foo") }, expectedSize: 0, expectedItems: []Item{nil, nil, nil}, }} for _, testCase := range testCases { t.Log(testCase.label) testCase.useCache(testCase.cache) assert.Equal(t, testCase.expectedSize, testCase.cache.Size()) assert.Equal(t, testCase.expectedItems, testCase.cache.Get(keys...)) } } ================================================ FILE: common/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package common // Comparator is a generic interface that represents items that can // be compared. type Comparator interface { // Compare compares this interface with another. Returns a positive // number if this interface is greater, 0 if equal, negative number // if less. Compare(Comparator) int } // Comparators is a typed list of type Comparator. type Comparators []Comparator ================================================ FILE: datastructures.go ================================================ /* Package datastructures exists solely to aid consumers of the go-datastructures library when using dependency managers. Depman, for instance, will work correctly with any datastructure by simply importing this package instead of each subpackage individually. For more information about the datastructures package, see the README at http://github.com/Workiva/go-datastructures */ package datastructures import ( _ "github.com/Workiva/go-datastructures/augmentedtree" _ "github.com/Workiva/go-datastructures/bitarray" _ "github.com/Workiva/go-datastructures/btree/palm" _ "github.com/Workiva/go-datastructures/btree/plus" _ "github.com/Workiva/go-datastructures/fibheap" _ "github.com/Workiva/go-datastructures/futures" _ "github.com/Workiva/go-datastructures/hashmap/fastinteger" _ "github.com/Workiva/go-datastructures/numerics/optimization" _ "github.com/Workiva/go-datastructures/queue" _ "github.com/Workiva/go-datastructures/rangetree" _ "github.com/Workiva/go-datastructures/rangetree/skiplist" _ "github.com/Workiva/go-datastructures/set" _ "github.com/Workiva/go-datastructures/slice" _ "github.com/Workiva/go-datastructures/slice/skip" _ "github.com/Workiva/go-datastructures/sort" _ "github.com/Workiva/go-datastructures/threadsafe/err" _ "github.com/Workiva/go-datastructures/tree/avl" _ "github.com/Workiva/go-datastructures/trie/xfast" _ "github.com/Workiva/go-datastructures/trie/yfast" ) ================================================ FILE: documentation.md ================================================ # Introducing go-datastructures The goal of the go-datastructures library is to port implementations of some common datastructures to Go or to improve on some existing datastructures. These datastructures are designed to be re-used for anyone that needs them throughout the community. (and hopefully improved upon). Given the commonality and popularity of these datastructures in other languages, it is hoped that by open sourcing this library we leverage a great deal of institutional knowledge to improve upon the Go-specific implementations. # Datastructures ## Augmented Tree Designed for determining intersections between ranges. For example, we can query the augmentedtree for any ranges that intersect with a cell, which can be represented as a range of size one (ie, Cell A1 can be represented as range A1:B2 where B2 is exclusive). In this way, we can walk through the graph looking for exact or approximate intersections. The current implementation exists in n-dimensions, but is quickest when an n-dimensional query can be reduced in its first dimension. That is, queries only filtered in anything but the first dimension will be slowest. The actual implementation is a top-down red-black binary search tree. ### Future Implement a bottom-up version as well. ## Bit Array Also known as a bitmap, a bitarray is useful for comparing two sets of data that can be represented as an integer. It's useful because bitwise operations can compare a number of these integers at once instead of independently. For instance, the sets {1, 3, 5} and {3, 5, 7} can be intersected in a single clock cycle if these sets were represented in their associated bit array. Included in this package is the ability to convert a bitarray back to integers. There are two implementations of bit arrays in this package, one is dense and the other borrows concepts from linear algebra's compressed row sparse matrix to represent bitarrays in much smaller spaces. Unfortunately, the sparse version has logarithmic insertions and existence checks but retains some speed advantages when checking for intersections. Incidentally, this is one of two things needed to build a native Go database. ### Future Implement a dense but expandable bit array. Optimize the current package to utilize larger amounts of mechanical sympathy. ## Futures We ran into some cases where we wanted to indicate to a goroutine that an operation had started in another goroutine and to pause go routines until the initial routine had completed. You can do this with buffered channels, but it seems somewhat redundant to send the same result to a channel to ensure all waiting threads were alerted. Futures operate similarly to how ndb futures work in GAE and might be thought of as a "broadcast" channel. ## Queue Pretty self-explanatory, this package includes both a queue and a priority queue. Currently, waitgroups are used to orchestrate threads but with a proper constructor hint, this does end up being faster than channels when attempting to send data to a go routine. The other advantage over a channel is that the queue will return an error if you attempt to put to a queue that has had Dispose called on it instead of panicking like what would happen if you attempted to send to a closed channel. I believe this is closer to the Golang's stated design goals. Speaking of Dispose, calling dispose on a queue will immediately return any waiting threads with an error. ### Future When I get time, I'd like to implement a lockless ring buffer for further performance enhancements. ## Range Tree The range tree is a way to store n-dimensional points of data in a manner that permits logarithmic-complexity queries. These points are usually represented as points on a Cartesian graph represented by integers. There are two implementations of a range tree in this package, one that is mutable and one that is immutable. The mutable version can be faster, but involves lock contention if the consumer needs to ensure threadsafety. The immutable version is a copy-on-write range tree that is optimized by only copying portions of the rangetree on write and is best written to in batches. Operations on the immutable version are slower, but it is safe to read and write from this version at the same time from different threads. Although rangetrees are often represented as BBSTs as described above, the n-dimensional nature of this rangetree actually made the design easier to implement as a sparse n-dimensional array. ### Future Unite both implementations of the rangetree under the same interface. The implementations (especially the immutable one) could use some further performance optimizations. ## Fibonacci Heap The usual Fibonacci Heap with a floating-point priority key. Does a good job as a priority queue, especially for large n. Should be useful in writing an optimal solution for Dijkstra's and Prim's algorithms. (because of it's efficient decrease-key) ### Future I'd like to add a value interface{} pointer that will be able to hold any user data attached to each node in the heap. Another thing would be writing a fast implementation of Dijkstra and Prim using this structure. And a third would be analysing thread-safety and coming up with a thread-safe variant. ## Set Not much to say here. This is an unordered set back by a Go map. This particular version is threadsafe which does hinder performance a bit, although reads can happen simultaneously. ### Future I'd like to experiment with a ground-up implementation of a hash map using the standard library's hash/fnv hashing function, which is a non-cryptographic hash that's proven to be very fast. I'd also like to experiment with a lockless hashmap. ## Slice Golang's standard library "sort" includes a slice of ints that contain some sorting and searching functions. This is like that standard library package but with Int64s, which requires a new package as Go doesn't want us to have generics. I also added a method for inserting to the slice. ## Threadsafe This package just wraps some common interfaces with a lock to make them threadsafe. Golang would tell us to forget about locks and use channels (even though channels themselves are just fancy abstractions around locks as evidenced in their source code) but I found some situations where I wanted to protect some memory that was accessible from multiple goroutines where channels would be ugly, slow, and unnecessary. The only interface with an implemntation thusfar is error, which is useful if you need to indicate that an error was returned from logic running in any number of goroutines. # Going Forward There is a PR into the datastructures repo that contains some pieces required for implementing a B+ tree. With a B+ tree and bitmap, the pieces are in place to write a native Go database. Going forward, I'd like to take these pieces, expand upon them, and implement a fast database in Go. As always, any optimizations or bug fixes in any of this code would be greatly appreciated and encouraged :). These datastructures can and are the foundations of many programs and algorithms, even if they are abstracted away in different libraries which makes working with them a lot of fun and very informative. ================================================ FILE: fibheap/Test Generator/EnqDecrKey.py ================================================ l = [-2901939070.965906, 4539462982.372177, -6222008480.049856, -1400427921.5968666, 9866088144.060883, -2943107648.529664, 8985474333.11443, 9204710651.257133, 5354113876.8447075, 8122228442.770859, -8121418938.303131, 538431208.3261185, 9913821013.519611, -8722989752.449871, -3091279426.694975, 7229910558.195713, -2908838839.99403, 2835257231.305996, 3922059795.3656673, -9298869735.322557] print(l) l = sorted(l) print(l) a1 = l[19] a2 = -8722989752.449871 print(str(a1) + " -> " + str(a2)) l[19] = a2 b1 = l[18] b2 = -9698869735.322557 print(str(b1) + " -> " + str(b2)) l[18] = b2 c1 = l[17] c2 = -9804710651.257133 print(str(c1) + " -> " + str(c2)) l[17] = c2 print(sorted(l)) ================================================ FILE: fibheap/Test Generator/EnqDelete.py ================================================ l = [-2901939070.965906, 4539462982.372177, -6222008480.049856, -1400427921.5968666, 9866088144.060883, -2943107648.529664, 8985474333.11443, 9204710651.257133, 5354113876.8447075, 8122228442.770859, -8121418938.303131, 538431208.3261185, 9913821013.519611, -8722989752.449871, -3091279426.694975, 7229910558.195713, -2908838839.99403, 2835257231.305996, 3922059795.3656673, -9298869735.322557] print(l) l = sorted(l) print(l) a1 = l[19] a2 = 0 print(str(a1) + " -> " + str(a2)) l[19] = a2 b1 = l[18] b2 = 0 print(str(b1) + " -> " + str(b2)) l[18] = b2 c1 = l[17] c2 = 0 print(str(c1) + " -> " + str(c2)) l[17] = c2 print(sorted(l)) ================================================ FILE: fibheap/Test Generator/EnqDeqMin.py ================================================ #!/usr/bin/python3 import random l = [] for i in range(20): l.append(random.uniform(-1E10, 1E10)) print(l) l = sorted(l) print(l) ================================================ FILE: fibheap/Test Generator/Merge.py ================================================ import random l1 = [] l2 = [] for i in range(20): l1.append(random.uniform(-1E10, 1E10)) l2.append(random.uniform(-1E10, 1E10)) print(l1) print(l2) l = [] l.extend(l1) l.extend(l2) print(sorted(l)) ''' [6015943293.071386, -3878285748.0708866, 8674121166.062424, -1528465047.6118088, 7584260716.494843, -373958476.80486107, -6367787695.054295, 6813992306.719868, 5986097626.907181, 9011134545.052086, 7123644338.268343, 2646164210.08445, 4407427446.995375, -888196668.2563229, 7973918726.985172, -6529216482.09644, 6079069259.51853, -8415952427.784341, -6859960084.757652, -502409126.89040375] [9241165993.258648, -9423768405.578083, 3280085607.6687145, -5253703037.682413, 3858507441.2785892, 9896256282.896187, -9439606732.236805, 3082628799.5320206, 9453124863.59945, 9928066165.458393, 1135071669.4712334, 6380353457.986282, 8329064041.853199, 2382910730.445751, -8478491750.445316, 9607469190.690144, 5417691217.440792, -9698248424.421888, -3933774735.280322, -5984555343.381466] [-9698248424.421888, -9439606732.236805, -9423768405.578083, -8478491750.445316, -8415952427.784341, -6859960084.757652, -6529216482.09644, -6367787695.054295, -5984555343.381466, -5253703037.682413, -3933774735.280322, -3878285748.0708866, -1528465047.6118088, -888196668.2563229, -502409126.89040375, -373958476.80486107, 1135071669.4712334, 2382910730.445751, 2646164210.08445, 3082628799.5320206, 3280085607.6687145, 3858507441.2785892, 4407427446.995375, 5417691217.440792, 5986097626.907181, 6015943293.071386, 6079069259.51853, 6380353457.986282, 6813992306.719868, 7123644338.268343, 7584260716.494843, 7973918726.985172, 8329064041.853199, 8674121166.062424, 9011134545.052086, 9241165993.258648, 9453124863.59945, 9607469190.690144, 9896256282.896187, 9928066165.458393] ''' ================================================ FILE: fibheap/Test Generator/README.md ================================================ These are some Python helper scripts to help generate sample test arrays. ================================================ FILE: fibheap/benchmarks.txt ================================================ BenchmarkFibHeap_Enqueue-4 10000000 280 ns/op 64 B/op 1 allocs/op BenchmarkFibHeap_DequeueMin-4 100 16990302 ns/op 16007168 B/op 2 allocs/op BenchmarkFibHeap_DecreaseKey-4 20000000 900 ns/op 122 B/op 3 allocs/op BenchmarkFibHeap_Delete-4 100 19087592 ns/op 16007168 B/op 2 allocs/op BenchmarkFibHeap_Merge-4 3000000 482 ns/op 128 B/op 2 allocs/op PASS coverage: 96.2% of statements ok _/home/nikola/git/go-datastructures/fibheap 37.206s ================================================ FILE: fibheap/fibheap.go ================================================ /* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Special thanks to Keith Schwarz (htiek@cs.stanford.edu), whose code and documentation have been used as a reference for the algorithm implementation. Java implementation: http://www.keithschwarz.com/interesting/code/?dir=fibonacci-heap Binomial heaps: http://web.stanford.edu/class/archive/cs/cs166/cs166.1166/lectures/08/Slides08.pdf Fibonacci heaps: http://web.stanford.edu/class/archive/cs/cs166/cs166.1166/lectures/09/Slides09.pdf */ /*Package fibheap is an implementation of a priority queue backed by a Fibonacci heap, as described by Fredman and Tarjan. Fibonacci heaps are interesting theoretically because they have asymptotically good runtime guarantees for many operations. In particular, insert, peek, and decrease-key all run in amortized O(1) time. dequeueMin and delete each run in amortized O(lg n) time. This allows algorithms that rely heavily on decrease-key to gain significant performance boosts. For example, Dijkstra's algorithm for single-source shortest paths can be shown to run in O(m + n lg n) using a Fibonacci heap, compared to O(m lg n) using a standard binary or binomial heap. Internally, a Fibonacci heap is represented as a circular, doubly-linked list of trees obeying the min-heap property. Each node stores pointers to its parent (if any) and some arbitrary child. Additionally, every node stores its degree (the number of children it has) and whether it is a "marked" node. Finally, each Fibonacci heap stores a pointer to the tree with the minimum value. To insert a node into a Fibonacci heap, a singleton tree is created and merged into the rest of the trees. The merge operation works by simply splicing together the doubly-linked lists of the two trees, then updating the min pointer to be the smaller of the minima of the two heaps. Peeking at the smallest element can therefore be accomplished by just looking at the min element. All of these operations complete in O(1) time. The tricky operations are dequeueMin and decreaseKey. dequeueMin works by removing the root of the tree containing the smallest element, then merging its children with the topmost roots. Then, the roots are scanned and merged so that there is only one tree of each degree in the root list. This works by maintaining a dynamic array of trees, each initially null, pointing to the roots of trees of each dimension. The list is then scanned and this array is populated. Whenever a conflict is discovered, the appropriate trees are merged together until no more conflicts exist. The resulting trees are then put into the root list. A clever analysis using the potential method can be used to show that the amortized cost of this operation is O(lg n), see "Introduction to Algorithms, Second Edition" by Cormen, Rivest, Leiserson, and Stein for more details. The other hard operation is decreaseKey, which works as follows. First, we update the key of the node to be the new value. If this leaves the node smaller than its parent, we're done. Otherwise, we cut the node from its parent, add it as a root, and then mark its parent. If the parent was already marked, we cut that node as well, recursively mark its parent, and continue this process. This can be shown to run in O(1) amortized time using yet another clever potential function. Finally, given this function, we can implement delete by decreasing a key to -\infty, then calling dequeueMin to extract it. */ package fibheap import ( "fmt" "math" ) /****************************************** ************** INTERFACE ***************** ******************************************/ // FloatingFibonacciHeap is an implementation of a fibonacci heap // with only floating-point priorities and no user data attached. type FloatingFibonacciHeap struct { min *Entry // The minimal element size uint // Size of the heap } // Entry is the entry type that will be used // for each node of the Fibonacci heap type Entry struct { degree int marked bool next, prev, child, parent *Entry // Priority is the numerical priority of the node Priority float64 } // EmptyHeapError fires when the heap is empty and an operation could // not be completed for that reason. Its string holds additional data. type EmptyHeapError string func (e EmptyHeapError) Error() string { return string(e) } // NilError fires when a heap or entry is nil and an operation could // not be completed for that reason. Its string holds additional data. type NilError string func (e NilError) Error() string { return string(e) } // NewFloatFibHeap creates a new, empty, Fibonacci heap object. func NewFloatFibHeap() FloatingFibonacciHeap { return FloatingFibonacciHeap{nil, 0} } // Enqueue adds and element to the heap func (heap *FloatingFibonacciHeap) Enqueue(priority float64) *Entry { singleton := newEntry(priority) // Merge singleton list with heap heap.min = mergeLists(heap.min, singleton) heap.size++ return singleton } // Min returns the minimum element in the heap func (heap *FloatingFibonacciHeap) Min() (*Entry, error) { if heap.IsEmpty() { return nil, EmptyHeapError("Trying to get minimum element of empty heap") } return heap.min, nil } // IsEmpty answers: is the heap empty? func (heap *FloatingFibonacciHeap) IsEmpty() bool { return heap.size == 0 } // Size gives the number of elements in the heap func (heap *FloatingFibonacciHeap) Size() uint { return heap.size } // DequeueMin removes and returns the // minimal element in the heap func (heap *FloatingFibonacciHeap) DequeueMin() (*Entry, error) { if heap.IsEmpty() { return nil, EmptyHeapError("Cannot dequeue minimum of empty heap") } heap.size-- // Copy pointer. Will need it later. min := heap.min if min.next == min { // This is the only root node heap.min = nil } else { // There are more root nodes heap.min.prev.next = heap.min.next heap.min.next.prev = heap.min.prev heap.min = heap.min.next // Arbitrary element of the root list } if min.child != nil { // Keep track of the first visited node curr := min.child for ok := true; ok; ok = (curr != min.child) { curr.parent = nil curr = curr.next } } heap.min = mergeLists(heap.min, min.child) if heap.min == nil { // If there are no entries left, we're done. return min, nil } treeSlice := make([]*Entry, 0, heap.size) toVisit := make([]*Entry, 0, heap.size) for curr := heap.min; len(toVisit) == 0 || toVisit[0] != curr; curr = curr.next { toVisit = append(toVisit, curr) } for _, curr := range toVisit { for { for curr.degree >= len(treeSlice) { treeSlice = append(treeSlice, nil) } if treeSlice[curr.degree] == nil { treeSlice[curr.degree] = curr break } other := treeSlice[curr.degree] treeSlice[curr.degree] = nil // Determine which of two trees has the smaller root var minT, maxT *Entry if other.Priority < curr.Priority { minT = other maxT = curr } else { minT = curr maxT = other } // Break max out of the root list, // then merge it into min's child list maxT.next.prev = maxT.prev maxT.prev.next = maxT.next // Make it a singleton so that we can merge it maxT.prev = maxT maxT.next = maxT minT.child = mergeLists(minT.child, maxT) // Reparent max appropriately maxT.parent = minT // Clear max's mark, since it can now lose another child maxT.marked = false // Increase min's degree. It has another child. minT.degree++ // Continue merging this tree curr = minT } /* Update the global min based on this node. Note that we compare * for <= instead of < here. That's because if we just did a * reparent operation that merged two different trees of equal * priority, we need to make sure that the min pointer points to * the root-level one. */ if curr.Priority <= heap.min.Priority { heap.min = curr } } return min, nil } // DecreaseKey decreases the key of the given element, sets it to the new // given priority and returns the node if successfully set func (heap *FloatingFibonacciHeap) DecreaseKey(node *Entry, newPriority float64) (*Entry, error) { if heap.IsEmpty() { return nil, EmptyHeapError("Cannot decrease key in an empty heap") } if node == nil { return nil, NilError("Cannot decrease key: given node is nil") } if newPriority >= node.Priority { return nil, fmt.Errorf("The given new priority: %v, is larger than or equal to the old: %v", newPriority, node.Priority) } decreaseKeyUnchecked(heap, node, newPriority) return node, nil } // Delete deletes the given element in the heap func (heap *FloatingFibonacciHeap) Delete(node *Entry) error { if heap.IsEmpty() { return EmptyHeapError("Cannot delete element from an empty heap") } if node == nil { return NilError("Cannot delete node: given node is nil") } decreaseKeyUnchecked(heap, node, -math.MaxFloat64) heap.DequeueMin() return nil } // Merge returns a new Fibonacci heap that contains // all of the elements of the two heaps. Each of the input heaps is // destructively modified by having all its elements removed. You can // continue to use those heaps, but be aware that they will be empty // after this call completes. func (heap *FloatingFibonacciHeap) Merge(other *FloatingFibonacciHeap) (FloatingFibonacciHeap, error) { if heap == nil || other == nil { return FloatingFibonacciHeap{}, NilError("One of the heaps to merge is nil. Cannot merge") } resultSize := heap.size + other.size resultMin := mergeLists(heap.min, other.min) heap.min = nil other.min = nil heap.size = 0 other.size = 0 return FloatingFibonacciHeap{resultMin, resultSize}, nil } /****************************************** ************** END INTERFACE ************* ******************************************/ // **************** // HELPER FUNCTIONS // **************** func newEntry(priority float64) *Entry { result := new(Entry) result.degree = 0 result.marked = false result.child = nil result.parent = nil result.next = result result.prev = result result.Priority = priority return result } func mergeLists(one, two *Entry) *Entry { if one == nil && two == nil { return nil } else if one != nil && two == nil { return one } else if one == nil && two != nil { return two } // Both trees non-null; actually do the merge. oneNext := one.next one.next = two.next one.next.prev = one two.next = oneNext two.next.prev = two if one.Priority < two.Priority { return one } return two } func decreaseKeyUnchecked(heap *FloatingFibonacciHeap, node *Entry, priority float64) { node.Priority = priority if node.parent != nil && node.Priority <= node.parent.Priority { cutNode(heap, node) } if node.Priority <= heap.min.Priority { heap.min = node } } func cutNode(heap *FloatingFibonacciHeap, node *Entry) { node.marked = false if node.parent == nil { return } // Rewire siblings if it has any if node.next != node { node.next.prev = node.prev node.prev.next = node.next } // Rewrite pointer if this is the representative child node if node.parent.child == node { if node.next != node { node.parent.child = node.next } else { node.parent.child = nil } } node.parent.degree-- node.prev = node node.next = node heap.min = mergeLists(heap.min, node) // cut parent recursively if marked if node.parent.marked { cutNode(heap, node.parent) } else { node.parent.marked = true } node.parent = nil } ================================================ FILE: fibheap/fibheap_examples_test.go ================================================ package fibheap // Tests for the Fibonacci heap with floating point number priorities import ( "fmt" ) const SomeNumber float64 = 15.5 const SomeSmallerNumber float64 = -10.1 const SomeLargerNumber float64 = 112.211 func ExampleFloatingFibonacciHeap_Enqueue() { heap := NewFloatFibHeap() // The function returns a pointer // to the node that contains the new value node := heap.Enqueue(SomeNumber) fmt.Println(node.Priority) // Output: 15.5 } func ExampleFloatingFibonacciHeap_Min() { heap := NewFloatFibHeap() heap.Enqueue(SomeNumber) heap.Enqueue(SomeLargerNumber) min, _ := heap.Min() fmt.Println(min.Priority) // Output: 15.5 } func ExampleFloatingFibonacciHeap_IsEmpty() { heap := NewFloatFibHeap() fmt.Printf("Empty before insert? %v\n", heap.IsEmpty()) heap.Enqueue(SomeNumber) fmt.Printf("Empty after insert? %v\n", heap.IsEmpty()) // Output: // Empty before insert? true // Empty after insert? false } func ExampleFloatingFibonacciHeap_Size() { heap := NewFloatFibHeap() fmt.Printf("Size before insert: %v\n", heap.Size()) heap.Enqueue(SomeNumber) fmt.Printf("Size after insert: %v\n", heap.Size()) // Output: // Size before insert: 0 // Size after insert: 1 } func ExampleFloatingFibonacciHeap_DequeueMin() { heap := NewFloatFibHeap() heap.Enqueue(SomeNumber) node, _ := heap.DequeueMin() fmt.Printf("Dequeueing minimal element: %v\n", node.Priority) // Output: // Dequeueing minimal element: 15.5 } func ExampleFloatingFibonacciHeap_DecreaseKey() { heap := NewFloatFibHeap() node := heap.Enqueue(SomeNumber) min, _ := heap.Min() fmt.Printf("Minimal element before decreasing key: %v\n", min.Priority) heap.DecreaseKey(node, SomeSmallerNumber) min, _ = heap.Min() fmt.Printf("Minimal element after decreasing key: %v\n", min.Priority) // Output: // Minimal element before decreasing key: 15.5 // Minimal element after decreasing key: -10.1 } func ExampleFloatingFibonacciHeap_Delete() { heap := NewFloatFibHeap() node := heap.Enqueue(SomeNumber) heap.Enqueue(SomeLargerNumber) min, _ := heap.Min() fmt.Printf("Minimal element before deletion: %v\n", min.Priority) heap.Delete(node) min, _ = heap.Min() fmt.Printf("Minimal element after deletion: %v\n", min.Priority) // Output: // Minimal element before deletion: 15.5 // Minimal element after deletion: 112.211 } func ExampleFloatingFibonacciHeap_Merge() { heap1 := NewFloatFibHeap() heap2 := NewFloatFibHeap() heap1.Enqueue(SomeNumber) heap1.Enqueue(SomeLargerNumber) heap2.Enqueue(SomeSmallerNumber) min, _ := heap1.Min() fmt.Printf("Minimal element of heap 1: %v\n", min.Priority) min, _ = heap2.Min() fmt.Printf("Minimal element of heap 2: %v\n", min.Priority) heap, _ := heap1.Merge(&heap2) min, _ = heap.Min() fmt.Printf("Minimal element of merged heap: %v\n", min.Priority) // Output: // Minimal element of heap 1: 15.5 // Minimal element of heap 2: -10.1 // Minimal element of merged heap: -10.1 } ================================================ FILE: fibheap/fibheap_single_example_test.go ================================================ package fibheap // Example usage of the Fibonacci heap import ( "fmt" ) const SomeNumberAround0 float64 = -0.001 const SomeLargerNumberAround15 float64 = 15.77 const SomeNumberAroundMinus1000 float64 = -1002.2001 const SomeNumberAroundMinus1003 float64 = -1003.4 func Example() { heap1 := NewFloatFibHeap() fmt.Println("Created heap 1.") nodeh1_1 := heap1.Enqueue(SomeLargerNumberAround15) fmt.Printf("Heap 1 insert: %v\n", nodeh1_1.Priority) heap2 := NewFloatFibHeap() fmt.Println("Created heap 2.") fmt.Printf("Heap 2 is empty? %v\n", heap2.IsEmpty()) nodeh2_1 := heap2.Enqueue(SomeNumberAroundMinus1000) fmt.Printf("Heap 2 insert: %v\n", nodeh2_1.Priority) nodeh2_2 := heap2.Enqueue(SomeNumberAround0) fmt.Printf("Heap 2 insert: %v\n", nodeh2_2.Priority) fmt.Printf("Heap 1 size: %v\n", heap1.Size()) fmt.Printf("Heap 2 size: %v\n", heap2.Size()) fmt.Printf("Heap 1 is empty? %v\n", heap1.IsEmpty()) fmt.Printf("Heap 2 is empty? %v\n", heap2.IsEmpty()) fmt.Printf("\nMerge Heap 1 and Heap 2.\n") mergedHeap, _ := heap1.Merge(&heap2) fmt.Printf("Merged heap size: %v\n", mergedHeap.Size()) fmt.Printf("Set node with priority %v to new priority %v\n", SomeNumberAroundMinus1000, SomeNumberAroundMinus1003) mergedHeap.DecreaseKey(nodeh2_1, SomeNumberAroundMinus1003) min, _ := mergedHeap.DequeueMin() fmt.Printf("Dequeue minimum of merged heap: %v\n", min.Priority) fmt.Printf("Merged heap size: %v\n", mergedHeap.Size()) fmt.Printf("Delete from merged heap: %v\n", SomeNumberAround0) mergedHeap.Delete(nodeh2_2) fmt.Printf("Merged heap size: %v\n", mergedHeap.Size()) min, _ = mergedHeap.DequeueMin() fmt.Printf("Extracting minimum of merged heap: %v\n", min.Priority) fmt.Printf("Merged heap size: %v\n", mergedHeap.Size()) fmt.Printf("Merged heap is empty? %v\n", mergedHeap.IsEmpty()) // Output: // Created heap 1. // Heap 1 insert: 15.77 // Created heap 2. // Heap 2 is empty? true // Heap 2 insert: -1002.2001 // Heap 2 insert: -0.001 // Heap 1 size: 1 // Heap 2 size: 2 // Heap 1 is empty? false // Heap 2 is empty? false // // Merge Heap 1 and Heap 2. // Merged heap size: 3 // Set node with priority -1002.2001 to new priority -1003.4 // Dequeue minimum of merged heap: -1003.4 // Merged heap size: 2 // Delete from merged heap: -0.001 // Merged heap size: 1 // Extracting minimum of merged heap: 15.77 // Merged heap size: 0 // Merged heap is empty? true } ================================================ FILE: fibheap/fibheap_test.go ================================================ package fibheap // Tests for the Fibonacci heap with floating point number priorities import ( "testing" "math/rand" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Go does not have constant arrays. // Settling for standard variables. var NumberSequence1 = [...]float64{6145466173.743959, 1717075442.6908855, -9223106115.008125, 6664774768.783949, -9185895273.675707, -2271628840.682966, -6843837387.469989, -3075112103.982916, -7315786187.596851, 9022422938.330479, 9230482598.051868, -2019031911.3141594, 4852342381.928253, 7767018098.497437, -5163143977.984332, 7265142312.343864, -9974588724.261246, -4721177341.970384, 6608275091.590723, -2509051968.8908787, -2608600569.397663, 4602079812.256586, 4204221071.262924, 2072073006.576254, -1375445006.5510921, 9753983872.378643, 3379810998.918478, -2120599284.15699, -9284902029.588614, 3804069225.763077, 4680667479.457649, 3550845076.5165443, 689351033.7409191, -6170564101.460268, 5769309548.4711685, -7203959673.554039, -1542719821.5259266, 8314666872.8992195, 4582459708.761353, 4558164249.709116, -409019759.7648945, 2050647646.0881348, 3337347280.2468243, 8841975976.437397, -1540752999.8368673, 4548535015.628077, -7013783667.095476, 2287926261.9939594, -2539231979.834078, -9359850979.452446, 5390795464.938633, -9969381716.563528, 3273172669.620493, -8839719143.511513, 9436856014.244781, 9032693590.852093, 748366072.01511, -8165322713.346881, -9745450118.0132, -6554663739.562494, -8350123090.830288, 4767099194.408716, -741610722.9710865, 978853190.937952, -4689006449.5764475, 6712607751.828266, 1834187952.9013042, 8144068220.835762, 2649156704.6132507, 5206492575.513319, 2355676989.886942, 6014313651.805082, 1559476573.9042358, -611075813.2161636, -3428570708.324188, 3758297334.844446, -73880069.57582092, 7939090089.227123, -6135368824.336376, 5680302744.840729, 7067968530.463007, -4736146992.716046, 6787733005.103142, 8291261997.956814, -7976948033.245457, -2717662205.411746, 1753831326.4953232, 3313929049.058649, -6798511690.417229, 4259620288.6441, -8795846089.203701, 666087815.4947224, -3189108786.1266823, 6098522858.07811, 3670419236.2020073, -4904172359.7338295, 7081860835.300518, 4838004130.57917, -8403025837.455175, 2858604246.067789, 9767232443.473625, 1853770486.2323227, 2111315124.8128128, -789990089.2266369, 3855299652.837984, -5262051498.344847, 5195097083.198868, -9453697711.29756, -144320772.42621613, -3280154832.042288, 4327603656.616592, -4916338352.631529, 177342499.89391518, -6863008836.282527, -4462732551.435464, 563531299.3931465, 243815563.513546, -2177539298.657405, 9064363201.461056, 7752407089.025448, 5072315736.623476, 1676308335.832735, 2368433225.444128, 7191228067.770271, -7952866649.176966, 9029961422.270164, -3694580624.20329, 2396384720.634838, 2919689806.6469193, 2516309466.887434, 5711191379.798178, -7111997035.1143055, -5887152915.558975, 7074496594.814234, 72399466.26899147, 9162739770.93885, 545095642.1330223, 589248875.6552525, 5429718452.359911, 2670541446.0850983, 7074768275.337322, -9376701618.064901, -719716639.8418808, 5870465712.600103, 8906050348.824574, 5260686230.481573, 4525930216.3939705, -7558925556.569441, -3524217648.1943235, -8559543174.289785, -402353821.38601303, -2939238306.2766924, -8421788462.600799, 173509960.46243477, 2823962320.1096497, -2040044596.465724, 8093258879.034134, 1026657583.5726833, -5939324535.959578, 1869187366.0910244, -8488159448.309237, -9162642241.327745, 9198652822.209103, 9981219597.001732, 1245929264.1492062, 6333145610.418182, -5007933225.524759, -7507006648.70326, -8682109235.019928, 7572534048.487186, 9172777289.492256, -4374595711.753318, 7302929281.918972, 6813548014.888256, 7839035144.903576, -5126801855.122898, 6523728766.098036, -8063474434.226172, -1011764426.4069233, -5468146510.412097, -7725685149.169344, 5224407910.623154, 5337833362.662783, 3878206583.8412895, -9990847539.012056, 2828249626.7454433, -8802730816.790993, -6223950138.847174, -5003095866.683969, 3701841328.9391365, -7438103512.551224, -1879515137.467103, -6931067459.813007, -3591253518.1452456, -3249229927.5027523, 249923973.47061348, -7291235820.978601, -4073015010.864023, -3089932753.657503, 8220825130.164364} const Seq1FirstMinimum float64 = -9990847539.012056 const Seq1ThirdMinimum float64 = -9969381716.563528 const Seq1FifthMinimum float64 = -9453697711.29756 const Seq1LastMinimum float64 = 9981219597.001732 var NumberSequence2 = [...]float64{-2901939070.965906, 4539462982.372177, -6222008480.049856, -1400427921.5968666, 9866088144.060883, -2943107648.529664, 8985474333.11443, 9204710651.257133, 5354113876.8447075, 8122228442.770859, -8121418938.303131, 538431208.3261185, 9913821013.519611, -8722989752.449871, -3091279426.694975, 7229910558.195713, -2908838839.99403, 2835257231.305996, 3922059795.3656673, -9298869735.322557} const Seq2DecreaseKey1Orig float64 = 9913821013.519611 const Seq2DecreaseKey1Trgt float64 = -8722989752.449871 const Seq2DecreaseKey2Orig float64 = 9866088144.060883 const Seq2DecreaseKey2Trgt float64 = -9698869735.322557 const Seq2DecreaseKey3Orig float64 = 9204710651.257133 const Seq2DecreaseKey3Trgt float64 = -9804710651.257133 var NumberSequence2Sorted = [...]float64{-9804710651.257133, -9698869735.322557, -9298869735.322557, -8722989752.449871, -8722989752.449871, -8121418938.303131, -6222008480.049856, -3091279426.694975, -2943107648.529664, -2908838839.99403, -2901939070.965906, -1400427921.5968666, 538431208.3261185, 2835257231.305996, 3922059795.3656673, 4539462982.372177, 5354113876.8447075, 7229910558.195713, 8122228442.770859, 8985474333.11443} var NumberSequence2Deleted3ElemSorted = [...]float64{-9298869735.322557, -8722989752.449871, -8121418938.303131, -6222008480.049856, -3091279426.694975, -2943107648.529664, -2908838839.99403, -2901939070.965906, -1400427921.5968666, 538431208.3261185, 2835257231.305996, 3922059795.3656673, 4539462982.372177, 5354113876.8447075, 7229910558.195713, 8122228442.770859, 8985474333.11443} var NumberSequence3 = [...]float64{6015943293.071386, -3878285748.0708866, 8674121166.062424, -1528465047.6118088, 7584260716.494843, -373958476.80486107, -6367787695.054295, 6813992306.719868, 5986097626.907181, 9011134545.052086, 7123644338.268343, 2646164210.08445, 4407427446.995375, -888196668.2563229, 7973918726.985172, -6529216482.09644, 6079069259.51853, -8415952427.784341, -6859960084.757652, -502409126.89040375} var NumberSequence4 = [...]float64{9241165993.258648, -9423768405.578083, 3280085607.6687145, -5253703037.682413, 3858507441.2785892, 9896256282.896187, -9439606732.236805, 3082628799.5320206, 9453124863.59945, 9928066165.458393, 1135071669.4712334, 6380353457.986282, 8329064041.853199, 2382910730.445751, -8478491750.445316, 9607469190.690144, 5417691217.440792, -9698248424.421888, -3933774735.280322, -5984555343.381466} var NumberSequenceMerged3And4Sorted = [...]float64{-9698248424.421888, -9439606732.236805, -9423768405.578083, -8478491750.445316, -8415952427.784341, -6859960084.757652, -6529216482.09644, -6367787695.054295, -5984555343.381466, -5253703037.682413, -3933774735.280322, -3878285748.0708866, -1528465047.6118088, -888196668.2563229, -502409126.89040375, -373958476.80486107, 1135071669.4712334, 2382910730.445751, 2646164210.08445, 3082628799.5320206, 3280085607.6687145, 3858507441.2785892, 4407427446.995375, 5417691217.440792, 5986097626.907181, 6015943293.071386, 6079069259.51853, 6380353457.986282, 6813992306.719868, 7123644338.268343, 7584260716.494843, 7973918726.985172, 8329064041.853199, 8674121166.062424, 9011134545.052086, 9241165993.258648, 9453124863.59945, 9607469190.690144, 9896256282.896187, 9928066165.458393} func TestEnqueueDequeueMin(t *testing.T) { heap := NewFloatFibHeap() for i := 0; i < len(NumberSequence1); i++ { heap.Enqueue(NumberSequence1[i]) } var min *Entry var err error for heap.Size() > 0 { min, err = heap.DequeueMin() require.NoError(t, err) if heap.Size() == 199 { assert.Equal(t, Seq1FirstMinimum, min.Priority) } if heap.Size() == 197 { assert.Equal(t, Seq1ThirdMinimum, min.Priority) } if heap.Size() == 195 { assert.Equal(t, Seq1FifthMinimum, min.Priority) } if heap.Size() == 0 { assert.Equal(t, Seq1LastMinimum, min.Priority) } } } func TestFibHeap_Enqueue_Min(t *testing.T) { heap := NewFloatFibHeap() for i := 0; i < len(NumberSequence1); i++ { heap.Enqueue(NumberSequence1[i]) } min, err := heap.Min() require.NoError(t, err) assert.Equal(t, Seq1FirstMinimum, min.Priority) } func TestFibHeap_Min_EmptyHeap(t *testing.T) { heap := NewFloatFibHeap() heap.Enqueue(0) min, err := heap.DequeueMin() require.NoError(t, err) // Heap should be empty at this point min, err = heap.Min() assert.EqualError(t, err, "Trying to get minimum element of empty heap") assert.Nil(t, min) } func TestFibHeap_DequeueMin_EmptyHeap(t *testing.T) { heap := NewFloatFibHeap() min, err := heap.DequeueMin() assert.IsType(t, EmptyHeapError(""), err) assert.EqualError(t, err, "Cannot dequeue minimum of empty heap") assert.Nil(t, min) } func TestEnqueueDecreaseKey(t *testing.T) { heap := NewFloatFibHeap() var e1, e2, e3 *Entry for i := 0; i < len(NumberSequence2); i++ { if NumberSequence2[i] == Seq2DecreaseKey1Orig { e1 = heap.Enqueue(NumberSequence2[i]) } else if NumberSequence2[i] == Seq2DecreaseKey2Orig { e2 = heap.Enqueue(NumberSequence2[i]) } else if NumberSequence2[i] == Seq2DecreaseKey3Orig { e3 = heap.Enqueue(NumberSequence2[i]) } else { heap.Enqueue(NumberSequence2[i]) } } require.NotNil(t, e1) require.NotNil(t, e2) require.NotNil(t, e3) _, err := heap.DecreaseKey(e1, Seq2DecreaseKey1Trgt) require.NoError(t, err) _, err = heap.DecreaseKey(e2, Seq2DecreaseKey2Trgt) require.NoError(t, err) _, err = heap.DecreaseKey(e3, Seq2DecreaseKey3Trgt) require.NoError(t, err) var min *Entry for i := 0; i < len(NumberSequence2Sorted); i++ { min, err = heap.DequeueMin() require.NoError(t, err) assert.Equal(t, NumberSequence2Sorted[i], min.Priority) } } func TestFibHeap_DecreaseKey_EmptyHeap(t *testing.T) { heap := NewFloatFibHeap() elem := heap.Enqueue(15) heap.DequeueMin() // Heap should be empty at this point min, err := heap.DecreaseKey(elem, 0) assert.IsType(t, EmptyHeapError(""), err) assert.EqualError(t, err, "Cannot decrease key in an empty heap") assert.Nil(t, min) } func TestFibHeap_DecreaseKey_NilNode(t *testing.T) { heap := NewFloatFibHeap() heap.Enqueue(1) min, err := heap.DecreaseKey(nil, 0) assert.IsType(t, NilError(""), err) assert.EqualError(t, err, "Cannot decrease key: given node is nil") assert.Nil(t, min) } func TestFibHeap_DecreaseKey_LargerNewPriority(t *testing.T) { heap := NewFloatFibHeap() node := heap.Enqueue(1) min, err := heap.DecreaseKey(node, 20) assert.EqualError(t, err, "The given new priority: 20, is larger than or equal to the old: 1") assert.Nil(t, min) } func TestEnqueueDelete(t *testing.T) { heap := NewFloatFibHeap() var e1, e2, e3 *Entry for i := 0; i < len(NumberSequence2); i++ { if NumberSequence2[i] == Seq2DecreaseKey1Orig { e1 = heap.Enqueue(NumberSequence2[i]) } else if NumberSequence2[i] == Seq2DecreaseKey2Orig { e2 = heap.Enqueue(NumberSequence2[i]) } else if NumberSequence2[i] == Seq2DecreaseKey3Orig { e3 = heap.Enqueue(NumberSequence2[i]) } else { heap.Enqueue(NumberSequence2[i]) } } assert.NotNil(t, e1) assert.NotNil(t, e2) assert.NotNil(t, e3) var err error err = heap.Delete(e1) require.NoError(t, err) err = heap.Delete(e2) require.NoError(t, err) err = heap.Delete(e3) require.NoError(t, err) var min *Entry for i := 0; i < len(NumberSequence2Deleted3ElemSorted); i++ { min, err = heap.DequeueMin() require.NoError(t, err) assert.Equal(t, NumberSequence2Deleted3ElemSorted[i], min.Priority) } } func TestFibHeap_Delete_EmptyHeap(t *testing.T) { heap := NewFloatFibHeap() elem := heap.Enqueue(15) heap.DequeueMin() // Heap should be empty at this point err := heap.Delete(elem) assert.IsType(t, EmptyHeapError(""), err) assert.EqualError(t, err, "Cannot delete element from an empty heap") } func TestFibHeap_Delete_NilNode(t *testing.T) { heap := NewFloatFibHeap() heap.Enqueue(1) err := heap.Delete(nil) assert.IsType(t, NilError(""), err) assert.EqualError(t, err, "Cannot delete node: given node is nil") } func TestMerge(t *testing.T) { heap1 := NewFloatFibHeap() for i := 0; i < len(NumberSequence3); i++ { heap1.Enqueue(NumberSequence3[i]) } heap2 := NewFloatFibHeap() for i := 0; i < len(NumberSequence4); i++ { heap1.Enqueue(NumberSequence4[i]) } heap, err := heap1.Merge(&heap2) require.NoError(t, err) var min *Entry for i := 0; i < len(NumberSequenceMerged3And4Sorted); i++ { min, err = heap.DequeueMin() require.NoError(t, err) assert.Equal(t, NumberSequenceMerged3And4Sorted[i], min.Priority) } } func TestFibHeap_Merge_NilHeap(t *testing.T) { var heap FloatingFibonacciHeap heap = NewFloatFibHeap() newHeap, err := heap.Merge(nil) assert.IsType(t, NilError(""), err) assert.EqualError(t, err, "One of the heaps to merge is nil. Cannot merge") assert.Equal(t, newHeap, FloatingFibonacciHeap{}) } // *************** // BENCHMARK TESTS // *************** /* Since the e.g. Enqeue operation is constant time, when go benchmark increases N, the prep time will increase linearly, but the actual operation we want to measure will always take the same, constant amount of time. This means that on some machines, Go Bench could try to exponentially increase N in order to decrease noise in the measurement, but it will get more and more noise. This can cause a system to run out of RAM. So be careful if you have a fast PC. I have removed the b.ResetTimer on constant-time functions to avoid this negative-feedback loop. */ // Runs in O(1) time func BenchmarkFibHeap_Enqueue(b *testing.B) { heap := NewFloatFibHeap() for i := 0; i < b.N; i++ { heap.Enqueue(2 * 1E10 * (rand.Float64() - 0.5)) } } // Runs in O(log(N)) time func BenchmarkFibHeap_DequeueMin(b *testing.B) { heap := NewFloatFibHeap() N := 1000000 slice := make([]float64, 0, N) for i := 0; i < N; i++ { slice = append(slice, 2*1E10*(rand.Float64()-0.5)) heap.Enqueue(slice[i]) } b.ResetTimer() for i := 0; i < b.N; i++ { heap.DequeueMin() } } // Runs in O(1) amortized time func BenchmarkFibHeap_DecreaseKey(b *testing.B) { heap := NewFloatFibHeap() N := 10000000 sliceFlt := make([]float64, 0, N) sliceE := make([]*Entry, 0, N) for i := 0; i < N; i++ { sliceFlt = append(sliceFlt, 2*1E10*(float64(i)-0.5)) sliceE = append(sliceE, heap.Enqueue(sliceFlt[i])) } b.ResetTimer() offset := float64(2) for i := 0; i < b.N; i++ { // Change offset if b.N larger than N if i%N == 0 && i > 0 { offset *= float64(i / N) } // Shift-decrease keys heap.DecreaseKey(sliceE[i%N], sliceFlt[i%N]-offset) } } // Runs in O(log(N)) time func BenchmarkFibHeap_Delete(b *testing.B) { heap := NewFloatFibHeap() N := 1000000 sliceFlt := make([]float64, 0, N) sliceE := make([]*Entry, 0, N) for i := 0; i < N; i++ { sliceFlt = append(sliceFlt, 2*1E10*(float64(i)-0.5)) sliceE = append(sliceE, heap.Enqueue(sliceFlt[i])) } // Delete runs in log(N) time // so safe to reset timer here b.ResetTimer() for i := 0; i < b.N; i++ { err := heap.Delete(sliceE[i]) assert.NoError(b, err) } } // Runs in O(1) time func BenchmarkFibHeap_Merge(b *testing.B) { heap1 := NewFloatFibHeap() heap2 := NewFloatFibHeap() for i := 0; i < b.N; i++ { heap1.Enqueue(2 * 1E10 * (rand.Float64() - 0.5)) heap2.Enqueue(2 * 1E10 * (rand.Float64() - 0.5)) _, err := heap1.Merge(&heap2) assert.NoError(b, err) } } ================================================ FILE: futures/futures.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package futures is useful for broadcasting an identical message to a multitude of listeners as opposed to channels which will choose a listener at random if multiple listeners are listening to the same channel. The future will also cache the result so any future interest will be immediately returned to the consumer. */ package futures import ( "fmt" "sync" "time" ) // Completer is a channel that the future expects to receive // a result on. The future only receives on this channel. type Completer <-chan interface{} // Future represents an object that can be used to perform asynchronous // tasks. The constructor of the future will complete it, and listeners // will block on getresult until a result is received. This is different // from a channel in that the future is only completed once, and anyone // listening on the future will get the result, regardless of the number // of listeners. type Future struct { triggered bool // because item can technically be nil and still be valid item interface{} err error lock sync.Mutex wg sync.WaitGroup } // GetResult will immediately fetch the result if it exists // or wait on the result until it is ready. func (f *Future) GetResult() (interface{}, error) { f.lock.Lock() if f.triggered { f.lock.Unlock() return f.item, f.err } f.lock.Unlock() f.wg.Wait() return f.item, f.err } // HasResult will return true iff the result exists func (f *Future) HasResult() bool { f.lock.Lock() hasResult := f.triggered f.lock.Unlock() return hasResult } func (f *Future) setItem(item interface{}, err error) { f.lock.Lock() f.triggered = true f.item = item f.err = err f.lock.Unlock() f.wg.Done() } func listenForResult(f *Future, ch Completer, timeout time.Duration, wg *sync.WaitGroup) { wg.Done() t := time.NewTimer(timeout) select { case item := <-ch: f.setItem(item, nil) t.Stop() // we want to trigger GC of this timer as soon as it's no longer needed case <-t.C: f.setItem(nil, fmt.Errorf(`timeout after %f seconds`, timeout.Seconds())) } } // New is the constructor to generate a new future. Pass the completed // item to the toComplete channel and any listeners will get // notified. If timeout is hit before toComplete is called, // any listeners will get passed an error. func New(completer Completer, timeout time.Duration) *Future { f := &Future{} f.wg.Add(1) var wg sync.WaitGroup wg.Add(1) go listenForResult(f, completer, timeout, &wg) wg.Wait() return f } ================================================ FILE: futures/futures_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package futures import ( "sync" "testing" "time" "github.com/stretchr/testify/assert" ) func TestWaitOnGetResult(t *testing.T) { completer := make(chan interface{}) f := New(completer, time.Duration(30*time.Minute)) var result interface{} var err error var wg sync.WaitGroup wg.Add(1) go func() { result, err = f.GetResult() wg.Done() }() completer <- `test` wg.Wait() assert.Nil(t, err) assert.Equal(t, `test`, result) // ensure we don't get paused on the next iteration. result, err = f.GetResult() assert.Equal(t, `test`, result) assert.Nil(t, err) } func TestHasResult(t *testing.T) { completer := make(chan interface{}) f := New(completer, time.Duration(30*time.Minute)) assert.False(t, f.HasResult()) var wg sync.WaitGroup wg.Add(1) go func() { f.GetResult() wg.Done() }() completer <- `test` wg.Wait() assert.True(t, f.HasResult()) } func TestTimeout(t *testing.T) { completer := make(chan interface{}) f := New(completer, time.Duration(0)) result, err := f.GetResult() assert.Nil(t, result) assert.NotNil(t, err) } func BenchmarkFuture(b *testing.B) { completer := make(chan interface{}) timeout := time.Duration(30 * time.Minute) var wg sync.WaitGroup b.ResetTimer() for i := 0; i < b.N; i++ { wg.Add(1) f := New(completer, timeout) go func() { f.GetResult() wg.Done() }() completer <- `test` wg.Wait() } } ================================================ FILE: futures/selectable.go ================================================ /* Copyright 2016 Workiva, LLC Copyright 2016 Sokolov Yura aka funny_falcon Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package futures import ( "errors" "sync" "sync/atomic" ) // ErrFutureCanceled signals that futures in canceled by a call to `f.Cancel()` var ErrFutureCanceled = errors.New("future canceled") // Selectable is a future with channel exposed for external `select`. // Many simultaneous listeners may wait for result either with `f.Value()` // or by selecting/fetching from `f.WaitChan()`, which is closed when future // fulfilled. // Selectable contains sync.Mutex, so it is not movable/copyable. type Selectable struct { m sync.Mutex val interface{} err error wait chan struct{} filled uint32 } // NewSelectable returns new selectable future. // Note: this method is for backward compatibility. // You may allocate it directly on stack or embedding into larger structure func NewSelectable() *Selectable { return &Selectable{} } func (f *Selectable) wchan() <-chan struct{} { f.m.Lock() if f.wait == nil { f.wait = make(chan struct{}) } ch := f.wait f.m.Unlock() return ch } // WaitChan returns channel, which is closed when future is fulfilled. func (f *Selectable) WaitChan() <-chan struct{} { if atomic.LoadUint32(&f.filled) == 1 { return closed } return f.wchan() } // GetResult waits for future to be fulfilled and returns value or error, // whatever is set first func (f *Selectable) GetResult() (interface{}, error) { if atomic.LoadUint32(&f.filled) == 0 { <-f.wchan() } return f.val, f.err } // Fill sets value for future, if it were not already fulfilled // Returns error, if it were already set to future. func (f *Selectable) Fill(v interface{}, e error) error { f.m.Lock() if f.filled == 0 { f.val = v f.err = e atomic.StoreUint32(&f.filled, 1) w := f.wait f.wait = closed if w != nil { close(w) } } f.m.Unlock() return f.err } // SetValue is alias for Fill(v, nil) func (f *Selectable) SetValue(v interface{}) error { return f.Fill(v, nil) } // SetError is alias for Fill(nil, e) func (f *Selectable) SetError(e error) { f.Fill(nil, e) } // Cancel is alias for SetError(ErrFutureCanceled) func (f *Selectable) Cancel() { f.SetError(ErrFutureCanceled) } var closed = make(chan struct{}) func init() { close(closed) } ================================================ FILE: futures/selectable_test.go ================================================ /* Copyright 2016 Workiva, LLC Copyright 2016 Sokolov Yura aka funny_falcon Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package futures import ( "fmt" "sync" "testing" "time" "github.com/stretchr/testify/assert" ) func TestSelectableGetResult(t *testing.T) { f := NewSelectable() var result interface{} var err error var wg sync.WaitGroup wg.Add(1) go func() { result, err = f.GetResult() wg.Done() }() f.SetValue(`test`) wg.Wait() assert.Nil(t, err) assert.Equal(t, `test`, result) // ensure we don't get paused on the next iteration. result, err = f.GetResult() assert.Equal(t, `test`, result) assert.Nil(t, err) } func TestSelectableSetError(t *testing.T) { f := NewSelectable() select { case <-f.WaitChan(): case <-time.After(0): f.SetError(fmt.Errorf("timeout")) } result, err := f.GetResult() assert.Nil(t, result) assert.NotNil(t, err) } func BenchmarkSelectable(b *testing.B) { timeout := time.After(30 * time.Minute) var wg sync.WaitGroup b.ResetTimer() for i := 0; i < b.N; i++ { wg.Add(1) f := NewSelectable() go func() { select { case <-f.WaitChan(): case <-timeout: f.SetError(fmt.Errorf("timeout")) } wg.Done() }() f.SetValue(`test`) wg.Wait() } } ================================================ FILE: go.mod ================================================ module github.com/Workiva/go-datastructures go 1.15 require ( github.com/stretchr/testify v1.7.0 github.com/tinylib/msgp v1.1.5 ) ================================================ FILE: go.sum ================================================ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ= github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tinylib/msgp v1.1.5 h1:2gXmtWueD2HefZHQe1QOy9HVzmFrLOVvsXwXBQ0ayy0= github.com/tinylib/msgp v1.1.5/go.mod h1:eQsjooMTnV42mHu917E26IogZ2930nFyBQdofk10Udg= github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: graph/simple.go ================================================ /* Copyright 2017 Julian Griggs Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package graph provides graph implementations. Currently, this includes an undirected simple graph. */ package graph import ( "errors" "sync" ) var ( // ErrVertexNotFound is returned when an operation is requested on a // non-existent vertex. ErrVertexNotFound = errors.New("vertex not found") // ErrSelfLoop is returned when an operation tries to create a disallowed // self loop. ErrSelfLoop = errors.New("self loops not permitted") // ErrParallelEdge is returned when an operation tries to create a // disallowed parallel edge. ErrParallelEdge = errors.New("parallel edges are not permitted") ) // SimpleGraph is a mutable, non-persistent undirected graph. // Parallel edges and self-loops are not permitted. // Additional description: https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)#Simple_graph type SimpleGraph struct { mutex sync.RWMutex adjacencyList map[interface{}]map[interface{}]struct{} v, e int } // V returns the number of vertices in the SimpleGraph func (g *SimpleGraph) V() int { g.mutex.RLock() defer g.mutex.RUnlock() return g.v } // E returns the number of edges in the SimpleGraph func (g *SimpleGraph) E() int { g.mutex.RLock() defer g.mutex.RUnlock() return g.e } // AddEdge will create an edge between vertices v and w func (g *SimpleGraph) AddEdge(v, w interface{}) error { g.mutex.Lock() defer g.mutex.Unlock() if v == w { return ErrSelfLoop } g.addVertex(v) g.addVertex(w) if _, ok := g.adjacencyList[v][w]; ok { return ErrParallelEdge } g.adjacencyList[v][w] = struct{}{} g.adjacencyList[w][v] = struct{}{} g.e++ return nil } // Adj returns the list of all vertices connected to v func (g *SimpleGraph) Adj(v interface{}) ([]interface{}, error) { g.mutex.RLock() defer g.mutex.RUnlock() deg, err := g.Degree(v) if err != nil { return nil, ErrVertexNotFound } adj := make([]interface{}, deg) i := 0 for key := range g.adjacencyList[v] { adj[i] = key i++ } return adj, nil } // Degree returns the number of vertices connected to v func (g *SimpleGraph) Degree(v interface{}) (int, error) { g.mutex.RLock() defer g.mutex.RUnlock() val, ok := g.adjacencyList[v] if !ok { return 0, ErrVertexNotFound } return len(val), nil } func (g *SimpleGraph) addVertex(v interface{}) { mm, ok := g.adjacencyList[v] if !ok { mm = make(map[interface{}]struct{}) g.adjacencyList[v] = mm g.v++ } } // NewSimpleGraph creates and returns a SimpleGraph func NewSimpleGraph() *SimpleGraph { return &SimpleGraph{ adjacencyList: make(map[interface{}]map[interface{}]struct{}), v: 0, e: 0, } } ================================================ FILE: graph/simple_test.go ================================================ /* Copyright 2017 Julian Griggs Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package graph import ( "testing" "github.com/stretchr/testify/assert" ) func TestV(t *testing.T) { assert := assert.New(t) sgraph := NewSimpleGraph() assert.Equal(0, sgraph.V()) sgraph.AddEdge("A", "B") assert.Equal(2, sgraph.V()) sgraph.AddEdge("B", "C") assert.Equal(3, sgraph.V()) sgraph.AddEdge("A", "C") assert.Equal(3, sgraph.V()) // Parallel edges not allowed sgraph.AddEdge("A", "C") assert.Equal(3, sgraph.V()) sgraph.AddEdge("C", "A") assert.Equal(3, sgraph.V()) // Self loops not allowed sgraph.AddEdge("C", "C") assert.Equal(3, sgraph.V()) sgraph.AddEdge("D", "D") assert.Equal(3, sgraph.V()) } func TestE(t *testing.T) { assert := assert.New(t) sgraph := NewSimpleGraph() assert.Equal(0, sgraph.E()) sgraph.AddEdge("A", "B") assert.Equal(1, sgraph.E()) sgraph.AddEdge("B", "C") assert.Equal(2, sgraph.E()) sgraph.AddEdge("A", "C") assert.Equal(3, sgraph.E()) // Parallel edges not allowed sgraph.AddEdge("A", "C") assert.Equal(3, sgraph.E()) sgraph.AddEdge("C", "A") assert.Equal(3, sgraph.E()) // Self loops not allowed so no edges added sgraph.AddEdge("C", "C") assert.Equal(3, sgraph.E()) sgraph.AddEdge("D", "D") assert.Equal(3, sgraph.E()) } func TestDegree(t *testing.T) { assert := assert.New(t) sgraph := NewSimpleGraph() // No edges added so degree is 0 v, err := sgraph.Degree("A") assert.Zero(v) assert.Error(err) // One edge added sgraph.AddEdge("A", "B") v, err = sgraph.Degree("A") assert.Equal(1, v) assert.Nil(err) // Self loops are not allowed sgraph.AddEdge("A", "A") v, err = sgraph.Degree("A") assert.Equal(1, v) assert.Nil(err) // Parallel edges are not allowed sgraph.AddEdge("A", "B") v, err = sgraph.Degree("A") assert.Equal(1, v) assert.Nil(err) sgraph.AddEdge("B", "A") v, err = sgraph.Degree("A") assert.Equal(1, v) assert.Nil(err) v, err = sgraph.Degree("B") assert.Equal(1, v) assert.Nil(err) sgraph.AddEdge("C", "D") sgraph.AddEdge("A", "C") sgraph.AddEdge("E", "F") sgraph.AddEdge("E", "G") sgraph.AddEdge("H", "G") v, err = sgraph.Degree("A") assert.Equal(2, v) assert.Nil(err) v, err = sgraph.Degree("B") assert.Equal(1, v) assert.Nil(err) v, err = sgraph.Degree("C") assert.Equal(2, v) assert.Nil(err) v, err = sgraph.Degree("D") assert.Equal(1, v) assert.Nil(err) v, err = sgraph.Degree("E") assert.Equal(2, v) assert.Nil(err) v, err = sgraph.Degree("G") assert.Equal(2, v) assert.Nil(err) } func TestAddEdge(t *testing.T) { assert := assert.New(t) sgraph := NewSimpleGraph() err := sgraph.AddEdge("A", "B") assert.Nil(err) err = sgraph.AddEdge("A", "B") assert.Error(err) err = sgraph.AddEdge("B", "A") assert.Error(err) err = sgraph.AddEdge("A", "A") assert.Error(err) err = sgraph.AddEdge("C", "C") assert.Error(err) err = sgraph.AddEdge("B", "C") assert.Nil(err) } func TestAdj(t *testing.T) { assert := assert.New(t) sgraph := NewSimpleGraph() v, err := sgraph.Adj("A") assert.Zero(v) assert.Error(err) // Self loops not allowed sgraph.AddEdge("A", "A") v, err = sgraph.Adj("A") assert.Zero(v) assert.Error(err) sgraph.AddEdge("A", "B") v, err = sgraph.Adj("A") assert.Equal(1, len(v)) assert.Nil(err) assert.Equal("B", v[0]) v, err = sgraph.Adj("B") assert.Equal(1, len(v)) assert.Nil(err) assert.Equal("A", v[0]) // Parallel Edges not allowed sgraph.AddEdge("A", "B") sgraph.AddEdge("B", "A") v, err = sgraph.Adj("B") assert.Equal(1, len(v)) assert.Nil(err) assert.Equal("A", v[0]) sgraph.AddEdge("C", "D") sgraph.AddEdge("A", "C") sgraph.AddEdge("E", "F") sgraph.AddEdge("E", "G") sgraph.AddEdge("H", "G") v, err = sgraph.Adj("A") assert.Equal(2, len(v)) assert.Nil(err) assert.Contains(v, "B") assert.Contains(v, "C") assert.NotContains(v, "A") assert.NotContains(v, "D") v, err = sgraph.Adj("B") assert.Equal(1, len(v)) assert.Nil(err) assert.Contains(v, "A") assert.NotContains(v, "B") assert.NotContains(v, "C") assert.NotContains(v, "D") v, err = sgraph.Adj("C") assert.Equal(2, len(v)) assert.Nil(err) assert.Contains(v, "A") assert.Contains(v, "D") assert.NotContains(v, "B") assert.NotContains(v, "C") v, err = sgraph.Adj("E") assert.Equal(2, len(v)) assert.Nil(err) assert.Contains(v, "F") assert.Contains(v, "G") assert.NotContains(v, "A") v, err = sgraph.Adj("G") assert.Equal(2, len(v)) assert.Nil(err) assert.Contains(v, "E") assert.Contains(v, "H") assert.NotContains(v, "A") } ================================================ FILE: hashmap/fastinteger/hash.go ================================================ package fastinteger // hash will convert the uint64 key into a hash based on Murmur3's 64-bit // integer finalizer. // Details here: https://code.google.com/p/smhasher/wiki/MurmurHash3 func hash(key uint64) uint64 { key ^= key >> 33 key *= 0xff51afd7ed558ccd key ^= key >> 33 key *= 0xc4ceb9fe1a85ec53 key ^= key >> 33 return key } ================================================ FILE: hashmap/fastinteger/hash_test.go ================================================ package fastinteger import ( "encoding/binary" "hash/fnv" "math/rand" "testing" "time" "github.com/stretchr/testify/assert" ) func TestHash(t *testing.T) { key := uint64(5) h := hash(key) assert.NotEqual(t, key, h) } func BenchmarkHash(b *testing.B) { numItems := 1000 r := rand.New(rand.NewSource(time.Now().UnixNano())) keys := make([]uint64, 0, numItems) for i := 0; i < numItems; i++ { key := uint64(r.Int63()) keys = append(keys, key) } b.ResetTimer() for i := 0; i < b.N; i++ { for _, key := range keys { hash(key) } } } func BenchmarkFnvHash(b *testing.B) { numItems := 1000 r := rand.New(rand.NewSource(time.Now().UnixNano())) keys := make([]uint64, 0, numItems) for i := 0; i < numItems; i++ { key := uint64(r.Int63()) keys = append(keys, key) } b.ResetTimer() for i := 0; i < b.N; i++ { for _, key := range keys { hasher := fnv.New64() binary.Write(hasher, binary.LittleEndian, key) hasher.Sum64() } } } ================================================ FILE: hashmap/fastinteger/hashmap.go ================================================ // Package fastinteger is designed to provide a very primitive // implementation of a hash map for unsigned integer keys and // values. It is designed to have existence checks and insertions // that are faster than Go's native implementation. Like Go's // native implementation, FastIntegerHashMap will dynamically // grow in size. // // Current benchmarks on identical machine against native Go implementation: // BenchmarkInsert-8 10000 131258 ns/op // BenchmarkGoMapInsert-8 10000 208787 ns/op // BenchmarkExists-8 100000 15820 ns/op // BenchmarkGoMapExists-8 100000 16394 ns/op // BenchmarkDelete-8 100000 17909 ns/op // BenchmarkGoDelete-8 30000 49376 ns/op // BenchmarkInsertWithExpand-8 20000 90301 ns/op // BenchmarkGoInsertWithExpand-8 10000 142088 ns/op // // // This performance could be further enhanced by using a // better probing technique. package fastinteger const ratio = .75 // ratio sets the capacity the hashmap has to be at before it expands // roundUp takes a uint64 greater than 0 and rounds it up to the next // power of 2. func roundUp(v uint64) uint64 { v-- v |= v >> 1 v |= v >> 2 v |= v >> 4 v |= v >> 8 v |= v >> 16 v |= v >> 32 v++ return v } type packet struct { key, value uint64 } type packets []*packet func (packets packets) find(key uint64) uint64 { h := hash(key) i := h & (uint64(len(packets)) - 1) for packets[i] != nil && packets[i].key != key { i = (i + 1) & (uint64(len(packets)) - 1) } return i } func (packets packets) set(packet *packet) { i := packets.find(packet.key) if packets[i] == nil { packets[i] = packet return } packets[i].value = packet.value } func (packets packets) get(key uint64) (uint64, bool) { i := packets.find(key) if packets[i] == nil { return 0, false } return packets[i].value, true } func (packets packets) delete(key uint64) bool { i := packets.find(key) if packets[i] == nil { return false } packets[i] = nil i = (i + 1) & (uint64(len(packets)) - 1) for packets[i] != nil { p := packets[i] packets[i] = nil packets.set(p) i = (i + 1) & (uint64(len(packets)) - 1) } return true } func (packets packets) exists(key uint64) bool { i := packets.find(key) return packets[i] != nil // technically, they can store nil } // FastIntegerHashMap is a simple hashmap to be used with // integer only keys. It supports few operations, and is designed // primarily for cases where the consumer needs a very simple // datastructure to set and check for existence of integer // keys over a sparse range. type FastIntegerHashMap struct { count uint64 packets packets } // rebuild is an expensive operation which requires us to iterate // over the current bucket and rehash the keys for insertion into // the new bucket. The new bucket is twice as large as the old // bucket by default. func (fi *FastIntegerHashMap) rebuild() { packets := make(packets, roundUp(uint64(len(fi.packets))+1)) for _, packet := range fi.packets { if packet == nil { continue } packets.set(packet) } fi.packets = packets } // Get returns an item from the map if it exists. Otherwise, // returns false for the second argument. func (fi *FastIntegerHashMap) Get(key uint64) (uint64, bool) { return fi.packets.get(key) } // Set will set the provided key with the provided value. func (fi *FastIntegerHashMap) Set(key, value uint64) { if float64(fi.count+1)/float64(len(fi.packets)) > ratio { fi.rebuild() } fi.packets.set(&packet{key: key, value: value}) fi.count++ } // Exists will return a bool indicating if the provided key // exists in the map. func (fi *FastIntegerHashMap) Exists(key uint64) bool { return fi.packets.exists(key) } // Delete will remove the provided key from the hashmap. If // the key cannot be found, this is a no-op. func (fi *FastIntegerHashMap) Delete(key uint64) { if fi.packets.delete(key) { fi.count-- } } // Len returns the number of items in the hashmap. func (fi *FastIntegerHashMap) Len() uint64 { return fi.count } // Cap returns the capacity of the hashmap. func (fi *FastIntegerHashMap) Cap() uint64 { return uint64(len(fi.packets)) } // New returns a new FastIntegerHashMap with a bucket size specified // by hint. func New(hint uint64) *FastIntegerHashMap { if hint == 0 { hint = 16 } hint = roundUp(hint) return &FastIntegerHashMap{ count: 0, packets: make(packets, hint), } } ================================================ FILE: hashmap/fastinteger/hashmap_test.go ================================================ package fastinteger import ( "math/rand" "testing" "time" "github.com/stretchr/testify/assert" ) func generateKeys(num int) []uint64 { r := rand.New(rand.NewSource(time.Now().UnixNano())) keys := make([]uint64, 0, num) for i := 0; i < num; i++ { key := uint64(r.Int63()) keys = append(keys, key) } return keys } func TestRoundUp(t *testing.T) { result := roundUp(21) assert.Equal(t, uint64(32), result) result = roundUp(uint64(1<<31) - 234) assert.Equal(t, uint64(1<<31), result) result = roundUp(uint64(1<<63) - 324) assert.Equal(t, uint64(1<<63), result) } func TestInsert(t *testing.T) { hm := New(10) hm.Set(5, 5) assert.True(t, hm.Exists(5)) value, ok := hm.Get(5) assert.Equal(t, uint64(5), value) assert.True(t, ok) assert.Equal(t, uint64(16), hm.Cap()) } func TestInsertOverwrite(t *testing.T) { hm := New(10) hm.Set(5, 5) hm.Set(5, 10) assert.True(t, hm.Exists(5)) value, ok := hm.Get(5) assert.Equal(t, uint64(10), value) assert.True(t, ok) } func TestGet(t *testing.T) { hm := New(10) value, ok := hm.Get(5) assert.False(t, ok) assert.Equal(t, uint64(0), value) } func TestMultipleInserts(t *testing.T) { hm := New(10) hm.Set(5, 5) hm.Set(6, 6) assert.True(t, hm.Exists(6)) value, ok := hm.Get(6) assert.True(t, ok) assert.Equal(t, uint64(6), value) } func TestRebuild(t *testing.T) { numItems := uint64(100) hm := New(10) for i := uint64(0); i < numItems; i++ { hm.Set(i, i) } for i := uint64(0); i < numItems; i++ { value, _ := hm.Get(i) assert.Equal(t, i, value) } } func TestDelete(t *testing.T) { hm := New(10) hm.Set(5, 5) hm.Set(6, 6) hm.Delete(5) assert.Equal(t, uint64(1), hm.Len()) assert.False(t, hm.Exists(5)) hm.Delete(6) assert.Equal(t, uint64(0), hm.Len()) assert.False(t, hm.Exists(6)) } func TestDeleteAll(t *testing.T) { numItems := uint64(100) hm := New(10) for i := uint64(0); i < numItems; i++ { hm.Set(i, i) } for i := uint64(0); i < numItems; i++ { hm.Delete(i) assert.False(t, hm.Exists(i)) } } func TestDeleteCollision(t *testing.T) { // 1, 27, 42 all hash to the same value using our hash function % 32 if hash(1)%32 != 12 || hash(27)%32 != 12 || hash(42)%32 != 12 { t.Error("test values don't hash to the same value") } m := New(32) m.Set(1, 1) m.Set(27, 27) m.Set(42, 42) m.Delete(27) value, ok := m.Get(42) assert.True(t, ok) assert.Equal(t, uint64(42), value) } func BenchmarkInsert(b *testing.B) { numItems := uint64(1000) keys := generateKeys(int(numItems)) b.ResetTimer() for i := 0; i < b.N; i++ { hm := New(numItems * 2) // so we don't rebuild for _, k := range keys { hm.Set(k, k) } } } func BenchmarkGoMapInsert(b *testing.B) { numItems := uint64(1000) keys := generateKeys(int(numItems)) b.ResetTimer() for i := 0; i < b.N; i++ { hm := make(map[uint64]uint64, numItems*2) // so we don't rebuild for _, k := range keys { hm[k] = k } } } func BenchmarkExists(b *testing.B) { numItems := uint64(1000) keys := generateKeys(int(numItems)) hm := New(numItems * 2) // so we don't rebuild for _, key := range keys { hm.Set(key, key) } b.ResetTimer() for i := 0; i < b.N; i++ { for _, key := range keys { hm.Exists(key) } } } func BenchmarkGoMapExists(b *testing.B) { numItems := uint64(1000) keys := generateKeys(int(numItems)) hm := make(map[uint64]uint64, numItems*2) // so we don't rebuild for _, key := range keys { hm[key] = key } b.ResetTimer() var ok bool for i := 0; i < b.N; i++ { for _, key := range keys { _, ok = hm[key] // or the compiler complains } } b.StopTimer() if ok { // or the compiler complains } } func BenchmarkDelete(b *testing.B) { numItems := uint64(1000) hms := make([]*FastIntegerHashMap, 0, b.N) for i := 0; i < b.N; i++ { hm := New(numItems * 2) for j := uint64(0); j < numItems; j++ { hm.Set(j, j) } hms = append(hms, hm) } b.ResetTimer() for i := 0; i < b.N; i++ { hm := hms[i] for j := uint64(0); j < numItems; j++ { hm.Delete(j) } } } func BenchmarkGoDelete(b *testing.B) { numItems := uint64(1000) hms := make([]map[uint64]uint64, 0, b.N) for i := 0; i < b.N; i++ { hm := make(map[uint64]uint64, numItems*2) for j := uint64(0); j < numItems; j++ { hm[j] = j } hms = append(hms, hm) } b.ResetTimer() for i := 0; i < b.N; i++ { hm := hms[i] for j := uint64(0); j < numItems; j++ { delete(hm, j) } } } func BenchmarkInsertWithExpand(b *testing.B) { numItems := uint64(1000) hms := make([]*FastIntegerHashMap, 0, b.N) for i := 0; i < b.N; i++ { hm := New(10) hms = append(hms, hm) } b.ResetTimer() for i := 0; i < b.N; i++ { hm := hms[i] for j := uint64(0); j < numItems; j++ { hm.Set(j, j) } } } func BenchmarkGoInsertWithExpand(b *testing.B) { numItems := uint64(1000) hms := make([]map[uint64]uint64, 0, b.N) for i := 0; i < b.N; i++ { hm := make(map[uint64]uint64, 10) hms = append(hms, hm) } b.ResetTimer() for i := 0; i < b.N; i++ { hm := hms[i] for j := uint64(0); j < numItems; j++ { hm[j] = j } } } ================================================ FILE: list/persistent.go ================================================ /* Copyright 2015 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package list provides list implementations. Currently, this includes a persistent, immutable linked list. */ package list import "errors" var ( // Empty is an empty PersistentList. Empty PersistentList = &emptyList{} // ErrEmptyList is returned when an invalid operation is performed on an // empty list. ErrEmptyList = errors.New("Empty list") ) // PersistentList is an immutable, persistent linked list. type PersistentList interface { // Head returns the head of the list. The bool will be false if the list is // empty. Head() (interface{}, bool) // Tail returns the tail of the list. The bool will be false if the list is // empty. Tail() (PersistentList, bool) // IsEmpty indicates if the list is empty. IsEmpty() bool // Length returns the number of items in the list. Length() uint // Add will add the item to the list, returning the new list. Add(head interface{}) PersistentList // Insert will insert the item at the given position, returning the new // list or an error if the position is invalid. Insert(val interface{}, pos uint) (PersistentList, error) // Get returns the item at the given position or an error if the position // is invalid. Get(pos uint) (interface{}, bool) // Remove will remove the item at the given position, returning the new // list or an error if the position is invalid. Remove(pos uint) (PersistentList, error) // Find applies the predicate function to the list and returns the first // item which matches. Find(func(interface{}) bool) (interface{}, bool) // FindIndex applies the predicate function to the list and returns the // index of the first item which matches or -1 if there is no match. FindIndex(func(interface{}) bool) int // Map applies the function to each entry in the list and returns the // resulting slice. Map(func(interface{}) interface{}) []interface{} } type emptyList struct{} // Head returns the head of the list. The bool will be false if the list is // empty. func (e *emptyList) Head() (interface{}, bool) { return nil, false } // Tail returns the tail of the list. The bool will be false if the list is // empty. func (e *emptyList) Tail() (PersistentList, bool) { return nil, false } // IsEmpty indicates if the list is empty. func (e *emptyList) IsEmpty() bool { return true } // Length returns the number of items in the list. func (e *emptyList) Length() uint { return 0 } // Add will add the item to the list, returning the new list. func (e *emptyList) Add(head interface{}) PersistentList { return &list{head, e} } // Insert will insert the item at the given position, returning the new list or // an error if the position is invalid. func (e *emptyList) Insert(val interface{}, pos uint) (PersistentList, error) { if pos == 0 { return e.Add(val), nil } return nil, ErrEmptyList } // Get returns the item at the given position or an error if the position is // invalid. func (e *emptyList) Get(pos uint) (interface{}, bool) { return nil, false } // Remove will remove the item at the given position, returning the new list or // an error if the position is invalid. func (e *emptyList) Remove(pos uint) (PersistentList, error) { return nil, ErrEmptyList } // Find applies the predicate function to the list and returns the first item // which matches. func (e *emptyList) Find(func(interface{}) bool) (interface{}, bool) { return nil, false } // FindIndex applies the predicate function to the list and returns the index // of the first item which matches or -1 if there is no match. func (e *emptyList) FindIndex(func(interface{}) bool) int { return -1 } // Map applies the function to each entry in the list and returns the resulting // slice. func (e *emptyList) Map(func(interface{}) interface{}) []interface{} { return nil } type list struct { head interface{} tail PersistentList } // Head returns the head of the list. The bool will be false if the list is // empty. func (l *list) Head() (interface{}, bool) { return l.head, true } // Tail returns the tail of the list. The bool will be false if the list is // empty. func (l *list) Tail() (PersistentList, bool) { return l.tail, true } // IsEmpty indicates if the list is empty. func (l *list) IsEmpty() bool { return false } // Length returns the number of items in the list. func (l *list) Length() uint { curr := l length := uint(0) for { length += 1 tail, _ := curr.Tail() if tail.IsEmpty() { return length } curr = tail.(*list) } } // Add will add the item to the list, returning the new list. func (l *list) Add(head interface{}) PersistentList { return &list{head, l} } // Insert will insert the item at the given position, returning the new list or // an error if the position is invalid. func (l *list) Insert(val interface{}, pos uint) (PersistentList, error) { if pos == 0 { return l.Add(val), nil } nl, err := l.tail.Insert(val, pos-1) if err != nil { return nil, err } return nl.Add(l.head), nil } // Get returns the item at the given position or an error if the position is // invalid. func (l *list) Get(pos uint) (interface{}, bool) { if pos == 0 { return l.head, true } return l.tail.Get(pos - 1) } // Remove will remove the item at the given position, returning the new list or // an error if the position is invalid. func (l *list) Remove(pos uint) (PersistentList, error) { if pos == 0 { nl, _ := l.Tail() return nl, nil } nl, err := l.tail.Remove(pos - 1) if err != nil { return nil, err } return &list{l.head, nl}, nil } // Find applies the predicate function to the list and returns the first item // which matches. func (l *list) Find(pred func(interface{}) bool) (interface{}, bool) { if pred(l.head) { return l.head, true } return l.tail.Find(pred) } // FindIndex applies the predicate function to the list and returns the index // of the first item which matches or -1 if there is no match. func (l *list) FindIndex(pred func(interface{}) bool) int { curr := l idx := 0 for { if pred(curr.head) { return idx } tail, _ := curr.Tail() if tail.IsEmpty() { return -1 } curr = tail.(*list) idx += 1 } } // Map applies the function to each entry in the list and returns the resulting // slice. func (l *list) Map(f func(interface{}) interface{}) []interface{} { return append(l.tail.Map(f), f(l.head)) } ================================================ FILE: list/persistent_test.go ================================================ /* Copyright 2015 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package list import ( "testing" "github.com/stretchr/testify/assert" ) func TestEmptyList(t *testing.T) { assert := assert.New(t) head, ok := Empty.Head() assert.Nil(head) assert.False(ok) tail, ok := Empty.Tail() assert.Nil(tail) assert.False(ok) assert.True(Empty.IsEmpty()) } func TestAdd(t *testing.T) { assert := assert.New(t) l1 := Empty.Add(1) // l1: [1] assert.False(l1.IsEmpty()) head, ok := l1.Head() assert.True(ok) assert.Equal(1, head) tail, ok := l1.Tail() assert.True(ok) assert.Equal(Empty, tail) l1 = l1.Add(2) // l1: [2, 1] head, ok = l1.Head() assert.True(ok) assert.Equal(2, head) tail, ok = l1.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(1, head) l2, err := l1.Insert("a", 1) assert.Nil(err) // l1: [2, 1] // l2: [2, "a", 1] head, ok = l1.Head() assert.True(ok) assert.Equal(2, head) tail, ok = l1.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(1, head) head, ok = l2.Head() assert.True(ok) assert.Equal(2, head) tail, ok = l2.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal("a", head) tail, ok = tail.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(1, head) } func TestInsertAndGet(t *testing.T) { assert := assert.New(t) _, err := Empty.Insert(1, 5) assert.Error(err) l, err := Empty.Insert(1, 0) assert.Nil(err) // [1] item, ok := l.Get(0) assert.True(ok) assert.Equal(1, item) l, err = l.Insert(2, 0) assert.Nil(err) // [2, 1] item, ok = l.Get(0) assert.True(ok) assert.Equal(2, item) item, ok = l.Get(1) assert.True(ok) assert.Equal(1, item) _, ok = l.Get(2) assert.False(ok) l, err = l.Insert("a", 3) assert.Nil(l) assert.Error(err) } func TestRemove(t *testing.T) { assert := assert.New(t) l, err := Empty.Remove(0) assert.Nil(l) assert.Error(err) l = Empty.Add(1) l = l.Add(2) l = l.Add(3) // [3, 2, 1] l1, err := l.Remove(3) assert.Nil(l1) assert.Error(err) l2, err := l.Remove(0) // l: [3, 2, 1] // l2: [2, 1] assert.Nil(err) head, ok := l.Head() assert.True(ok) assert.Equal(3, head) tail, ok := l.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(2, head) tail, ok = tail.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(1, head) assert.Nil(err) head, ok = l2.Head() assert.True(ok) assert.Equal(2, head) tail, ok = l2.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(1, head) l2, err = l.Remove(1) // l: [3, 2, 1] // l2: [3, 1] assert.Nil(err) head, ok = l.Head() assert.True(ok) assert.Equal(3, head) tail, ok = l.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(2, head) tail, ok = tail.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(1, head) assert.Nil(err) head, ok = l2.Head() assert.True(ok) assert.Equal(3, head) tail, ok = l2.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(1, head) l2, err = l.Remove(2) // l: [3, 2, 1] // l2: [3, 2] assert.Nil(err) head, ok = l.Head() assert.True(ok) assert.Equal(3, head) tail, ok = l.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(2, head) tail, ok = tail.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(1, head) assert.Nil(err) head, ok = l2.Head() assert.True(ok) assert.Equal(3, head) tail, ok = l2.Tail() assert.True(ok) head, ok = tail.Head() assert.True(ok) assert.Equal(2, head) } func TestFind(t *testing.T) { assert := assert.New(t) pred := func(item interface{}) bool { return item == 1 } found, ok := Empty.Find(pred) assert.Nil(found) assert.False(ok) l := Empty.Add("blah").Add("bleh") found, ok = l.Find(pred) assert.Nil(found) assert.False(ok) l = l.Add(1).Add("foo") found, ok = l.Find(pred) assert.Equal(1, found) assert.True(ok) } func TestFindIndex(t *testing.T) { assert := assert.New(t) pred := func(item interface{}) bool { return item == 1 } idx := Empty.FindIndex(pred) assert.Equal(-1, idx) l := Empty.Add("blah").Add("bleh") idx = l.FindIndex(pred) assert.Equal(-1, idx) l = l.Add(1).Add("foo") idx = l.FindIndex(pred) assert.Equal(1, idx) } func TestLength(t *testing.T) { assert := assert.New(t) assert.Equal(uint(0), Empty.Length()) l := Empty.Add("foo") assert.Equal(uint(1), l.Length()) l = l.Add("bar").Add("baz") assert.Equal(uint(3), l.Length()) } func TestMap(t *testing.T) { assert := assert.New(t) f := func(x interface{}) interface{} { return x.(int) * x.(int) } assert.Nil(Empty.Map(f)) l := Empty.Add(1).Add(2).Add(3).Add(4) assert.Equal([]interface{}{1, 4, 9, 16}, l.Map(f)) } ================================================ FILE: mock/batcher.go ================================================ /* Copyright 2015 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package mock import ( "github.com/stretchr/testify/mock" "github.com/Workiva/go-datastructures/batcher" ) var _ batcher.Batcher = new(Batcher) type Batcher struct { mock.Mock PutChan chan bool } func (m *Batcher) Put(items interface{}) error { args := m.Called(items) if m.PutChan != nil { m.PutChan <- true } return args.Error(0) } func (m *Batcher) Get() ([]interface{}, error) { args := m.Called() return args.Get(0).([]interface{}), args.Error(1) } func (m *Batcher) Flush() error { args := m.Called() return args.Error(0) } func (m *Batcher) Dispose() { m.Called() } func (m *Batcher) IsDisposed() bool { args := m.Called() return args.Bool(0) } ================================================ FILE: mock/rangetree.go ================================================ package mock import ( "github.com/stretchr/testify/mock" "github.com/Workiva/go-datastructures/rangetree" ) type RangeTree struct { mock.Mock } var _ rangetree.RangeTree = new(RangeTree) func (m *RangeTree) Add(entries ...rangetree.Entry) rangetree.Entries { args := m.Called(entries) ifc := args.Get(0) if ifc == nil { return nil } return ifc.(rangetree.Entries) } func (m *RangeTree) Len() uint64 { return m.Called().Get(0).(uint64) } func (m *RangeTree) Delete(entries ...rangetree.Entry) rangetree.Entries { return m.Called(entries).Get(0).(rangetree.Entries) } func (m *RangeTree) Query(interval rangetree.Interval) rangetree.Entries { args := m.Called(interval) ifc := args.Get(0) if ifc == nil { return nil } return ifc.(rangetree.Entries) } func (m *RangeTree) InsertAtDimension(dimension uint64, index, number int64) (rangetree.Entries, rangetree.Entries) { args := m.Called(dimension, index, number) return args.Get(0).(rangetree.Entries), args.Get(1).(rangetree.Entries) } func (m *RangeTree) Apply(interval rangetree.Interval, fn func(rangetree.Entry) bool) { m.Called(interval, fn) } func (m *RangeTree) Get(entries ...rangetree.Entry) rangetree.Entries { ifc := m.Called(entries).Get(0) if ifc == nil { return nil } return ifc.(rangetree.Entries) } ================================================ FILE: numerics/hilbert/hilbert.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package Hilbert is designed to allow consumers to find the Hilbert distance on the Hilbert curve if given a 2 dimensional coordinate. This could be useful for hashing or constructing a Hilbert R-Tree. Algorithm taken from here: http://en.wikipedia.org/wiki/Hilbert_curve This expects coordinates in the range [0, 0] to [MaxInt32, MaxInt32]. Using negative values for x and y will have undefinied behavior. Benchmarks: BenchmarkEncode-8 10000000 181 ns/op BenchmarkDecode-8 10000000 191 ns/op */ package hilbert // n defines the maximum power of 2 that can define a bound, // this is the value for 2-d space if you want to support // all hilbert ids with a single integer variable const n = 1 << 31 func boolToInt(value bool) int32 { if value { return int32(1) } return int32(0) } func rotate(n, rx, ry int32, x, y *int32) { if ry == 0 { if rx == 1 { *x = n - 1 - *x *y = n - 1 - *y } t := *x *x = *y *y = t } } // Encode will encode the provided x and y coordinates into a Hilbert // distance. func Encode(x, y int32) int64 { var rx, ry int32 var d int64 for s := int32(n / 2); s > 0; s /= 2 { rx = boolToInt(x&s > 0) ry = boolToInt(y&s > 0) d += int64(int64(s) * int64(s) * int64(((3 * rx) ^ ry))) rotate(s, rx, ry, &x, &y) } return d } // Decode will decode the provided Hilbert distance into a corresponding // x and y value, respectively. func Decode(h int64) (int32, int32) { var ry, rx int64 var x, y int32 t := h for s := int64(1); s < int64(n); s *= 2 { rx = 1 & (t / 2) ry = 1 & (t ^ rx) rotate(int32(s), int32(rx), int32(ry), &x, &y) x += int32(s * rx) y += int32(s * ry) t /= 4 } return x, y } ================================================ FILE: numerics/hilbert/hilbert_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package hilbert import ( "math" "testing" "github.com/stretchr/testify/assert" ) func TestHilbert(t *testing.T) { h := Encode(0, 0) x, y := Decode(h) assert.Equal(t, int64(0), h) assert.Equal(t, int32(0), x) assert.Equal(t, int32(0), y) h = Encode(1, 0) x, y = Decode(h) assert.Equal(t, int64(3), h) assert.Equal(t, int32(1), x) assert.Equal(t, int32(0), y) h = Encode(1, 1) x, y = Decode(h) assert.Equal(t, int64(2), h) assert.Equal(t, int32(1), x) assert.Equal(t, int32(1), y) h = Encode(0, 1) x, y = Decode(h) assert.Equal(t, int64(1), h) assert.Equal(t, int32(0), x) assert.Equal(t, int32(1), y) } func TestHilbertAtMaxRange(t *testing.T) { x, y := int32(math.MaxInt32), int32(math.MaxInt32) h := Encode(x, y) resultx, resulty := Decode(h) assert.Equal(t, x, resultx) assert.Equal(t, y, resulty) } func BenchmarkEncode(b *testing.B) { for i := 0; i < b.N; i++ { Encode(int32(i), int32(i)) } } func BenchmarkDecode(b *testing.B) { for i := 0; i < b.N; i++ { Decode(int64(i)) } } ================================================ FILE: numerics/optimization/global.go ================================================ package optimization import ( "math" "sort" ) type pbs []*vertexProbabilityBundle type vertexProbabilityBundle struct { probability float64 vertex *nmVertex } // calculateVVP will calculate the variable variance probability // of the provided vertex based on the previous best guess // and the provided sigma. The sigma changes with each run // of the optimization algorithm and accounts for a changing // number of guesses. // // VVP is defined as: // 1/((2*pi)^(1/2)*sigma)*(1-e^(-dmin^2/2*sigma^2)) // where dmin = euclidean distance between this vertex and the best guess // and sigma = (3*(m^(1/n)))^-1 // func calculateVVP(guess, vertex *nmVertex, sigma float64) float64 { distance := -guess.euclideanDistance(vertex) lhs := 1 / (math.Sqrt(2*math.Pi) * sigma) rhs := 1 - math.Exp(math.Pow(distance, 2)/(2*math.Pow(sigma, 2))) return rhs * lhs } // calculateSigma will calculate sigma based on the provided information. // Typically, sigma will decrease as the number of sampled points // increases. // // sigma = (3*(m^(1/n)))^-1 // func calculateSigma(dimensions, guesses int) float64 { return math.Pow(3*math.Pow(float64(guesses), 1/float64(dimensions)), -1) } func (pbs pbs) calculateProbabilities(bestGuess *nmVertex, sigma float64) { for _, v := range pbs { v.probability = calculateVVP(bestGuess, v.vertex, sigma) } } func (pbs pbs) sort() { sort.Sort(pbs) } func (pbs pbs) Less(i, j int) bool { return pbs[i].probability < pbs[j].probability } func (pbs pbs) Swap(i, j int) { pbs[i], pbs[j] = pbs[j], pbs[i] } func (pbs pbs) Len() int { return len(pbs) } // results stores the results of previous iterations of the // nelder-mead algorithm type results struct { // vertices are the results generated by the algorithm vertices vertices // config is useful for examining target config NelderMeadConfiguration // pbs contains the randomly generated guess vertices pbs pbs } // search will search this list of results based on order, order // being defined in the NelderMeadConfiguration, that is a defined // target will be treated func (results *results) search(result *nmVertex) int { return sort.Search(len(results.vertices), func(i int) bool { return !results.vertices[i].less(results.config, result) }) } func (results *results) exists(result *nmVertex, hint int) bool { if hint < 0 { hint = results.search(result) } // maximum hint here should be len(results.vertices) if hint > 0 && results.vertices[hint-1].approximatelyEqualToVertex(result) { return true } // -1 here because if hint == len(vertices) we would've already // checked the last value in the previous conditional if hint < len(results.vertices)-1 && results.vertices[hint].approximatelyEqualToVertex(result) { return true } return false } func (results *results) insert(vertex *nmVertex) { i := results.search(vertex) if results.exists(vertex, i) { return } if i == len(results.vertices) { results.vertices = append(results.vertices, vertex) return } results.vertices = append(results.vertices, nil) copy(results.vertices[i+1:], results.vertices[i:]) results.vertices[i] = vertex } func (results *results) grab(num int) vertices { vs := make(vertices, 0, num) // first, copy what you want to the list to return // not returning a sub-slice as we're about to mutate // the original slice for i := 0; i < num; i++ { vs = append(vs, results.pbs[i].vertex) } // now we overwrite the vertices that we are taking // from the beginning copy(results.pbs, results.pbs[num:]) length := len(results.pbs) - num // this next part is required for the GC for i := length; i < len(results.pbs); i++ { results.pbs[i] = nil } // and finally set the new slice as a subslice results.pbs = results.pbs[:length] return vs } // reSort will re-sort the list of possible guess vertices // based upon the latest calculated result. It was also // add this result to the list of results. func (results *results) reSort(vertex *nmVertex) { results.insert(vertex) bestGuess := results.vertices[0] sigma := calculateSigma(len(results.config.Vars), len(results.vertices)) results.pbs.calculateProbabilities(bestGuess, sigma) results.pbs.sort() } func newResults(guess *nmVertex, config NelderMeadConfiguration, num int) *results { vertices := make(vertices, 0, num+1) vertices = append(vertices, guess) vertices = append(vertices, generateRandomVerticesFromGuess(guess, num)...) bundles := make(pbs, 0, len(vertices)) for _, v := range vertices { bundles = append(bundles, &vertexProbabilityBundle{vertex: v}) } return &results{ pbs: bundles, config: config, } } ================================================ FILE: numerics/optimization/nelder_mead.go ================================================ package optimization import ( "fmt" "math" "math/rand" "sort" "time" ) const ( alpha = 1 // reflection, must be > 0 beta = 2 // expansion, must be > 1 gamma = .5 // contraction, 0 < gamma < 1 sigma = .5 // shrink, 0 < sigma < 1 delta = .0001 // going to use this to determine convergence maxRuns = 130 maxIterations = 5 // maxIterations defines the number of restarts that should // occur when attempting to find a global critical point ) var ( min = math.Inf(-1) max = math.Inf(1) ) // generateRandomVerticesFromGuess will generate num number of vertices // with random func generateRandomVerticesFromGuess(guess *nmVertex, num int) vertices { // summed allows us to prevent duplicate guesses, checking // all previous guesses for every guess created would be too // time consuming so we take an indexed shortcut here. summed // is a map of a sum of the vars to the vertices that have an // identical sum. In this way, we can sum the vars of a new guess // and check only a small subset of previous guesses to determine // if this is an identical guess. summed := make(map[float64]vertices, num) dimensions := len(guess.vars) vs := make(vertices, 0, num) i := 0 r := rand.New(rand.NewSource(time.Now().UnixNano())) Guess: for i < num { sum := float64(0) vars := make([]float64, 0, dimensions) for j := 0; j < dimensions; j++ { v := r.Float64() * 1000 // we do a separate random check here to determine // sign so we don't end up with all high v's one sign // and low v's another if r.Float64() > .5 { v = -v } sum += v vars = append(vars, v) } guess := &nmVertex{ vars: vars, } if vs, ok := summed[sum]; !ok { vs = make(vertices, 0, dimensions) // dimensions is really just a guess, no real way of knowing what this is vs = append(vs, guess) summed[sum] = vs } else { for _, vertex := range vs { // if we've already guessed this, try the loop again if guess.equalToVertex(vertex) { continue Guess } } vs = append(vs, guess) } vs = append(vs, guess) i++ } return vs } func isInf(num float64) bool { return math.IsInf(num, -1) || math.IsInf(num, 1) } func findMin(vertices ...*nmVertex) *nmVertex { min := vertices[0] for _, v := range vertices[1:] { if v.distance < min.distance { min = v } } return min } // findMidpoint will find the midpoint of the provided vertices // and return a new vertex. func findMidpoint(vertices ...*nmVertex) *nmVertex { num := len(vertices) // this is what we divide by vars := make([]float64, 0, num) for i := 0; i < num; i++ { sum := float64(0) for _, v := range vertices { sum += v.vars[i] } vars = append(vars, sum/float64(num)) } return &nmVertex{ vars: vars, } } // determineDistance will determine the distance between the value // and the target. If the target is positive or negative infinity, // (ie find max or min), this is clamped to max or min float64. func determineDistance(value, target float64) float64 { if math.IsInf(target, 1) { // positive infinity target = math.MaxFloat64 } else if math.IsInf(target, -1) { // negative infinity target = -math.MaxFloat64 } return math.Abs(target - value) } type vertices []*nmVertex // evaluate will call evaluate on all the verticies in this list // and order them by distance to target. func (vertices vertices) evaluate(config NelderMeadConfiguration) { for _, v := range vertices { v.evaluate(config) } vertices.sort(config) } func (vertices vertices) sort(config NelderMeadConfiguration) { sorter := sorter{ config: config, vertices: vertices, } sorter.sort() } type sorter struct { config NelderMeadConfiguration vertices vertices } func (sorter sorter) sort() { sort.Sort(sorter) } // the following methods are required for sort.Interface. We // use the standard libraries sort here as it uses an adaptive // sort and we really don't expect there to be a ton of dimensions // here so mulithreaded sort in this repo really isn't // necessary. func (sorter sorter) Less(i, j int) bool { return sorter.vertices[i].less(sorter.config, sorter.vertices[j]) } func (sorter sorter) Len() int { return len(sorter.vertices) } func (sorter sorter) Swap(i, j int) { sorter.vertices[i], sorter.vertices[j] = sorter.vertices[j], sorter.vertices[i] } // String prints out a string representation of every vertex in this list. // Useful for debugging :). func (vertices vertices) String() string { result := `` for i, v := range vertices { result += fmt.Sprintf(`VERTEX INDEX: %+v, VERTEX: %+v`, i, v) result += fmt.Sprintln(``) } return result } // NelderMeadConfiguration is the struct that must be // passed into the NelderMead function. This defines // the target value, the function to be run, and a guess // of the variables. type NelderMeadConfiguration struct { // Target is the target we are trying to converge // to. Set this to positive or negative infinity // to find the min/max. Target float64 // Fn defines the function that Nelder Mead is going // to call to determine if it is moving closer // to convergence. In all likelihood, the execution // of this function is going to be the bottleneck. // The second value returns a bool indicating if the // calculated values are "good", that is, that no // constraint has been hit. Fn func([]float64) (float64, bool) // Vars is a guess and will determine what other // vertices will be used. By convention, since // this guess will contain as many numbers as the // target function requires, the len of Vars determines // the dimension of this problem. Vars []float64 } type nmVertex struct { // vars indicates the values used to calculate this vertex. vars []float64 // distance is the distance between this vertex and the desired // value. This metric has little meaning if the desired value // is +- inf. // result is the calculated result of this vertex. This can // be used to measure distance or as a metrix to compare two // vertices if the desired result is a min/max. distance, result float64 // good indicates if the calculated values here // are within all constraints, this should always // be true if this vertex is in a list of vertices. good bool } func (nm *nmVertex) evaluate(config NelderMeadConfiguration) { nm.result, nm.good = config.Fn(nm.vars) nm.distance = determineDistance(nm.result, config.Target) } func (nm *nmVertex) add(other *nmVertex) *nmVertex { vars := make([]float64, 0, len(nm.vars)) for i := 0; i < len(nm.vars); i++ { vars = append(vars, nm.vars[i]+other.vars[i]) } return &nmVertex{ vars: vars, } } func (nm *nmVertex) multiply(scalar float64) *nmVertex { vars := make([]float64, 0, len(nm.vars)) for i := 0; i < len(nm.vars); i++ { vars = append(vars, nm.vars[i]*scalar) } return &nmVertex{ vars: vars, } } func (nm *nmVertex) subtract(other *nmVertex) *nmVertex { vars := make([]float64, 0, len(nm.vars)) for i := 0; i < len(nm.vars); i++ { vars = append(vars, nm.vars[i]-other.vars[i]) } return &nmVertex{ vars: vars, } } // less defines a relationship between two points. It is best not to // think of less as returning a value indicating absolute relationship between // two points, but instead think of less returning a bool indicating // if this vertex is *closer* to the desired convergence, or a delta // less than the other vertex. For -inf, this returns a value indicating // if this vertex has a less absolute value than the other vertex, if +inf // less returns a bool indicating if this vertex has a *greater* absolute // value than the other vertex. Otherwise, this method returns a bool // indicating if this vertex is closer to *converging* upon the desired // value. func (nm *nmVertex) less(config NelderMeadConfiguration, other *nmVertex) bool { if config.Target == min { // looking for a min return nm.result < other.result } if config.Target == max { // looking for a max return nm.result > other.result } return nm.distance < other.distance } func (nm *nmVertex) equal(config NelderMeadConfiguration, other *nmVertex) bool { if isInf(config.Target) { // if we are looking for a min or max, we compare result return nm.result == other.result } // otherwise, we compare distances return nm.distance == other.distance } // euclideanDistance determines the euclidean distance between two points. func (nm *nmVertex) euclideanDistance(other *nmVertex) float64 { sum := float64(0) // first we want to sum all the distances between the points for i, otherPoint := range other.vars { // distance between points is defined by (qi-ri)^2 sum += math.Pow(otherPoint-nm.vars[i], 2) } return math.Sqrt(sum) } // equalToVertex will compare this vertex to the provided vertex // to determine if the two vertices are actually identical (that is, // they fall on the same point). func (nm *nmVertex) equalToVertex(other *nmVertex) bool { for i, n := range nm.vars { if n != other.vars[i] { return false } } return true } // approximatelyEqualToVertex returns a bool indicating if the // *result* of this vertex is approximately equal to the vertex // provided. Approximately is 2 * delta as the algorithm may // cease within a delta distance of the true value, so we may // end up with a result that's 2*delta away if we came from // the other direction. func (nm *nmVertex) approximatelyEqualToVertex(other *nmVertex) bool { return math.Abs(nm.result-other.result) < 2*delta } type nelderMead struct { config NelderMeadConfiguration results *results } // evaluateWithConstraints will safely evaluate the vertex while // conforming to any imposed restraints. If a constraint is found, // this method will backtrack the vertex as described here: // http://www.iccm-central.org/Proceedings/ICCM16proceedings/contents/pdf/MonK/MoKA1-04ge_ghiasimh224461p.pdf // This should work with even non-linear constraints, but it is up to // the consumer to check these constraints. func (nm *nelderMead) evaluateWithConstraints(vertices vertices, vertex *nmVertex) *nmVertex { vertex.evaluate(nm.config) return vertex if vertex.good { return vertex } best := vertices[0] for i := 0; i < 5; i++ { vertex = best.add((vertex.subtract(best).multiply(alpha))) if vertex.good { return vertex } } return best } // reflect will find the reflection point between the two best guesses // with the provided midpoint. func (nm *nelderMead) reflect(vertices vertices, midpoint *nmVertex) *nmVertex { toScalar := midpoint.subtract(nm.lastVertex(vertices)) toScalar = toScalar.multiply(alpha) toScalar = midpoint.add(toScalar) return nm.evaluateWithConstraints(vertices, toScalar) } func (nm *nelderMead) expand(vertices vertices, midpoint, reflection *nmVertex) *nmVertex { toScalar := reflection.subtract(midpoint) toScalar = toScalar.multiply(beta) toScalar = midpoint.add(toScalar) return nm.evaluateWithConstraints(vertices, toScalar) } // lastDimensionVertex returns the vertex that is represented by the // last dimension, effectively, second to last in the list of // vertices. func (nm *nelderMead) lastDimensionVertex(vertices vertices) *nmVertex { return vertices[len(vertices)-2] } // lastVertex returns the last vertex in the list of vertices. // It's important to remember that this vertex represents the // number of dimensions + 1. func (nm *nelderMead) lastVertex(vertices vertices) *nmVertex { return vertices[len(vertices)-1] } func (nm *nelderMead) outsideContract(vertices vertices, midpoint, reflection *nmVertex) *nmVertex { toScalar := reflection.subtract(midpoint) toScalar = toScalar.multiply(gamma) toScalar = midpoint.add(toScalar) return nm.evaluateWithConstraints(vertices, toScalar) } func (nm *nelderMead) insideContract(vertices vertices, midpoint, reflection *nmVertex) *nmVertex { toScalar := reflection.subtract(midpoint) toScalar = toScalar.multiply(gamma) toScalar = midpoint.subtract(toScalar) return nm.evaluateWithConstraints(vertices, toScalar) } func (nm *nelderMead) shrink(vertices vertices) { one := vertices[0] for i := 1; i < len(vertices); i++ { toScalar := vertices[i].subtract(one) toScalar = toScalar.multiply(sigma) vertices[i] = one.add(toScalar) } } // checkIteration checks some key values to determine if // iteration should be complete. Returns false if iteration // should be terminated and true if iteration should continue. func (nm *nelderMead) checkIteration(vertices vertices) bool { // this will never be true for += inf if math.Abs(vertices[0].result-nm.config.Target) < delta { return false } best := vertices[0] // here we are checking distance convergence. If all vertices // are near convergence, that is they are all within some delta // from the expected value, we can go ahead and quit early. This // can only be performed on convergence checks, not for finding // min/max. if !isInf(nm.config.Target) { for _, v := range vertices[1:] { if math.Abs(best.distance-v.distance) >= delta { return true } } } // next we want to check to see if the changes in our polytopes // dip below some threshold. That is, we want to look at the // euclidean distances between the best guess and all the other // guesses to see if they are converged upon some point. If // all of the vertices have converged close enough, it may be // worth it to cease iteration. for _, v := range vertices[1:] { if best.euclideanDistance(v) >= delta { return true } } return false } func (nm *nelderMead) evaluate() { vertices := nm.results.grab(len(nm.config.Vars) + 1) // if the initial guess provided is not good, then // we are going to die early, leave it up to the user // to create a good first guess. vertices[0].evaluate(nm.config) if !vertices[0].good { nm.results.insert(vertices[0]) return } // the outer loop controls how hard we try to find // a global critical point for i := 0; i < maxIterations; i++ { // the inner loop controls the degenerate case where // we can't converge to a critical point for j := 0; j < maxRuns; j++ { // TODO: optimize this to prevent duplicate evaluations. vertices.evaluate(nm.config) best := vertices[0] if !nm.checkIteration(vertices) { break } midpoint := findMidpoint(vertices[:len(vertices)-1]...) // we are guaranteed to have two points here reflection := nm.reflect(vertices, midpoint) // we could not find a reflection that met constraints, the // best guess is the best guess. if reflection == best { break } // in this case, quality-wise, we are between the best // and second to best points if reflection.less(nm.config, nm.lastDimensionVertex(vertices)) && !vertices[0].less(nm.config, reflection) { vertices[len(vertices)-1] = reflection } // midpoint is closer than our previous best guess if reflection.less(nm.config, vertices[0]) { expanded := nm.expand(vertices, midpoint, reflection) // we could not expand a valid guess, best is the best guess if expanded == best { break } // we only need to expand here if expanded.less(nm.config, reflection) { vertices[len(vertices)-1] = expanded } else { vertices[len(vertices)-1] = reflection } continue } // reflection is a bad guess, let's try to contract both // inside and outside and see if we can find a better value if reflection.less(nm.config, nm.lastVertex(vertices)) { oc := nm.outsideContract(vertices, midpoint, reflection) if oc == best { break } if oc.less(nm.config, reflection) || oc.equal(nm.config, reflection) { vertices[len(vertices)-1] = oc continue } } else if !reflection.less(nm.config, nm.lastVertex(vertices)) { ic := nm.insideContract(vertices, midpoint, reflection) if ic == best { break } if ic.less(nm.config, nm.lastVertex(vertices)) { vertices[len(vertices)-1] = ic continue } } // we could not guess a better value than nm.vertices[0], so // let's converge the other to guesses to our best guess. nm.shrink(vertices) } nm.results.reSort(vertices[0]) vertices = nm.results.grab(len(nm.config.Vars) + 1) } } func newNelderMead(config NelderMeadConfiguration) *nelderMead { v := &nmVertex{vars: config.Vars} // construct initial vertex with first guess results := newResults(v, config, 1000) // 1000 represents 1000 initial vertex guesses return &nelderMead{ config: config, results: results, } } // NelderMead takes a configuration and returns a list // of floats that can be plugged into the provided function // to converge at the target value. func NelderMead(config NelderMeadConfiguration) []float64 { nm := newNelderMead(config) nm.evaluate() return nm.results.vertices[0].vars } ================================================ FILE: numerics/optimization/nelder_mead_test.go ================================================ package optimization import ( "math" "testing" "github.com/stretchr/testify/assert" ) func TestNelderMead(t *testing.T) { fn := func(vars []float64) (float64, bool) { return vars[0] * vars[1], true } config := NelderMeadConfiguration{ Target: float64(9), Fn: fn, Vars: []float64{2, 4}, } result, _ := fn(NelderMead(config)) assert.True(t, math.Abs(result-config.Target) <= .01) } func TestNelderMeadPolynomial(t *testing.T) { fn := func(vars []float64) (float64, bool) { // x^2-4x+y^2-y-xy, solution is (3, 2) return math.Pow(vars[0], 2) - 4*vars[0] + math.Pow(vars[1], 2) - vars[1] - vars[0]*vars[1], true } config := NelderMeadConfiguration{ Target: float64(-100), Fn: fn, Vars: []float64{-10, 10}, } result := NelderMead(config) calced, _ := fn(result) assert.True(t, math.Abs(7-math.Abs(calced)) <= .01) assert.True(t, math.Abs(3-result[0]) <= .1) assert.True(t, math.Abs(2-result[1]) <= .1) } func TestNelderMeadPolynomialMin(t *testing.T) { fn := func(vars []float64) (float64, bool) { // x^2-4x+y^2-y-xy, solution is (3, 2) return math.Pow(vars[0], 2) - 4*vars[0] + math.Pow(vars[1], 2) - vars[1] - vars[0]*vars[1], true } config := NelderMeadConfiguration{ Target: math.Inf(-1), Fn: fn, Vars: []float64{-10, 10}, } result := NelderMead(config) calced, _ := fn(result) assert.True(t, math.Abs(7-math.Abs(calced)) <= .01) assert.True(t, math.Abs(3-result[0]) <= .01) assert.True(t, math.Abs(2-result[1]) <= .01) } func TestNelderMeadPolynomialMax(t *testing.T) { fn := func(vars []float64) (float64, bool) { // 3+sin(x)+2cos(y)^2, the min on this equation is 2 and the max is 6 return 3 + math.Sin(vars[0]) + 2*math.Pow(math.Cos(vars[1]), 2), true } config := NelderMeadConfiguration{ Target: math.Inf(1), Fn: fn, Vars: []float64{-5, 5}, } result := NelderMead(config) calced, _ := fn(result) assert.True(t, math.Abs(6-math.Abs(calced)) <= .01) } func TestNelderMeadConstrained(t *testing.T) { fn := func(vars []float64) (float64, bool) { if vars[0] < 1 || vars[1] < 1 { return 0, false } return math.Pow(vars[0], 2) - 4*vars[0] + math.Pow(vars[1], 2) - vars[1] - vars[0]*vars[1], true } // by default, converging on this point with the initial // guess of (6, 3) will converge to (~.46, ~4.75). The // fn has the added constraint that no guesses may be below // 1. This should now converge to a point (~8.28, ~4.93). config := NelderMeadConfiguration{ Target: float64(14), Fn: fn, Vars: []float64{6, 3}, } result := NelderMead(config) calced, _ := fn(result) assert.True(t, math.Abs(14-math.Abs(calced)) <= .01) assert.True(t, result[0] >= 1) assert.True(t, result[1] >= 1) fn = func(vars []float64) (float64, bool) { if vars[0] < 6 || vars[0] > 8 { return 0, false } if vars[1] < 0 || vars[1] > 2 { return 0, false } return math.Pow(vars[0], 2) - 4*vars[0] + math.Pow(vars[1], 2) - vars[1] - vars[0]*vars[1], true } config = NelderMeadConfiguration{ Target: float64(14), Fn: fn, Vars: []float64{6, .5}, } result = NelderMead(config) calced, _ = fn(result) // there are two local min here assert.True(t, math.Abs(14-math.Abs(calced)) <= .01 || math.Abs(8.75-math.Abs(calced)) <= .01) assert.True(t, result[0] >= 6 && result[0] <= 8) assert.True(t, result[1] >= 0 && result[1] <= 2) } func TestNelderMeadConstrainedBadGuess(t *testing.T) { fn := func(vars []float64) (float64, bool) { if vars[0] < 1 || vars[1] < 1 { return 0, false } return math.Pow(vars[0], 2) - 4*vars[0] + math.Pow(vars[1], 2) - vars[1] - vars[0]*vars[1], true } // this is a bad guess, as in the initial guess doesn't // match the constraints. In that case, we return the guessed // values. config := NelderMeadConfiguration{ Target: float64(14), Fn: fn, Vars: []float64{0, 3}, } result := NelderMead(config) assert.Equal(t, float64(0), result[0]) assert.Equal(t, float64(3), result[1]) } // Commenting this function out for now as it's entirely // probabilistic. Realistically, we can only say that we'll // find the local vs global min/max some percentage of the time // and that percentage depends entirely on the function. // This is here for debugging purposes. /* func TestNelderMeadFindGlobal(t *testing.T) { fn := func(vars []float64) (float64, bool) { if vars[0] < -4 || vars[0] > 2 { return 0, false } // x3 + 3x2 − 2x + 1 over [-4, 2] has a global maximum at x = 2 return math.Pow(vars[0], 3) + 3*math.Pow(vars[0], 2) - 2*vars[0] + 1, true } config := NelderMeadConfiguration{ Target: math.Inf(1), Fn: fn, Vars: []float64{1.5}, } result := NelderMead(config) calced, _ := fn(result) wc, _ := fn([]float64{2}) t.Logf(`RESULT: %+v, CALCED: %+v, WC: %+v`, result, calced, wc) t.Fail() }*/ ================================================ FILE: queue/error.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package queue import "errors" var ( // ErrDisposed is returned when an operation is performed on a disposed // queue. ErrDisposed = errors.New(`queue: disposed`) // ErrTimeout is returned when an applicable queue operation times out. ErrTimeout = errors.New(`queue: poll timed out`) // ErrEmptyQueue is returned when an non-applicable queue operation was called // due to the queue's empty item state ErrEmptyQueue = errors.New(`queue: empty queue`) ) ================================================ FILE: queue/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package queue type mockItem int func (mi mockItem) Compare(other Item) int { omi := other.(mockItem) if mi > omi { return 1 } else if mi == omi { return 0 } return -1 } ================================================ FILE: queue/priority_queue.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* The priority queue is almost a spitting image of the logic used for a regular queue. In order to keep the logic fast, this code is repeated instead of using casts to cast to interface{} back and forth. If Go had inheritance and generics, this problem would be easier to solve. */ package queue import "sync" // Item is an item that can be added to the priority queue. type Item interface { // Compare returns a bool that can be used to determine // ordering in the priority queue. Assuming the queue // is in ascending order, this should return > logic. // Return 1 to indicate this object is greater than the // the other logic, 0 to indicate equality, and -1 to indicate // less than other. Compare(other Item) int } type priorityItems []Item func (items *priorityItems) swap(i, j int) { (*items)[i], (*items)[j] = (*items)[j], (*items)[i] } func (items *priorityItems) pop() Item { size := len(*items) // Move last leaf to root, and 'pop' the last item. items.swap(size-1, 0) item := (*items)[size-1] // Item to return. (*items)[size-1], *items = nil, (*items)[:size-1] // 'Bubble down' to restore heap property. index := 0 childL, childR := 2*index+1, 2*index+2 for len(*items) > childL { child := childL if len(*items) > childR && (*items)[childR].Compare((*items)[childL]) < 0 { child = childR } if (*items)[child].Compare((*items)[index]) < 0 { items.swap(index, child) index = child childL, childR = 2*index+1, 2*index+2 } else { break } } return item } func (items *priorityItems) get(number int) []Item { returnItems := make([]Item, 0, number) for i := 0; i < number; i++ { if len(*items) == 0 { break } returnItems = append(returnItems, items.pop()) } return returnItems } func (items *priorityItems) push(item Item) { // Stick the item as the end of the last level. *items = append(*items, item) // 'Bubble up' to restore heap property. index := len(*items) - 1 parent := int((index - 1) / 2) for parent >= 0 && (*items)[parent].Compare(item) > 0 { items.swap(index, parent) index = parent parent = int((index - 1) / 2) } } // PriorityQueue is similar to queue except that it takes // items that implement the Item interface and adds them // to the queue in priority order. type PriorityQueue struct { waiters waiters items priorityItems itemMap map[Item]struct{} lock sync.Mutex disposeLock sync.Mutex disposed bool allowDuplicates bool } // Put adds items to the queue. func (pq *PriorityQueue) Put(items ...Item) error { if len(items) == 0 { return nil } pq.lock.Lock() defer pq.lock.Unlock() if pq.disposed { return ErrDisposed } for _, item := range items { if pq.allowDuplicates { pq.items.push(item) } else if _, ok := pq.itemMap[item]; !ok { pq.itemMap[item] = struct{}{} pq.items.push(item) } } for { sema := pq.waiters.get() if sema == nil { break } sema.response.Add(1) sema.ready <- true sema.response.Wait() if len(pq.items) == 0 { break } } return nil } // Get retrieves items from the queue. If the queue is empty, // this call blocks until the next item is added to the queue. This // will attempt to retrieve number of items. func (pq *PriorityQueue) Get(number int) ([]Item, error) { if number < 1 { return nil, nil } pq.lock.Lock() if pq.disposed { pq.lock.Unlock() return nil, ErrDisposed } var items []Item // Remove references to popped items. deleteItems := func(items []Item) { for _, item := range items { delete(pq.itemMap, item) } } if len(pq.items) == 0 { sema := newSema() pq.waiters.put(sema) pq.lock.Unlock() <-sema.ready if pq.Disposed() { return nil, ErrDisposed } items = pq.items.get(number) if !pq.allowDuplicates { deleteItems(items) } sema.response.Done() return items, nil } items = pq.items.get(number) deleteItems(items) pq.lock.Unlock() return items, nil } // Peek will look at the next item without removing it from the queue. func (pq *PriorityQueue) Peek() Item { pq.lock.Lock() defer pq.lock.Unlock() if len(pq.items) > 0 { return pq.items[0] } return nil } // Empty returns a bool indicating if there are any items left // in the queue. func (pq *PriorityQueue) Empty() bool { pq.lock.Lock() defer pq.lock.Unlock() return len(pq.items) == 0 } // Len returns a number indicating how many items are in the queue. func (pq *PriorityQueue) Len() int { pq.lock.Lock() defer pq.lock.Unlock() return len(pq.items) } // Disposed returns a bool indicating if this queue has been disposed. func (pq *PriorityQueue) Disposed() bool { pq.disposeLock.Lock() defer pq.disposeLock.Unlock() return pq.disposed } // Dispose will prevent any further reads/writes to this queue // and frees available resources. func (pq *PriorityQueue) Dispose() { pq.lock.Lock() defer pq.lock.Unlock() pq.disposeLock.Lock() defer pq.disposeLock.Unlock() pq.disposed = true for _, waiter := range pq.waiters { waiter.response.Add(1) waiter.ready <- true } pq.items = nil pq.waiters = nil } // NewPriorityQueue is the constructor for a priority queue. func NewPriorityQueue(hint int, allowDuplicates bool) *PriorityQueue { return &PriorityQueue{ items: make(priorityItems, 0, hint), itemMap: make(map[Item]struct{}, hint), allowDuplicates: allowDuplicates, } } ================================================ FILE: queue/priority_queue_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package queue import ( "sync" "testing" "github.com/stretchr/testify/assert" ) func TestPriorityPut(t *testing.T) { q := NewPriorityQueue(1, false) q.Put(mockItem(2)) assert.Len(t, q.items, 1) assert.Equal(t, mockItem(2), q.items[0]) q.Put(mockItem(1)) if !assert.Len(t, q.items, 2) { return } assert.Equal(t, mockItem(1), q.items[0]) assert.Equal(t, mockItem(2), q.items[1]) } func TestPriorityGet(t *testing.T) { q := NewPriorityQueue(1, false) q.Put(mockItem(2)) result, err := q.Get(2) if !assert.Nil(t, err) { return } if !assert.Len(t, result, 1) { return } assert.Equal(t, mockItem(2), result[0]) assert.Len(t, q.items, 0) q.Put(mockItem(2)) q.Put(mockItem(1)) result, err = q.Get(1) if !assert.Nil(t, err) { return } if !assert.Len(t, result, 1) { return } assert.Equal(t, mockItem(1), result[0]) assert.Len(t, q.items, 1) result, err = q.Get(2) if !assert.Nil(t, err) { return } if !assert.Len(t, result, 1) { return } assert.Equal(t, mockItem(2), result[0]) } func TestAddEmptyPriorityPut(t *testing.T) { q := NewPriorityQueue(1, false) q.Put() assert.Len(t, q.items, 0) } func TestPriorityGetNonPositiveNumber(t *testing.T) { q := NewPriorityQueue(1, false) q.Put(mockItem(1)) result, err := q.Get(0) if !assert.Nil(t, err) { return } assert.Len(t, result, 0) result, err = q.Get(-1) if !assert.Nil(t, err) { return } assert.Len(t, result, 0) } func TestPriorityEmpty(t *testing.T) { q := NewPriorityQueue(1, false) assert.True(t, q.Empty()) q.Put(mockItem(1)) assert.False(t, q.Empty()) } func TestPriorityGetEmpty(t *testing.T) { q := NewPriorityQueue(1, false) go func() { q.Put(mockItem(1)) }() result, err := q.Get(1) if !assert.Nil(t, err) { return } if !assert.Len(t, result, 1) { return } assert.Equal(t, mockItem(1), result[0]) } func TestMultiplePriorityGetEmpty(t *testing.T) { q := NewPriorityQueue(1, false) var wg sync.WaitGroup wg.Add(2) results := make([][]Item, 2) go func() { wg.Done() local, _ := q.Get(1) results[0] = local wg.Done() }() go func() { wg.Done() local, _ := q.Get(1) results[1] = local wg.Done() }() wg.Wait() wg.Add(2) q.Put(mockItem(1), mockItem(3), mockItem(2)) wg.Wait() if !assert.Len(t, results[0], 1) || !assert.Len(t, results[1], 1) { return } assert.True( t, (results[0][0] == mockItem(1) && results[1][0] == mockItem(2)) || results[0][0] == mockItem(2) && results[1][0] == mockItem(1), ) } func TestEmptyPriorityGetWithDispose(t *testing.T) { q := NewPriorityQueue(1, false) var wg sync.WaitGroup wg.Add(1) var err error go func() { wg.Done() _, err = q.Get(1) wg.Done() }() wg.Wait() wg.Add(1) q.Dispose() wg.Wait() assert.IsType(t, ErrDisposed, err) } func TestPriorityGetPutDisposed(t *testing.T) { q := NewPriorityQueue(1, false) q.Dispose() _, err := q.Get(1) assert.IsType(t, ErrDisposed, err) err = q.Put(mockItem(1)) assert.IsType(t, ErrDisposed, err) } func BenchmarkPriorityQueue(b *testing.B) { q := NewPriorityQueue(b.N, false) var wg sync.WaitGroup wg.Add(1) i := 0 go func() { for { q.Get(1) i++ if i == b.N { wg.Done() break } } }() for i := 0; i < b.N; i++ { q.Put(mockItem(i)) } wg.Wait() } func TestPriorityPeek(t *testing.T) { q := NewPriorityQueue(1, false) q.Put(mockItem(1)) assert.Equal(t, mockItem(1), q.Peek()) } func TestInsertDuplicate(t *testing.T) { q := NewPriorityQueue(1, false) q.Put(mockItem(1)) q.Put(mockItem(1)) assert.Equal(t, 1, q.Len()) } func TestAllowDuplicates(t *testing.T) { q := NewPriorityQueue(2, true) q.Put(mockItem(1)) q.Put(mockItem(1)) assert.Equal(t, 2, q.Len()) } ================================================ FILE: queue/queue.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package queue includes a regular queue and a priority queue. These queues rely on waitgroups to pause listening threads on empty queues until a message is received. If any thread calls Dispose on the queue, any listeners are immediately returned with an error. Any subsequent put to the queue will return an error as opposed to panicking as with channels. Queues will grow with unbounded behavior as opposed to channels which can be buffered but will pause while a thread attempts to put to a full channel. Recently added is a lockless ring buffer using the same basic C design as found here: http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue Modified for use with Go with the addition of some dispose semantics providing the capability to release blocked threads. This works for both puts and gets, either will return an error if they are blocked and the buffer is disposed. This could serve as a signal to kill a goroutine. All threadsafety is achieved using CAS operations, making this buffer pretty quick. Benchmarks: BenchmarkPriorityQueue-8 2000000 782 ns/op BenchmarkQueue-8 2000000 671 ns/op BenchmarkChannel-8 1000000 2083 ns/op BenchmarkQueuePut-8 20000 84299 ns/op BenchmarkQueueGet-8 20000 80753 ns/op BenchmarkExecuteInParallel-8 20000 68891 ns/op BenchmarkRBLifeCycle-8 10000000 177 ns/op BenchmarkRBPut-8 30000000 58.1 ns/op BenchmarkRBGet-8 50000000 26.8 ns/op TODO: We really need a Fibonacci heap for the priority queue. TODO: Unify the types of queue to the same interface. */ package queue import ( "runtime" "sync" "sync/atomic" "time" ) type waiters []*sema func (w *waiters) get() *sema { if len(*w) == 0 { return nil } sema := (*w)[0] copy((*w)[0:], (*w)[1:]) (*w)[len(*w)-1] = nil // or the zero value of T *w = (*w)[:len(*w)-1] return sema } func (w *waiters) put(sema *sema) { *w = append(*w, sema) } func (w *waiters) remove(sema *sema) { if len(*w) == 0 { return } // build new slice, copy all except sema ws := *w newWs := make(waiters, 0, len(*w)) for i := range ws { if ws[i] != sema { newWs = append(newWs, ws[i]) } } *w = newWs } type items []interface{} func (items *items) get(number int64) []interface{} { returnItems := make([]interface{}, 0, number) index := int64(0) for i := int64(0); i < number; i++ { if i >= int64(len(*items)) { break } returnItems = append(returnItems, (*items)[i]) (*items)[i] = nil index++ } *items = (*items)[index:] return returnItems } func (items *items) peek() (interface{}, bool) { length := len(*items) if length == 0 { return nil, false } return (*items)[0], true } func (items *items) getUntil(checker func(item interface{}) bool) []interface{} { length := len(*items) if len(*items) == 0 { // returning nil here actually wraps that nil in a list // of interfaces... thanks go return []interface{}{} } returnItems := make([]interface{}, 0, length) index := -1 for i, item := range *items { if !checker(item) { break } returnItems = append(returnItems, item) index = i (*items)[i] = nil // prevent memory leak } *items = (*items)[index+1:] return returnItems } type sema struct { ready chan bool response *sync.WaitGroup } func newSema() *sema { return &sema{ ready: make(chan bool, 1), response: &sync.WaitGroup{}, } } // Queue is the struct responsible for tracking the state // of the queue. type Queue struct { waiters waiters items items lock sync.Mutex disposed bool } // Put will add the specified items to the queue. func (q *Queue) Put(items ...interface{}) error { if len(items) == 0 { return nil } q.lock.Lock() if q.disposed { q.lock.Unlock() return ErrDisposed } q.items = append(q.items, items...) for { sema := q.waiters.get() if sema == nil { break } sema.response.Add(1) select { case sema.ready <- true: sema.response.Wait() default: // This semaphore timed out. } if len(q.items) == 0 { break } } q.lock.Unlock() return nil } // Get retrieves items from the queue. If there are some items in the // queue, get will return a number UP TO the number passed in as a // parameter. If no items are in the queue, this method will pause // until items are added to the queue. func (q *Queue) Get(number int64) ([]interface{}, error) { return q.Poll(number, 0) } // Poll retrieves items from the queue. If there are some items in the queue, // Poll will return a number UP TO the number passed in as a parameter. If no // items are in the queue, this method will pause until items are added to the // queue or the provided timeout is reached. A non-positive timeout will block // until items are added. If a timeout occurs, ErrTimeout is returned. func (q *Queue) Poll(number int64, timeout time.Duration) ([]interface{}, error) { if number < 1 { // thanks again go return []interface{}{}, nil } q.lock.Lock() if q.disposed { q.lock.Unlock() return nil, ErrDisposed } var items []interface{} if len(q.items) == 0 { sema := newSema() q.waiters.put(sema) q.lock.Unlock() var timeoutC <-chan time.Time if timeout > 0 { timeoutC = time.After(timeout) } select { case <-sema.ready: // we are now inside the put's lock if q.disposed { return nil, ErrDisposed } items = q.items.get(number) sema.response.Done() return items, nil case <-timeoutC: // cleanup the sema that was added to waiters select { case sema.ready <- true: // we called this before Put() could // Remove sema from waiters. q.lock.Lock() q.waiters.remove(sema) q.lock.Unlock() default: // Put() got it already, we need to call Done() so Put() can move on sema.response.Done() } return nil, ErrTimeout } } items = q.items.get(number) q.lock.Unlock() return items, nil } // Peek returns a the first item in the queue by value // without modifying the queue. func (q *Queue) Peek() (interface{}, error) { q.lock.Lock() defer q.lock.Unlock() if q.disposed { return nil, ErrDisposed } peekItem, ok := q.items.peek() if !ok { return nil, ErrEmptyQueue } return peekItem, nil } // TakeUntil takes a function and returns a list of items that // match the checker until the checker returns false. This does not // wait if there are no items in the queue. func (q *Queue) TakeUntil(checker func(item interface{}) bool) ([]interface{}, error) { if checker == nil { return nil, nil } q.lock.Lock() if q.disposed { q.lock.Unlock() return nil, ErrDisposed } result := q.items.getUntil(checker) q.lock.Unlock() return result, nil } // Empty returns a bool indicating if this bool is empty. func (q *Queue) Empty() bool { q.lock.Lock() defer q.lock.Unlock() return len(q.items) == 0 } // Len returns the number of items in this queue. func (q *Queue) Len() int64 { q.lock.Lock() defer q.lock.Unlock() return int64(len(q.items)) } // Disposed returns a bool indicating if this queue // has had disposed called on it. func (q *Queue) Disposed() bool { q.lock.Lock() defer q.lock.Unlock() return q.disposed } // Dispose will dispose of this queue and returns // the items disposed. Any subsequent calls to Get // or Put will return an error. func (q *Queue) Dispose() []interface{} { q.lock.Lock() defer q.lock.Unlock() q.disposed = true for _, waiter := range q.waiters { waiter.response.Add(1) select { case waiter.ready <- true: // release Poll immediately default: // ignore if it's a timeout or in the get } } disposedItems := q.items q.items = nil q.waiters = nil return disposedItems } // New is a constructor for a new threadsafe queue. func New(hint int64) *Queue { return &Queue{ items: make([]interface{}, 0, hint), } } // ExecuteInParallel will (in parallel) call the provided function // with each item in the queue until the queue is exhausted. When the queue // is exhausted execution is complete and all goroutines will be killed. // This means that the queue will be disposed so cannot be used again. func ExecuteInParallel(q *Queue, fn func(interface{})) { if q == nil { return } q.lock.Lock() // so no one touches anything in the middle // of this process todo, done := uint64(len(q.items)), int64(-1) // this is important or we might face an infinite loop if todo == 0 { return } numCPU := 1 if runtime.NumCPU() > 1 { numCPU = runtime.NumCPU() - 1 } var wg sync.WaitGroup wg.Add(numCPU) items := q.items for i := 0; i < numCPU; i++ { go func() { for { index := atomic.AddInt64(&done, 1) if index >= int64(todo) { wg.Done() break } fn(items[index]) items[index] = 0 } }() } wg.Wait() q.lock.Unlock() q.Dispose() } ================================================ FILE: queue/queue_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package queue import ( "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) func TestPut(t *testing.T) { q := New(10) q.Put(`test`) assert.Equal(t, int64(1), q.Len()) results, err := q.Get(1) assert.Nil(t, err) result := results[0] assert.Equal(t, `test`, result) assert.True(t, q.Empty()) q.Put(`test2`) assert.Equal(t, int64(1), q.Len()) results, err = q.Get(1) assert.Nil(t, err) result = results[0] assert.Equal(t, `test2`, result) assert.True(t, q.Empty()) } func TestGet(t *testing.T) { q := New(10) q.Put(`test`) result, err := q.Get(2) if !assert.Nil(t, err) { return } assert.Len(t, result, 1) assert.Equal(t, `test`, result[0]) assert.Equal(t, int64(0), q.Len()) q.Put(`1`) q.Put(`2`) result, err = q.Get(1) if !assert.Nil(t, err) { return } assert.Len(t, result, 1) assert.Equal(t, `1`, result[0]) assert.Equal(t, int64(1), q.Len()) result, err = q.Get(2) if !assert.Nil(t, err) { return } assert.Equal(t, `2`, result[0]) } func TestPoll(t *testing.T) { q := New(10) // should be able to Poll() before anything is present, without breaking future Puts q.Poll(1, time.Millisecond) q.Put(`test`) result, err := q.Poll(2, 0) if !assert.Nil(t, err) { return } assert.Len(t, result, 1) assert.Equal(t, `test`, result[0]) assert.Equal(t, int64(0), q.Len()) q.Put(`1`) q.Put(`2`) result, err = q.Poll(1, time.Millisecond) if !assert.Nil(t, err) { return } assert.Len(t, result, 1) assert.Equal(t, `1`, result[0]) assert.Equal(t, int64(1), q.Len()) result, err = q.Poll(2, time.Millisecond) if !assert.Nil(t, err) { return } assert.Equal(t, `2`, result[0]) before := time.Now() _, err = q.Poll(1, 5*time.Millisecond) // This delta is normally 1-3 ms but running tests in CI with -race causes // this to run much slower. For now, just bump up the threshold. assert.InDelta(t, 5, time.Since(before).Seconds()*1000, 10) assert.Equal(t, ErrTimeout, err) } func TestPollNoMemoryLeak(t *testing.T) { q := New(0) assert.Len(t, q.waiters, 0) for i := 0; i < 10; i++ { // Poll() should cleanup waiters after timeout q.Poll(1, time.Nanosecond) assert.Len(t, q.waiters, 0) } } func TestAddEmptyPut(t *testing.T) { q := New(10) q.Put() if q.Len() != 0 { t.Errorf(`Expected len: %d, received: %d`, 0, q.Len()) } } func TestGetNonPositiveNumber(t *testing.T) { q := New(10) q.Put(`test`) result, err := q.Get(0) if !assert.Nil(t, err) { return } if len(result) != 0 { t.Errorf(`Expected len: %d, received: %d`, 0, len(result)) } } func TestEmpty(t *testing.T) { q := New(10) if !q.Empty() { t.Errorf(`Expected empty queue.`) } q.Put(`test`) if q.Empty() { t.Errorf(`Expected non-empty queue.`) } } func TestGetEmpty(t *testing.T) { q := New(10) go func() { q.Put(`a`) }() result, err := q.Get(2) if !assert.Nil(t, err) { return } assert.Len(t, result, 1) assert.Equal(t, `a`, result[0]) } func TestMultipleGetEmpty(t *testing.T) { q := New(10) var wg sync.WaitGroup wg.Add(2) results := make([][]interface{}, 2) go func() { wg.Done() local, err := q.Get(1) assert.Nil(t, err) results[0] = local wg.Done() }() go func() { wg.Done() local, err := q.Get(1) assert.Nil(t, err) results[1] = local wg.Done() }() wg.Wait() wg.Add(2) q.Put(`a`, `b`, `c`) wg.Wait() if assert.Len(t, results[0], 1) && assert.Len(t, results[1], 1) { assert.True(t, (results[0][0] == `a` && results[1][0] == `b`) || (results[0][0] == `b` && results[1][0] == `a`), `The array should be a, b or b, a`) } } func TestDispose(t *testing.T) { // when the queue is empty q := New(10) itemsDisposed := q.Dispose() assert.Empty(t, itemsDisposed) // when the queue is not empty q = New(10) q.Put(`1`) itemsDisposed = q.Dispose() expected := []interface{}{`1`} assert.Equal(t, expected, itemsDisposed) // when the queue has been disposed itemsDisposed = q.Dispose() assert.Nil(t, itemsDisposed) } func TestEmptyGetWithDispose(t *testing.T) { q := New(10) var wg sync.WaitGroup wg.Add(1) var err error go func() { wg.Done() _, err = q.Get(1) wg.Done() }() wg.Wait() wg.Add(1) q.Dispose() wg.Wait() assert.IsType(t, ErrDisposed, err) } func TestDisposeAfterEmptyPoll(t *testing.T) { q := New(10) _, err := q.Poll(1, time.Millisecond) assert.IsType(t, ErrTimeout, err) // it should not hang q.Dispose() _, err = q.Poll(1, time.Millisecond) assert.IsType(t, ErrDisposed, err) } func TestGetPutDisposed(t *testing.T) { q := New(10) q.Dispose() _, err := q.Get(1) assert.IsType(t, ErrDisposed, err) err = q.Put(`a`) assert.IsType(t, ErrDisposed, err) } func BenchmarkQueue(b *testing.B) { q := New(int64(b.N)) var wg sync.WaitGroup wg.Add(1) i := 0 go func() { for { q.Get(1) i++ if i == b.N { wg.Done() break } } }() for i := 0; i < b.N; i++ { q.Put(`a`) } wg.Wait() } func BenchmarkChannel(b *testing.B) { ch := make(chan interface{}, 1) var wg sync.WaitGroup wg.Add(1) i := 0 go func() { for { <-ch i++ if i == b.N { wg.Done() break } } }() for i := 0; i < b.N; i++ { ch <- `a` } wg.Wait() } func TestPeek(t *testing.T) { q := New(10) q.Put(`a`) q.Put(`b`) q.Put(`c`) peekResult, err := q.Peek() peekExpected := `a` assert.Nil(t, err) assert.Equal(t, q.Len(), int64(3)) assert.Equal(t, peekExpected, peekResult) popResult, err := q.Get(1) assert.Nil(t, err) assert.Equal(t, peekResult, popResult[0]) assert.Equal(t, q.Len(), int64(2)) } func TestPeekOnDisposedQueue(t *testing.T) { q := New(10) q.Dispose() result, err := q.Peek() assert.Nil(t, result) assert.IsType(t, ErrDisposed, err) } func TestTakeUntil(t *testing.T) { q := New(10) q.Put(`a`, `b`, `c`) result, err := q.TakeUntil(func(item interface{}) bool { return item != `c` }) if !assert.Nil(t, err) { return } expected := []interface{}{`a`, `b`} assert.Equal(t, expected, result) } func TestTakeUntilEmptyQueue(t *testing.T) { q := New(10) result, err := q.TakeUntil(func(item interface{}) bool { return item != `c` }) if !assert.Nil(t, err) { return } expected := []interface{}{} assert.Equal(t, expected, result) } func TestTakeUntilThenGet(t *testing.T) { q := New(10) q.Put(`a`, `b`, `c`) takeItems, _ := q.TakeUntil(func(item interface{}) bool { return item != `c` }) restItems, _ := q.Get(3) assert.Equal(t, []interface{}{`a`, `b`}, takeItems) assert.Equal(t, []interface{}{`c`}, restItems) } func TestTakeUntilNoMatches(t *testing.T) { q := New(10) q.Put(`a`, `b`, `c`) takeItems, _ := q.TakeUntil(func(item interface{}) bool { return item != `a` }) restItems, _ := q.Get(3) assert.Equal(t, []interface{}{}, takeItems) assert.Equal(t, []interface{}{`a`, `b`, `c`}, restItems) } func TestTakeUntilOnDisposedQueue(t *testing.T) { q := New(10) q.Dispose() result, err := q.TakeUntil(func(item interface{}) bool { return true }) assert.Nil(t, result) assert.IsType(t, ErrDisposed, err) } func TestWaiters(t *testing.T) { s1, s2, s3, s4 := newSema(), newSema(), newSema(), newSema() w := waiters{} assert.Len(t, w, 0) // // test put() w.put(s1) assert.Equal(t, waiters{s1}, w) w.put(s2) w.put(s3) w.put(s4) assert.Equal(t, waiters{s1, s2, s3, s4}, w) // // test remove() // // remove from middle w.remove(s2) assert.Equal(t, waiters{s1, s3, s4}, w) // remove non-existing element w.remove(s2) assert.Equal(t, waiters{s1, s3, s4}, w) // remove from beginning w.remove(s1) assert.Equal(t, waiters{s3, s4}, w) // remove from end w.remove(s4) assert.Equal(t, waiters{s3}, w) // remove last element w.remove(s3) assert.Empty(t, w) // remove non-existing element w.remove(s3) assert.Empty(t, w) // // test get() // // start with 3 elements in list w.put(s1) w.put(s2) w.put(s3) assert.Equal(t, waiters{s1, s2, s3}, w) // get() returns each item in insertion order assert.Equal(t, s1, w.get()) assert.Equal(t, s2, w.get()) w.put(s4) // interleave a put(), item should go to the end assert.Equal(t, s3, w.get()) assert.Equal(t, s4, w.get()) assert.Empty(t, w) assert.Nil(t, w.get()) } func TestExecuteInParallel(t *testing.T) { q := New(10) for i := 0; i < 10; i++ { q.Put(i) } numCalls := uint64(0) ExecuteInParallel(q, func(item interface{}) { t.Logf("ExecuteInParallel called us with %+v", item) atomic.AddUint64(&numCalls, 1) }) assert.Equal(t, uint64(10), numCalls) assert.True(t, q.Disposed()) } func TestExecuteInParallelEmptyQueue(t *testing.T) { q := New(1) // basically just ensuring we don't deadlock here ExecuteInParallel(q, func(interface{}) { t.Fail() }) } func BenchmarkQueuePut(b *testing.B) { numItems := int64(1000) qs := make([]*Queue, 0, b.N) for i := 0; i < b.N; i++ { q := New(10) qs = append(qs, q) } b.ResetTimer() for i := 0; i < b.N; i++ { q := qs[i] for j := int64(0); j < numItems; j++ { q.Put(j) } } } func BenchmarkQueueGet(b *testing.B) { numItems := int64(1000) qs := make([]*Queue, 0, b.N) for i := 0; i < b.N; i++ { q := New(numItems) for j := int64(0); j < numItems; j++ { q.Put(j) } qs = append(qs, q) } b.ResetTimer() for i := 0; i < b.N; i++ { q := qs[i] for j := int64(0); j < numItems; j++ { q.Get(1) } } } func BenchmarkQueuePoll(b *testing.B) { numItems := int64(1000) qs := make([]*Queue, 0, b.N) for i := 0; i < b.N; i++ { q := New(numItems) for j := int64(0); j < numItems; j++ { q.Put(j) } qs = append(qs, q) } b.ResetTimer() for _, q := range qs { for j := int64(0); j < numItems; j++ { q.Poll(1, time.Millisecond) } } } func BenchmarkExecuteInParallel(b *testing.B) { numItems := int64(1000) qs := make([]*Queue, 0, b.N) for i := 0; i < b.N; i++ { q := New(numItems) for j := int64(0); j < numItems; j++ { q.Put(j) } qs = append(qs, q) } var counter int64 fn := func(ifc interface{}) { c := ifc.(int64) atomic.AddInt64(&counter, c) } b.ResetTimer() for i := 0; i < b.N; i++ { q := qs[i] ExecuteInParallel(q, fn) } } ================================================ FILE: queue/ring.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package queue import ( "runtime" "sync/atomic" "time" ) // roundUp takes a uint64 greater than 0 and rounds it up to the next // power of 2. func roundUp(v uint64) uint64 { v-- v |= v >> 1 v |= v >> 2 v |= v >> 4 v |= v >> 8 v |= v >> 16 v |= v >> 32 v++ return v } type node struct { position uint64 data interface{} } type nodes []node // RingBuffer is a MPMC buffer that achieves threadsafety with CAS operations // only. A put on full or get on empty call will block until an item // is put or retrieved. Calling Dispose on the RingBuffer will unblock // any blocked threads with an error. This buffer is similar to the buffer // described here: http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue // with some minor additions. type RingBuffer struct { _padding0 [8]uint64 queue uint64 _padding1 [8]uint64 dequeue uint64 _padding2 [8]uint64 mask, disposed uint64 _padding3 [8]uint64 nodes nodes } func (rb *RingBuffer) init(size uint64) { size = roundUp(size) rb.nodes = make(nodes, size) for i := uint64(0); i < size; i++ { rb.nodes[i] = node{position: i} } rb.mask = size - 1 // so we don't have to do this with every put/get operation } // Put adds the provided item to the queue. If the queue is full, this // call will block until an item is added to the queue or Dispose is called // on the queue. An error will be returned if the queue is disposed. func (rb *RingBuffer) Put(item interface{}) error { _, err := rb.put(item, false) return err } // Offer adds the provided item to the queue if there is space. If the queue // is full, this call will return false. An error will be returned if the // queue is disposed. func (rb *RingBuffer) Offer(item interface{}) (bool, error) { return rb.put(item, true) } func (rb *RingBuffer) put(item interface{}, offer bool) (bool, error) { var n *node pos := atomic.LoadUint64(&rb.queue) L: for { if atomic.LoadUint64(&rb.disposed) == 1 { return false, ErrDisposed } n = &rb.nodes[pos&rb.mask] seq := atomic.LoadUint64(&n.position) switch dif := seq - pos; { case dif == 0: if atomic.CompareAndSwapUint64(&rb.queue, pos, pos+1) { break L } case dif < 0: panic(`Ring buffer in a compromised state during a put operation.`) default: pos = atomic.LoadUint64(&rb.queue) } if offer { return false, nil } runtime.Gosched() // free up the cpu before the next iteration } n.data = item atomic.StoreUint64(&n.position, pos+1) return true, nil } // Get will return the next item in the queue. This call will block // if the queue is empty. This call will unblock when an item is added // to the queue or Dispose is called on the queue. An error will be returned // if the queue is disposed. func (rb *RingBuffer) Get() (interface{}, error) { return rb.Poll(0) } // Poll will return the next item in the queue. This call will block // if the queue is empty. This call will unblock when an item is added // to the queue, Dispose is called on the queue, or the timeout is reached. An // error will be returned if the queue is disposed or a timeout occurs. A // non-positive timeout will block indefinitely. func (rb *RingBuffer) Poll(timeout time.Duration) (interface{}, error) { var ( n *node pos = atomic.LoadUint64(&rb.dequeue) start time.Time ) if timeout > 0 { start = time.Now() } L: for { if atomic.LoadUint64(&rb.disposed) == 1 { return nil, ErrDisposed } n = &rb.nodes[pos&rb.mask] seq := atomic.LoadUint64(&n.position) switch dif := seq - (pos + 1); { case dif == 0: if atomic.CompareAndSwapUint64(&rb.dequeue, pos, pos+1) { break L } case dif < 0: panic(`Ring buffer in compromised state during a get operation.`) default: pos = atomic.LoadUint64(&rb.dequeue) } if timeout > 0 && time.Since(start) >= timeout { return nil, ErrTimeout } runtime.Gosched() // free up the cpu before the next iteration } data := n.data n.data = nil atomic.StoreUint64(&n.position, pos+rb.mask+1) return data, nil } // Len returns the number of items in the queue. func (rb *RingBuffer) Len() uint64 { return atomic.LoadUint64(&rb.queue) - atomic.LoadUint64(&rb.dequeue) } // Cap returns the capacity of this ring buffer. func (rb *RingBuffer) Cap() uint64 { return uint64(len(rb.nodes)) } // Dispose will dispose of this queue and free any blocked threads // in the Put and/or Get methods. Calling those methods on a disposed // queue will return an error. func (rb *RingBuffer) Dispose() { atomic.CompareAndSwapUint64(&rb.disposed, 0, 1) } // IsDisposed will return a bool indicating if this queue has been // disposed. func (rb *RingBuffer) IsDisposed() bool { return atomic.LoadUint64(&rb.disposed) == 1 } // NewRingBuffer will allocate, initialize, and return a ring buffer // with the specified size. func NewRingBuffer(size uint64) *RingBuffer { rb := &RingBuffer{} rb.init(size) return rb } ================================================ FILE: queue/ring_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package queue import ( "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) func TestRingInsert(t *testing.T) { rb := NewRingBuffer(5) assert.Equal(t, uint64(8), rb.Cap()) err := rb.Put(5) if !assert.Nil(t, err) { return } result, err := rb.Get() if !assert.Nil(t, err) { return } assert.Equal(t, 5, result) } func TestRingMultipleInserts(t *testing.T) { rb := NewRingBuffer(5) err := rb.Put(1) if !assert.Nil(t, err) { return } err = rb.Put(2) if !assert.Nil(t, err) { return } result, err := rb.Get() if !assert.Nil(t, err) { return } assert.Equal(t, 1, result) result, err = rb.Get() if assert.Nil(t, err) { return } assert.Equal(t, 2, result) } func TestIntertwinedGetAndPut(t *testing.T) { rb := NewRingBuffer(5) err := rb.Put(1) if !assert.Nil(t, err) { return } result, err := rb.Get() if !assert.Nil(t, err) { return } assert.Equal(t, 1, result) err = rb.Put(2) if !assert.Nil(t, err) { return } result, err = rb.Get() if !assert.Nil(t, err) { return } assert.Equal(t, 2, result) } func TestPutToFull(t *testing.T) { rb := NewRingBuffer(3) for i := 0; i < 4; i++ { err := rb.Put(i) if !assert.Nil(t, err) { return } } var wg sync.WaitGroup wg.Add(2) go func() { err := rb.Put(4) assert.Nil(t, err) wg.Done() }() go func() { defer wg.Done() result, err := rb.Get() if !assert.Nil(t, err) { return } assert.Equal(t, 0, result) }() wg.Wait() } func TestOffer(t *testing.T) { rb := NewRingBuffer(2) ok, err := rb.Offer("foo") assert.True(t, ok) assert.Nil(t, err) ok, err = rb.Offer("bar") assert.True(t, ok) assert.Nil(t, err) ok, err = rb.Offer("baz") assert.False(t, ok) assert.Nil(t, err) item, err := rb.Get() assert.Nil(t, err) assert.Equal(t, "foo", item) item, err = rb.Get() assert.Nil(t, err) assert.Equal(t, "bar", item) } func TestRingGetEmpty(t *testing.T) { rb := NewRingBuffer(3) var wg sync.WaitGroup wg.Add(1) // want to kick off this consumer to ensure it blocks go func() { wg.Done() result, err := rb.Get() assert.Nil(t, err) assert.Equal(t, 0, result) wg.Done() }() wg.Wait() wg.Add(2) go func() { defer wg.Done() err := rb.Put(0) assert.Nil(t, err) }() wg.Wait() } func TestRingPollEmpty(t *testing.T) { rb := NewRingBuffer(3) _, err := rb.Poll(1) assert.Equal(t, ErrTimeout, err) } func TestRingPoll(t *testing.T) { rb := NewRingBuffer(10) // should be able to Poll() before anything is present, without breaking future Puts rb.Poll(time.Millisecond) rb.Put(`test`) result, err := rb.Poll(0) if !assert.Nil(t, err) { return } assert.Equal(t, `test`, result) assert.Equal(t, uint64(0), rb.Len()) rb.Put(`1`) rb.Put(`2`) result, err = rb.Poll(time.Millisecond) if !assert.Nil(t, err) { return } assert.Equal(t, `1`, result) assert.Equal(t, uint64(1), rb.Len()) result, err = rb.Poll(time.Millisecond) if !assert.Nil(t, err) { return } assert.Equal(t, `2`, result) before := time.Now() _, err = rb.Poll(5 * time.Millisecond) // This delta is normally 1-3 ms but running tests in CI with -race causes // this to run much slower. For now, just bump up the threshold. assert.InDelta(t, 5, time.Since(before).Seconds()*1000, 10) assert.Equal(t, ErrTimeout, err) } func TestRingLen(t *testing.T) { rb := NewRingBuffer(4) assert.Equal(t, uint64(0), rb.Len()) rb.Put(1) assert.Equal(t, uint64(1), rb.Len()) rb.Get() assert.Equal(t, uint64(0), rb.Len()) for i := 0; i < 4; i++ { rb.Put(1) } assert.Equal(t, uint64(4), rb.Len()) rb.Get() assert.Equal(t, uint64(3), rb.Len()) } func TestDisposeOnGet(t *testing.T) { numThreads := 8 var wg sync.WaitGroup wg.Add(numThreads) rb := NewRingBuffer(4) var spunUp sync.WaitGroup spunUp.Add(numThreads) for i := 0; i < numThreads; i++ { go func() { spunUp.Done() defer wg.Done() _, err := rb.Get() assert.NotNil(t, err) }() } spunUp.Wait() rb.Dispose() wg.Wait() assert.True(t, rb.IsDisposed()) } func TestDisposeOnPut(t *testing.T) { numThreads := 8 var wg sync.WaitGroup wg.Add(numThreads) rb := NewRingBuffer(4) var spunUp sync.WaitGroup spunUp.Add(numThreads) // fill up the queue for i := 0; i < 4; i++ { rb.Put(i) } // it's now full for i := 0; i < numThreads; i++ { go func(i int) { spunUp.Done() defer wg.Done() err := rb.Put(i) assert.NotNil(t, err) }(i) } spunUp.Wait() rb.Dispose() wg.Wait() assert.True(t, rb.IsDisposed()) } func BenchmarkRBLifeCycle(b *testing.B) { rb := NewRingBuffer(64) counter := uint64(0) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for { _, err := rb.Get() assert.Nil(b, err) if atomic.AddUint64(&counter, 1) == uint64(b.N) { return } } }() b.ResetTimer() for i := 0; i < b.N; i++ { rb.Put(i) } wg.Wait() } func BenchmarkRBLifeCycleContention(b *testing.B) { rb := NewRingBuffer(64) var wwg sync.WaitGroup var rwg sync.WaitGroup wwg.Add(10) rwg.Add(10) for i := 0; i < 10; i++ { go func() { for { _, err := rb.Get() if err == ErrDisposed { rwg.Done() return } else { assert.Nil(b, err) } } }() } b.ResetTimer() for i := 0; i < 10; i++ { go func() { for j := 0; j < b.N; j++ { rb.Put(j) } wwg.Done() }() } wwg.Wait() rb.Dispose() rwg.Wait() } func BenchmarkRBPut(b *testing.B) { rb := NewRingBuffer(uint64(b.N)) b.ResetTimer() for i := 0; i < b.N; i++ { ok, err := rb.Offer(i) if !ok { b.Fail() } if err != nil { b.Log(err) b.Fail() } } } func BenchmarkRBGet(b *testing.B) { rb := NewRingBuffer(uint64(b.N)) for i := 0; i < b.N; i++ { rb.Offer(i) } b.ResetTimer() for i := 0; i < b.N; i++ { rb.Get() } } func BenchmarkRBAllocation(b *testing.B) { for i := 0; i < b.N; i++ { NewRingBuffer(1024) } } ================================================ FILE: rangetree/entries.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree import "sync" var entriesPool = sync.Pool{ New: func() interface{} { return make(Entries, 0, 10) }, } // Entries is a typed list of Entry that can be reused if Dispose // is called. type Entries []Entry // Dispose will free the resources consumed by this list and // allow the list to be reused. func (entries *Entries) Dispose() { for i := 0; i < len(*entries); i++ { (*entries)[i] = nil } *entries = (*entries)[:0] entriesPool.Put(*entries) } // NewEntries will return a reused list of entries. func NewEntries() Entries { return entriesPool.Get().(Entries) } ================================================ FILE: rangetree/entries_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree import ( "testing" "github.com/stretchr/testify/assert" ) func TestDisposeEntries(t *testing.T) { entries := NewEntries() entries = append(entries, constructMockEntry(0, 0)) entries.Dispose() assert.Len(t, entries, 0) } ================================================ FILE: rangetree/error.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree import "fmt" // NoEntriesError is returned from an operation that requires // existing entries when none are found. type NoEntriesError struct{} func (nee NoEntriesError) Error() string { return `No entries in this tree.` } // OutOfDimensionError is returned when a requested operation // doesn't meet dimensional requirements. type OutOfDimensionError struct { provided, max uint64 } func (oode OutOfDimensionError) Error() string { return fmt.Sprintf(`Provided dimension: %d is greater than max dimension: %d`, oode.provided, oode.max, ) } ================================================ FILE: rangetree/immutable.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree import "github.com/Workiva/go-datastructures/slice" type immutableRangeTree struct { number uint64 top orderedNodes dimensions uint64 } func newCache(dimensions uint64) []slice.Int64Slice { cache := make([]slice.Int64Slice, 0, dimensions-1) for i := uint64(0); i < dimensions; i++ { cache = append(cache, slice.Int64Slice{}) } return cache } func (irt *immutableRangeTree) needNextDimension() bool { return irt.dimensions > 1 } func (irt *immutableRangeTree) add(nodes *orderedNodes, cache []slice.Int64Slice, entry Entry, added *uint64) { var node *node list := nodes for i := uint64(1); i <= irt.dimensions; i++ { if isLastDimension(irt.dimensions, i) { if i != 1 && !cache[i-1].Exists(node.value) { nodes := make(orderedNodes, len(*list)) copy(nodes, *list) list = &nodes cache[i-1].Insert(node.value) } newNode := newNode(entry.ValueAtDimension(i), entry, false) overwritten := list.add(newNode) if overwritten == nil { *added++ } if node != nil { node.orderedNodes = *list } break } if i != 1 && !cache[i-1].Exists(node.value) { nodes := make(orderedNodes, len(*list)) copy(nodes, *list) list = &nodes cache[i-1].Insert(node.value) node.orderedNodes = *list } node, _ = list.getOrAdd(entry, i, irt.dimensions) list = &node.orderedNodes } } // Add will add the provided entries into the tree and return // a new tree with those entries added. func (irt *immutableRangeTree) Add(entries ...Entry) *immutableRangeTree { if len(entries) == 0 { return irt } cache := newCache(irt.dimensions) top := make(orderedNodes, len(irt.top)) copy(top, irt.top) added := uint64(0) for _, entry := range entries { irt.add(&top, cache, entry, &added) } tree := newImmutableRangeTree(irt.dimensions) tree.top = top tree.number = irt.number + added return tree } // InsertAtDimension will increment items at and above the given index // by the number provided. Provide a negative number to to decrement. // Returned are two lists and the modified tree. The first list is a // list of entries that were moved. The second is a list entries that // were deleted. These lists are exclusive. func (irt *immutableRangeTree) InsertAtDimension(dimension uint64, index, number int64) (*immutableRangeTree, Entries, Entries) { if dimension > irt.dimensions || number == 0 { return irt, nil, nil } modified, deleted := make(Entries, 0, 100), make(Entries, 0, 100) tree := newImmutableRangeTree(irt.dimensions) tree.top = irt.top.immutableInsert( dimension, 1, irt.dimensions, index, number, &modified, &deleted, ) tree.number = irt.number - uint64(len(deleted)) return tree, modified, deleted } type immutableNodeBundle struct { list *orderedNodes index int previousNode *node newNode *node } func (irt *immutableRangeTree) Delete(entries ...Entry) *immutableRangeTree { cache := newCache(irt.dimensions) top := make(orderedNodes, len(irt.top)) copy(top, irt.top) deleted := uint64(0) for _, entry := range entries { irt.delete(&top, cache, entry, &deleted) } tree := newImmutableRangeTree(irt.dimensions) tree.top = top tree.number = irt.number - deleted return tree } func (irt *immutableRangeTree) delete(top *orderedNodes, cache []slice.Int64Slice, entry Entry, deleted *uint64) { path := make([]*immutableNodeBundle, 0, 5) var index int var n *node var local *node list := top for i := uint64(1); i <= irt.dimensions; i++ { value := entry.ValueAtDimension(i) local, index = list.get(value) if local == nil { // there's nothing to delete return } nb := &immutableNodeBundle{ list: list, index: index, previousNode: n, } path = append(path, nb) n = local list = &n.orderedNodes } *deleted++ for i := len(path) - 1; i >= 0; i-- { nb := path[i] if nb.previousNode != nil { nodes := make(orderedNodes, len(*nb.list)) copy(nodes, *nb.list) nb.list = &nodes if len(*nb.list) == 1 { continue } nn := newNode( nb.previousNode.value, nb.previousNode.entry, !isLastDimension(irt.dimensions, uint64(i)+1), ) nn.orderedNodes = nodes path[i-1].newNode = nn } } for _, nb := range path { if nb.newNode == nil { nb.list.deleteAt(nb.index) } else { (*nb.list)[nb.index] = nb.newNode } } } func (irt *immutableRangeTree) apply(list orderedNodes, interval Interval, dimension uint64, fn func(*node) bool) bool { low, high := interval.LowAtDimension(dimension), interval.HighAtDimension(dimension) if isLastDimension(irt.dimensions, dimension) { if !list.apply(low, high, fn) { return false } } else { if !list.apply(low, high, func(n *node) bool { if !irt.apply(n.orderedNodes, interval, dimension+1, fn) { return false } return true }) { return false } return true } return true } // Query will return an ordered list of results in the given // interval. func (irt *immutableRangeTree) Query(interval Interval) Entries { entries := NewEntries() irt.apply(irt.top, interval, 1, func(n *node) bool { entries = append(entries, n.entry) return true }) return entries } func (irt *immutableRangeTree) get(entry Entry) Entry { on := irt.top for i := uint64(1); i <= irt.dimensions; i++ { n, _ := on.get(entry.ValueAtDimension(i)) if n == nil { return nil } if i == irt.dimensions { return n.entry } on = n.orderedNodes } return nil } // Get returns any entries that exist at the addresses provided by the // given entries. Entries are returned in the order in which they are // received. If an entry cannot be found, a nil is returned in its // place. func (irt *immutableRangeTree) Get(entries ...Entry) Entries { result := make(Entries, 0, len(entries)) for _, entry := range entries { result = append(result, irt.get(entry)) } return result } // Len returns the number of items in this tree. func (irt *immutableRangeTree) Len() uint64 { return irt.number } func newImmutableRangeTree(dimensions uint64) *immutableRangeTree { return &immutableRangeTree{ dimensions: dimensions, } } ================================================ FILE: rangetree/immutable_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree import ( "testing" "github.com/stretchr/testify/assert" ) func TestImmutableSingleDimensionAdd(t *testing.T) { tree := newImmutableRangeTree(1) entry := constructMockEntry(0, int64(0), int64(0)) tree2 := tree.Add(entry) result := tree.Query( constructMockInterval(dimension{0, 10}, dimension{0, 10}), ) assert.Len(t, result, 0) result = tree2.Query( constructMockInterval(dimension{0, 10}, dimension{0, 10}), ) assert.Equal(t, Entries{entry}, result) } func TestImmutableSingleDimensionMultipleAdds(t *testing.T) { tree := newImmutableRangeTree(1) e1 := constructMockEntry(0, int64(0), int64(0)) e2 := constructMockEntry(0, int64(1), int64(1)) e3 := constructMockEntry(0, int64(2), int64(2)) tree1 := tree.Add(e1) tree2 := tree1.Add(e2) tree3 := tree2.Add(e3) iv := constructMockInterval(dimension{0, 10}, dimension{0, 10}) result := tree1.Query(iv) assert.Equal(t, Entries{e1}, result) assert.Equal(t, uint64(1), tree1.Len()) result = tree2.Query(iv) assert.Equal(t, Entries{e1, e2}, result) assert.Equal(t, uint64(2), tree2.Len()) result = tree3.Query(iv) assert.Equal(t, Entries{e1, e2, e3}, result) assert.Equal(t, uint64(3), tree3.Len()) } func TestImmutableSingleDimensionBulkAdd(t *testing.T) { tree := newImmutableRangeTree(1) e1 := constructMockEntry(0, int64(0), int64(0)) e2 := constructMockEntry(0, int64(1), int64(1)) e3 := constructMockEntry(0, int64(2), int64(2)) entries := Entries{e1, e2, e3} tree1 := tree.Add(entries...) result := tree1.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, entries, result) assert.Equal(t, uint64(3), tree1.Len()) } func TestImmutableMultiDimensionAdd(t *testing.T) { tree := newImmutableRangeTree(2) entry := constructMockEntry(0, int64(0), int64(0)) tree2 := tree.Add(entry) result := tree.Query( constructMockInterval(dimension{0, 10}, dimension{0, 10}), ) assert.Len(t, result, 0) result = tree2.Query( constructMockInterval(dimension{0, 10}, dimension{0, 10}), ) assert.Equal(t, Entries{entry}, result) } func TestImmutableMultiDimensionMultipleAdds(t *testing.T) { tree := newImmutableRangeTree(2) e1 := constructMockEntry(0, int64(0), int64(0)) e2 := constructMockEntry(0, int64(1), int64(1)) e3 := constructMockEntry(0, int64(2), int64(2)) tree1 := tree.Add(e1) tree2 := tree1.Add(e2) tree3 := tree2.Add(e3) iv := constructMockInterval(dimension{0, 10}, dimension{0, 10}) result := tree1.Query(iv) assert.Equal(t, Entries{e1}, result) assert.Equal(t, uint64(1), tree1.Len()) result = tree2.Query(iv) assert.Equal(t, Entries{e1, e2}, result) assert.Equal(t, uint64(2), tree2.Len()) result = tree3.Query(iv) assert.Equal(t, Entries{e1, e2, e3}, result) assert.Equal(t, uint64(3), tree3.Len()) } func TestImmutableMultiDimensionBulkAdd(t *testing.T) { tree := newImmutableRangeTree(2) e1 := constructMockEntry(0, int64(0), int64(0)) e2 := constructMockEntry(0, int64(1), int64(1)) e3 := constructMockEntry(0, int64(2), int64(2)) entries := Entries{e1, e2, e3} tree1 := tree.Add(entries...) result := tree1.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, entries, result) assert.Equal(t, uint64(3), tree1.Len()) } func BenchmarkImmutableMultiDimensionInserts(b *testing.B) { numItems := int64(1000) entries := make(Entries, 0, numItems) for i := int64(0); i < numItems; i++ { e := constructMockEntry(uint64(i), i, i) entries = append(entries, e) } b.ResetTimer() for i := 0; i < b.N; i++ { tree := newImmutableRangeTree(2) for _, e := range entries { tree = tree.Add(e) } } } func BenchmarkImmutableMultiDimensionBulkInsert(b *testing.B) { numItems := int64(100000) entries := make(Entries, 0, numItems) for i := int64(0); i < numItems; i++ { e := constructMockEntry(uint64(i), i, i) entries = append(entries, e) } b.ResetTimer() for i := 0; i < b.N; i++ { tree := newImmutableRangeTree(2) tree.Add(entries...) } } func BenchmarkMultiDimensionBulkInsert(b *testing.B) { numItems := int64(100000) entries := make(Entries, 0, numItems) for i := int64(0); i < numItems; i++ { e := constructMockEntry(uint64(i), i, i) entries = append(entries, e) } b.ResetTimer() for i := 0; i < b.N; i++ { tree := newOrderedTree(2) tree.Add(entries...) } } func TestImmutableSingleDimensionDelete(t *testing.T) { tree := newImmutableRangeTree(1) entry := constructMockEntry(0, int64(0), int64(0)) tree2 := tree.Add(entry) tree3 := tree2.Delete(entry) iv := constructMockInterval(dimension{0, 10}, dimension{0, 10}) result := tree3.Query(iv) assert.Len(t, result, 0) } func TestImmutableSingleDimensionMultipleDeletes(t *testing.T) { tree := newImmutableRangeTree(1) e1 := constructMockEntry(0, int64(0), int64(0)) e2 := constructMockEntry(0, int64(1), int64(1)) e3 := constructMockEntry(0, int64(2), int64(2)) tree1 := tree.Add(e1) tree2 := tree1.Add(e2) tree3 := tree2.Add(e3) iv := constructMockInterval(dimension{0, 10}, dimension{0, 10}) tree4 := tree3.Delete(e3) result := tree4.Query(iv) assert.Equal(t, Entries{e1, e2}, result) assert.Equal(t, uint64(2), tree4.Len()) tree5 := tree4.Delete(e2) result = tree5.Query(iv) assert.Equal(t, Entries{e1}, result) assert.Equal(t, uint64(1), tree5.Len()) tree6 := tree5.Delete(e1) result = tree6.Query(iv) assert.Len(t, result, 0) assert.Equal(t, uint64(0), tree6.Len()) result = tree3.Query(iv) assert.Equal(t, Entries{e1, e2, e3}, result) assert.Equal(t, uint64(3), tree3.Len()) tree7 := tree3.Delete(constructMockEntry(0, int64(3), int64(3))) assert.Equal(t, tree3, tree7) } func TestImmutableSingleDimensionBulkDeletes(t *testing.T) { tree := newImmutableRangeTree(1) e1 := constructMockEntry(0, int64(0), int64(0)) e2 := constructMockEntry(0, int64(1), int64(1)) e3 := constructMockEntry(0, int64(2), int64(2)) tree1 := tree.Add(e1, e2, e3) tree2 := tree1.Delete(e2, e3) iv := constructMockInterval(dimension{0, 10}, dimension{0, 10}) result := tree2.Query(iv) assert.Equal(t, Entries{e1}, result) assert.Equal(t, uint64(1), tree2.Len()) tree3 := tree2.Delete(e1) result = tree3.Query(iv) assert.Len(t, result, 0) assert.Equal(t, uint64(0), tree3.Len()) } func TestImmutableMultiDimensionDelete(t *testing.T) { tree := newImmutableRangeTree(2) entry := constructMockEntry(0, int64(0), int64(0)) tree2 := tree.Add(entry) tree3 := tree2.Delete(entry) iv := constructMockInterval(dimension{0, 10}, dimension{0, 10}) result := tree3.Query(iv) assert.Len(t, result, 0) assert.Equal(t, uint64(0), tree3.Len()) } func TestImmutableMultiDimensionMultipleDeletes(t *testing.T) { tree := newImmutableRangeTree(2) e1 := constructMockEntry(0, int64(0), int64(0)) e2 := constructMockEntry(0, int64(1), int64(1)) e3 := constructMockEntry(0, int64(2), int64(2)) tree1 := tree.Add(e1) tree2 := tree1.Add(e2) tree3 := tree2.Add(e3) iv := constructMockInterval(dimension{0, 10}, dimension{0, 10}) tree4 := tree3.Delete(e3) result := tree4.Query(iv) assert.Equal(t, Entries{e1, e2}, result) assert.Equal(t, uint64(2), tree4.Len()) tree5 := tree4.Delete(e2) result = tree5.Query(iv) assert.Equal(t, Entries{e1}, result) assert.Equal(t, uint64(1), tree5.Len()) tree6 := tree5.Delete(e1) result = tree6.Query(iv) assert.Len(t, result, 0) assert.Equal(t, uint64(0), tree6.Len()) result = tree3.Query(iv) assert.Equal(t, Entries{e1, e2, e3}, result) assert.Equal(t, uint64(3), tree3.Len()) tree7 := tree3.Delete(constructMockEntry(0, int64(3), int64(3))) assert.Equal(t, tree3, tree7) } func TestImmutableMultiDimensionBulkDeletes(t *testing.T) { tree := newImmutableRangeTree(2) e1 := constructMockEntry(0, int64(0), int64(0)) e2 := constructMockEntry(0, int64(1), int64(1)) e3 := constructMockEntry(0, int64(2), int64(2)) tree1 := tree.Add(e1, e2, e3) tree2 := tree1.Delete(e2, e3) iv := constructMockInterval(dimension{0, 10}, dimension{0, 10}) result := tree2.Query(iv) assert.Equal(t, Entries{e1}, result) assert.Equal(t, uint64(1), tree2.Len()) tree3 := tree2.Delete(e1) result = tree3.Query(iv) assert.Len(t, result, 0) assert.Equal(t, uint64(0), tree3.Len()) } func constructMultiDimensionalImmutableTree(number int64) (*immutableRangeTree, Entries) { tree := newImmutableRangeTree(2) entries := make(Entries, 0, number) for i := int64(0); i < number; i++ { entries = append(entries, constructMockEntry(uint64(i), i, i)) } return tree.Add(entries...), entries } func TestImmutableInsertPositiveIndexFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(2) tree1, modified, deleted := tree.InsertAtDimension(1, 1, 1) assert.Len(t, deleted, 0) assert.Equal(t, entries[1:], modified) result := tree1.Query(constructMockInterval(dimension{2, 10}, dimension{1, 10})) assert.Equal(t, entries[1:], result) result = tree.Query(constructMockInterval(dimension{2, 10}, dimension{0, 10})) assert.Len(t, result, 0) } func TestImmutableInsertPositiveIndexSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(2, 1, 1) assert.Len(t, deleted, 0) assert.Equal(t, entries[1:], modified) result := tree1.Query(constructMockInterval(dimension{1, 10}, dimension{2, 10})) assert.Equal(t, entries[1:], result) result = tree.Query(constructMockInterval(dimension{1, 10}, dimension{2, 10})) assert.Equal(t, entries[2:], result) } func TestImmutableInsertPositiveIndexOutOfBoundsFirstDimension(t *testing.T) { tree, _ := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(1, 4, 1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) assert.Equal(t, tree, tree1) } func TestImmutableInsertPositiveIndexOutOfBoundsSecondDimension(t *testing.T) { tree, _ := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(2, 4, 1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) assert.Equal(t, tree, tree1) } func TestImmutableInsertMultiplePositiveIndexFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(1, 1, 2) assert.Len(t, deleted, 0) assert.Equal(t, entries[1:], modified) result := tree1.Query(constructMockInterval(dimension{3, 10}, dimension{1, 10})) assert.Equal(t, entries[1:], result) result = tree.Query(constructMockInterval(dimension{3, 10}, dimension{1, 10})) assert.Len(t, result, 0) } func TestImmutableInsertMultiplePositiveIndexSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(2, 1, 2) assert.Len(t, deleted, 0) assert.Equal(t, entries[1:], modified) result := tree1.Query(constructMockInterval(dimension{1, 10}, dimension{3, 10})) assert.Equal(t, entries[1:], result) result = tree.Query(constructMockInterval(dimension{1, 10}, dimension{3, 10})) assert.Len(t, result, 0) } func TestImmutableInsertNegativeIndexFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(1, 1, -1) assert.Equal(t, entries[1:2], deleted) assert.Equal(t, entries[2:], modified) result := tree1.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Equal(t, entries[2:], result) result = tree1.Query(constructMockInterval(dimension{2, 10}, dimension{1, 10})) assert.Len(t, result, 0) assert.Equal(t, uint64(2), tree1.Len()) result = tree.Query(constructMockInterval(dimension{2, 10}, dimension{1, 10})) assert.Equal(t, entries[2:], result) assert.Equal(t, uint64(3), tree.Len()) } func TestImmutableInsertNegativeIndexSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(2, 1, -1) assert.Equal(t, entries[1:2], deleted) assert.Equal(t, entries[2:], modified) result := tree1.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Equal(t, entries[2:], result) result = tree1.Query(constructMockInterval(dimension{1, 10}, dimension{2, 10})) assert.Len(t, result, 0) assert.Equal(t, uint64(2), tree1.Len()) result = tree.Query(constructMockInterval(dimension{1, 10}, dimension{2, 10})) assert.Equal(t, entries[2:], result) assert.Equal(t, uint64(3), tree.Len()) } func TestImmutableInsertNegativeIndexOutOfBoundsFirstDimension(t *testing.T) { tree, _ := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(1, 4, -1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) assert.Equal(t, tree, tree1) } func TestImmutableInsertNegativeIndexOutOfBoundsSecondDimension(t *testing.T) { tree, _ := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(2, 4, -1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) assert.Equal(t, tree, tree1) } func TestImmutableInsertMultipleNegativeIndexFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(1, 1, -2) assert.Equal(t, entries[1:], deleted) assert.Len(t, modified, 0) result := tree1.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Len(t, result, 0) assert.Equal(t, uint64(1), tree1.Len()) result = tree.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Equal(t, entries[1:], result) } func TestImmutableInsertMultipleNegativeIndexSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(2, 1, -2) assert.Equal(t, entries[1:], deleted) assert.Len(t, modified, 0) result := tree1.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Len(t, result, 0) assert.Equal(t, uint64(1), tree1.Len()) result = tree.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Equal(t, entries[1:], result) } func TestImmutableInsertInvalidDimension(t *testing.T) { tree, _ := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(3, 1, -1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) assert.Equal(t, tree, tree1) } func TestImmutableInsertInvalidNumber(t *testing.T) { tree, _ := constructMultiDimensionalImmutableTree(3) tree1, modified, deleted := tree.InsertAtDimension(1, 1, 0) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) assert.Equal(t, tree, tree1) } func TestImmutableGet(t *testing.T) { tree, entries := constructMultiDimensionalImmutableTree(2) result := tree.Get(entries...) assert.Equal(t, entries, result) result = tree.Get(constructMockEntry(10000, 5000, 5000)) assert.Equal(t, Entries{nil}, result) } func BenchmarkImmutableInsertFirstDimension(b *testing.B) { numItems := int64(100000) tree, _ := constructMultiDimensionalImmutableTree(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree.InsertAtDimension(1, 0, 1) } } func BenchmarkImmutableInsertSecondDimension(b *testing.B) { numItems := int64(100000) tree, _ := constructMultiDimensionalImmutableTree(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree.InsertAtDimension(2, 0, 1) } } ================================================ FILE: rangetree/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package rangetree is designed to store n-dimensional data in an easy-to-query way. Given this package's primary use as representing cartesian data, this information is represented by int64s at n-dimensions. This implementation is not actually a tree but a sparse n-dimensional list. This package also includes two implementations of this sparse list, one mutable (and not threadsafe) and another that is immutable copy-on-write which is threadsafe. The mutable version is obviously faster but will likely have write contention for any consumer that needs a threadsafe rangetree. TODO: unify both implementations with the same interface. */ package rangetree // Entry defines items that can be added to the rangetree. type Entry interface { // ValueAtDimension returns the value of this entry // at the specified dimension. ValueAtDimension(dimension uint64) int64 } // Interval describes the methods required to query the rangetree. Note that // all ranges are inclusive. type Interval interface { // LowAtDimension returns an integer representing the lower bound // at the requested dimension. LowAtDimension(dimension uint64) int64 // HighAtDimension returns an integer representing the higher bound // at the request dimension. HighAtDimension(dimension uint64) int64 } // RangeTree describes the methods available to the rangetree. type RangeTree interface { // Add will add the provided entries to the tree. Any entries that // were overwritten will be returned in the order in which they // were overwritten. If an entry's addition does not overwrite, a nil // is returned for that entry's index in the provided cells. Add(entries ...Entry) Entries // Len returns the number of entries in the tree. Len() uint64 // Delete will remove the provided entries from the tree. // Any entries that were deleted will be returned in the order in // which they were deleted. If an entry does not exist to be deleted, // a nil is returned for that entry's index in the provided cells. Delete(entries ...Entry) Entries // Query will return a list of entries that fall within // the provided interval. The values at dimensions are inclusive. Query(interval Interval) Entries // Apply will call the provided function with each entry that exists // within the provided range, in order. Return false at any time to // cancel iteration. Altering the entry in such a way that its location // changes will result in undefined behavior. Apply(interval Interval, fn func(Entry) bool) // Get returns any entries that exist at the addresses provided by the // given entries. Entries are returned in the order in which they are // received. If an entry cannot be found, a nil is returned in its // place. Get(entries ...Entry) Entries // InsertAtDimension will increment items at and above the given index // by the number provided. Provide a negative number to to decrement. // Returned are two lists. The first list is a list of entries that // were moved. The second is a list entries that were deleted. These // lists are exclusive. InsertAtDimension(dimension uint64, index, number int64) (Entries, Entries) } ================================================ FILE: rangetree/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree type mockEntry struct { id uint64 dimensions []int64 } func (me *mockEntry) ID() uint64 { return me.id } func (me *mockEntry) ValueAtDimension(dimension uint64) int64 { return me.dimensions[dimension-1] } func constructMockEntry(id uint64, values ...int64) *mockEntry { return &mockEntry{ id: id, dimensions: values, } } type dimension struct { low, high int64 } type mockInterval struct { dimensions []dimension } func (mi *mockInterval) LowAtDimension(dimension uint64) int64 { return mi.dimensions[dimension-1].low } func (mi *mockInterval) HighAtDimension(dimension uint64) int64 { return mi.dimensions[dimension-1].high } func constructMockInterval(dimensions ...dimension) *mockInterval { return &mockInterval{ dimensions: dimensions, } } ================================================ FILE: rangetree/node.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree type nodes []*node type node struct { value int64 entry Entry orderedNodes orderedNodes } func newNode(value int64, entry Entry, needNextDimension bool) *node { n := &node{} n.value = value if needNextDimension { n.orderedNodes = make(orderedNodes, 0, 10) } else { n.entry = entry } return n } ================================================ FILE: rangetree/ordered.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree import "sort" // orderedNodes represents an ordered list of points living // at the last dimension. No duplicates can be inserted here. type orderedNodes nodes func (nodes orderedNodes) search(value int64) int { return sort.Search( len(nodes), func(i int) bool { return nodes[i].value >= value }, ) } // addAt will add the provided node at the provided index. Returns // a node if one was overwritten. func (nodes *orderedNodes) addAt(i int, node *node) *node { if i == len(*nodes) { *nodes = append(*nodes, node) return nil } if (*nodes)[i].value == node.value { overwritten := (*nodes)[i] // this is a duplicate, there can't be a duplicate // point in the last dimension (*nodes)[i] = node return overwritten } *nodes = append(*nodes, nil) copy((*nodes)[i+1:], (*nodes)[i:]) (*nodes)[i] = node return nil } func (nodes *orderedNodes) add(node *node) *node { i := nodes.search(node.value) return nodes.addAt(i, node) } func (nodes *orderedNodes) deleteAt(i int) *node { if i >= len(*nodes) { // no matching found return nil } deleted := (*nodes)[i] copy((*nodes)[i:], (*nodes)[i+1:]) (*nodes)[len(*nodes)-1] = nil *nodes = (*nodes)[:len(*nodes)-1] return deleted } func (nodes *orderedNodes) delete(value int64) *node { i := nodes.search(value) if (*nodes)[i].value != value || i == len(*nodes) { return nil } return nodes.deleteAt(i) } func (nodes orderedNodes) apply(low, high int64, fn func(*node) bool) bool { index := nodes.search(low) if index == len(nodes) { return true } for ; index < len(nodes); index++ { if nodes[index].value > high { break } if !fn(nodes[index]) { return false } } return true } func (nodes orderedNodes) get(value int64) (*node, int) { i := nodes.search(value) if i == len(nodes) { return nil, i } if nodes[i].value == value { return nodes[i], i } return nil, i } func (nodes *orderedNodes) getOrAdd(entry Entry, dimension, lastDimension uint64) (*node, bool) { isLastDimension := isLastDimension(lastDimension, dimension) value := entry.ValueAtDimension(dimension) i := nodes.search(value) if i == len(*nodes) { node := newNode(value, entry, !isLastDimension) *nodes = append(*nodes, node) return node, true } if (*nodes)[i].value == value { return (*nodes)[i], false } node := newNode(value, entry, !isLastDimension) *nodes = append(*nodes, nil) copy((*nodes)[i+1:], (*nodes)[i:]) (*nodes)[i] = node return node, true } func (nodes orderedNodes) flatten(entries *Entries) { for _, node := range nodes { if node.orderedNodes != nil { node.orderedNodes.flatten(entries) } else { *entries = append(*entries, node.entry) } } } func (nodes *orderedNodes) insert(insertDimension, dimension, maxDimension uint64, index, number int64, modified, deleted *Entries) { lastDimension := isLastDimension(maxDimension, dimension) if insertDimension == dimension { i := nodes.search(index) var toDelete []int for j := i; j < len(*nodes); j++ { (*nodes)[j].value += number if (*nodes)[j].value < index { toDelete = append(toDelete, j) if lastDimension { *deleted = append(*deleted, (*nodes)[j].entry) } else { (*nodes)[j].orderedNodes.flatten(deleted) } continue } if lastDimension { *modified = append(*modified, (*nodes)[j].entry) } else { (*nodes)[j].orderedNodes.flatten(modified) } } for i, index := range toDelete { nodes.deleteAt(index - i) } return } for _, node := range *nodes { node.orderedNodes.insert( insertDimension, dimension+1, maxDimension, index, number, modified, deleted, ) } } func (nodes orderedNodes) immutableInsert(insertDimension, dimension, maxDimension uint64, index, number int64, modified, deleted *Entries) orderedNodes { lastDimension := isLastDimension(maxDimension, dimension) cp := make(orderedNodes, len(nodes)) copy(cp, nodes) if insertDimension == dimension { i := cp.search(index) var toDelete []int for j := i; j < len(cp); j++ { nn := newNode(cp[j].value+number, cp[j].entry, !lastDimension) nn.orderedNodes = cp[j].orderedNodes cp[j] = nn if cp[j].value < index { toDelete = append(toDelete, j) if lastDimension { *deleted = append(*deleted, cp[j].entry) } else { cp[j].orderedNodes.flatten(deleted) } continue } if lastDimension { *modified = append(*modified, cp[j].entry) } else { cp[j].orderedNodes.flatten(modified) } } for _, index := range toDelete { cp.deleteAt(index) } return cp } for i := 0; i < len(cp); i++ { oldNode := nodes[i] nn := newNode(oldNode.value, oldNode.entry, !lastDimension) nn.orderedNodes = oldNode.orderedNodes.immutableInsert( insertDimension, dimension+1, maxDimension, index, number, modified, deleted, ) cp[i] = nn } return cp } ================================================ FILE: rangetree/ordered_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree import ( "testing" "github.com/stretchr/testify/assert" ) func TestOrderedAdd(t *testing.T) { nodes := make(orderedNodes, 0) n1 := newNode(4, constructMockEntry(1, 4), false) n2 := newNode(1, constructMockEntry(2, 1), false) overwritten := nodes.add(n1) assert.Nil(t, overwritten) overwritten = nodes.add(n2) assert.Nil(t, overwritten) assert.Equal(t, orderedNodes{n2, n1}, nodes) n3 := newNode(4, constructMockEntry(1, 4), false) overwritten = nodes.add(n3) assert.True(t, n1 == overwritten) assert.Equal(t, orderedNodes{n2, n3}, nodes) } func TestOrderedDelete(t *testing.T) { nodes := make(orderedNodes, 0) n1 := newNode(4, constructMockEntry(1, 4), false) n2 := newNode(1, constructMockEntry(2, 1), false) nodes.add(n1) nodes.add(n2) deleted := nodes.delete(n2.value) assert.Equal(t, orderedNodes{n1}, nodes) assert.Equal(t, n2, deleted) missingValue := int64(3) deleted = nodes.delete(missingValue) assert.Equal(t, orderedNodes{n1}, nodes) assert.Nil(t, deleted) deleted = nodes.delete(n1.value) assert.Empty(t, nodes) assert.Equal(t, n1, deleted) } func TestApply(t *testing.T) { ns := make(orderedNodes, 0) n1 := newNode(4, constructMockEntry(1, 4), false) n2 := newNode(1, constructMockEntry(2, 1), false) ns.add(n1) ns.add(n2) results := make(nodes, 0, 2) ns.apply(1, 1, func(n *node) bool { results = append(results, n) return true }) assert.Equal(t, nodes{n2}, results) results = results[:0] ns.apply(0, 0, func(n *node) bool { results = append(results, n) return true }) assert.Len(t, results, 0) results = results[:0] ns.apply(2, 3, func(n *node) bool { results = append(results, n) return true }) assert.Len(t, results, 0) results = results[:0] ns.apply(4, 5, func(n *node) bool { results = append(results, n) return true }) assert.Equal(t, nodes{n1}, results) results = results[:0] ns.apply(0, 5, func(n *node) bool { results = append(results, n) return true }) assert.Equal(t, nodes{n2, n1}, results) results = results[:0] ns.apply(5, 10, func(n *node) bool { results = append(results, n) return true }) assert.Len(t, results, 0) results = results[:0] ns.apply(0, 100, func(n *node) bool { results = append(results, n) return false }) assert.Equal(t, nodes{n2}, results) } func TestInsertDelete(t *testing.T) { ns := make(orderedNodes, 0) n1 := newNode(4, constructMockEntry(1, 4), false) n2 := newNode(1, constructMockEntry(2, 1), false) n3 := newNode(2, constructMockEntry(3, 2), false) ns.add(n1) ns.add(n2) ns.add(n3) modified := make(Entries, 0, 1) deleted := make(Entries, 0, 1) ns.insert(2, 2, 2, 0, -5, &modified, &deleted) assert.Len(t, ns, 0) assert.Equal(t, Entries{n2.entry, n3.entry, n1.entry}, deleted) } func BenchmarkPrepend(b *testing.B) { numItems := 100000 ns := make(orderedNodes, 0, numItems) for i := b.N; i < b.N+numItems; i++ { ns.add(newNode(int64(i), constructMockEntry(uint64(i), int64(i)), false)) } b.ResetTimer() for i := 0; i < b.N; i++ { ns.add(newNode(int64(i), constructMockEntry(uint64(i), int64(i)), false)) } } ================================================ FILE: rangetree/orderedtree.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree func isLastDimension(value, test uint64) bool { return test >= value } type nodeBundle struct { list *orderedNodes index int } type orderedTree struct { top orderedNodes number uint64 dimensions uint64 path []*nodeBundle } func (ot *orderedTree) resetPath() { ot.path = ot.path[:0] } func (ot *orderedTree) needNextDimension() bool { return ot.dimensions > 1 } // add will add the provided entry to the rangetree and return an // entry if one was overwritten. func (ot *orderedTree) add(entry Entry) *node { var node *node list := &ot.top for i := uint64(1); i <= ot.dimensions; i++ { if isLastDimension(ot.dimensions, i) { overwritten := list.add( newNode(entry.ValueAtDimension(i), entry, false), ) if overwritten == nil { ot.number++ } return overwritten } node, _ = list.getOrAdd(entry, i, ot.dimensions) list = &node.orderedNodes } return nil } // Add will add the provided entries to the tree. This method // returns a list of entries that were overwritten in the order // in which entries were received. If an entry doesn't overwrite // anything, a nil will be returned for that entry in the returned // slice. func (ot *orderedTree) Add(entries ...Entry) Entries { if len(entries) == 0 { return nil } overwrittens := make(Entries, len(entries)) for i, entry := range entries { if entry == nil { continue } overwritten := ot.add(entry) if overwritten != nil { overwrittens[i] = overwritten.entry } } return overwrittens } func (ot *orderedTree) delete(entry Entry) *node { ot.resetPath() var index int var node *node list := &ot.top for i := uint64(1); i <= ot.dimensions; i++ { value := entry.ValueAtDimension(i) node, index = list.get(value) if node == nil { // there's nothing to delete return nil } nb := &nodeBundle{list: list, index: index} ot.path = append(ot.path, nb) list = &node.orderedNodes } ot.number-- for i := len(ot.path) - 1; i >= 0; i-- { nb := ot.path[i] nb.list.deleteAt(nb.index) if len(*nb.list) > 0 { break } } return node } func (ot *orderedTree) get(entry Entry) Entry { on := ot.top for i := uint64(1); i <= ot.dimensions; i++ { n, _ := on.get(entry.ValueAtDimension(i)) if n == nil { return nil } if i == ot.dimensions { return n.entry } on = n.orderedNodes } return nil } // Get returns any entries that exist at the addresses provided by the // given entries. Entries are returned in the order in which they are // received. If an entry cannot be found, a nil is returned in its // place. func (ot *orderedTree) Get(entries ...Entry) Entries { result := make(Entries, 0, len(entries)) for _, entry := range entries { result = append(result, ot.get(entry)) } return result } // Delete will remove the provided entries from the tree. // Any entries that were deleted will be returned in the order in // which they were deleted. If an entry does not exist to be deleted, // a nil is returned for that entry's index in the provided cells. func (ot *orderedTree) Delete(entries ...Entry) Entries { if len(entries) == 0 { return nil } deletedEntries := make(Entries, len(entries)) for i, entry := range entries { if entry == nil { continue } deleted := ot.delete(entry) if deleted != nil { deletedEntries[i] = deleted.entry } } return deletedEntries } // Len returns the number of items in the tree. func (ot *orderedTree) Len() uint64 { return ot.number } func (ot *orderedTree) apply(list orderedNodes, interval Interval, dimension uint64, fn func(*node) bool) bool { low, high := interval.LowAtDimension(dimension), interval.HighAtDimension(dimension) if isLastDimension(ot.dimensions, dimension) { if !list.apply(low, high, fn) { return false } } else { if !list.apply(low, high, func(n *node) bool { if !ot.apply(n.orderedNodes, interval, dimension+1, fn) { return false } return true }) { return false } return true } return true } // Apply will call (in order) the provided function to every // entry that falls within the provided interval. Any alteration // the the entry that would result in different answers to the // interface methods results in undefined behavior. func (ot *orderedTree) Apply(interval Interval, fn func(Entry) bool) { ot.apply(ot.top, interval, 1, func(n *node) bool { return fn(n.entry) }) } // Query will return an ordered list of results in the given // interval. func (ot *orderedTree) Query(interval Interval) Entries { entries := NewEntries() ot.apply(ot.top, interval, 1, func(n *node) bool { entries = append(entries, n.entry) return true }) return entries } // InsertAtDimension will increment items at and above the given index // by the number provided. Provide a negative number to to decrement. // Returned are two lists. The first list is a list of entries that // were moved. The second is a list entries that were deleted. These // lists are exclusive. func (ot *orderedTree) InsertAtDimension(dimension uint64, index, number int64) (Entries, Entries) { // TODO: perhaps return an error here? if dimension > ot.dimensions || number == 0 { return nil, nil } modified := make(Entries, 0, 100) deleted := make(Entries, 0, 100) ot.top.insert(dimension, 1, ot.dimensions, index, number, &modified, &deleted, ) ot.number -= uint64(len(deleted)) return modified, deleted } func newOrderedTree(dimensions uint64) *orderedTree { return &orderedTree{ dimensions: dimensions, path: make([]*nodeBundle, 0, dimensions), } } // New is the constructor to create a new rangetree with // the provided number of dimensions. func New(dimensions uint64) RangeTree { return newOrderedTree(dimensions) } ================================================ FILE: rangetree/orderedtree_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rangetree import ( "math/rand" "testing" "github.com/stretchr/testify/assert" ) func constructMultiDimensionalOrderedTree(number uint64) ( *orderedTree, Entries) { tree := newOrderedTree(2) entries := make(Entries, 0, number) for i := uint64(0); i < number; i++ { entries = append(entries, constructMockEntry(i, int64(i), int64(i))) } tree.Add(entries...) return tree, entries } func TestOTRootAddMultipleDimensions(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(1) assert.Equal(t, uint64(1), tree.Len()) result := tree.Query(constructMockInterval(dimension{0, 0}, dimension{0, 0})) assert.Equal(t, Entries{entries[0]}, result) } func TestOTMultipleAddMultipleDimensions(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(4) assert.Equal(t, uint64(4), tree.Len()) result := tree.Query(constructMockInterval(dimension{0, 0}, dimension{0, 0})) assert.Equal(t, Entries{entries[0]}, result) result = tree.Query(constructMockInterval(dimension{3, 4}, dimension{3, 4})) assert.Equal(t, Entries{entries[3]}, result) result = tree.Query(constructMockInterval(dimension{0, 4}, dimension{0, 4})) assert.Equal(t, entries, result) result = tree.Query(constructMockInterval(dimension{1, 2}, dimension{1, 2})) assert.Equal(t, Entries{entries[1], entries[2]}, result) result = tree.Query(constructMockInterval(dimension{0, 2}, dimension{10, 20})) assert.Len(t, result, 0) result = tree.Query(constructMockInterval(dimension{10, 20}, dimension{0, 2})) assert.Len(t, result, 0) result = tree.Query(constructMockInterval(dimension{0, 1}, dimension{0, 0})) assert.Equal(t, Entries{entries[0]}, result) result = tree.Query(constructMockInterval(dimension{0, 0}, dimension{0, 1})) assert.Equal(t, Entries{entries[0]}, result) } func TestOTAddInOrderMultiDimensions(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(10) result := tree.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, uint64(10), tree.Len()) assert.Len(t, result, 10) assert.Equal(t, entries, result) } func TestOTAddReverseOrderMultiDimensions(t *testing.T) { tree := newOrderedTree(2) for i := uint64(10); i > 0; i-- { tree.Add(constructMockEntry(i, int64(i), int64(i))) } result := tree.Query(constructMockInterval(dimension{0, 11}, dimension{0, 11})) assert.Len(t, result, 10) assert.Equal(t, uint64(10), tree.Len()) } func TestOTAddRandomOrderMultiDimensions(t *testing.T) { tree := newOrderedTree(2) starts := []uint64{0, 4, 2, 1, 3} for _, start := range starts { tree.Add(constructMockEntry(start, int64(start), int64(start))) } result := tree.Query(constructMockInterval(dimension{0, 5}, dimension{0, 5})) assert.Len(t, result, 5) assert.Equal(t, uint64(5), tree.Len()) } func TestOTAddLargeNumbersMultiDimension(t *testing.T) { numItems := uint64(1000) tree := newOrderedTree(2) for i := uint64(0); i < numItems; i++ { tree.Add(constructMockEntry(i, int64(i), int64(i))) } result := tree.Query( constructMockInterval( dimension{0, int64(numItems)}, dimension{0, int64(numItems)}, ), ) assert.Equal(t, numItems, tree.Len()) assert.Len(t, result, int(numItems)) } func TestOTAddReturnsOverwritten(t *testing.T) { tree := newOrderedTree(2) starts := []uint64{0, 4, 2, 1, 3} entries := make(Entries, 0, len(starts)) for _, start := range starts { entries = append(entries, constructMockEntry(start, int64(start), int64(start))) } overwritten := tree.Add(entries...) assert.Equal(t, Entries{nil, nil, nil, nil, nil}, overwritten) oldEntry := entries[2] newEntry := constructMockEntry(10, oldEntry.ValueAtDimension(1), oldEntry.ValueAtDimension(2)) overwritten = tree.Add(newEntry) assert.Equal(t, Entries{oldEntry}, overwritten) result := tree.Query(constructMockInterval(dimension{0, 5}, dimension{0, 5})) assert.Len(t, result, 5) assert.Equal(t, uint64(5), tree.Len()) } func BenchmarkOTAddItemsMultiDimensions(b *testing.B) { numItems := b.N entries := make(Entries, 0, numItems) for i := uint64(0); i < uint64(numItems); i++ { value := rand.Int63() entries = append(entries, constructMockEntry(i, value, value)) } rt := newOrderedTree(2) b.ResetTimer() for i := 0; i < b.N; i++ { rt.Add(entries[i%numItems]) } } func BenchmarkOTQueryItemsMultiDimensions(b *testing.B) { numItems := uint64(1000) entries := make(Entries, 0, numItems) for i := uint64(0); i < numItems; i++ { entries = append(entries, constructMockEntry(i, int64(i), int64(i))) } tree := newOrderedTree(2) tree.Add(entries...) iv := constructMockInterval( dimension{0, int64(numItems)}, dimension{0, int64(numItems)}, ) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Query(iv) } } func TestOTRootDeleteMultiDimensions(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(1) tree.Delete(entries...) assert.Equal(t, uint64(0), tree.Len()) result := tree.Query(constructMockInterval(dimension{0, 100}, dimension{0, 100})) assert.Len(t, result, 0) } func TestOTDeleteMultiDimensions(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(4) tree.Delete(entries[2]) assert.Equal(t, uint64(3), tree.Len()) result := tree.Query(constructMockInterval(dimension{0, 4}, dimension{0, 4})) assert.Equal(t, Entries{entries[0], entries[1], entries[3]}, result) result = tree.Query(constructMockInterval(dimension{3, 4}, dimension{3, 4})) assert.Equal(t, Entries{entries[3]}, result) result = tree.Query(constructMockInterval(dimension{0, 2}, dimension{0, 2})) assert.Equal(t, Entries{entries[0], entries[1]}, result) } func TestOTDeleteInOrderMultiDimensions(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(10) tree.Delete(entries[5]) result := tree.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Len(t, result, 9) assert.Equal(t, uint64(9), tree.Len()) assert.NotContains(t, result, entries[5]) } func TestOTDeleteReverseOrderMultiDimensions(t *testing.T) { tree := newOrderedTree(2) entries := NewEntries() for i := uint64(10); i > 0; i-- { entries = append(entries, constructMockEntry(i, int64(i), int64(i))) } tree.Add(entries...) tree.Delete(entries[5]) result := tree.Query(constructMockInterval(dimension{0, 11}, dimension{0, 11})) assert.Len(t, result, 9) assert.Equal(t, uint64(9), tree.Len()) assert.NotContains(t, result, entries[5]) } func TestOTDeleteRandomOrderMultiDimensions(t *testing.T) { tree := newOrderedTree(2) entries := NewEntries() starts := []uint64{0, 4, 2, 1, 3} for _, start := range starts { entries = append(entries, constructMockEntry(start, int64(start), int64(start))) } tree.Add(entries...) tree.Delete(entries[2]) result := tree.Query(constructMockInterval(dimension{0, 11}, dimension{0, 11})) assert.Len(t, result, 4) assert.Equal(t, uint64(4), tree.Len()) assert.NotContains(t, result, entries[2]) } func TestOTDeleteEmptyTreeMultiDimensions(t *testing.T) { tree := newOrderedTree(2) tree.Delete(constructMockEntry(0, 0, 0)) assert.Equal(t, uint64(0), tree.Len()) } func TestOTDeleteReturnsDeleted(t *testing.T) { tree := newOrderedTree(2) entries := NewEntries() starts := []uint64{0, 4, 2, 1, 3} for _, start := range starts { entries = append(entries, constructMockEntry(start, int64(start), int64(start))) } tree.Add(entries...) deleted := tree.Delete(entries[2], constructMockEntry(10, 10, 10)) assert.Equal(t, Entries{entries[2], nil}, deleted) result := tree.Query(constructMockInterval(dimension{0, 11}, dimension{0, 11})) assert.Len(t, result, 4) assert.Equal(t, uint64(4), tree.Len()) assert.NotContains(t, result, entries[2]) } func BenchmarkOTDeleteItemsMultiDimensions(b *testing.B) { numItems := uint64(1000) entries := make(Entries, 0, numItems) for i := uint64(0); i < numItems; i++ { entries = append(entries, constructMockEntry(i, int64(i), int64(i))) } trees := make([]*orderedTree, 0, b.N) for i := 0; i < b.N; i++ { tree := newOrderedTree(2) tree.Add(entries...) trees = append(trees, tree) } b.ResetTimer() for i := 0; i < b.N; i++ { trees[i].Delete(entries...) } } func TestOverwrites(t *testing.T) { tree, _ := constructMultiDimensionalOrderedTree(1) entry := constructMockEntry(10, 10, 10) overwritten := tree.Add(entry) assert.Equal(t, Entries{nil}, overwritten) results := tree.Query(constructMockInterval(dimension{10, 11}, dimension{10, 11})) assert.Equal(t, Entries{entry}, results) assert.Equal(t, uint64(2), tree.Len()) newEntry := constructMockEntry(10, 10, 10) overwritten = tree.Add(newEntry) assert.Equal(t, Entries{entry}, overwritten) } func TestGet(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(2) result := tree.Get(entries...) assert.Equal(t, entries, result) result = tree.Get(constructMockEntry(10000, 5000, 5000)) assert.Equal(t, Entries{nil}, result) } func TestTreeApply(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(2) result := make(Entries, 0, len(entries)) tree.Apply(constructMockInterval(dimension{0, 100}, dimension{0, 100}), func(e Entry) bool { result = append(result, e) return true }, ) assert.Equal(t, entries, result) } func TestApplyWithBail(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(2) result := make(Entries, 0, 1) tree.Apply(constructMockInterval(dimension{0, 100}, dimension{0, 100}), func(e Entry) bool { result = append(result, e) return false }, ) assert.Equal(t, entries[:1], result) } func BenchmarkApply(b *testing.B) { numItems := 1000 tree, _ := constructMultiDimensionalOrderedTree(uint64(numItems)) iv := constructMockInterval( dimension{0, int64(numItems)}, dimension{0, int64(numItems)}, ) fn := func(Entry) bool { return true } b.ResetTimer() for i := 0; i < b.N; i++ { tree.Apply(iv, fn) } } func TestInsertPositiveIndexFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(2) modified, deleted := tree.InsertAtDimension(1, 1, 1) assert.Len(t, deleted, 0) assert.Equal(t, entries[1:], modified) result := tree.Query(constructMockInterval(dimension{2, 10}, dimension{1, 10})) assert.Equal(t, entries[1:], result) } func TestInsertPositiveIndexSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(2, 1, 1) assert.Len(t, deleted, 0) assert.Equal(t, entries[1:], modified) result := tree.Query(constructMockInterval(dimension{1, 10}, dimension{2, 10})) assert.Equal(t, entries[1:], result) } func TestInsertPositiveIndexOutOfBoundsFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(1, 4, 1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) result := tree.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, entries, result) } func TestInsertPositiveIndexOutOfBoundsSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(2, 4, 1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) result := tree.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, entries, result) } func TestInsertMultiplePositiveIndexFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(1, 1, 2) assert.Len(t, deleted, 0) assert.Equal(t, entries[1:], modified) result := tree.Query(constructMockInterval(dimension{3, 10}, dimension{1, 10})) assert.Equal(t, entries[1:], result) } func TestInsertMultiplePositiveIndexSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(2, 1, 2) assert.Len(t, deleted, 0) assert.Equal(t, entries[1:], modified) result := tree.Query(constructMockInterval(dimension{1, 10}, dimension{3, 10})) assert.Equal(t, entries[1:], result) } func TestInsertNegativeIndexFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(1, 1, -1) assert.Equal(t, entries[1:2], deleted) assert.Equal(t, entries[2:], modified) result := tree.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Equal(t, entries[2:], result) result = tree.Query(constructMockInterval(dimension{2, 10}, dimension{1, 10})) assert.Len(t, result, 0) assert.Equal(t, uint64(2), tree.Len()) } func TestInsertNegativeIndexSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(2, 1, -1) assert.Equal(t, entries[1:2], deleted) assert.Equal(t, entries[2:], modified) result := tree.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Equal(t, entries[2:], result) result = tree.Query(constructMockInterval(dimension{1, 10}, dimension{2, 10})) assert.Len(t, result, 0) assert.Equal(t, uint64(2), tree.Len()) } func TestInsertNegativeIndexOutOfBoundsFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(1, 4, -1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) result := tree.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, entries, result) assert.Equal(t, uint64(3), tree.Len()) } func TestInsertNegativeIndexOutOfBoundsSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(2, 4, -1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) result := tree.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, entries, result) assert.Equal(t, uint64(3), tree.Len()) } func TestInsertMultipleNegativeIndexFirstDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(1, 1, -2) assert.Equal(t, entries[1:], deleted) assert.Len(t, modified, 0) result := tree.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Len(t, result, 0) assert.Equal(t, uint64(1), tree.Len()) } func TestInsertMultipleNegativeIndexSecondDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(2, 1, -2) assert.Equal(t, entries[1:], deleted) assert.Len(t, modified, 0) result := tree.Query(constructMockInterval(dimension{1, 10}, dimension{1, 10})) assert.Len(t, result, 0) assert.Equal(t, uint64(1), tree.Len()) } func TestInsertInvalidDimension(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(3, 1, -1) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) result := tree.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, entries, result) } func TestInsertInvalidNumber(t *testing.T) { tree, entries := constructMultiDimensionalOrderedTree(3) modified, deleted := tree.InsertAtDimension(1, 1, 0) assert.Len(t, modified, 0) assert.Len(t, deleted, 0) result := tree.Query(constructMockInterval(dimension{0, 10}, dimension{0, 10})) assert.Equal(t, entries, result) } func BenchmarkInsertFirstDimension(b *testing.B) { numItems := uint64(100000) tree, _ := constructMultiDimensionalOrderedTree(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree.InsertAtDimension(1, 0, 1) } } func BenchmarkInsertSecondDimension(b *testing.B) { numItems := uint64(100000) tree, _ := constructMultiDimensionalOrderedTree(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree.InsertAtDimension(2, 0, 1) } } func BenchmarkDeleteFirstDimension(b *testing.B) { numItems := uint64(100000) tree, _ := constructMultiDimensionalOrderedTree(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree.InsertAtDimension(1, 0, -1) } } func BenchmarkDeleteSecondDimension(b *testing.B) { numItems := uint64(100000) tree, _ := constructMultiDimensionalOrderedTree(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree.InsertAtDimension(2, 0, -1) } } func BenchmarkGetMultiDimensions(b *testing.B) { numItemsX := 10000 numItemsY := 100 tree := newOrderedTree(2) entries := make(Entries, 0, numItemsY*numItemsX) for i := 0; i < numItemsX; i++ { for j := 0; j < numItemsY; j++ { e := constructMockEntry(uint64(j*numItemsY+i), int64(i), int64(j)) entries = append(entries, e) } } tree.Add(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Get(entries[i%len(entries)]) } } ================================================ FILE: rangetree/skiplist/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package skiplist type mockEntry struct { values []int64 } func (me *mockEntry) ValueAtDimension(dimension uint64) int64 { return me.values[dimension] } func newMockEntry(values ...int64) *mockEntry { return &mockEntry{values: values} } type mockInterval struct { lows, highs []int64 } func (mi *mockInterval) LowAtDimension(dimension uint64) int64 { return mi.lows[dimension] } func (mi *mockInterval) HighAtDimension(dimension uint64) int64 { return mi.highs[dimension] } func newMockInterval(lows, highs []int64) *mockInterval { return &mockInterval{ lows: lows, highs: highs, } } ================================================ FILE: rangetree/skiplist/skiplist.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package skiplist implements an n-dimensional rangetree based on a skip list. This should be faster than a straight slice implementation as memcopy is avoided. Time complexities revolve around the ability to quickly find items in the n-dimensional skiplist. That time can be defined by the number of items in any dimension. Let N1, N2,... Nn define the number of dimensions. Performance characteristics: Space: O(n) Search: O(log N1 + log N2 + ...log Nn) = O(log N1*N2*...Nn) Insert: O(log N1 + log N2 + ...log Nn) = O(log N1*N2*...Nn) Delete: O(log N1 + log N2 + ...log Nn) = O(log N1*N2*...Nn) */ package skiplist import ( "github.com/Workiva/go-datastructures/common" "github.com/Workiva/go-datastructures/rangetree" "github.com/Workiva/go-datastructures/slice/skip" ) // keyed is required as in the rangetree code we often want to compare // two different types of bundles and this allows us to do so without // checking for each one. type keyed interface { key() uint64 } type skipEntry uint64 // Compare is required by the Comparator interface. func (se skipEntry) Compare(other common.Comparator) int { otherSe := other.(skipEntry) if se == otherSe { return 0 } if se > otherSe { return 1 } return -1 } func (se skipEntry) key() uint64 { return uint64(se) } // isLastDimension simply returns dimension == lastDimension-1. // This panics if dimension >= lastDimension. func isLastDimension(dimension, lastDimension uint64) bool { if dimension >= lastDimension { // useful in testing and denotes a serious problem panic(`Dimension is greater than possible dimensions.`) } return dimension == lastDimension-1 } // needsDeletion returns a bool indicating if the provided value // needs to be deleted based on the provided index and number. func needsDeletion(value, index, number int64) bool { if number > 0 { return false } number = -number // get the magnitude offset := value - index return offset >= 0 && offset < number } // dimensionalBundle is an intermediate holder up to the last // dimension and represents a wrapper around a skiplist. type dimensionalBundle struct { id uint64 sl *skip.SkipList } // Compare returns a value indicating the relative relationship and the // provided bundle. func (db *dimensionalBundle) Compare(e common.Comparator) int { keyed := e.(keyed) if db.id == keyed.key() { return 0 } if db.id > keyed.key() { return 1 } return -1 } // key returns the key for this bundle. func (db *dimensionalBundle) key() uint64 { return db.id } // lastBundle represents a bundle living at the last dimension // of the tree. type lastBundle struct { id uint64 entry rangetree.Entry } // Compare returns a value indicating the relative relationship and the // provided bundle. func (lb *lastBundle) Compare(e common.Comparator) int { keyed := e.(keyed) if lb.id == keyed.key() { return 0 } if lb.id > keyed.key() { return 1 } return -1 } // Key returns the key for this bundle. func (lb *lastBundle) key() uint64 { return lb.id } type skipListRT struct { top *skip.SkipList dimensions, number uint64 } func (rt *skipListRT) init(dimensions uint64) { rt.dimensions = dimensions rt.top = skip.New(uint64(0)) } func (rt *skipListRT) add(entry rangetree.Entry) rangetree.Entry { var ( value int64 e common.Comparator sl = rt.top db *dimensionalBundle lb *lastBundle ) for i := uint64(0); i < rt.dimensions; i++ { value = entry.ValueAtDimension(i) e = sl.Get(skipEntry(value))[0] if isLastDimension(i, rt.dimensions) { if e != nil { // this is an overwrite lb = e.(*lastBundle) oldEntry := lb.entry lb.entry = entry return oldEntry } // need to add new sl entry lb = &lastBundle{id: uint64(value), entry: entry} rt.number++ sl.Insert(lb) return nil } if e == nil { // we need the intermediate dimension db = &dimensionalBundle{id: uint64(value), sl: skip.New(uint64(0))} sl.Insert(db) } else { db = e.(*dimensionalBundle) } sl = db.sl } panic(`Ran out of dimensions before for loop completed.`) } // Add will add the provided entries to the tree. This method // returns a list of entries that were overwritten in the order // in which entries were received. If an entry doesn't overwrite // anything, a nil will be returned for that entry in the returned // slice. func (rt *skipListRT) Add(entries ...rangetree.Entry) rangetree.Entries { overwritten := make(rangetree.Entries, len(entries)) for i, e := range entries { overwritten[i] = rt.add(e) } return overwritten } func (rt *skipListRT) get(entry rangetree.Entry) rangetree.Entry { var ( sl = rt.top e common.Comparator value uint64 ) for i := uint64(0); i < rt.dimensions; i++ { value = uint64(entry.ValueAtDimension(i)) e = sl.Get(skipEntry(value))[0] if e == nil { return nil } if isLastDimension(i, rt.dimensions) { return e.(*lastBundle).entry } sl = e.(*dimensionalBundle).sl } panic(`Reached past for loop without finding last dimension.`) } // Get will return any rangetree.Entries matching the provided entries. // Similar in functionality to a key lookup, this returns nil for any // entry that could not be found. func (rt *skipListRT) Get(entries ...rangetree.Entry) rangetree.Entries { results := make(rangetree.Entries, 0, len(entries)) for _, e := range entries { results = append(results, rt.get(e)) } return results } // Len returns the number of entries in the tree. func (rt *skipListRT) Len() uint64 { return rt.number } // deleteRecursive is used by the delete logic. The recursion depth // only goes as far as the number of dimensions, so this shouldn't be an // issue. func (rt *skipListRT) deleteRecursive(sl *skip.SkipList, dimension uint64, entry rangetree.Entry) rangetree.Entry { value := entry.ValueAtDimension(dimension) if isLastDimension(dimension, rt.dimensions) { entries := sl.Delete(skipEntry(value)) if entries[0] == nil { return nil } rt.number-- return entries[0].(*lastBundle).entry } db, ok := sl.Get(skipEntry(value))[0].(*dimensionalBundle) if !ok { // value was not found return nil } result := rt.deleteRecursive(db.sl, dimension+1, entry) if result == nil { return nil } if db.sl.Len() == 0 { sl.Delete(db) } return result } func (rt *skipListRT) delete(entry rangetree.Entry) rangetree.Entry { return rt.deleteRecursive(rt.top, 0, entry) } // Delete will remove the provided entries from the tree. // Any entries that were deleted will be returned in the order in // which they were deleted. If an entry does not exist to be deleted, // a nil is returned for that entry's index in the provided cells. func (rt *skipListRT) Delete(entries ...rangetree.Entry) rangetree.Entries { deletedEntries := make(rangetree.Entries, len(entries)) for i, e := range entries { deletedEntries[i] = rt.delete(e) } return deletedEntries } func (rt *skipListRT) apply(sl *skip.SkipList, dimension uint64, interval rangetree.Interval, fn func(rangetree.Entry) bool) bool { lowValue, highValue := interval.LowAtDimension(dimension), interval.HighAtDimension(dimension) var e common.Comparator for iter := sl.Iter(skipEntry(lowValue)); iter.Next(); { e = iter.Value() if int64(e.(keyed).key()) >= highValue { break } if isLastDimension(dimension, rt.dimensions) { if !fn(e.(*lastBundle).entry) { return false } } else { if !rt.apply(e.(*dimensionalBundle).sl, dimension+1, interval, fn) { return false } } } return true } // Apply will call the provided function with each entry that exists // within the provided range, in order. Return false at any time to // cancel iteration. Altering the entry in such a way that its location // changes will result in undefined behavior. func (rt *skipListRT) Apply(interval rangetree.Interval, fn func(rangetree.Entry) bool) { rt.apply(rt.top, 0, interval, fn) } // Query will return a list of entries that fall within // the provided interval. func (rt *skipListRT) Query(interval rangetree.Interval) rangetree.Entries { entries := make(rangetree.Entries, 0, 100) rt.apply(rt.top, 0, interval, func(e rangetree.Entry) bool { entries = append(entries, e) return true }) return entries } func (rt *skipListRT) flatten(sl *skip.SkipList, dimension uint64, entries *rangetree.Entries) { lastDimension := isLastDimension(dimension, rt.dimensions) for iter := sl.Iter(skipEntry(0)); iter.Next(); { if lastDimension { *entries = append(*entries, iter.Value().(*lastBundle).entry) } else { rt.flatten(iter.Value().(*dimensionalBundle).sl, dimension+1, entries) } } } func (rt *skipListRT) insert(sl *skip.SkipList, dimension, insertDimension uint64, index, number int64, deleted, affected *rangetree.Entries) { var e common.Comparator lastDimension := isLastDimension(dimension, rt.dimensions) affectedDimension := dimension == insertDimension var iter skip.Iterator if dimension == insertDimension { iter = sl.Iter(skipEntry(index)) } else { iter = sl.Iter(skipEntry(0)) } var toDelete common.Comparators if number < 0 { toDelete = make(common.Comparators, 0, 100) } for iter.Next() { e = iter.Value() if !affectedDimension { rt.insert(e.(*dimensionalBundle).sl, dimension+1, insertDimension, index, number, deleted, affected, ) continue } if needsDeletion(int64(e.(keyed).key()), index, number) { toDelete = append(toDelete, e) continue } if lastDimension { e.(*lastBundle).id += uint64(number) *affected = append(*affected, e.(*lastBundle).entry) } else { e.(*dimensionalBundle).id += uint64(number) rt.flatten(e.(*dimensionalBundle).sl, dimension+1, affected) } } if len(toDelete) > 0 { for _, e := range toDelete { if lastDimension { *deleted = append(*deleted, e.(*lastBundle).entry) } else { rt.flatten(e.(*dimensionalBundle).sl, dimension+1, deleted) } } sl.Delete(toDelete...) } } // InsertAtDimension will increment items at and above the given index // by the number provided. Provide a negative number to to decrement. // Returned are two lists. The first list is a list of entries that // were moved. The second is a list entries that were deleted. These // lists are exclusive. func (rt *skipListRT) InsertAtDimension(dimension uint64, index, number int64) (rangetree.Entries, rangetree.Entries) { if dimension >= rt.dimensions || number == 0 { return rangetree.Entries{}, rangetree.Entries{} } affected := make(rangetree.Entries, 0, 100) var deleted rangetree.Entries if number < 0 { deleted = make(rangetree.Entries, 0, 100) } rt.insert(rt.top, 0, dimension, index, number, &deleted, &affected) rt.number -= uint64(len(deleted)) return affected, deleted } func new(dimensions uint64) *skipListRT { sl := &skipListRT{} sl.init(dimensions) return sl } // New will allocate, initialize, and return a new rangetree.RangeTree // with the provided number of dimensions. func New(dimensions uint64) rangetree.RangeTree { return new(dimensions) } ================================================ FILE: rangetree/skiplist/skiplist_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package skiplist import ( "math" "math/rand" "testing" "github.com/stretchr/testify/assert" "github.com/Workiva/go-datastructures/rangetree" ) func generateMultiDimensionalEntries(num int) rangetree.Entries { entries := make(rangetree.Entries, 0, num) for i := 0; i < num; i++ { entries = append(entries, newMockEntry(int64(i), int64(i))) } return entries } func generateRandomMultiDimensionalEntries(num int) rangetree.Entries { entries := make(rangetree.Entries, 0, num) for i := 0; i < num; i++ { value := rand.Int63() entries = append(entries, newMockEntry(value, value)) } return entries } func TestRTSingleDimensionAdd(t *testing.T) { rt := new(1) m1 := newMockEntry(3) m2 := newMockEntry(5) overwritten := rt.Add(m1, m2) assert.Equal(t, rangetree.Entries{nil, nil}, overwritten) assert.Equal(t, uint64(2), rt.Len()) assert.Equal(t, rangetree.Entries{m1, m2}, rt.Get(m1, m2)) } func TestRTMultiDimensionAdd(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 5) m2 := newMockEntry(4, 6) overwritten := rt.Add(m1, m2) assert.Equal(t, rangetree.Entries{nil, nil}, overwritten) assert.Equal(t, uint64(2), rt.Len()) assert.Equal(t, rangetree.Entries{m1, m2}, rt.Get(m1, m2)) } func TestRTSingleDimensionOverwrite(t *testing.T) { rt := new(1) m1 := newMockEntry(5) m2 := newMockEntry(5) overwritten := rt.Add(m1) assert.Equal(t, rangetree.Entries{nil}, overwritten) assert.Equal(t, uint64(1), rt.Len()) overwritten = rt.Add(m2) assert.Equal(t, rangetree.Entries{m1}, overwritten) assert.Equal(t, uint64(1), rt.Len()) assert.Equal(t, rangetree.Entries{m2}, rt.Get(m2)) } func TestRTMultiDimensionOverwrite(t *testing.T) { rt := new(2) m1 := newMockEntry(5, 6) m2 := newMockEntry(5, 6) overwritten := rt.Add(m1) assert.Equal(t, rangetree.Entries{nil}, overwritten) assert.Equal(t, uint64(1), rt.Len()) overwritten = rt.Add(m2) assert.Equal(t, rangetree.Entries{m1}, overwritten) assert.Equal(t, uint64(1), rt.Len()) assert.Equal(t, rangetree.Entries{m2}, rt.Get(m2)) } func TestRTSingleDimensionDelete(t *testing.T) { rt := new(1) m1 := newMockEntry(5) m2 := newMockEntry(2) rt.Add(m1, m2) deleted := rt.Delete(m1, newMockEntry(10), m2) assert.Equal(t, rangetree.Entries{m1, nil, m2}, deleted) assert.Equal(t, uint64(0), rt.Len()) assert.Equal(t, rangetree.Entries{nil, nil}, rt.Get(m1, m2)) } func TestRTMultiDimensionDelete(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 5) m2 := newMockEntry(4, 6) rt.Add(m1, m2) deleted := rt.Delete(m1, newMockEntry(10, 10), m2) assert.Equal(t, rangetree.Entries{m1, nil, m2}, deleted) assert.Equal(t, uint64(0), rt.Len()) assert.Equal(t, rangetree.Entries{nil, nil}, rt.Get(m1, m2)) } func TestRTSingleDimensionQuery(t *testing.T) { rt := new(1) m1 := newMockEntry(3) m2 := newMockEntry(6) m3 := newMockEntry(9) rt.Add(m1, m2, m3) result := rt.Query(newMockInterval([]int64{1}, []int64{7})) assert.Equal(t, rangetree.Entries{m1, m2}, result) result = rt.Query(newMockInterval([]int64{6}, []int64{10})) assert.Equal(t, rangetree.Entries{m2, m3}, result) result = rt.Query(newMockInterval([]int64{9}, []int64{11})) assert.Equal(t, rangetree.Entries{m3}, result) result = rt.Query(newMockInterval([]int64{0}, []int64{3})) assert.Len(t, result, 0) result = rt.Query(newMockInterval([]int64{10}, []int64{13})) assert.Len(t, result, 0) } func TestRTMultiDimensionQuery(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 3) m2 := newMockEntry(6, 6) m3 := newMockEntry(9, 9) rt.Add(m1, m2, m3) result := rt.Query(newMockInterval([]int64{1, 1}, []int64{7, 7})) assert.Equal(t, rangetree.Entries{m1, m2}, result) result = rt.Query(newMockInterval([]int64{6, 6}, []int64{10, 10})) assert.Equal(t, rangetree.Entries{m2, m3}, result) result = rt.Query(newMockInterval([]int64{9, 9}, []int64{11, 11})) assert.Equal(t, rangetree.Entries{m3}, result) result = rt.Query(newMockInterval([]int64{0, 0}, []int64{3, 3})) assert.Len(t, result, 0) result = rt.Query(newMockInterval([]int64{10, 10}, []int64{13, 13})) assert.Len(t, result, 0) result = rt.Query(newMockInterval([]int64{0, 0}, []int64{3, 3})) assert.Len(t, result, 0) result = rt.Query(newMockInterval([]int64{6, 1}, []int64{7, 6})) assert.Len(t, result, 0) result = rt.Query(newMockInterval([]int64{0, 0}, []int64{7, 4})) assert.Equal(t, rangetree.Entries{m1}, result) } func TestRTSingleDimensionInsert(t *testing.T) { rt := new(1) m1 := newMockEntry(3) m2 := newMockEntry(6) m3 := newMockEntry(9) rt.Add(m1, m2, m3) affected, deleted := rt.InsertAtDimension(0, 0, 1) assert.Equal(t, rangetree.Entries{m1, m2, m3}, affected) assert.Len(t, deleted, 0) assert.Equal(t, uint64(3), rt.Len()) assert.Equal(t, rangetree.Entries{nil, nil, nil}, rt.Get(m1, m2, m3)) e1 := newMockEntry(4) e2 := newMockEntry(7) e3 := newMockEntry(10) assert.Equal(t, rangetree.Entries{m1, m2, m3}, rt.Get(e1, e2, e3)) } func TestRTSingleDimensionInsertNegative(t *testing.T) { rt := new(1) m1 := newMockEntry(3) m2 := newMockEntry(6) m3 := newMockEntry(9) rt.Add(m1, m2, m3) affected, deleted := rt.InsertAtDimension(0, 6, -2) assert.Equal(t, rangetree.Entries{m3}, affected) assert.Equal(t, rangetree.Entries{m2}, deleted) assert.Equal(t, uint64(2), rt.Len()) assert.Equal(t, rangetree.Entries{m1, nil}, rt.Get(m1, m2)) e2 := newMockEntry(4) e3 := newMockEntry(7) assert.Equal(t, rangetree.Entries{nil, m3}, rt.Get(e2, e3)) } func TestRTMultiDimensionInsert(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 3) m2 := newMockEntry(6, 6) m3 := newMockEntry(9, 9) rt.Add(m1, m2, m3) affected, deleted := rt.InsertAtDimension(1, 4, 2) assert.Equal(t, rangetree.Entries{m2, m3}, affected) assert.Len(t, deleted, 0) assert.Equal(t, uint64(3), rt.Len()) e2 := newMockEntry(6, 8) e3 := newMockEntry(9, 11) assert.Equal(t, rangetree.Entries{m1, nil, nil}, rt.Get(m1, m2, m3)) assert.Equal(t, rangetree.Entries{m2, m3}, rt.Get(e2, e3)) } func TestRTMultiDimensionInsertNegative(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 3) m2 := newMockEntry(6, 6) m3 := newMockEntry(9, 9) rt.Add(m1, m2, m3) affected, deleted := rt.InsertAtDimension(1, 6, -2) assert.Equal(t, rangetree.Entries{m3}, affected) assert.Equal(t, rangetree.Entries{m2}, deleted) assert.Equal(t, uint64(2), rt.Len()) assert.Equal(t, rangetree.Entries{m1, nil, nil}, rt.Get(m1, m2, m3)) e2 := newMockEntry(6, 4) e3 := newMockEntry(9, 7) assert.Equal(t, rangetree.Entries{nil, m3}, rt.Get(e2, e3)) } func TestRTInsertInZeroDimensionMultiDimensionList(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 3) m2 := newMockEntry(6, 6) m3 := newMockEntry(9, 9) rt.Add(m1, m2, m3) affected, deleted := rt.InsertAtDimension(0, 4, 2) assert.Equal(t, rangetree.Entries{m2, m3}, affected) assert.Len(t, deleted, 0) assert.Equal(t, uint64(3), rt.Len()) assert.Equal(t, rangetree.Entries{m1, nil, nil}, rt.Get(m1, m2, m3)) e2 := newMockEntry(8, 6) e3 := newMockEntry(11, 9) assert.Equal(t, rangetree.Entries{m2, m3}, rt.Get(e2, e3)) } func TestRTInsertNegativeInZeroDimensionMultiDimensionList(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 3) m2 := newMockEntry(6, 6) m3 := newMockEntry(9, 9) rt.Add(m1, m2, m3) affected, deleted := rt.InsertAtDimension(0, 6, -2) assert.Equal(t, rangetree.Entries{m3}, affected) assert.Equal(t, rangetree.Entries{m2}, deleted) assert.Equal(t, uint64(2), rt.Len()) assert.Equal(t, rangetree.Entries{m1, nil, nil}, rt.Get(m1, m2, m3)) e2 := newMockEntry(4, 6) e3 := newMockEntry(7, 9) assert.Equal(t, rangetree.Entries{nil, m3}, rt.Get(e2, e3)) } func TestRTInsertBeyondDimension(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 3) rt.Add(m1) affected, deleted := rt.InsertAtDimension(4, 0, 1) assert.Len(t, affected, 0) assert.Len(t, deleted, 0) assert.Equal(t, rangetree.Entries{m1}, rt.Get(m1)) } func TestRTInsertZero(t *testing.T) { rt := new(2) m1 := newMockEntry(3, 3) rt.Add(m1) affected, deleted := rt.InsertAtDimension(1, 0, 0) assert.Len(t, affected, 0) assert.Len(t, deleted, 0) assert.Equal(t, rangetree.Entries{m1}, rt.Get(m1)) } func BenchmarkMultiDimensionInsert(b *testing.B) { numItems := b.N rt := new(2) entries := generateMultiDimensionalEntries(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { rt.Add(entries[i%numItems]) } } func BenchmarkMultiDimensionInsertReverse(b *testing.B) { numItems := b.N rt := new(2) entries := generateMultiDimensionalEntries(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { index := numItems - (i % numItems) - 1 rt.Add(entries[index]) } } func BenchmarkMultiDimensionRandomInsert(b *testing.B) { numItems := b.N rt := new(2) entries := generateRandomMultiDimensionalEntries(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { rt.Add(entries[i%numItems]) } } func BenchmarkMultiDimensionalGet(b *testing.B) { numItems := b.N rt := new(2) entries := generateRandomMultiDimensionalEntries(numItems) rt.Add(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { rt.Get(entries[i%numItems]) } } func BenchmarkMultiDimensionDelete(b *testing.B) { numItems := b.N rt := new(2) entries := generateRandomMultiDimensionalEntries(numItems) rt.Add(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { rt.Delete(entries[i%numItems]) } } func BenchmarkMultiDimensionQuery(b *testing.B) { numItems := b.N rt := new(2) entries := generateRandomMultiDimensionalEntries(numItems) rt.Add(entries...) iv := newMockInterval([]int64{0, 0}, []int64{math.MaxInt64, math.MaxInt64}) var result rangetree.Entries b.ResetTimer() for i := 0; i < b.N; i++ { result = rt.Query(iv) } assert.Len(b, result, numItems) } func BenchmarkMultiDimensionInsertAtZeroDimension(b *testing.B) { numItems := b.N rt := new(2) entries := generateRandomMultiDimensionalEntries(numItems) rt.Add(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { rt.InsertAtDimension(0, 0, 1) } } func BenchmarkMultiDimensionInsertNegativeAtZeroDimension(b *testing.B) { numItems := b.N rt := new(2) entries := generateRandomMultiDimensionalEntries(numItems) rt.Add(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { rt.InsertAtDimension(0, 0, -1) } } ================================================ FILE: rtree/hilbert/action.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package hilbert import ( "runtime" "sync" "sync/atomic" "github.com/Workiva/go-datastructures/rtree" ) type actions []action type action interface { operation() operation keys() hilberts rects() []*hilbertBundle complete() addNode(int64, *node) nodes() []*node } type getAction struct { result rtree.Rectangles completer *sync.WaitGroup lookup *rectangle } func (ga *getAction) complete() { ga.completer.Done() } func (ga *getAction) operation() operation { return get } func (ga *getAction) keys() hilberts { return nil } func (ga *getAction) addNode(i int64, n *node) { return // not necessary for gets } func (ga *getAction) nodes() []*node { return nil } func (ga *getAction) rects() []*hilbertBundle { return []*hilbertBundle{&hilbertBundle{}} } func newGetAction(rect rtree.Rectangle) *getAction { r := newRectangeFromRect(rect) ga := &getAction{ completer: new(sync.WaitGroup), lookup: r, } ga.completer.Add(1) return ga } type insertAction struct { rs []*hilbertBundle completer *sync.WaitGroup ns []*node } func (ia *insertAction) complete() { ia.completer.Done() } func (ia *insertAction) operation() operation { return add } func (ia *insertAction) keys() hilberts { return nil } func (ia *insertAction) addNode(i int64, n *node) { ia.ns[i] = n } func (ia *insertAction) nodes() []*node { return ia.ns } func (ia *insertAction) rects() []*hilbertBundle { return ia.rs } func newInsertAction(rects rtree.Rectangles) *insertAction { ia := &insertAction{ rs: bundlesFromRects(rects...), completer: new(sync.WaitGroup), ns: make([]*node, len(rects)), } ia.completer.Add(1) return ia } type removeAction struct { *insertAction } func (ra *removeAction) operation() operation { return remove } func newRemoveAction(rects rtree.Rectangles) *removeAction { return &removeAction{ newInsertAction(rects), } } func minUint64(choices ...uint64) uint64 { min := choices[0] for i := 1; i < len(choices); i++ { if choices[i] < min { min = choices[i] } } return min } type interfaces []interface{} func executeInterfacesInParallel(ifs interfaces, fn func(interface{})) { if len(ifs) == 0 { return } done := int64(-1) numCPU := uint64(runtime.NumCPU()) if numCPU > 1 { numCPU-- } numCPU = minUint64(numCPU, uint64(len(ifs))) var wg sync.WaitGroup wg.Add(int(numCPU)) for i := uint64(0); i < numCPU; i++ { go func() { defer wg.Done() for { i := atomic.AddInt64(&done, 1) if i >= int64(len(ifs)) { return } fn(ifs[i]) } }() } wg.Wait() } func executeInterfacesInSerial(ifs interfaces, fn func(interface{})) { if len(ifs) == 0 { return } for _, ifc := range ifs { fn(ifc) } } ================================================ FILE: rtree/hilbert/hilbert.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package hilbert import ( "runtime" "sync" h "github.com/Workiva/go-datastructures/numerics/hilbert" "github.com/Workiva/go-datastructures/rtree" ) func getCenter(rect rtree.Rectangle) (int32, int32) { xlow, ylow := rect.LowerLeft() xhigh, yhigh := rect.UpperRight() return (xhigh + xlow) / 2, (yhigh + ylow) / 2 } type hilbertBundle struct { hilbert hilbert rect rtree.Rectangle } func bundlesFromRects(rects ...rtree.Rectangle) []*hilbertBundle { chunks := chunkRectangles(rects, int64(runtime.NumCPU())) bundleChunks := make([][]*hilbertBundle, len(chunks)) var wg sync.WaitGroup wg.Add(len(chunks)) for i := 0; i < runtime.NumCPU(); i++ { if len(chunks[i]) == 0 { bundleChunks[i] = []*hilbertBundle{} wg.Done() continue } go func(i int) { bundles := make([]*hilbertBundle, 0, len(chunks[i])) for _, r := range chunks[i] { h := h.Encode(getCenter(r)) bundles = append(bundles, &hilbertBundle{hilbert(h), r}) } bundleChunks[i] = bundles wg.Done() }(i) } wg.Wait() bundles := make([]*hilbertBundle, 0, len(rects)) for _, bc := range bundleChunks { bundles = append(bundles, bc...) } return bundles } // chunkRectangles takes a slice of rtree.Rectangle values and chunks it into `numParts` subslices. func chunkRectangles(slice rtree.Rectangles, numParts int64) []rtree.Rectangles { parts := make([]rtree.Rectangles, numParts) for i := int64(0); i < numParts; i++ { parts[i] = slice[i*int64(len(slice))/numParts : (i+1)*int64(len(slice))/numParts] } return parts } ================================================ FILE: rtree/hilbert/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package hilbert type mockRectangle struct { xlow, ylow, xhigh, yhigh int32 } func (mr *mockRectangle) LowerLeft() (int32, int32) { return mr.xlow, mr.ylow } func (mr *mockRectangle) UpperRight() (int32, int32) { return mr.xhigh, mr.yhigh } func newMockRectangle(xlow, ylow, xhigh, yhigh int32) *mockRectangle { return &mockRectangle{ xlow: xlow, ylow: ylow, xhigh: xhigh, yhigh: yhigh, } } ================================================ FILE: rtree/hilbert/node.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package hilbert import ( "sort" "github.com/Workiva/go-datastructures/rtree" ) type hilbert int64 type hilberts []hilbert func getParent(parent *node, key hilbert, r1 rtree.Rectangle) *node { var n *node for parent != nil && !parent.isLeaf { n = parent.searchNode(key) parent = n } if parent != nil && r1 != nil { // must be leaf and we need exact match // we are safe to travel to the right i := parent.search(key) for parent.keys.byPosition(i) == key { if equal(parent.nodes.list[i], r1) { break } i++ if i == parent.keys.len() { if parent.right == nil { // we are far to the right break } if parent.right.keys.byPosition(0) != key { break } parent = parent.right i = 0 } } } return parent } type nodes struct { list rtree.Rectangles } func (ns *nodes) push(n rtree.Rectangle) { ns.list = append(ns.list, n) } func (ns *nodes) splitAt(i, capacity uint64) (*nodes, *nodes) { i++ right := make(rtree.Rectangles, uint64(len(ns.list))-i, capacity) copy(right, ns.list[i:]) for j := i; j < uint64(len(ns.list)); j++ { ns.list[j] = nil } ns.list = ns.list[:i] return ns, &nodes{list: right} } func (ns *nodes) byPosition(pos uint64) *node { if pos >= uint64(len(ns.list)) { return nil } return ns.list[pos].(*node) } func (ns *nodes) insertAt(i uint64, n rtree.Rectangle) { ns.list = append(ns.list, nil) copy(ns.list[i+1:], ns.list[i:]) ns.list[i] = n } func (ns *nodes) replaceAt(i uint64, n rtree.Rectangle) { ns.list[i] = n } func (ns *nodes) len() uint64 { return uint64(len(ns.list)) } func (ns *nodes) deleteAt(i uint64) { copy(ns.list[i:], ns.list[i+1:]) ns.list = ns.list[:len(ns.list)-1] } func newNodes(size uint64) *nodes { return &nodes{ list: make(rtree.Rectangles, 0, size), } } type keys struct { list hilberts } func (ks *keys) splitAt(i, capacity uint64) (*keys, *keys) { i++ right := make(hilberts, uint64(len(ks.list))-i, capacity) copy(right, ks.list[i:]) ks.list = ks.list[:i] return ks, &keys{list: right} } func (ks *keys) len() uint64 { return uint64(len(ks.list)) } func (ks *keys) byPosition(i uint64) hilbert { if i >= uint64(len(ks.list)) { return -1 } return ks.list[i] } func (ks *keys) deleteAt(i uint64) { copy(ks.list[i:], ks.list[i+1:]) ks.list = ks.list[:len(ks.list)-1] } func (ks *keys) delete(k hilbert) hilbert { i := ks.search(k) if i >= uint64(len(ks.list)) { return -1 } if ks.list[i] != k { return -1 } old := ks.list[i] ks.deleteAt(i) return old } func (ks *keys) search(key hilbert) uint64 { i := sort.Search(len(ks.list), func(i int) bool { return ks.list[i] >= key }) return uint64(i) } func (ks *keys) insert(key hilbert) (hilbert, uint64) { i := ks.search(key) if i == uint64(len(ks.list)) { ks.list = append(ks.list, key) return -1, i } var old hilbert if ks.list[i] == key { old = ks.list[i] ks.list[i] = key } else { ks.insertAt(i, key) } return old, i } func (ks *keys) last() hilbert { return ks.list[len(ks.list)-1] } func (ks *keys) insertAt(i uint64, k hilbert) { ks.list = append(ks.list, -1) copy(ks.list[i+1:], ks.list[i:]) ks.list[i] = k } func (ks *keys) withPosition(k hilbert) (hilbert, uint64) { i := ks.search(k) if i == uint64(len(ks.list)) { return -1, i } if ks.list[i] == k { return ks.list[i], i } return -1, i } func newKeys(size uint64) *keys { return &keys{ list: make(hilberts, 0, size), } } type node struct { keys *keys nodes *nodes isLeaf bool parent, right *node mbr *rectangle maxHilbert hilbert } func (n *node) insert(kb *keyBundle) rtree.Rectangle { i := n.keys.search(kb.key) if n.isLeaf { // we can have multiple keys with the same hilbert number for i < n.keys.len() && n.keys.list[i] == kb.key { if equal(n.nodes.list[i], kb.left) { old := n.nodes.list[i] n.nodes.list[i] = kb.left return old } i++ } } if i == n.keys.len() { n.maxHilbert = kb.key } n.keys.insertAt(i, kb.key) if n.isLeaf { n.nodes.insertAt(i, kb.left) } else { if n.nodes.len() == 0 { n.nodes.push(kb.left) n.nodes.push(kb.right) } else { n.nodes.replaceAt(i, kb.left) n.nodes.insertAt(i+1, kb.right) } n.mbr.adjust(kb.left) n.mbr.adjust(kb.right) if kb.right.(*node).maxHilbert > n.maxHilbert { n.maxHilbert = kb.right.(*node).maxHilbert } } return nil } func (n *node) delete(kb *keyBundle) rtree.Rectangle { i := n.keys.search(kb.key) if n.keys.byPosition(i) != kb.key { // hilbert value not found return nil } if !equal(n.nodes.list[i], kb.left) { return nil } old := n.nodes.list[i] n.keys.deleteAt(i) n.nodes.deleteAt(i) return old } func (n *node) LowerLeft() (int32, int32) { return n.mbr.xlow, n.mbr.ylow } func (n *node) UpperRight() (int32, int32) { return n.mbr.xhigh, n.mbr.yhigh } func (n *node) needsSplit(ary uint64) bool { return n.keys.len() >= ary } func (n *node) splitLeaf(i, capacity uint64) (hilbert, *node, *node) { key := n.keys.byPosition(i) _, rightKeys := n.keys.splitAt(i, capacity) _, rightNodes := n.nodes.splitAt(i, capacity) nn := &node{ keys: rightKeys, nodes: rightNodes, isLeaf: true, right: n.right, parent: n.parent, } n.right = nn nn.mbr = newRectangleFromRects(rightNodes.list) n.mbr = newRectangleFromRects(n.nodes.list) nn.maxHilbert = rightKeys.last() n.maxHilbert = n.keys.last() return key, n, nn } func (n *node) splitInternal(i, capacity uint64) (hilbert, *node, *node) { key := n.keys.byPosition(i) n.keys.delete(key) _, rightKeys := n.keys.splitAt(i-1, capacity) _, rightNodes := n.nodes.splitAt(i, capacity) nn := newNode(false, rightKeys, rightNodes) for _, n := range rightNodes.list { n.(*node).parent = nn } nn.mbr = newRectangleFromRects(rightNodes.list) n.mbr = newRectangleFromRects(n.nodes.list) nn.maxHilbert = nn.keys.last() n.maxHilbert = n.keys.last() return key, n, nn } func (n *node) split(i, capacity uint64) (hilbert, *node, *node) { if n.isLeaf { return n.splitLeaf(i, capacity) } return n.splitInternal(i, capacity) } func (n *node) search(key hilbert) uint64 { return n.keys.search(key) } func (n *node) searchNode(key hilbert) *node { i := n.search(key) return n.nodes.byPosition(uint64(i)) } func (n *node) searchRects(r *rectangle) rtree.Rectangles { rects := make(rtree.Rectangles, 0, n.nodes.len()) for _, child := range n.nodes.list { if intersect(r, child) { rects = append(rects, child) } } return rects } func (n *node) key() hilbert { return n.keys.last() } func newNode(isLeaf bool, keys *keys, ns *nodes) *node { return &node{ isLeaf: isLeaf, keys: keys, nodes: ns, } } ================================================ FILE: rtree/hilbert/rectangle.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package hilbert import "github.com/Workiva/go-datastructures/rtree" type rectangle struct { xlow, xhigh, ylow, yhigh int32 } func (r *rectangle) adjust(rect rtree.Rectangle) { x, y := rect.LowerLeft() if x < r.xlow { r.xlow = x } if y < r.ylow { r.ylow = y } x, y = rect.UpperRight() if x > r.xhigh { r.xhigh = x } if y > r.yhigh { r.yhigh = y } } func equal(r1, r2 rtree.Rectangle) bool { xlow1, ylow1 := r1.LowerLeft() xhigh2, yhigh2 := r2.UpperRight() xhigh1, yhigh1 := r1.UpperRight() xlow2, ylow2 := r2.LowerLeft() return xlow1 == xlow2 && xhigh1 == xhigh2 && ylow1 == ylow2 && yhigh1 == yhigh2 } func intersect(rect1 *rectangle, rect2 rtree.Rectangle) bool { xhigh2, yhigh2 := rect2.UpperRight() xlow2, ylow2 := rect2.LowerLeft() return xhigh2 >= rect1.xlow && xlow2 <= rect1.xhigh && yhigh2 >= rect1.ylow && ylow2 <= rect1.yhigh } func newRectangeFromRect(rect rtree.Rectangle) *rectangle { r := &rectangle{} x, y := rect.LowerLeft() r.xlow = x r.ylow = y x, y = rect.UpperRight() r.xhigh = x r.yhigh = y return r } func newRectangleFromRects(rects rtree.Rectangles) *rectangle { if len(rects) == 0 { panic(`Cannot construct rectangle with no dimensions.`) } xlow, ylow := rects[0].LowerLeft() xhigh, yhigh := rects[0].UpperRight() r := &rectangle{ xlow: xlow, xhigh: xhigh, ylow: ylow, yhigh: yhigh, } for i := 1; i < len(rects); i++ { r.adjust(rects[i]) } return r } ================================================ FILE: rtree/hilbert/tree.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package hilbert implements a Hilbert R-tree based on PALM principles to improve multithreaded performance. This package is not quite complete and some optimization and delete codes remain to be completed. This serves as a potential replacement for the interval tree and rangetree. Benchmarks: BenchmarkBulkAddPoints-8 500 2589270 ns/op BenchmarkBulkUpdatePoints-8 2000 1212641 ns/op BenchmarkPointInsertion-8 200000 9135 ns/op BenchmarkQueryPoints-8 500000 3122 ns/op */ package hilbert import ( "runtime" "sync" "sync/atomic" "github.com/Workiva/go-datastructures/queue" "github.com/Workiva/go-datastructures/rtree" ) type operation int const ( get operation = iota add remove ) const multiThreadAt = 1000 // number of keys before we multithread lookups type keyBundle struct { key hilbert left, right rtree.Rectangle } type tree struct { root *node _ [8]uint64 number uint64 _ [8]uint64 ary, bufferSize uint64 actions *queue.RingBuffer cache []interface{} _ [8]uint64 disposed uint64 _ [8]uint64 running uint64 } func (tree *tree) checkAndRun(action action) { if tree.actions.Len() > 0 { if action != nil { tree.actions.Put(action) } if atomic.CompareAndSwapUint64(&tree.running, 0, 1) { var a interface{} var err error for tree.actions.Len() > 0 { a, err = tree.actions.Get() if err != nil { return } tree.cache = append(tree.cache, a) if uint64(len(tree.cache)) >= tree.bufferSize { break } } go tree.operationRunner(tree.cache, true) } } else if action != nil { if atomic.CompareAndSwapUint64(&tree.running, 0, 1) { switch action.operation() { case get: ga := action.(*getAction) result := tree.search(ga.lookup) ga.result = result action.complete() tree.reset() case add, remove: if len(action.keys()) > multiThreadAt { tree.operationRunner(interfaces{action}, true) } else { tree.operationRunner(interfaces{action}, false) } } } else { tree.actions.Put(action) tree.checkAndRun(nil) } } } func (tree *tree) init(bufferSize, ary uint64) { tree.bufferSize = bufferSize tree.ary = ary tree.cache = make([]interface{}, 0, bufferSize) tree.root = newNode(true, newKeys(ary), newNodes(ary)) tree.root.mbr = &rectangle{} tree.actions = queue.NewRingBuffer(tree.bufferSize) } func (tree *tree) operationRunner(xns interfaces, threaded bool) { writeOperations, deleteOperations, toComplete := tree.fetchKeys(xns, threaded) tree.recursiveMutate(writeOperations, deleteOperations, false, threaded) for _, a := range toComplete { a.complete() } tree.reset() } func (tree *tree) fetchKeys(xns interfaces, inParallel bool) (map[*node][]*keyBundle, map[*node][]*keyBundle, actions) { if inParallel { tree.fetchKeysInParallel(xns) } else { tree.fetchKeysInSerial(xns) } writeOperations := make(map[*node][]*keyBundle) deleteOperations := make(map[*node][]*keyBundle) toComplete := make(actions, 0, len(xns)/2) for _, ifc := range xns { action := ifc.(action) switch action.operation() { case add: for i, n := range action.nodes() { writeOperations[n] = append(writeOperations[n], &keyBundle{key: action.rects()[i].hilbert, left: action.rects()[i].rect}) } toComplete = append(toComplete, action) case remove: for i, n := range action.nodes() { deleteOperations[n] = append(deleteOperations[n], &keyBundle{key: action.rects()[i].hilbert, left: action.rects()[i].rect}) } toComplete = append(toComplete, action) case get: action.complete() } } return writeOperations, deleteOperations, toComplete } func (tree *tree) fetchKeysInSerial(xns interfaces) { for _, ifc := range xns { action := ifc.(action) switch action.operation() { case add, remove: for i, key := range action.rects() { n := getParent(tree.root, key.hilbert, key.rect) action.addNode(int64(i), n) } case get: ga := action.(*getAction) rects := tree.search(ga.lookup) ga.result = rects } } } func (tree *tree) reset() { for i := range tree.cache { tree.cache[i] = nil } tree.cache = tree.cache[:0] atomic.StoreUint64(&tree.running, 0) tree.checkAndRun(nil) } func (tree *tree) fetchKeysInParallel(xns []interface{}) { var forCache struct { i int64 buffer [8]uint64 // different cache lines js []int64 } for j := 0; j < len(xns); j++ { forCache.js = append(forCache.js, -1) } numCPU := runtime.NumCPU() if numCPU > 1 { numCPU-- } var wg sync.WaitGroup wg.Add(numCPU) for k := 0; k < numCPU; k++ { go func() { for { index := atomic.LoadInt64(&forCache.i) if index >= int64(len(xns)) { break } action := xns[index].(action) j := atomic.AddInt64(&forCache.js[index], 1) if j > int64(len(action.rects())) { // someone else is updating i continue } else if j == int64(len(action.rects())) { atomic.StoreInt64(&forCache.i, index+1) continue } switch action.operation() { case add, remove: hb := action.rects()[j] n := getParent(tree.root, hb.hilbert, hb.rect) action.addNode(j, n) case get: ga := action.(*getAction) result := tree.search(ga.lookup) ga.result = result } } wg.Done() }() } wg.Wait() } func (tree *tree) splitNode(n, parent *node, nodes *[]*node, keys *hilberts) { if !n.needsSplit(tree.ary) { return } length := n.keys.len() splitAt := tree.ary - 1 for i := splitAt; i < length; i += splitAt { offset := length - i k, left, right := n.split(offset, tree.ary) left.right = right *keys = append(*keys, k) *nodes = append(*nodes, left, right) left.parent = parent right.parent = parent } } func (tree *tree) applyNode(n *node, adds, deletes []*keyBundle) { for _, kb := range deletes { if n.keys.len() == 0 { break } deleted := n.delete(kb) if deleted != nil { atomic.AddUint64(&tree.number, ^uint64(0)) } } for _, kb := range adds { old := n.insert(kb) if n.isLeaf && old == nil { atomic.AddUint64(&tree.number, 1) } } } func (tree *tree) recursiveMutate(adds, deletes map[*node][]*keyBundle, setRoot, inParallel bool) { if len(adds) == 0 && len(deletes) == 0 { return } if setRoot && len(adds) > 1 { panic(`SHOULD ONLY HAVE ONE ROOT`) } ifs := make(interfaces, 0, len(adds)) for n := range adds { if n.parent == nil { setRoot = true } ifs = append(ifs, n) } for n := range deletes { if n.parent == nil { setRoot = true } if _, ok := adds[n]; !ok { ifs = append(ifs, n) } } var dummyRoot *node if setRoot { dummyRoot = &node{ keys: newKeys(tree.ary), nodes: newNodes(tree.ary), mbr: &rectangle{}, } } var write sync.Mutex nextLayerWrite := make(map[*node][]*keyBundle) nextLayerDelete := make(map[*node][]*keyBundle) var mutate func(interfaces, func(interface{})) if inParallel { mutate = executeInterfacesInParallel } else { mutate = executeInterfacesInSerial } mutate(ifs, func(ifc interface{}) { n := ifc.(*node) adds := adds[n] deletes := deletes[n] if len(adds) == 0 && len(deletes) == 0 { return } if setRoot { tree.root = n } parent := n.parent if parent == nil { parent = dummyRoot setRoot = true } tree.applyNode(n, adds, deletes) if n.needsSplit(tree.ary) { keys := make(hilberts, 0, n.keys.len()) nodes := make([]*node, 0, n.nodes.len()) tree.splitNode(n, parent, &nodes, &keys) write.Lock() for i, k := range keys { nextLayerWrite[parent] = append(nextLayerWrite[parent], &keyBundle{key: k, left: nodes[i*2], right: nodes[i*2+1]}) } write.Unlock() } }) tree.recursiveMutate(nextLayerWrite, nextLayerDelete, setRoot, inParallel) } // Insert will add the provided keys to the tree. func (tree *tree) Insert(rects ...rtree.Rectangle) { ia := newInsertAction(rects) tree.checkAndRun(ia) ia.completer.Wait() } // Delete will remove the provided keys from the tree. If no // matching key is found, this is a no-op. func (tree *tree) Delete(rects ...rtree.Rectangle) { ra := newRemoveAction(rects) tree.checkAndRun(ra) ra.completer.Wait() } func (tree *tree) search(r *rectangle) rtree.Rectangles { if tree.root == nil { return rtree.Rectangles{} } result := make(rtree.Rectangles, 0, 10) whs := tree.root.searchRects(r) for len(whs) > 0 { wh := whs[0] if n, ok := wh.(*node); ok { whs = append(whs, n.searchRects(r)...) } else { result = append(result, wh) } whs = whs[1:] } return result } // Search will return a list of rectangles that intersect the provided // rectangle. func (tree *tree) Search(rect rtree.Rectangle) rtree.Rectangles { ga := newGetAction(rect) tree.checkAndRun(ga) ga.completer.Wait() return ga.result } // Len returns the number of items in the tree. func (tree *tree) Len() uint64 { return atomic.LoadUint64(&tree.number) } // Dispose will clean up any resources used by this tree. This // must be called to prevent a memory leak. func (tree *tree) Dispose() { tree.actions.Dispose() atomic.StoreUint64(&tree.disposed, 1) } func newTree(bufferSize, ary uint64) *tree { tree := &tree{} tree.init(bufferSize, ary) return tree } // New will construct a new Hilbert R-Tree and return it. func New(bufferSize, ary uint64) rtree.RTree { return newTree(bufferSize, ary) } ================================================ FILE: rtree/hilbert/tree_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package hilbert import ( "log" "math" "math/rand" "os" "testing" "github.com/stretchr/testify/assert" "github.com/Workiva/go-datastructures/rtree" ) func getConsoleLogger() *log.Logger { return log.New(os.Stderr, "", log.LstdFlags) } func (n *node) print(log *log.Logger) { log.Printf(`NODE: %+v, MBR: %+v, %p`, n, n.mbr, n) if n.isLeaf { for i, wh := range n.nodes.list { xlow, ylow := wh.LowerLeft() xhigh, yhigh := wh.UpperRight() log.Printf(`KEY: %+v, XLOW: %+v, YLOW: %+v, XHIGH: %+v, YHIGH: %+v`, n.keys.list[i], xlow, ylow, xhigh, yhigh) } } else { for _, wh := range n.nodes.list { wh.(*node).print(log) } } } func (t *tree) print(log *log.Logger) { log.Println(`PRINTING TREE`) if t.root == nil { log.Println(`EMPTY TREE.`) return } t.root.print(log) } func constructMockPoints(num int) rtree.Rectangles { rects := make(rtree.Rectangles, 0, num) for i := int32(0); i < int32(num); i++ { rects = append(rects, newMockRectangle(i, i, i, i)) } return rects } func constructRandomMockPoints(num int) rtree.Rectangles { rects := make(rtree.Rectangles, 0, num) for i := 0; i < num; i++ { r := rand.Int31() rects = append(rects, newMockRectangle(r, r, r, r)) } return rects } func constructInfiniteRect() rtree.Rectangle { return newMockRectangle(0, 0, math.MaxInt32, math.MaxInt32) } func TestSimpleInsert(t *testing.T) { r1 := newMockRectangle(0, 0, 10, 10) tree := newTree(3, 3) tree.Insert(r1) assert.Equal(t, uint64(1), tree.Len()) q := newMockRectangle(5, 5, 15, 15) result := tree.Search(q) assert.Equal(t, rtree.Rectangles{r1}, result) } func TestSimpleDelete(t *testing.T) { r1 := newMockRectangle(0, 0, 10, 10) tree := newTree(3, 3) tree.Insert(r1) tree.Delete(r1) assert.Equal(t, uint64(0), tree.Len()) q := newMockRectangle(5, 5, 15, 15) result := tree.Search(q) assert.Len(t, result, 0) } func TestDeleteIdenticalHilbergNumber(t *testing.T) { r1 := newMockRectangle(0, 0, 20, 20) r2 := newMockRectangle(5, 5, 15, 15) tree := newTree(3, 3) tree.Insert(r1) tree.Delete(r2) assert.Equal(t, uint64(1), tree.Len()) result := tree.Search(r2) assert.Equal(t, rtree.Rectangles{r1}, result) tree.Delete(r1) assert.Equal(t, uint64(0), tree.Len()) result = tree.Search(r1) assert.Len(t, result, 0) } func TestDeleteAll(t *testing.T) { points := constructRandomMockPoints(3) tree := newTree(3, 3) tree.Insert(points...) assert.Equal(t, uint64(len(points)), tree.Len()) tree.Delete(points...) assert.Equal(t, uint64(0), tree.Len()) result := tree.Search(constructInfiniteRect()) assert.Len(t, result, 0) } func TestTwoInsert(t *testing.T) { r1 := newMockRectangle(0, 0, 10, 10) r2 := newMockRectangle(5, 5, 15, 15) tree := newTree(3, 3) tree.Insert(r1, r2) assert.Equal(t, uint64(2), tree.Len()) q := newMockRectangle(0, 0, 20, 20) result := tree.Search(q) assert.Equal(t, rtree.Rectangles{r1, r2}, result) q = newMockRectangle(0, 0, 4, 4) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{r1}, result) q = newMockRectangle(11, 11, 20, 20) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{r2}, result) } func TestInsertCausesRootSplitOddAry(t *testing.T) { r1 := newMockRectangle(0, 0, 10, 10) r2 := newMockRectangle(5, 5, 15, 15) r3 := newMockRectangle(10, 10, 20, 20) tree := newTree(3, 3) tree.Insert(r1, r2, r3) assert.Equal(t, uint64(3), tree.Len()) q := newMockRectangle(0, 0, 20, 20) result := tree.Search(q) assert.Contains(t, result, r1) assert.Contains(t, result, r2) assert.Contains(t, result, r3) } func TestInsertCausesRootSplitEvenAry(t *testing.T) { r1 := newMockRectangle(0, 0, 10, 10) r2 := newMockRectangle(5, 5, 15, 15) r3 := newMockRectangle(10, 10, 20, 20) r4 := newMockRectangle(15, 15, 25, 25) tree := newTree(4, 4) tree.Insert(r1, r2, r3, r4) assert.Equal(t, uint64(4), tree.Len()) q := newMockRectangle(0, 0, 25, 25) result := tree.Search(q) assert.Contains(t, result, r1) assert.Contains(t, result, r2) assert.Contains(t, result, r3) assert.Contains(t, result, r4) } func TestQueryWithLine(t *testing.T) { r1 := newMockRectangle(0, 0, 10, 10) r2 := newMockRectangle(5, 5, 15, 15) tree := newTree(3, 3) tree.Insert(r1, r2) // vertical line at x=5 q := newMockRectangle(5, 0, 5, 10) result := tree.Search(q) assert.Equal(t, rtree.Rectangles{r1, r2}, result) // horizontal line at y=5 q = newMockRectangle(0, 5, 10, 5) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{r1, r2}, result) // vertical line at x=15 q = newMockRectangle(15, 0, 15, 20) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{r2}, result) // horizontal line at y=15 q = newMockRectangle(0, 15, 20, 15) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{r2}, result) // vertical line on the y-axis q = newMockRectangle(0, 0, 0, 10) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{r1}, result) // horizontal line on the x-axis q = newMockRectangle(0, 0, 10, 0) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{r1}, result) // vertical line at x=20 q = newMockRectangle(20, 0, 20, 20) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{}, result) // horizontal line at y=20 q = newMockRectangle(0, 20, 20, 20) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{}, result) } func TestQueryForPoint(t *testing.T) { r1 := newMockRectangle(5, 5, 5, 5) // (5, 5) r2 := newMockRectangle(10, 10, 10, 10) // (10, 10) tree := newTree(3, 3) tree.Insert(r1, r2) q := newMockRectangle(0, 0, 5, 5) result := tree.Search(q) assert.Equal(t, rtree.Rectangles{r1}, result) q = newMockRectangle(0, 0, 20, 20) result = tree.Search(q) assert.Contains(t, result, r1) assert.Contains(t, result, r2) q = newMockRectangle(6, 6, 20, 20) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{r2}, result) q = newMockRectangle(20, 20, 30, 30) result = tree.Search(q) assert.Equal(t, rtree.Rectangles{}, result) } func TestMultipleInsertsCauseInternalSplitOddAry(t *testing.T) { points := constructMockPoints(100) tree := newTree(3, 3) tree.Insert(points...) assert.Equal(t, uint64(len(points)), tree.Len()) q := newMockRectangle(0, 0, int32(len(points)), int32(len(points))) result := tree.Search(q) succeeded := true for _, p := range points { if !assert.Contains(t, result, p) { succeeded = false } } if !succeeded { tree.print(getConsoleLogger()) } } func TestMultipleInsertsCauseInternalSplitOddAryRandomPoints(t *testing.T) { points := constructRandomMockPoints(100) tree := newTree(3, 3) tree.Insert(points...) assert.Equal(t, uint64(len(points)), tree.Len()) q := newMockRectangle(0, 0, math.MaxInt32, math.MaxInt32) result := tree.Search(q) succeeded := true for _, p := range points { if !assert.Contains(t, result, p) { succeeded = false } } if !succeeded { tree.print(getConsoleLogger()) } } func TestMultipleInsertsCauseInternalSplitEvenAry(t *testing.T) { points := constructMockPoints(100) tree := newTree(4, 4) tree.Insert(points...) assert.Equal(t, uint64(len(points)), tree.Len()) q := newMockRectangle(0, 0, math.MaxInt32, math.MaxInt32) result := tree.Search(q) succeeded := true for _, p := range points { if !assert.Contains(t, result, p) { succeeded = false } } if !succeeded { tree.print(getConsoleLogger()) } } func TestMultipleInsertsCauseInternalSplitEvenAryRandomOrder(t *testing.T) { points := constructRandomMockPoints(100) tree := newTree(4, 4) tree.Insert(points...) assert.Equal(t, uint64(len(points)), tree.Len()) q := newMockRectangle(0, 0, math.MaxInt32, math.MaxInt32) result := tree.Search(q) succeeded := true for _, p := range points { if !assert.Contains(t, result, p) { succeeded = false } } if !succeeded { tree.print(getConsoleLogger()) } } func TestInsertDuplicateHilbert(t *testing.T) { r1 := newMockRectangle(0, 0, 20, 20) r2 := newMockRectangle(1, 1, 19, 19) r3 := newMockRectangle(2, 2, 18, 18) r4 := newMockRectangle(3, 3, 17, 17) tree := newTree(3, 3) tree.Insert(r1) tree.Insert(r2) tree.Insert(r3) tree.Insert(r4) assert.Equal(t, uint64(4), tree.Len()) q := newMockRectangle(0, 0, 30, 30) result := tree.Search(q) assert.Len(t, result, 4) assert.Contains(t, result, r1) assert.Contains(t, result, r2) assert.Contains(t, result, r3) assert.Contains(t, result, r4) } func TestDeleteAllDuplicateHilbert(t *testing.T) { r1 := newMockRectangle(0, 0, 20, 20) r2 := newMockRectangle(1, 1, 19, 19) r3 := newMockRectangle(2, 2, 18, 18) r4 := newMockRectangle(3, 3, 17, 17) tree := newTree(3, 3) tree.Insert(r1) tree.Insert(r2) tree.Insert(r3) tree.Insert(r4) tree.Delete(r1, r2, r3, r4) assert.Equal(t, uint64(0), tree.Len()) result := tree.Search(constructInfiniteRect()) assert.Len(t, result, 0) } func TestInsertDuplicateRect(t *testing.T) { r1 := newMockRectangle(0, 0, 20, 20) r2 := newMockRectangle(0, 0, 20, 20) tree := newTree(3, 3) tree.Insert(r1) tree.Insert(r2) assert.Equal(t, uint64(1), tree.Len()) result := tree.Search(constructInfiniteRect()) assert.Equal(t, rtree.Rectangles{r2}, result) } func BenchmarkBulkAddPoints(b *testing.B) { numItems := 1000 points := constructMockPoints(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { tree := newTree(8, 8) tree.Insert(points...) } } func BenchmarkBulkUpdatePoints(b *testing.B) { numItems := 1000 points := constructMockPoints(numItems) tree := newTree(8, 8) tree.Insert(points...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(points...) } } func BenchmarkPointInsertion(b *testing.B) { numItems := b.N points := constructMockPoints(numItems) tree := newTree(8, 8) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Insert(points[i%numItems]) } } func BenchmarkQueryPoints(b *testing.B) { numItems := b.N points := constructMockPoints(numItems) tree := newTree(8, 8) tree.Insert(points...) b.ResetTimer() for i := int32(0); i < int32(b.N); i++ { tree.Search(newMockRectangle(i, i, i+10, i+10)) } } func BenchmarkQueryBulkPoints(b *testing.B) { numItems := b.N points := constructMockPoints(numItems) tree := newTree(8, 8) tree.Insert(points...) b.ResetTimer() for i := int32(0); i < int32(b.N); i++ { tree.Search(newMockRectangle(i, i, int32(numItems), int32(numItems))) } } func BenchmarkDelete(b *testing.B) { numItems := b.N points := constructMockPoints(numItems) tree := newTree(8, 8) tree.Insert(points...) b.ResetTimer() for i := 0; i < b.N; i++ { tree.Delete(points[i%numItems]) } } ================================================ FILE: rtree/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package rtree // Rectangles is a typed list of Rectangle. type Rectangles []Rectangle // Rectangle describes a two-dimensional bound. type Rectangle interface { // LowerLeft describes the lower left coordinate of this rectangle. LowerLeft() (int32, int32) // UpperRight describes the upper right coordinate of this rectangle. UpperRight() (int32, int32) } // RTree defines an object that can be returned from any subpackage // of this package. type RTree interface { // Search will perform an intersection search of the given // rectangle and return any rectangles that intersect. Search(Rectangle) Rectangles // Len returns in the number of items in the RTree. Len() uint64 // Dispose will clean up any objects used by the RTree. Dispose() // Delete will remove the provided rectangles from the RTree. Delete(...Rectangle) // Insert will add the provided rectangles to the RTree. Insert(...Rectangle) } ================================================ FILE: set/dict.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package set is a simple unordered set implemented with a map. This set is threadsafe which decreases performance. TODO: Actually write custom hashmap using the hash/fnv hasher. TODO: Our Set implementation Could be further optimized by getting the uintptr of the generic interface{} used and using that as the key; Golang maps handle uintptr much better than the generic interface{} key. */ package set import "sync" var pool = sync.Pool{} // Set is an implementation of ISet using the builtin map type. Set is threadsafe. type Set struct { items map[interface{}]struct{} lock sync.RWMutex flattened []interface{} } // Add will add the provided items to the set. func (set *Set) Add(items ...interface{}) { set.lock.Lock() defer set.lock.Unlock() set.flattened = nil for _, item := range items { set.items[item] = struct{}{} } } // Remove will remove the given items from the set. func (set *Set) Remove(items ...interface{}) { set.lock.Lock() defer set.lock.Unlock() set.flattened = nil for _, item := range items { delete(set.items, item) } } // Exists returns a bool indicating if the given item exists in the set. func (set *Set) Exists(item interface{}) bool { set.lock.RLock() _, ok := set.items[item] set.lock.RUnlock() return ok } // Flatten will return a list of the items in the set. func (set *Set) Flatten() []interface{} { set.lock.Lock() defer set.lock.Unlock() if set.flattened != nil { return set.flattened } set.flattened = make([]interface{}, 0, len(set.items)) for item := range set.items { set.flattened = append(set.flattened, item) } return set.flattened } // Len returns the number of items in the set. func (set *Set) Len() int64 { set.lock.RLock() size := int64(len(set.items)) set.lock.RUnlock() return size } // Clear will remove all items from the set. func (set *Set) Clear() { set.lock.Lock() set.items = map[interface{}]struct{}{} set.lock.Unlock() } // All returns a bool indicating if all of the supplied items exist in the set. func (set *Set) All(items ...interface{}) bool { set.lock.RLock() defer set.lock.RUnlock() for _, item := range items { if _, ok := set.items[item]; !ok { return false } } return true } // Dispose will add this set back into the pool. func (set *Set) Dispose() { set.lock.Lock() defer set.lock.Unlock() for k := range set.items { delete(set.items, k) } //this is so we don't hang onto any references for i := 0; i < len(set.flattened); i++ { set.flattened[i] = nil } set.flattened = set.flattened[:0] pool.Put(set) } // New is the constructor for sets. It will pull from a reuseable memory pool if it can. // Takes a list of items to initialize the set with. func New(items ...interface{}) *Set { set := pool.Get().(*Set) for _, item := range items { set.items[item] = struct{}{} } if len(items) > 0 { set.flattened = nil } return set } func init() { pool.New = func() interface{} { return &Set{ items: make(map[interface{}]struct{}, 10), } } } ================================================ FILE: set/dict_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package set import ( "reflect" "strconv" "testing" ) func TestAddDuplicateItem(t *testing.T) { set := New() set.Add(`test`) set.Add(`test`) if !reflect.DeepEqual([]interface{}{`test`}, set.Flatten()) { t.Errorf(`Incorrect result returned: %+v`, set.Flatten()) } } func TestAddItems(t *testing.T) { set := New() set.Add(`test`) set.Add(`test1`) firstSeen := false secondSeen := false // order is not guaranteed for _, item := range set.Flatten() { if item.(string) == `test` { firstSeen = true } else if item.(string) == `test1` { secondSeen = true } } if !firstSeen || !secondSeen { t.Errorf(`Not all items seen in set.`) } } func TestRemove(t *testing.T) { set := New() set.Add(`test`) set.Remove(`test`) if !reflect.DeepEqual([]interface{}{}, set.Flatten()) { t.Errorf(`Incorrect result returned: %+v`, set.Flatten()) } } func TestExists(t *testing.T) { set := New() set.Add(`test`) if !set.Exists(`test`) { t.Errorf(`Correct existence not determined`) } if set.Exists(`test1`) { t.Errorf(`Correct nonexistence not determined.`) } } func TestExists_WithNewItems(t *testing.T) { set := New(`test`, `test1`) if !set.Exists(`test`) { t.Errorf(`Correct existence not determined`) } if !set.Exists(`test1`) { t.Errorf(`Correct existence not determined`) } if set.Exists(`test2`) { t.Errorf(`Correct nonexistence not determined.`) } } func TestLen(t *testing.T) { set := New() set.Add(`test`) if set.Len() != 1 { t.Errorf(`Expected len: %d, received: %d`, 1, set.Len()) } set.Add(`test1`) if set.Len() != 2 { t.Errorf(`Expected len: %d, received: %d`, 2, set.Len()) } } func TestFlattenCaches(t *testing.T) { set := New() item := `test` set.Add(item) set.Flatten() if len(set.flattened) != 1 { t.Errorf(`Expected len: %d, received: %d`, 1, len(set.flattened)) } } func TestFlattenCaches_CacheReturn(t *testing.T) { set := New() item := `test` set.Add(item) flatten1 := set.Flatten() flatten2 := set.Flatten() if !reflect.DeepEqual(flatten1, flatten2) { t.Errorf(`Flatten cache is not the same as original result. Got %+v, expected %+v`, flatten2, flatten1) } } func TestAddClearsCache(t *testing.T) { set := New() item := `test` set.Add(item) set.Flatten() set.Add(item) if len(set.flattened) != 0 { t.Errorf(`Expected len: %d, received: %d`, 0, len(set.flattened)) } item = `test2` set.Add(item) if set.flattened != nil { t.Errorf(`Cache not cleared.`) } } func TestDeleteClearsCache(t *testing.T) { set := New() item := `test` set.Add(item) set.Flatten() set.Remove(item) if set.flattened != nil { t.Errorf(`Cache not cleared.`) } } func TestAll(t *testing.T) { set := New() item := `test` set.Add(item) result := set.All(item) if !result { t.Errorf(`Expected true.`) } itemTwo := `test1` result = set.All(item, itemTwo) if result { t.Errorf(`Expected false.`) } } func TestClear(t *testing.T) { set := New() set.Add(`test`) set.Clear() if set.Len() != 0 { t.Errorf(`Expected len: %d, received: %d`, 0, set.Len()) } } func BenchmarkFlatten(b *testing.B) { set := New() for i := 0; i < 50; i++ { item := strconv.Itoa(i) set.Add(item) } b.ResetTimer() for i := 0; i < b.N; i++ { set.Flatten() } } func BenchmarkLen(b *testing.B) { set := New() for i := 0; i < 50; i++ { item := strconv.Itoa(i) set.Add(item) } b.ResetTimer() for i := 0; i < b.N; i++ { set.Len() } } func BenchmarkExists(b *testing.B) { set := New() set.Add(1) b.ResetTimer() for i := 0; i < b.N; i++ { set.Exists(1) } } func BenchmarkClear(b *testing.B) { set := New() for i := 0; i < b.N; i++ { set.Clear() } } ================================================ FILE: slice/int64.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package Int64 simply adds an Int64-typed version of the standard library's sort/IntSlice implementation. Also added is an Insert method. */ package slice import "sort" // Int64Slice is a slice that fulfills the sort.Interface interface. type Int64Slice []int64 // Len returns the len of this slice. Required by sort.Interface. func (s Int64Slice) Len() int { return len(s) } // Less returns a bool indicating if the value at position i // is less than at position j. Required by sort.Interface. func (s Int64Slice) Less(i, j int) bool { return s[i] < s[j] } // Search will search this slice and return an index that corresponds // to the lowest position of that value. You'll need to check // separately if the value at that position is equal to x. The // behavior of this method is undefinited if the slice is not sorted. func (s Int64Slice) Search(x int64) int { return sort.Search(len(s), func(i int) bool { return s[i] >= x }) } // Sort will in-place sort this list of int64s. func (s Int64Slice) Sort() { sort.Sort(s) } // Swap will swap the elements at positions i and j. This is required // by sort.Interface. func (s Int64Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // Exists returns a bool indicating if the provided value exists // in this list. This has undefined behavior if the list is not // sorted. func (s Int64Slice) Exists(x int64) bool { i := s.Search(x) if i == len(s) { return false } return s[i] == x } // Insert will insert x into the sorted position in this list // and return a list with the value added. If this slice has not // been sorted Insert's behavior is undefined. func (s Int64Slice) Insert(x int64) Int64Slice { i := s.Search(x) if i == len(s) { return append(s, x) } if s[i] == x { return s } s = append(s, 0) copy(s[i+1:], s[i:]) s[i] = x return s } ================================================ FILE: slice/int64_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package slice import ( "testing" "github.com/stretchr/testify/assert" ) func TestSort(t *testing.T) { s := Int64Slice{3, 6, 1, 0, -1} s.Sort() assert.Equal(t, Int64Slice{-1, 0, 1, 3, 6}, s) } func TestSearch(t *testing.T) { s := Int64Slice{1, 3, 6} assert.Equal(t, 1, s.Search(3)) assert.Equal(t, 1, s.Search(2)) assert.Equal(t, 3, s.Search(7)) } func TestExists(t *testing.T) { s := Int64Slice{1, 3, 6} assert.True(t, s.Exists(3)) assert.False(t, s.Exists(4)) } func TestInsert(t *testing.T) { s := Int64Slice{1, 3, 6} s = s.Insert(2) assert.Equal(t, Int64Slice{1, 2, 3, 6}, s) s = s.Insert(7) assert.Equal(t, Int64Slice{1, 2, 3, 6, 7}, s) } ================================================ FILE: slice/skip/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package skip import "github.com/Workiva/go-datastructures/common" // Iterator defines an interface that allows a consumer to iterate // all results of a query. All values will be visited in-order. type Iterator interface { // Next returns a bool indicating if there is future value // in the iterator and moves the iterator to that value. Next() bool // Value returns a Comparator representing the iterator's current // position. If there is no value, this returns nil. Value() common.Comparator // exhaust is a helper method that will iterate this iterator // to completion and return a list of resulting Entries // in order. exhaust() common.Comparators } ================================================ FILE: slice/skip/iterator.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package skip import "github.com/Workiva/go-datastructures/common" const iteratorExhausted = -2 // iterator represents an object that can be iterated. It will // return false on Next and nil on Value if there are no further // values to be iterated. type iterator struct { first bool n *node } // Next returns a bool indicating if there are any further values // in this iterator. func (iter *iterator) Next() bool { if iter.first { iter.first = false return iter.n != nil } if iter.n == nil { return false } iter.n = iter.n.forward[0] return iter.n != nil } // Value returns a Comparator representing the iterator's present // position in the query. Returns nil if no values remain to iterate. func (iter *iterator) Value() common.Comparator { if iter.n == nil { return nil } return iter.n.entry } // exhaust is a helper method to exhaust this iterator and return // all remaining entries. func (iter *iterator) exhaust() common.Comparators { entries := make(common.Comparators, 0, 10) for i := iter; i.Next(); { entries = append(entries, i.Value()) } return entries } // nilIterator returns an iterator that will always return false // for Next and nil for Value. func nilIterator() *iterator { return &iterator{} } ================================================ FILE: slice/skip/iterator_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package skip import ( "testing" "github.com/stretchr/testify/assert" ) func TestIterate(t *testing.T) { e1 := newMockEntry(1) n1 := newNode(e1, 8) iter := &iterator{ n: n1, first: true, } assert.True(t, iter.Next()) assert.Equal(t, e1, iter.Value()) assert.False(t, iter.Next()) assert.Nil(t, iter.Value()) e2 := newMockEntry(2) n2 := newNode(e2, 8) n1.forward[0] = n2 iter = &iterator{ n: n1, first: true, } assert.True(t, iter.Next()) assert.Equal(t, e1, iter.Value()) assert.True(t, iter.Next()) assert.Equal(t, e2, iter.Value()) assert.False(t, iter.Next()) assert.Nil(t, iter.Value()) iter = nilIterator() assert.False(t, iter.Next()) assert.Nil(t, iter.Value()) } ================================================ FILE: slice/skip/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package skip import ( "github.com/stretchr/testify/mock" "github.com/Workiva/go-datastructures/common" ) type mockEntry uint64 func (me mockEntry) Compare(other common.Comparator) int { otherU := other.(mockEntry) if me == otherU { return 0 } if me > otherU { return 1 } return -1 } func newMockEntry(key uint64) mockEntry { return mockEntry(key) } type mockIterator struct { mock.Mock } func (mi *mockIterator) Next() bool { args := mi.Called() return args.Bool(0) } func (mi *mockIterator) Value() common.Comparator { args := mi.Called() result, ok := args.Get(0).(common.Comparator) if !ok { return nil } return result } func (mi *mockIterator) exhaust() common.Comparators { return nil } ================================================ FILE: slice/skip/node.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package skip import "github.com/Workiva/go-datastructures/common" type widths []uint64 type nodes []*node type node struct { // forward denotes the forward pointing pointers in this // node. forward nodes // widths keeps track of the distance between this pointer // and the forward pointers so we can access skip list // values by position in logarithmic time. widths widths // entry is the associated value with this node. entry common.Comparator } func (n *node) Compare(e common.Comparator) int { return n.entry.Compare(e) } // newNode will allocate and return a new node with the entry // provided. maxLevels will determine the length of the forward // pointer list associated with this node. func newNode(cmp common.Comparator, maxLevels uint8) *node { return &node{ entry: cmp, forward: make(nodes, maxLevels), widths: make(widths, maxLevels), } } ================================================ FILE: slice/skip/skip.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package skip defines a skiplist datastructure. That is, a data structure that probabilistically determines relationships between keys. By doing so, it becomes easier to program than a binary search tree but maintains similar speeds. Performance characteristics: Insert: O(log n) Search: O(log n) Delete: O(log n) Space: O(n) Recently added is the capability to address, insert, and replace an entry by position. This capability is achieved by saving the width of the "gap" between two nodes. Searching for an item by position is very similar to searching by value in that the same basic algorithm is used but we are searching for width instead of value. Because this avoids the overhead associated with Golang interfaces, operations by position are about twice as fast as operations by value. Time complexities listed below. SearchByPosition: O(log n) InsertByPosition: O(log n) More information here: http://cglab.ca/~morin/teaching/5408/refs/p90b.pdf Benchmarks: BenchmarkInsert-8 2000000 930 ns/op BenchmarkGet-8 2000000 989 ns/op BenchmarkDelete-8 3000000 600 ns/op BenchmarkPrepend-8 1000000 1468 ns/op BenchmarkByPosition-8 10000000 202 ns/op BenchmarkInsertAtPosition-8 3000000 485 ns/op CPU profiling has shown that the most expensive thing we do here is call Compare. A potential optimization for gets only is to do a binary search in the forward/width lists instead of visiting every value. We could also use generics if Golang had them and let the consumer specify primitive types, which would speed up these operation dramatically. */ package skip import ( "math/rand" "sync" "sync/atomic" "time" "github.com/Workiva/go-datastructures/common" ) const p = .5 // the p level defines the probability that a node // with a value at level i also has a value at i+1. This number // is also important in determining max level. Max level will // be defined as L(N) where L = log base (1/p) of n where n // is the number of items in the list and N is the number of possible // items in the universe. If p = .5 then maxlevel = 32 is appropriate // for uint32. // lockedSource is an implementation of rand.Source that is safe for // concurrent use by multiple goroutines. The code is modeled after // https://golang.org/src/math/rand/rand.go. type lockedSource struct { mu sync.Mutex src rand.Source } // Int63 implements the rand.Source interface. func (ls *lockedSource) Int63() (n int64) { ls.mu.Lock() n = ls.src.Int63() ls.mu.Unlock() return } // Seed implements the rand.Source interface. func (ls *lockedSource) Seed(seed int64) { ls.mu.Lock() ls.src.Seed(seed) ls.mu.Unlock() } // generator will be the common generator to create random numbers. It // is seeded with unix nanosecond when this line is executed at runtime, // and only executed once ensuring all random numbers come from the same // randomly seeded generator. var generator = rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())}) func generateLevel(maxLevel uint8) uint8 { var level uint8 for level = uint8(1); level < maxLevel-1; level++ { if generator.Float64() >= p { return level } } return level } func insertNode(sl *SkipList, n *node, cmp common.Comparator, pos uint64, cache nodes, posCache widths, allowDuplicate bool) common.Comparator { if !allowDuplicate && n != nil && n.Compare(cmp) == 0 { // a simple update in this case oldEntry := n.entry n.entry = cmp return oldEntry } atomic.AddUint64(&sl.num, 1) nodeLevel := generateLevel(sl.maxLevel) if nodeLevel > sl.level { for i := sl.level; i < nodeLevel; i++ { cache[i] = sl.head } sl.level = nodeLevel } nn := newNode(cmp, nodeLevel) for i := uint8(0); i < nodeLevel; i++ { nn.forward[i] = cache[i].forward[i] cache[i].forward[i] = nn formerWidth := cache[i].widths[i] if formerWidth == 0 { nn.widths[i] = 0 } else { nn.widths[i] = posCache[i] + formerWidth + 1 - pos } if cache[i].forward[i] != nil { cache[i].widths[i] = pos - posCache[i] } } for i := nodeLevel; i < sl.level; i++ { if cache[i].forward[i] == nil { continue } cache[i].widths[i]++ } return nil } func splitAt(sl *SkipList, index uint64) (*SkipList, *SkipList) { right := &SkipList{} right.maxLevel = sl.maxLevel right.level = sl.level right.cache = make(nodes, sl.maxLevel) right.posCache = make(widths, sl.maxLevel) right.head = newNode(nil, sl.maxLevel) sl.searchByPosition(index, sl.cache, sl.posCache) // populate the cache that needs updating for i := uint8(0); i <= sl.level; i++ { right.head.forward[i] = sl.cache[i].forward[i] if sl.cache[i].forward[i] != nil { right.head.widths[i] = sl.cache[i].widths[i] - (index - sl.posCache[i]) } sl.cache[i].widths[i] = 0 sl.cache[i].forward[i] = nil } right.num = sl.Len() - index // right is not in user's hands yet atomic.AddUint64(&sl.num, -right.num) sl.resetMaxLevel() right.resetMaxLevel() return sl, right } // Skip list is a datastructure that probabalistically determines // relationships between nodes. This results in a structure // that performs similarly to a BST but is much easier to build // from a programmatic perspective (no rotations). type SkipList struct { maxLevel, level uint8 head *node num uint64 // a list of nodes that can be reused, should reduce // the number of allocations in the insert/delete case. cache nodes posCache widths } // init will initialize this skiplist. The parameter is expected // to be of some uint type which will set this skiplist's maximum // level. func (sl *SkipList) init(ifc interface{}) { switch ifc.(type) { case uint8: sl.maxLevel = 8 case uint16: sl.maxLevel = 16 case uint32: sl.maxLevel = 32 case uint64, uint: sl.maxLevel = 64 } sl.cache = make(nodes, sl.maxLevel) sl.posCache = make(widths, sl.maxLevel) sl.head = newNode(nil, sl.maxLevel) } func (sl *SkipList) search(cmp common.Comparator, update nodes, widths widths) (*node, uint64) { if sl.Len() == 0 { // nothing in the list return nil, 1 } var pos uint64 = 0 var offset uint8 var alreadyChecked *node n := sl.head for i := uint8(0); i <= sl.level; i++ { offset = sl.level - i for n.forward[offset] != nil && n.forward[offset] != alreadyChecked && n.forward[offset].Compare(cmp) < 0 { pos += n.widths[offset] n = n.forward[offset] } alreadyChecked = n if update != nil { update[offset] = n widths[offset] = pos } } return n.forward[0], pos + 1 } func (sl *SkipList) resetMaxLevel() { if sl.level < 1 { sl.level = 1 return } for sl.head.forward[sl.level-1] == nil && sl.level > 1 { sl.level-- } } func (sl *SkipList) searchByPosition(position uint64, update nodes, widths widths) (*node, uint64) { if sl.Len() == 0 { // nothing in the list return nil, 1 } if position > sl.Len() { return nil, 1 } var pos uint64 = 0 var offset uint8 n := sl.head for i := uint8(0); i <= sl.level; i++ { offset = sl.level - i for n.forward[offset] != nil && pos+n.widths[offset] <= position { pos += n.widths[offset] n = n.forward[offset] } if update != nil { update[offset] = n widths[offset] = pos } } return n, pos + 1 } // Get will retrieve values associated with the keys provided. If an // associated value could not be found, a nil is returned in its place. // This is an O(log n) operation. func (sl *SkipList) Get(comparators ...common.Comparator) common.Comparators { result := make(common.Comparators, 0, len(comparators)) var n *node for _, cmp := range comparators { n, _ = sl.search(cmp, nil, nil) if n != nil && n.Compare(cmp) == 0 { result = append(result, n.entry) } else { result = append(result, nil) } } return result } // GetWithPosition will retrieve the value with the provided key and // return the position of that value within the list. Returns nil, 0 // if an associated value could not be found. func (sl *SkipList) GetWithPosition(cmp common.Comparator) (common.Comparator, uint64) { n, pos := sl.search(cmp, nil, nil) if n == nil { return nil, 0 } return n.entry, pos - 1 } // ByPosition returns the Comparator at the given position. func (sl *SkipList) ByPosition(position uint64) common.Comparator { n, _ := sl.searchByPosition(position+1, nil, nil) if n == nil { return nil } return n.entry } func (sl *SkipList) insert(cmp common.Comparator) common.Comparator { n, pos := sl.search(cmp, sl.cache, sl.posCache) return insertNode(sl, n, cmp, pos, sl.cache, sl.posCache, false) } // Insert will insert the provided comparators into the list. Returned // is a list of comparators that were overwritten. This is expected to // be an O(log n) operation. func (sl *SkipList) Insert(comparators ...common.Comparator) common.Comparators { overwritten := make(common.Comparators, 0, len(comparators)) for _, cmp := range comparators { overwritten = append(overwritten, sl.insert(cmp)) } return overwritten } func (sl *SkipList) insertAtPosition(position uint64, cmp common.Comparator) { if position > sl.Len() { position = sl.Len() } n, pos := sl.searchByPosition(position, sl.cache, sl.posCache) insertNode(sl, n, cmp, pos, sl.cache, sl.posCache, true) } // InsertAtPosition will insert the provided Comparator at the provided position. // If position is greater than the length of the skiplist, the Comparator // is appended. This method bypasses order checks and checks for // duplicates so use with caution. func (sl *SkipList) InsertAtPosition(position uint64, cmp common.Comparator) { sl.insertAtPosition(position, cmp) } func (sl *SkipList) replaceAtPosition(position uint64, cmp common.Comparator) { n, _ := sl.searchByPosition(position+1, nil, nil) if n == nil { return } n.entry = cmp } // Replace at position will replace the Comparator at the provided position // with the provided Comparator. If the provided position does not exist, // this operation is a no-op. func (sl *SkipList) ReplaceAtPosition(position uint64, cmp common.Comparator) { sl.replaceAtPosition(position, cmp) } func (sl *SkipList) delete(cmp common.Comparator) common.Comparator { n, _ := sl.search(cmp, sl.cache, sl.posCache) if n == nil || n.Compare(cmp) != 0 { return nil } atomic.AddUint64(&sl.num, ^uint64(0)) // decrement for i := uint8(0); i <= sl.level; i++ { if sl.cache[i].forward[i] != n { if sl.cache[i].forward[i] != nil { sl.cache[i].widths[i]-- } continue } sl.cache[i].widths[i] += n.widths[i] - 1 sl.cache[i].forward[i] = n.forward[i] } for sl.level > 1 && sl.head.forward[sl.level-1] == nil { sl.head.widths[sl.level] = 0 sl.level-- } return n.entry } // Delete will remove the provided keys from the skiplist and return // a list of in-order Comparators that were deleted. This is a no-op if // an associated key could not be found. This is an O(log n) operation. func (sl *SkipList) Delete(comparators ...common.Comparator) common.Comparators { deleted := make(common.Comparators, 0, len(comparators)) for _, cmp := range comparators { deleted = append(deleted, sl.delete(cmp)) } return deleted } // Len returns the number of items in this skiplist. func (sl *SkipList) Len() uint64 { return atomic.LoadUint64(&sl.num) } func (sl *SkipList) iterAtPosition(pos uint64) *iterator { n, _ := sl.searchByPosition(pos, nil, nil) if n == nil || n.entry == nil { return nilIterator() } return &iterator{ first: true, n: n, } } // IterAtPosition is the sister method to Iter only the user defines // a position in the skiplist to begin iteration instead of a value. func (sl *SkipList) IterAtPosition(pos uint64) Iterator { return sl.iterAtPosition(pos + 1) } func (sl *SkipList) iter(cmp common.Comparator) *iterator { n, _ := sl.search(cmp, nil, nil) if n == nil { return nilIterator() } return &iterator{ first: true, n: n, } } // Iter will return an iterator that can be used to iterate // over all the values with a key equal to or greater than // the key provided. func (sl *SkipList) Iter(cmp common.Comparator) Iterator { return sl.iter(cmp) } // SplitAt will split the current skiplist into two lists. The first // skiplist returned is the "left" list and the second is the "right." // The index defines the last item in the left list. If index is greater // then the length of this list, only the left skiplist is returned // and the right will be nil. This is a mutable operation and modifies // the content of this list. func (sl *SkipList) SplitAt(index uint64) (*SkipList, *SkipList) { index++ // 0-index offset if index >= sl.Len() { return sl, nil } return splitAt(sl, index) } // New will allocate, initialize, and return a new skiplist. // The provided parameter should be of type uint and will determine // the maximum possible level that will be created to ensure // a random and quick distribution of levels. Parameter must // be a uint type. func New(ifc interface{}) *SkipList { sl := &SkipList{} sl.init(ifc) return sl } ================================================ FILE: slice/skip/skip_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package skip import ( "math/rand" "testing" "github.com/stretchr/testify/assert" "github.com/Workiva/go-datastructures/common" ) func generateMockEntries(num int) common.Comparators { entries := make(common.Comparators, 0, num) for i := uint64(0); i < uint64(num); i++ { entries = append(entries, newMockEntry(i)) } return entries } func generateRandomMockEntries(num int) common.Comparators { entries := make(common.Comparators, 0, num) for i := 0; i < num; i++ { entries = append(entries, newMockEntry(uint64(rand.Int()))) } return entries } func TestInsertByPosition(t *testing.T) { m1 := newMockEntry(5) m2 := newMockEntry(6) m3 := newMockEntry(2) sl := New(uint8(0)) sl.InsertAtPosition(2, m1) sl.InsertAtPosition(0, m2) sl.InsertAtPosition(0, m3) assert.Equal(t, m3, sl.ByPosition(0)) assert.Equal(t, m2, sl.ByPosition(1)) assert.Equal(t, m1, sl.ByPosition(2)) assert.Nil(t, sl.ByPosition(3)) } func TestGetByPosition(t *testing.T) { m1 := newMockEntry(5) m2 := newMockEntry(6) sl := New(uint8(0)) sl.Insert(m1, m2) assert.Equal(t, m1, sl.ByPosition(0)) assert.Equal(t, m2, sl.ByPosition(1)) assert.Nil(t, sl.ByPosition(2)) } func TestSplitAt(t *testing.T) { m1 := newMockEntry(3) m2 := newMockEntry(5) m3 := newMockEntry(7) sl := New(uint8(0)) sl.Insert(m1, m2, m3) left, right := sl.SplitAt(1) assert.Equal(t, uint64(2), left.Len()) assert.Equal(t, uint64(1), right.Len()) assert.Equal(t, common.Comparators{m1, m2, nil}, left.Get(m1, m2, m3)) assert.Equal(t, common.Comparators{nil, nil, m3}, right.Get(m1, m2, m3)) assert.Equal(t, m1, left.ByPosition(0)) assert.Equal(t, m2, left.ByPosition(1)) assert.Equal(t, m3, right.ByPosition(0)) assert.Equal(t, nil, left.ByPosition(2)) assert.Equal(t, nil, right.ByPosition(1)) } func TestSplitLargeSkipList(t *testing.T) { entries := generateMockEntries(100) leftEntries := entries[:50] rightEntries := entries[50:] sl := New(uint64(0)) sl.Insert(entries...) left, right := sl.SplitAt(49) assert.Equal(t, uint64(50), left.Len()) for _, le := range leftEntries { result, index := left.GetWithPosition(le) assert.Equal(t, le, result) assert.Equal(t, le, left.ByPosition(index)) } assert.Equal(t, uint64(50), right.Len()) for _, re := range rightEntries { result, index := right.GetWithPosition(re) assert.Equal(t, re, result) assert.Equal(t, re, right.ByPosition(index)) } } func TestSplitLargeSkipListOddNumber(t *testing.T) { entries := generateMockEntries(99) leftEntries := entries[:50] rightEntries := entries[50:] sl := New(uint64(0)) sl.Insert(entries...) left, right := sl.SplitAt(49) assert.Equal(t, uint64(50), left.Len()) for _, le := range leftEntries { result, index := left.GetWithPosition(le) assert.Equal(t, le, result) assert.Equal(t, le, left.ByPosition(index)) } assert.Equal(t, uint64(49), right.Len()) for _, re := range rightEntries { result, index := right.GetWithPosition(re) assert.Equal(t, re, result) assert.Equal(t, re, right.ByPosition(index)) } } func TestSplitAtSkipListLength(t *testing.T) { entries := generateMockEntries(5) sl := New(uint64(0)) sl.Insert(entries...) left, right := sl.SplitAt(4) assert.Equal(t, sl, left) assert.Nil(t, right) } func TestGetWithPosition(t *testing.T) { m1 := newMockEntry(5) m2 := newMockEntry(6) sl := New(uint8(0)) sl.Insert(m1, m2) e, pos := sl.GetWithPosition(m1) assert.Equal(t, m1, e) assert.Equal(t, uint64(0), pos) e, pos = sl.GetWithPosition(m2) assert.Equal(t, m2, e) assert.Equal(t, uint64(1), pos) } func TestReplaceAtPosition(t *testing.T) { m1 := newMockEntry(5) m2 := newMockEntry(6) sl := New(uint8(0)) sl.Insert(m1, m2) m3 := newMockEntry(9) sl.ReplaceAtPosition(0, m3) assert.Equal(t, m3, sl.ByPosition(0)) assert.Equal(t, m2, sl.ByPosition(1)) } func TestInsertRandomGetByPosition(t *testing.T) { entries := generateRandomMockEntries(100) sl := New(uint64(0)) sl.Insert(entries...) for _, e := range entries { _, pos := sl.GetWithPosition(e) assert.Equal(t, e, sl.ByPosition(pos)) } } func TestGetManyByPosition(t *testing.T) { entries := generateMockEntries(10) sl := New(uint64(0)) sl.Insert(entries...) for i, e := range entries { assert.Equal(t, e, sl.ByPosition(uint64(i))) } } func TestGetPositionAfterDelete(t *testing.T) { m1 := newMockEntry(5) m2 := newMockEntry(6) sl := New(uint8(0)) sl.Insert(m1, m2) sl.Delete(m1) assert.Equal(t, m2, sl.ByPosition(0)) assert.Nil(t, sl.ByPosition(1)) sl.Delete(m2) assert.Nil(t, sl.ByPosition(0)) assert.Nil(t, sl.ByPosition(1)) } func TestGetPositionBulkDelete(t *testing.T) { es := generateMockEntries(20) e1 := es[:10] e2 := es[10:] sl := New(uint64(0)) sl.Insert(e1...) sl.Insert(e2...) for _, e := range e1 { sl.Delete(e) } for i, e := range e2 { assert.Equal(t, e, sl.ByPosition(uint64(i))) } } func TestSimpleInsert(t *testing.T) { m1 := newMockEntry(5) m2 := newMockEntry(6) sl := New(uint8(0)) overwritten := sl.Insert(m1) assert.Equal(t, common.Comparators{m1}, sl.Get(m1)) assert.Equal(t, uint64(1), sl.Len()) assert.Equal(t, common.Comparators{nil}, overwritten) assert.Equal(t, common.Comparators{nil}, sl.Get(mockEntry(1))) overwritten = sl.Insert(m2) assert.Equal(t, common.Comparators{m2}, sl.Get(m2)) assert.Equal(t, common.Comparators{nil}, sl.Get(mockEntry(7))) assert.Equal(t, uint64(2), sl.Len()) assert.Equal(t, common.Comparators{nil}, overwritten) } func TestSimpleOverwrite(t *testing.T) { m1 := newMockEntry(5) m2 := newMockEntry(5) sl := New(uint8(0)) overwritten := sl.Insert(m1) assert.Equal(t, common.Comparators{nil}, overwritten) assert.Equal(t, uint64(1), sl.Len()) overwritten = sl.Insert(m2) assert.Equal(t, common.Comparators{m1}, overwritten) assert.Equal(t, uint64(1), sl.Len()) } func TestInsertOutOfOrder(t *testing.T) { m1 := newMockEntry(6) m2 := newMockEntry(5) sl := New(uint8(0)) overwritten := sl.Insert(m1, m2) assert.Equal(t, common.Comparators{nil, nil}, overwritten) assert.Equal(t, common.Comparators{m1, m2}, sl.Get(m1, m2)) } func TestSimpleDelete(t *testing.T) { m1 := newMockEntry(5) sl := New(uint8(0)) sl.Insert(m1) deleted := sl.Delete(m1) assert.Equal(t, common.Comparators{m1}, deleted) assert.Equal(t, uint64(0), sl.Len()) assert.Equal(t, common.Comparators{nil}, sl.Get(m1)) deleted = sl.Delete(m1) assert.Equal(t, common.Comparators{nil}, deleted) } func TestDeleteAll(t *testing.T) { m1 := newMockEntry(5) m2 := newMockEntry(6) sl := New(uint8(0)) sl.Insert(m1, m2) deleted := sl.Delete(m1, m2) assert.Equal(t, common.Comparators{m1, m2}, deleted) assert.Equal(t, uint64(0), sl.Len()) assert.Equal(t, common.Comparators{nil, nil}, sl.Get(m1, m2)) } func TestIter(t *testing.T) { sl := New(uint8(0)) m1 := newMockEntry(5) m2 := newMockEntry(10) sl.Insert(m1, m2) iter := sl.Iter(mockEntry(0)) assert.Equal(t, common.Comparators{m1, m2}, iter.exhaust()) iter = sl.Iter(mockEntry(5)) assert.Equal(t, common.Comparators{m1, m2}, iter.exhaust()) iter = sl.Iter(mockEntry(6)) assert.Equal(t, common.Comparators{m2}, iter.exhaust()) iter = sl.Iter(mockEntry(10)) assert.Equal(t, common.Comparators{m2}, iter.exhaust()) iter = sl.Iter(mockEntry(11)) assert.Equal(t, common.Comparators{}, iter.exhaust()) } func TestIterAtPosition(t *testing.T) { sl := New(uint8(0)) m1 := newMockEntry(5) m2 := newMockEntry(10) sl.Insert(m1, m2) iter := sl.IterAtPosition(0) assert.Equal(t, common.Comparators{m1, m2}, iter.exhaust()) iter = sl.IterAtPosition(1) assert.Equal(t, common.Comparators{m2}, iter.exhaust()) iter = sl.IterAtPosition(2) assert.Equal(t, common.Comparators{}, iter.exhaust()) } func BenchmarkInsert(b *testing.B) { numItems := b.N sl := New(uint64(0)) entries := generateMockEntries(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { sl.Insert(entries[i%numItems]) } } func BenchmarkGet(b *testing.B) { numItems := b.N sl := New(uint64(0)) entries := generateMockEntries(numItems) sl.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { sl.Get(entries[i%numItems]) } } func BenchmarkDelete(b *testing.B) { numItems := b.N sl := New(uint64(0)) entries := generateMockEntries(numItems) sl.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { sl.Delete(entries[i]) } } func BenchmarkPrepend(b *testing.B) { numItems := b.N sl := New(uint64(0)) entries := make(common.Comparators, 0, numItems) for i := b.N; i < b.N+numItems; i++ { entries = append(entries, newMockEntry(uint64(i))) } sl.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { sl.Insert(newMockEntry(uint64(i))) } } func BenchmarkByPosition(b *testing.B) { numItems := b.N sl := New(uint64(0)) entries := generateMockEntries(numItems) sl.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { sl.ByPosition(uint64(i % numItems)) } } func BenchmarkInsertAtPosition(b *testing.B) { numItems := b.N sl := New(uint64(0)) entries := generateRandomMockEntries(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { sl.InsertAtPosition(0, entries[i%numItems]) } } ================================================ FILE: sort/interface.go ================================================ package merge // Comparators defines a typed list of type Comparator. type Comparators []Comparator // Less returns a bool indicating if the comparator at index i // is less than the comparator at index j. func (c Comparators) Less(i, j int) bool { return c[i].Compare(c[j]) < 0 } // Len returns an int indicating the length of this list // of comparators. func (c Comparators) Len() int { return len(c) } // Swap swaps the values at positions i and j. func (c Comparators) Swap(i, j int) { c[j], c[i] = c[i], c[j] } // Comparator defines items that can be sorted. It contains // a single method allowing the compare logic to compare one // comparator to another. type Comparator interface { // Compare will return a value indicating how this comparator // compares with the provided comparator. A negative number // indicates this comparator is less than the provided comparator, // a 0 indicates equality, and a positive number indicates this // comparator is greater than the provided comparator. Compare(Comparator) int } ================================================ FILE: sort/sort.go ================================================ package merge import ( "runtime" "sort" "sync" ) func sortBucket(comparators Comparators) { sort.Sort(comparators) } func copyChunk(chunk []Comparators) []Comparators { cp := make([]Comparators, len(chunk)) copy(cp, chunk) return cp } // MultithreadedSortComparators will take a list of comparators // and sort it using as many threads as are available. The list // is split into buckets for a bucket sort and then recursively // merged using SymMerge. func MultithreadedSortComparators(comparators Comparators) Comparators { toBeSorted := make(Comparators, len(comparators)) copy(toBeSorted, comparators) var wg sync.WaitGroup numCPU := int64(runtime.NumCPU()) if numCPU == 1 { // single core machine numCPU = 2 } else { // otherwise this algo only works with a power of two numCPU = int64(prevPowerOfTwo(uint64(numCPU))) } chunks := chunk(toBeSorted, numCPU) wg.Add(len(chunks)) for i := 0; i < len(chunks); i++ { go func(i int) { sortBucket(chunks[i]) wg.Done() }(i) } wg.Wait() todo := make([]Comparators, len(chunks)/2) for { todo = todo[:len(chunks)/2] wg.Add(len(chunks) / 2) for i := 0; i < len(chunks); i += 2 { go func(i int) { todo[i/2] = SymMerge(chunks[i], chunks[i+1]) wg.Done() }(i) } wg.Wait() chunks = copyChunk(todo) if len(chunks) == 1 { break } } return chunks[0] } func chunk(comparators Comparators, numParts int64) []Comparators { parts := make([]Comparators, numParts) for i := int64(0); i < numParts; i++ { parts[i] = comparators[i*int64(len(comparators))/numParts : (i+1)*int64(len(comparators))/numParts] } return parts } func prevPowerOfTwo(x uint64) uint64 { x = x | (x >> 1) x = x | (x >> 2) x = x | (x >> 4) x = x | (x >> 8) x = x | (x >> 16) x = x | (x >> 32) return x - (x >> 1) } ================================================ FILE: sort/sort_test.go ================================================ package merge import ( "testing" "github.com/stretchr/testify/assert" ) func TestMultiThreadedSortEvenNumber(t *testing.T) { comparators := constructOrderedMockComparators(10) comparators = reverseComparators(comparators) result := MultithreadedSortComparators(comparators) comparators = reverseComparators(comparators) assert.Equal(t, comparators, result) } func TestMultiThreadedSortOddNumber(t *testing.T) { comparators := constructOrderedMockComparators(9) comparators = reverseComparators(comparators) result := MultithreadedSortComparators(comparators) comparators = reverseComparators(comparators) assert.Equal(t, comparators, result) } func BenchmarkMultiThreadedSort(b *testing.B) { numCells := 100000 comparators := constructOrderedMockComparators(numCells) comparators = reverseComparators(comparators) b.ResetTimer() for i := 0; i < b.N; i++ { MultithreadedSortComparators(comparators) } } ================================================ FILE: sort/symmerge.go ================================================ package merge import ( "math" "sync" ) // symSearch is like symBinarySearch but operates // on two sorted lists instead of a sorted list and an index. // It's duplication of code but you buy performance. func symSearch(u, w Comparators) int { start, stop, p := 0, len(u), len(w)-1 for start < stop { mid := (start + stop) / 2 if u[mid].Compare(w[p-mid]) <= 0 { start = mid + 1 } else { stop = mid } } return start } // swap will swap positions of the two lists from index // to the end of the list. It expects that these lists // are the same size or one different. func swap(u, w Comparators, index int) { for i := index; i < len(u); i++ { u[i], w[i-index] = w[i-index], u[i] } } // decomposeForSymMerge pulls an active site out of the list // of length in size. W becomes the active site for future sym // merges and v1, v2 are decomposed and split among the other // list to be merged and w. func decomposeForSymMerge(length int, comparators Comparators) (v1 Comparators, w Comparators, v2 Comparators) { if length >= len(comparators) { panic(`INCORRECT PARAMS FOR SYM MERGE.`) } overhang := (len(comparators) - length) / 2 v1 = comparators[:overhang] w = comparators[overhang : overhang+length] v2 = comparators[overhang+length:] return } // symBinarySearch will perform a binary search between the provided // indices and find the index at which a rotation should occur. func symBinarySearch(u Comparators, start, stop, total int) int { for start < stop { mid := (start + stop) / 2 if u[mid].Compare(u[total-mid]) <= 0 { start = mid + 1 } else { stop = mid } } return start } // symSwap will perform a rotation or swap between the provided // indices. Again, there is duplication here with swap, but // we are buying performance. func symSwap(u Comparators, start1, start2, end int) { for i := 0; i < end; i++ { u[start1+i], u[start2+i] = u[start2+i], u[start1+i] } } // symRotate determines the indices to use in a symSwap and // performs the swap. func symRotate(u Comparators, start1, start2, end int) { i := start2 - start1 if i == 0 { return } j := end - start2 if j == 0 { return } if i == j { symSwap(u, start1, start2, i) return } p := start1 + i for i != j { if i > j { symSwap(u, p-i, p, j) i -= j } else { symSwap(u, p-i, p+j-i, i) j -= i } } symSwap(u, p-i, p, i) } // symMerge is the recursive and internal form of SymMerge. func symMerge(u Comparators, start1, start2, last int) { if start1 < start2 && start2 < last { mid := (start1 + last) / 2 n := mid + start2 var start int if start2 > mid { start = symBinarySearch(u, n-last, mid, n-1) } else { start = symBinarySearch(u, start1, start2, n-1) } end := n - start symRotate(u, start, start2, end) symMerge(u, start1, start, mid) symMerge(u, mid, end, last) } } // SymMerge will perform a symmetrical merge of the two provided // lists. It is expected that these lists are pre-sorted. Failure // to do so will result in undefined behavior. This function does // make use of goroutines, so multithreading can aid merge time. // This makes M*log(N/M+1) comparisons where M is the length // of the shorter list and N is the length of the longer list. func SymMerge(u, w Comparators) Comparators { lenU, lenW := len(u), len(w) if lenU == 0 { return w } if lenW == 0 { return u } diff := lenU - lenW if math.Abs(float64(diff)) > 1 { u1, w1, u2, w2 := prepareForSymMerge(u, w) lenU1 := len(u1) lenU2 := len(u2) u = append(u1, w1...) w = append(u2, w2...) var wg sync.WaitGroup wg.Add(2) go func() { symMerge(u, 0, lenU1, len(u)) wg.Done() }() go func() { symMerge(w, 0, lenU2, len(w)) wg.Done() }() wg.Wait() u = append(u, w...) return u } u = append(u, w...) symMerge(u, 0, lenU, len(u)) return u } // prepareForSymMerge performs a symmetrical decomposition on two // lists of different sizes. It breaks apart the longer list into // an active site (equal to the size of the shorter list) and performs // a symmetrical rotation with the active site and the shorter list. // The two stubs are then split between the active site and shorter list // ensuring two equally sized lists where every value in u' is less // than w'. func prepareForSymMerge(u, w Comparators) (u1, w1, u2, w2 Comparators) { if u.Len() > w.Len() { u, w = w, u } v1, w, v2 := decomposeForSymMerge(len(u), w) i := symSearch(u, w) u1 = make(Comparators, i) copy(u1, u[:i]) w1 = append(v1, w[:len(w)-i]...) u2 = make(Comparators, len(u)-i) copy(u2, u[i:]) w2 = append(w[len(w)-i:], v2...) return } ================================================ FILE: sort/symmerge_test.go ================================================ package merge import ( "testing" "github.com/stretchr/testify/assert" ) type mockComparator int func (mc mockComparator) Compare(other Comparator) int { if mc == other.(mockComparator) { return 0 } if mc > other.(mockComparator) { return 1 } return -1 } func constructMockComparators(values ...int) Comparators { comparators := make(Comparators, 0, len(values)) for _, v := range values { comparators = append(comparators, mockComparator(v)) } return comparators } func constructOrderedMockComparators(upTo int) Comparators { comparators := make(Comparators, 0, upTo) for i := 0; i < upTo; i++ { comparators = append(comparators, mockComparator(i)) } return comparators } func reverseComparators(comparators Comparators) Comparators { for i := 0; i < len(comparators); i++ { li := len(comparators) - i - 1 comparators[i], comparators[li] = comparators[li], comparators[i] } return comparators } func TestDecomposeForSymMergeOddNumber(t *testing.T) { comparators := constructOrderedMockComparators(7) v1, w, v2 := decomposeForSymMerge(3, comparators) assert.Equal(t, comparators[:2], v1) assert.Equal(t, comparators[2:5], w) assert.Equal(t, comparators[5:], v2) } func TestDecomposeForSymMergeEvenNumber(t *testing.T) { comparators := constructOrderedMockComparators(8) v1, w, v2 := decomposeForSymMerge(5, comparators) assert.Equal(t, comparators[:1], v1) assert.Equal(t, comparators[1:6], w) assert.Equal(t, comparators[6:], v2) } func TestNearCompleteDecomposeForSymMerge(t *testing.T) { comparators := constructOrderedMockComparators(8) v1, w, v2 := decomposeForSymMerge(7, comparators) assert.Len(t, v1, 0) assert.Equal(t, comparators[:7], w) assert.Equal(t, comparators[7:], v2) } func TestDecomposePanicsWithWrongLength(t *testing.T) { comparators := constructOrderedMockComparators(8) assert.Panics(t, func() { decomposeForSymMerge(8, comparators) }) } func TestSymSearch(t *testing.T) { u := constructMockComparators(1, 5, 7, 9) w := constructMockComparators(1, 3, 9, 10) result := symSearch(u, w) assert.Equal(t, 2, result) u = constructMockComparators(1, 5, 7) w = constructMockComparators(1, 3, 9) result = symSearch(u, w) assert.Equal(t, 1, result) } func TestSwap(t *testing.T) { u := constructMockComparators(1, 5, 7, 9) w := constructMockComparators(2, 8, 11, 13) u1 := constructMockComparators(1, 5, 2, 8) w1 := constructMockComparators(7, 9, 11, 13) swap(u, w, 2) assert.Equal(t, u1, u) assert.Equal(t, w1, w) } func TestSymMergeSmallLists(t *testing.T) { u := constructMockComparators(1, 5) w := constructMockComparators(2, 8) expected := constructMockComparators(1, 2, 5, 8) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestSymMergeAlreadySorted(t *testing.T) { u := constructMockComparators(1, 5) w := constructMockComparators(6, 7) expected := constructMockComparators(1, 5, 6, 7) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestSymMergeAlreadySortedReverseOrder(t *testing.T) { u := constructMockComparators(6, 7) w := constructMockComparators(1, 5) expected := constructMockComparators(1, 5, 6, 7) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestSymMergeUnevenLists(t *testing.T) { u := constructMockComparators(1, 3, 7) w := constructMockComparators(2, 4) expected := constructMockComparators(1, 2, 3, 4, 7) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestSymMergeUnevenListsWrongOrder(t *testing.T) { u := constructMockComparators(2, 4) w := constructMockComparators(1, 3, 7) expected := constructMockComparators(1, 2, 3, 4, 7) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestMergeVeryUnevenLists(t *testing.T) { u := constructMockComparators(1, 3, 7, 12, 15) w := constructMockComparators(2, 4) expected := constructMockComparators(1, 2, 3, 4, 7, 12, 15) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestMergeVeryUnevenListsWrongOrder(t *testing.T) { u := constructMockComparators(2, 4) w := constructMockComparators(1, 3, 7, 12, 15) expected := constructMockComparators(1, 2, 3, 4, 7, 12, 15) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestMergeVeryUnevenListsAlreadySorted(t *testing.T) { u := constructMockComparators(2, 4) w := constructMockComparators(5, 7, 9, 10, 11, 12) expected := constructMockComparators(2, 4, 5, 7, 9, 10, 11, 12) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestMergeVeryUnevenListsAlreadySortedWrongOrder(t *testing.T) { w := constructMockComparators(2, 4) u := constructMockComparators(5, 7, 9, 10, 11, 12) expected := constructMockComparators(2, 4, 5, 7, 9, 10, 11, 12) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestMergeVeryUnevenListIsSubset(t *testing.T) { u := constructMockComparators(2, 4) w := constructMockComparators(1, 3, 5, 7, 9) expected := constructMockComparators(1, 2, 3, 4, 5, 7, 9) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestMergeVeryUnevenListIsSubsetReverseOrder(t *testing.T) { w := constructMockComparators(2, 4) u := constructMockComparators(1, 3, 5, 7, 9) expected := constructMockComparators(1, 2, 3, 4, 5, 7, 9) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestMergeUnevenOneListIsOne(t *testing.T) { u := constructMockComparators(1) w := constructMockComparators(0, 3, 5, 7, 9) expected := constructMockComparators(0, 1, 3, 5, 7, 9) u = SymMerge(u, w) assert.Equal(t, expected, u) } func TestMergeEmptyList(t *testing.T) { u := constructMockComparators(1, 3, 5) expected := constructMockComparators(1, 3, 5) u = SymMerge(u, nil) assert.Equal(t, expected, u) } ================================================ FILE: threadsafe/err/error.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package err implements a threadsafe error interface. In my places, I found myself needing a lock to protect writing to a common error interface from multiple go routines (channels are great but slow). This just makes that process more convenient. */ package err import "sync" // Error is a struct that holds an error and allows this error // to be set and retrieved in a threadsafe manner. type Error struct { lock sync.RWMutex err error } // Set will set the error of this structure to the provided // value. func (e *Error) Set(err error) { e.lock.Lock() defer e.lock.Unlock() e.err = err } // Get will return any error associated with this structure. func (e *Error) Get() error { e.lock.RLock() defer e.lock.RUnlock() return e.err } // New is a constructor to generate a new error object // that can be set and retrieved in a threadsafe manner. func New() *Error { return &Error{} } ================================================ FILE: threadsafe/err/error_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package err import ( "fmt" "testing" "github.com/stretchr/testify/assert" ) func TestGetSetError(t *testing.T) { e := New() assert.Nil(t, e.Get()) err := fmt.Errorf(`test`) e.Set(err) assert.Equal(t, err, e.Get()) } ================================================ FILE: tree/avl/avl.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package avl includes an immutable AVL tree. AVL trees can be used as the foundation for many functional data types. Combined with a B+ tree, you can make an immutable index which serves as the backbone for many different kinds of key/value stores. Time complexities: Space: O(n) Insert: O(log n) Delete: O(log n) Get: O(log n) The immutable version of the AVL tree is obviously going to be slower than the mutable version but should offer higher read availability. */ package avl import "math" // Immutable represents an immutable AVL tree. This is achieved // by branch copying. type Immutable struct { root *node number uint64 dummy node // helper for inserts. } // copy returns a copy of this immutable tree with a copy // of the root and a new dummy helper for the insertion operation. func (immutable *Immutable) copy() *Immutable { var root *node if immutable.root != nil { root = immutable.root.copy() } cp := &Immutable{ root: root, number: immutable.number, dummy: *newNode(nil), } return cp } func (immutable *Immutable) resetDummy() { immutable.dummy.children[0], immutable.dummy.children[1] = nil, nil immutable.dummy.balance = 0 } func (immutable *Immutable) init() { immutable.dummy = node{ children: [2]*node{}, } } func (immutable *Immutable) get(entry Entry) Entry { n := immutable.root var result int for n != nil { switch result = n.entry.Compare(entry); { case result == 0: return n.entry case result > 0: n = n.children[0] case result < 0: n = n.children[1] } } return nil } // Get will get the provided Entries from the tree. If no matching // Entry is found, a nil is returned in its place. func (immutable *Immutable) Get(entries ...Entry) Entries { result := make(Entries, 0, len(entries)) for _, e := range entries { result = append(result, immutable.get(e)) } return result } // Len returns the number of items in this immutable. func (immutable *Immutable) Len() uint64 { return immutable.number } func (immutable *Immutable) insert(entry Entry) Entry { // TODO: check cache to see if a node has already been copied. if immutable.root == nil { immutable.root = newNode(entry) immutable.number++ return nil } immutable.resetDummy() var ( dummy = immutable.dummy p, s, q *node dir, normalized int helper = &dummy ) // set this AFTER clearing dummy helper.children[1] = immutable.root // we'll go ahead and copy on the way down as we'll need to branch // copy no matter what. for s, p = helper.children[1], helper.children[1]; ; { dir = p.entry.Compare(entry) normalized = normalizeComparison(dir) if dir > 0 { // go left if p.children[0] != nil { q = p.children[0].copy() p.children[0] = q } else { q = nil } } else if dir < 0 { // go right if p.children[1] != nil { q = p.children[1].copy() p.children[1] = q } else { q = nil } } else { // equality oldEntry := p.entry p.entry = entry return oldEntry } if q == nil { break } if q.balance != 0 { helper = p s = q } p = q } immutable.number++ q = newNode(entry) p.children[normalized] = q immutable.root = dummy.children[1] for p = s; p != q; p = p.children[normalized] { normalized = normalizeComparison(p.entry.Compare(entry)) if normalized == 0 { p.balance-- } else { p.balance++ } } q = s if math.Abs(float64(s.balance)) > 1 { normalized = normalizeComparison(s.entry.Compare(entry)) s = insertBalance(s, normalized) } if q == dummy.children[1] { immutable.root = s } else { helper.children[intFromBool(helper.children[1] == q)] = s } return nil } // Insert will add the provided entries into the tree and return the new // state. Also returned is a list of Entries that were overwritten. If // nothing was overwritten for an Entry, a nil is returned in its place. func (immutable *Immutable) Insert(entries ...Entry) (*Immutable, Entries) { if len(entries) == 0 { return immutable, Entries{} } overwritten := make(Entries, 0, len(entries)) cp := immutable.copy() for _, e := range entries { overwritten = append(overwritten, cp.insert(e)) } return cp, overwritten } func (immutable *Immutable) delete(entry Entry) Entry { // TODO: reuse cache and dirs, check cache to see if nodes // really need to be copied. if immutable.root == nil { // easy case, nothing to remove return nil } var ( // we are going to make a list here representing our stack. // This means we don't have to copy if a value wasn't found. cache = make(nodes, 64) it, p, q *node top, done, dir, normalized int dirs = make([]int, 64) oldEntry Entry ) it = immutable.root for { if it == nil { return nil } dir = it.entry.Compare(entry) if dir == 0 { break } normalized = normalizeComparison(dir) dirs[top] = normalized cache[top] = it top++ it = it.children[normalized] } immutable.number-- oldEntry = it.entry // we need to return this // we need to make a branch copy now for i := 0; i < top; i++ { // first item will be root p = cache[i] if p.children[dirs[i]] != nil { q = p.children[dirs[i]].copy() p.children[dirs[i]] = q if i != top-1 { cache[i+1] = q } } } it = it.copy() // the node we found needs to be copied oldTop := top if it.children[0] == nil || it.children[1] == nil { // need to set children on parent, splicing out dir = intFromBool(it.children[0] == nil) if top != 0 { cache[top-1].children[dirs[top-1]] = it.children[dir] } else { immutable.root = it.children[dir] } } else { // climb up and set heirs heir := it.children[1] dirs[top] = 1 cache[top] = it top++ for heir.children[0] != nil { dirs[top] = 0 cache[top] = heir top++ heir = heir.children[0] } it.entry = heir.entry if oldTop != 0 { cache[oldTop-1].children[dirs[oldTop-1]] = it } else { immutable.root = it } cache[top-1].children[intFromBool(cache[top-1] == it)] = heir.children[1] } for top-1 >= 0 && done == 0 { top-- // set bounded balance if dirs[top] != 0 { cache[top].balance-- } else { cache[top].balance++ } if math.Abs(float64(cache[top].balance)) == 1 { break } else if math.Abs(float64(cache[top].balance)) > 1 { // any rotations done here cache[top] = removeBalance(cache[top], dirs[top], &done) if top != 0 { cache[top-1].children[dirs[top-1]] = cache[top] } else { immutable.root = cache[0] } } } return oldEntry } // Delete will remove the provided entries from this AVL tree and // return a new tree and any entries removed. If an entry could not // be found, nil is returned in its place. func (immutable *Immutable) Delete(entries ...Entry) (*Immutable, Entries) { if len(entries) == 0 { return immutable, Entries{} } deleted := make(Entries, 0, len(entries)) cp := immutable.copy() for _, e := range entries { deleted = append(deleted, cp.delete(e)) } return cp, deleted } func insertBalance(root *node, dir int) *node { n := root.children[dir] var bal int8 if dir == 0 { bal = -1 } else { bal = 1 } if n.balance == bal { root.balance, n.balance = 0, 0 root = rotate(root, takeOpposite(dir)) } else { adjustBalance(root, dir, int(bal)) root = doubleRotate(root, takeOpposite(dir)) } return root } func removeBalance(root *node, dir int, done *int) *node { n := root.children[takeOpposite(dir)].copy() root.children[takeOpposite(dir)] = n var bal int8 if dir == 0 { bal = -1 } else { bal = 1 } if n.balance == -bal { root.balance, n.balance = 0, 0 root = rotate(root, dir) } else if n.balance == bal { adjustBalance(root, takeOpposite(dir), int(-bal)) root = doubleRotate(root, dir) } else { root.balance = -bal n.balance = bal root = rotate(root, dir) *done = 1 } return root } func intFromBool(value bool) int { if value { return 1 } return 0 } func takeOpposite(value int) int { return 1 - value } func adjustBalance(root *node, dir, bal int) { n := root.children[dir] nn := n.children[takeOpposite(dir)] if nn.balance == 0 { root.balance, n.balance = 0, 0 } else if int(nn.balance) == bal { root.balance = int8(-bal) n.balance = 0 } else { root.balance = 0 n.balance = int8(bal) } nn.balance = 0 } func rotate(parent *node, dir int) *node { otherDir := takeOpposite(dir) child := parent.children[otherDir] parent.children[otherDir] = child.children[dir] child.children[dir] = parent return child } func doubleRotate(parent *node, dir int) *node { otherDir := takeOpposite(dir) parent.children[otherDir] = rotate(parent.children[otherDir], otherDir) return rotate(parent, dir) } // normalizeComparison converts the value returned from Entry.Compare // into a direction, ie, left or right, 0 or 1. func normalizeComparison(i int) int { if i < 0 { return 1 } if i > 0 { return 0 } return -1 } // NewImmutable allocates, initializes, and returns a new immutable // AVL tree. func NewImmutable() *Immutable { immutable := &Immutable{} immutable.init() return immutable } ================================================ FILE: tree/avl/avl_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package avl import ( "testing" "github.com/stretchr/testify/assert" ) func generateMockEntries(num int) Entries { entries := make(Entries, 0, num) for i := 0; i < num; i++ { entries = append(entries, mockEntry(i)) } return entries } func TestAVLSimpleInsert(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(5) m2 := mockEntry(10) i2, overwritten := i1.Insert(m1, m2) assert.Equal(t, Entries{nil, nil}, overwritten) assert.Equal(t, uint64(2), i2.Len()) assert.Equal(t, uint64(0), i1.Len()) assert.Equal(t, Entries{nil, nil}, i1.Get(m1, m2)) assert.Equal(t, Entries{m1, m2}, i2.Get(m1, m2)) m3 := mockEntry(1) i3, overwritten := i2.Insert(m3) assert.Equal(t, Entries{nil}, overwritten) assert.Equal(t, uint64(3), i3.Len()) assert.Equal(t, uint64(2), i2.Len()) assert.Equal(t, uint64(0), i1.Len()) assert.Equal(t, Entries{m1, m2, m3}, i3.Get(m1, m2, m3)) } func TestAVLInsertRightLeaning(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(1) m2 := mockEntry(5) m3 := mockEntry(10) i2, overwritten := i1.Insert(m1, m2, m3) assert.Equal(t, Entries{nil, nil, nil}, overwritten) assert.Equal(t, uint64(0), i1.Len()) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, Entries{m1, m2, m3}, i2.Get(m1, m2, m3)) assert.Equal(t, Entries{nil, nil, nil}, i1.Get(m1, m2, m3)) m4 := mockEntry(15) m5 := mockEntry(20) i3, overwritten := i2.Insert(m4, m5) assert.Equal(t, Entries{nil, nil}, overwritten) assert.Equal(t, uint64(5), i3.Len()) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, Entries{nil, nil}, i2.Get(m4, m5)) assert.Equal(t, Entries{m4, m5}, i3.Get(m4, m5)) } func TestAVLInsertRightLeaningDoubleRotation(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(1) m2 := mockEntry(10) m3 := mockEntry(5) i2, overwritten := i1.Insert(m1, m2, m3) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, Entries{nil, nil, nil}, overwritten) assert.Equal(t, Entries{nil, nil, nil}, i1.Get(m1, m2, m3)) assert.Equal(t, Entries{m1, m2, m3}, i2.Get(m1, m2, m3)) m4 := mockEntry(20) m5 := mockEntry(15) i3, overwritten := i2.Insert(m4, m5) assert.Equal(t, Entries{nil, nil}, overwritten) assert.Equal(t, uint64(5), i3.Len()) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, Entries{nil, nil}, i2.Get(m4, m5)) assert.Equal(t, Entries{m4, m5}, i3.Get(m4, m5)) } func TestAVLInsertLeftLeaning(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(20) m2 := mockEntry(15) m3 := mockEntry(10) i2, overwritten := i1.Insert(m1, m2, m3) assert.Equal(t, Entries{nil, nil, nil}, overwritten) assert.Equal(t, uint64(0), i1.Len()) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, Entries{nil, nil, nil}, i1.Get(m1, m2, m3)) assert.Equal(t, Entries{m1, m2, m3}, i2.Get(m1, m2, m3)) m4 := mockEntry(5) m5 := mockEntry(1) i3, overwritten := i2.Insert(m4, m5) assert.Equal(t, Entries{nil, nil}, overwritten) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, uint64(5), i3.Len()) assert.Equal(t, Entries{nil, nil}, i2.Get(m4, m5)) assert.Equal(t, Entries{m4, m5}, i3.Get(m4, m5)) } func TestAVLInsertLeftLeaningDoubleRotation(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(20) m2 := mockEntry(10) m3 := mockEntry(15) i2, overwritten := i1.Insert(m1, m2, m3) assert.Equal(t, Entries{nil, nil, nil}, overwritten) assert.Equal(t, uint64(0), i1.Len()) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, Entries{nil, nil, nil}, i1.Get(m1, m2, m3)) assert.Equal(t, Entries{m1, m2, m3}, i2.Get(m1, m2, m3)) m4 := mockEntry(1) m5 := mockEntry(5) i3, overwritten := i2.Insert(m4, m5) assert.Equal(t, Entries{nil, nil}, overwritten) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, uint64(5), i3.Len()) assert.Equal(t, Entries{nil, nil}, i2.Get(m4, m5)) assert.Equal(t, Entries{m4, m5}, i3.Get(m4, m5)) assert.Equal(t, Entries{m1, m2, m3}, i3.Get(m1, m2, m3)) } func TestAVLInsertOverwrite(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(20) m2 := mockEntry(10) m3 := mockEntry(15) i2, _ := i1.Insert(m1, m2, m3) m4 := mockEntry(15) i3, overwritten := i2.Insert(m4) assert.Equal(t, Entries{m3}, overwritten) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, uint64(3), i3.Len()) assert.Equal(t, Entries{m4}, i3.Get(m4)) assert.Equal(t, Entries{m3}, i2.Get(m3)) } func TestAVLSimpleDelete(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(10) m2 := mockEntry(15) m3 := mockEntry(20) i2, _ := i1.Insert(m1, m2, m3) i3, deleted := i2.Delete(m3) assert.Equal(t, Entries{m3}, deleted) assert.Equal(t, uint64(3), i2.Len()) assert.Equal(t, uint64(2), i3.Len()) assert.Equal(t, Entries{m1, m2, m3}, i2.Get(m1, m2, m3)) assert.Equal(t, Entries{m1, m2, nil}, i3.Get(m1, m2, m3)) i4, deleted := i3.Delete(m2) assert.Equal(t, Entries{m2}, deleted) assert.Equal(t, uint64(2), i3.Len()) assert.Equal(t, uint64(1), i4.Len()) assert.Equal(t, Entries{m1, m2, nil}, i3.Get(m1, m2, m3)) assert.Equal(t, Entries{m1, nil, nil}, i4.Get(m1, m2, m3)) i5, deleted := i4.Delete(m1) assert.Equal(t, Entries{m1}, deleted) assert.Equal(t, uint64(0), i5.Len()) assert.Equal(t, uint64(1), i4.Len()) assert.Equal(t, Entries{m1, nil, nil}, i4.Get(m1, m2, m3)) assert.Equal(t, Entries{nil, nil, nil}, i5.Get(m1, m2, m3)) } func TestAVLDeleteWithRotation(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(1) m2 := mockEntry(5) m3 := mockEntry(10) m4 := mockEntry(15) m5 := mockEntry(20) i2, _ := i1.Insert(m1, m2, m3, m4, m5) assert.Equal(t, uint64(5), i2.Len()) i3, deleted := i2.Delete(m1) assert.Equal(t, uint64(4), i3.Len()) assert.Equal(t, Entries{m1}, deleted) assert.Equal(t, Entries{m1, m2, m3, m4, m5}, i2.Get(m1, m2, m3, m4, m5)) assert.Equal(t, Entries{nil, m2, m3, m4, m5}, i3.Get(m1, m2, m3, m4, m5)) } func TestAVLDeleteWithDoubleRotation(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(1) m2 := mockEntry(5) m3 := mockEntry(10) m4 := mockEntry(15) i2, _ := i1.Insert(m2, m1, m3, m4) assert.Equal(t, uint64(4), i2.Len()) i3, deleted := i2.Delete(m1) assert.Equal(t, Entries{m1}, deleted) assert.Equal(t, uint64(3), i3.Len()) assert.Equal(t, Entries{m1, m2, m3, m4}, i2.Get(m1, m2, m3, m4)) assert.Equal(t, Entries{nil, m2, m3, m4}, i3.Get(m1, m2, m3, m4)) } func TestAVLDeleteAll(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(1) m2 := mockEntry(5) m3 := mockEntry(10) m4 := mockEntry(15) i2, _ := i1.Insert(m2, m1, m3, m4) assert.Equal(t, uint64(4), i2.Len()) i3, deleted := i2.Delete(m1, m2, m3, m4) assert.Equal(t, Entries{m1, m2, m3, m4}, deleted) assert.Equal(t, uint64(0), i3.Len()) assert.Equal(t, Entries{nil, nil, nil, nil}, i3.Get(m1, m2, m3, m4)) assert.Equal(t, Entries{m1, m2, m3, m4}, i2.Get(m1, m2, m3, m4)) } func TestAVLDeleteNotLeaf(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(1) m2 := mockEntry(5) m3 := mockEntry(10) m4 := mockEntry(15) i2, _ := i1.Insert(m2, m1, m3, m4) i3, deleted := i2.Delete(m3) assert.Equal(t, Entries{m3}, deleted) assert.Equal(t, uint64(3), i3.Len()) } func TestAVLBulkDeleteAll(t *testing.T) { i1 := NewImmutable() entries := generateMockEntries(5) i2, _ := i1.Insert(entries...) i3, deleted := i2.Delete(entries...) assert.Equal(t, entries, deleted) assert.Equal(t, uint64(0), i3.Len()) i3, deleted = i2.Delete(entries...) assert.Equal(t, entries, deleted) assert.Equal(t, uint64(0), i3.Len()) } func TestAVLDeleteReplay(t *testing.T) { i1 := NewImmutable() m1 := mockEntry(1) m2 := mockEntry(5) m3 := mockEntry(10) m4 := mockEntry(15) i2, _ := i1.Insert(m2, m1, m3, m4) i3, deleted := i2.Delete(m3) assert.Equal(t, uint64(3), i3.Len()) assert.Equal(t, Entries{m3}, deleted) assert.Equal(t, uint64(4), i2.Len()) i3, deleted = i2.Delete(m3) assert.Equal(t, uint64(3), i3.Len()) assert.Equal(t, Entries{m3}, deleted) assert.Equal(t, uint64(4), i2.Len()) } func TestAVLFails(t *testing.T) { keys := []mockEntry{ mockEntry(0), mockEntry(1), mockEntry(3), mockEntry(4), mockEntry(5), mockEntry(6), mockEntry(7), mockEntry(2), } i1 := NewImmutable() for _, k := range keys { i1, _ = i1.Insert(k) } for _, k := range keys { var deleted Entries i1, deleted = i1.Delete(k) assert.Equal(t, Entries{k}, deleted) } } func BenchmarkImmutableInsert(b *testing.B) { numItems := b.N sl := NewImmutable() entries := generateMockEntries(numItems) sl, _ = sl.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { sl, _ = sl.Insert(entries[i%numItems]) } } func BenchmarkImmutableGet(b *testing.B) { numItems := b.N sl := NewImmutable() entries := generateMockEntries(numItems) sl, _ = sl.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { sl.Get(entries[i%numItems]) } } func BenchmarkImmutableBulkInsert(b *testing.B) { numItems := b.N sl := NewImmutable() entries := generateMockEntries(numItems) b.ResetTimer() for i := 0; i < b.N; i++ { sl.Insert(entries...) } } func BenchmarkImmutableDelete(b *testing.B) { numItems := b.N sl := NewImmutable() entries := generateMockEntries(numItems) sl, _ = sl.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { sl, _ = sl.Delete(entries[i%numItems]) } } func BenchmarkImmutableBulkDelete(b *testing.B) { numItems := b.N sl := NewImmutable() entries := generateMockEntries(numItems) sl, _ = sl.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { sl.Delete(entries...) } } ================================================ FILE: tree/avl/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package avl // Entries is a list of type Entry. type Entries []Entry // Entry represents all items that can be placed into the AVL tree. // They must implement a Compare method that can be used to determine // the Entry's correct place in the tree. Any object can implement // Compare. type Entry interface { // Compare should return a value indicating the relationship // of this Entry to the provided Entry. A -1 means this entry // is less than, 0 means equality, and 1 means greater than. Compare(Entry) int } ================================================ FILE: tree/avl/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package avl type mockEntry int func (me mockEntry) Compare(other Entry) int { otherMe := other.(mockEntry) if me > otherMe { return 1 } if me < otherMe { return -1 } return 0 } ================================================ FILE: tree/avl/node.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package avl type nodes []*node func (ns nodes) reset() { for i := range ns { ns[i] = nil } } type node struct { balance int8 // bounded, |balance| should be <= 1 children [2]*node entry Entry } // copy returns a copy of this node with pointers to the original // children. func (n *node) copy() *node { return &node{ balance: n.balance, children: [2]*node{n.children[0], n.children[1]}, entry: n.entry, } } // newNode returns a new node for the provided entry. A nil // entry is used to represent the dummy node. func newNode(entry Entry) *node { return &node{ entry: entry, children: [2]*node{}, } } ================================================ FILE: trie/ctrie/ctrie.go ================================================ /* Copyright 2015 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package ctrie provides an implementation of the Ctrie data structure, which is a concurrent, lock-free hash trie. This data structure was originally presented in the paper Concurrent Tries with Efficient Non-Blocking Snapshots: https://axel22.github.io/resources/docs/ctries-snapshot.pdf */ package ctrie import ( "bytes" "errors" "hash" "hash/fnv" "sync/atomic" "unsafe" "github.com/Workiva/go-datastructures/list" ) const ( // w controls the number of branches at a node (2^w branches). w = 5 // exp2 is 2^w, which is the hashcode space. exp2 = 32 ) // HashFactory returns a new Hash32 used to hash keys. type HashFactory func() hash.Hash32 func defaultHashFactory() hash.Hash32 { return fnv.New32a() } // Ctrie is a concurrent, lock-free hash trie. By default, keys are hashed // using FNV-1a unless a HashFactory is provided to New. type Ctrie struct { root *iNode readOnly bool hashFactory HashFactory } // generation demarcates Ctrie snapshots. We use a heap-allocated reference // instead of an integer to avoid integer overflows. Struct must have a field // on it since two distinct zero-size variables may have the same address in // memory. type generation struct{ _ int } // iNode is an indirection node. I-nodes remain present in the Ctrie even as // nodes above and below change. Thread-safety is achieved in part by // performing CAS operations on the I-node instead of the internal node array. type iNode struct { main *mainNode gen *generation // rdcss is set during an RDCSS operation. The I-node is actually a wrapper // around the descriptor in this case so that a single type is used during // CAS operations on the root. rdcss *rdcssDescriptor } // copyToGen returns a copy of this I-node copied to the given generation. func (i *iNode) copyToGen(gen *generation, ctrie *Ctrie) *iNode { nin := &iNode{gen: gen} main := gcasRead(i, ctrie) atomic.StorePointer( (*unsafe.Pointer)(unsafe.Pointer(&nin.main)), unsafe.Pointer(main)) return nin } // mainNode is either a cNode, tNode, lNode, or failed node which makes up an // I-node. type mainNode struct { cNode *cNode tNode *tNode lNode *lNode failed *mainNode // prev is set as a failed main node when we attempt to CAS and the // I-node's generation does not match the root generation. This signals // that the GCAS failed and the I-node's main node must be set back to the // previous value. prev *mainNode } // cNode is an internal main node containing a bitmap and the array with // references to branch nodes. A branch node is either another I-node or a // singleton S-node. type cNode struct { bmp uint32 array []branch gen *generation } // newMainNode is a recursive constructor which creates a new mainNode. This // mainNode will consist of cNodes as long as the hashcode chunks of the two // keys are equal at the given level. If the level exceeds 2^w, an lNode is // created. func newMainNode(x *sNode, xhc uint32, y *sNode, yhc uint32, lev uint, gen *generation) *mainNode { if lev < exp2 { xidx := (xhc >> lev) & 0x1f yidx := (yhc >> lev) & 0x1f bmp := uint32((1 << xidx) | (1 << yidx)) if xidx == yidx { // Recurse when indexes are equal. main := newMainNode(x, xhc, y, yhc, lev+w, gen) iNode := &iNode{main: main, gen: gen} return &mainNode{cNode: &cNode{bmp, []branch{iNode}, gen}} } if xidx < yidx { return &mainNode{cNode: &cNode{bmp, []branch{x, y}, gen}} } return &mainNode{cNode: &cNode{bmp, []branch{y, x}, gen}} } l := list.Empty.Add(x).Add(y) return &mainNode{lNode: &lNode{l}} } // inserted returns a copy of this cNode with the new entry at the given // position. func (c *cNode) inserted(pos, flag uint32, br branch, gen *generation) *cNode { length := uint32(len(c.array)) bmp := c.bmp array := make([]branch, length+1) copy(array, c.array) array[pos] = br for i, x := pos, uint32(0); x < length-pos; i++ { array[i+1] = c.array[i] x++ } ncn := &cNode{bmp: bmp | flag, array: array, gen: gen} return ncn } // updated returns a copy of this cNode with the entry at the given index // updated. func (c *cNode) updated(pos uint32, br branch, gen *generation) *cNode { array := make([]branch, len(c.array)) copy(array, c.array) array[pos] = br ncn := &cNode{bmp: c.bmp, array: array, gen: gen} return ncn } // removed returns a copy of this cNode with the entry at the given index // removed. func (c *cNode) removed(pos, flag uint32, gen *generation) *cNode { length := uint32(len(c.array)) bmp := c.bmp array := make([]branch, length-1) for i := uint32(0); i < pos; i++ { array[i] = c.array[i] } for i, x := pos, uint32(0); x < length-pos-1; i++ { array[i] = c.array[i+1] x++ } ncn := &cNode{bmp: bmp ^ flag, array: array, gen: gen} return ncn } // renewed returns a copy of this cNode with the I-nodes below it copied to the // given generation. func (c *cNode) renewed(gen *generation, ctrie *Ctrie) *cNode { array := make([]branch, len(c.array)) for i, br := range c.array { switch t := br.(type) { case *iNode: array[i] = t.copyToGen(gen, ctrie) default: array[i] = br } } return &cNode{bmp: c.bmp, array: array, gen: gen} } // tNode is tomb node which is a special node used to ensure proper ordering // during removals. type tNode struct { *sNode } // untombed returns the S-node contained by the T-node. func (t *tNode) untombed() *sNode { return &sNode{&Entry{Key: t.Key, hash: t.hash, Value: t.Value}} } // lNode is a list node which is a leaf node used to handle hashcode // collisions by keeping such keys in a persistent list. type lNode struct { list.PersistentList } // entry returns the first S-node contained in the L-node. func (l *lNode) entry() *sNode { head, _ := l.Head() return head.(*sNode) } // lookup returns the value at the given entry in the L-node or returns false // if it's not contained. func (l *lNode) lookup(e *Entry) (interface{}, bool) { found, ok := l.Find(func(sn interface{}) bool { return bytes.Equal(e.Key, sn.(*sNode).Key) }) if !ok { return nil, false } return found.(*sNode).Value, true } // inserted creates a new L-node with the added entry. func (l *lNode) inserted(entry *Entry) *lNode { return &lNode{l.removed(entry).Add(&sNode{entry})} } // removed creates a new L-node with the entry removed. func (l *lNode) removed(e *Entry) *lNode { idx := l.FindIndex(func(sn interface{}) bool { return bytes.Equal(e.Key, sn.(*sNode).Key) }) if idx < 0 { return l } nl, _ := l.Remove(uint(idx)) return &lNode{nl} } // length returns the L-node list length. func (l *lNode) length() uint { return l.Length() } // branch is either an iNode or sNode. type branch interface{} // Entry contains a Ctrie key-value pair. type Entry struct { Key []byte Value interface{} hash uint32 } // sNode is a singleton node which contains a single key and value. type sNode struct { *Entry } // New creates an empty Ctrie which uses the provided HashFactory for key // hashing. If nil is passed in, it will default to FNV-1a hashing. func New(hashFactory HashFactory) *Ctrie { if hashFactory == nil { hashFactory = defaultHashFactory } root := &iNode{main: &mainNode{cNode: &cNode{}}} return newCtrie(root, hashFactory, false) } func newCtrie(root *iNode, hashFactory HashFactory, readOnly bool) *Ctrie { return &Ctrie{ root: root, hashFactory: hashFactory, readOnly: readOnly, } } // Insert adds the key-value pair to the Ctrie, replacing the existing value if // the key already exists. func (c *Ctrie) Insert(key []byte, value interface{}) { c.assertReadWrite() c.insert(&Entry{ Key: key, Value: value, hash: c.hash(key), }) } // Lookup returns the value for the associated key or returns false if the key // doesn't exist. func (c *Ctrie) Lookup(key []byte) (interface{}, bool) { return c.lookup(&Entry{Key: key, hash: c.hash(key)}) } // Remove deletes the value for the associated key, returning true if it was // removed or false if the entry doesn't exist. func (c *Ctrie) Remove(key []byte) (interface{}, bool) { c.assertReadWrite() return c.remove(&Entry{Key: key, hash: c.hash(key)}) } // Snapshot returns a stable, point-in-time snapshot of the Ctrie. If the Ctrie // is read-only, the returned Ctrie will also be read-only. func (c *Ctrie) Snapshot() *Ctrie { return c.snapshot(c.readOnly) } // ReadOnlySnapshot returns a stable, point-in-time snapshot of the Ctrie which // is read-only. Write operations on a read-only snapshot will panic. func (c *Ctrie) ReadOnlySnapshot() *Ctrie { return c.snapshot(true) } // snapshot wraps up the CAS logic to make a snapshot or a read-only snapshot. func (c *Ctrie) snapshot(readOnly bool) *Ctrie { if readOnly && c.readOnly { return c } for { root := c.readRoot() main := gcasRead(root, c) if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) { if readOnly { // For a read-only snapshot, we can share the old generation // root. return newCtrie(root, c.hashFactory, readOnly) } // For a read-write snapshot, we need to take a copy of the root // in the new generation. return newCtrie(c.readRoot().copyToGen(&generation{}, c), c.hashFactory, readOnly) } } } // Clear removes all keys from the Ctrie. func (c *Ctrie) Clear() { for { root := c.readRoot() gen := &generation{} newRoot := &iNode{ main: &mainNode{cNode: &cNode{array: make([]branch, 0), gen: gen}}, gen: gen, } if c.rdcssRoot(root, gcasRead(root, c), newRoot) { return } } } // Iterator returns a channel which yields the Entries of the Ctrie. If a // cancel channel is provided, closing it will terminate and close the iterator // channel. Note that if a cancel channel is not used and not every entry is // read from the iterator, a goroutine will leak. func (c *Ctrie) Iterator(cancel <-chan struct{}) <-chan *Entry { ch := make(chan *Entry) snapshot := c.ReadOnlySnapshot() go func() { snapshot.traverse(snapshot.readRoot(), ch, cancel) close(ch) }() return ch } // Size returns the number of keys in the Ctrie. func (c *Ctrie) Size() uint { // TODO: The size operation can be optimized further by caching the size // information in main nodes of a read-only Ctrie – this reduces the // amortized complexity of the size operation to O(1) because the size // computation is amortized across the update operations that occurred // since the last snapshot. size := uint(0) for _ = range c.Iterator(nil) { size++ } return size } var errCanceled = errors.New("canceled") func (c *Ctrie) traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error { main := gcasRead(i, c) switch { case main.cNode != nil: for _, br := range main.cNode.array { switch b := br.(type) { case *iNode: if err := c.traverse(b, ch, cancel); err != nil { return err } case *sNode: select { case ch <- b.Entry: case <-cancel: return errCanceled } } } case main.lNode != nil: for _, e := range main.lNode.Map(func(sn interface{}) interface{} { return sn.(*sNode).Entry }) { select { case ch <- e.(*Entry): case <-cancel: return errCanceled } } case main.tNode != nil: select { case ch <- main.tNode.Entry: case <-cancel: return errCanceled } } return nil } func (c *Ctrie) assertReadWrite() { if c.readOnly { panic("Cannot modify read-only snapshot") } } func (c *Ctrie) insert(entry *Entry) { root := c.readRoot() if !c.iinsert(root, entry, 0, nil, root.gen) { c.insert(entry) } } func (c *Ctrie) lookup(entry *Entry) (interface{}, bool) { root := c.readRoot() result, exists, ok := c.ilookup(root, entry, 0, nil, root.gen) for !ok { return c.lookup(entry) } return result, exists } func (c *Ctrie) remove(entry *Entry) (interface{}, bool) { root := c.readRoot() result, exists, ok := c.iremove(root, entry, 0, nil, root.gen) for !ok { return c.remove(entry) } return result, exists } func (c *Ctrie) hash(k []byte) uint32 { hasher := c.hashFactory() hasher.Write(k) return hasher.Sum32() } // iinsert attempts to insert the entry into the Ctrie. If false is returned, // the operation should be retried. func (c *Ctrie) iinsert(i *iNode, entry *Entry, lev uint, parent *iNode, startGen *generation) bool { // Linearization point. main := gcasRead(i, c) switch { case main.cNode != nil: cn := main.cNode flag, pos := flagPos(entry.hash, lev, cn.bmp) if cn.bmp&flag == 0 { // If the relevant bit is not in the bitmap, then a copy of the // cNode with the new entry is created. The linearization point is // a successful CAS. rn := cn if cn.gen != i.gen { rn = cn.renewed(i.gen, c) } ncn := &mainNode{cNode: rn.inserted(pos, flag, &sNode{entry}, i.gen)} return gcas(i, main, ncn, c) } // If the relevant bit is present in the bitmap, then its corresponding // branch is read from the array. branch := cn.array[pos] switch branch.(type) { case *iNode: // If the branch is an I-node, then iinsert is called recursively. in := branch.(*iNode) if startGen == in.gen { return c.iinsert(in, entry, lev+w, i, startGen) } if gcas(i, main, &mainNode{cNode: cn.renewed(startGen, c)}, c) { return c.iinsert(i, entry, lev, parent, startGen) } return false case *sNode: sn := branch.(*sNode) if !bytes.Equal(sn.Key, entry.Key) { // If the branch is an S-node and its key is not equal to the // key being inserted, then the Ctrie has to be extended with // an additional level. The C-node is replaced with its updated // version, created using the updated function that adds a new // I-node at the respective position. The new Inode has its // main node pointing to a C-node with both keys. The // linearization point is a successful CAS. rn := cn if cn.gen != i.gen { rn = cn.renewed(i.gen, c) } nsn := &sNode{entry} nin := &iNode{main: newMainNode(sn, sn.hash, nsn, nsn.hash, lev+w, i.gen), gen: i.gen} ncn := &mainNode{cNode: rn.updated(pos, nin, i.gen)} return gcas(i, main, ncn, c) } // If the key in the S-node is equal to the key being inserted, // then the C-node is replaced with its updated version with a new // S-node. The linearization point is a successful CAS. ncn := &mainNode{cNode: cn.updated(pos, &sNode{entry}, i.gen)} return gcas(i, main, ncn, c) default: panic("Ctrie is in an invalid state") } case main.tNode != nil: clean(parent, lev-w, c) return false case main.lNode != nil: nln := &mainNode{lNode: main.lNode.inserted(entry)} return gcas(i, main, nln, c) default: panic("Ctrie is in an invalid state") } } // ilookup attempts to fetch the entry from the Ctrie. The first two return // values are the entry value and whether or not the entry was contained in the // Ctrie. The last bool indicates if the operation succeeded. False means it // should be retried. func (c *Ctrie) ilookup(i *iNode, entry *Entry, lev uint, parent *iNode, startGen *generation) (interface{}, bool, bool) { // Linearization point. main := gcasRead(i, c) switch { case main.cNode != nil: cn := main.cNode flag, pos := flagPos(entry.hash, lev, cn.bmp) if cn.bmp&flag == 0 { // If the bitmap does not contain the relevant bit, a key with the // required hashcode prefix is not present in the trie. return nil, false, true } // Otherwise, the relevant branch at index pos is read from the array. branch := cn.array[pos] switch branch.(type) { case *iNode: // If the branch is an I-node, the ilookup procedure is called // recursively at the next level. in := branch.(*iNode) if c.readOnly || startGen == in.gen { return c.ilookup(in, entry, lev+w, i, startGen) } if gcas(i, main, &mainNode{cNode: cn.renewed(startGen, c)}, c) { return c.ilookup(i, entry, lev, parent, startGen) } return nil, false, false case *sNode: // If the branch is an S-node, then the key within the S-node is // compared with the key being searched – these two keys have the // same hashcode prefixes, but they need not be equal. If they are // equal, the corresponding value from the S-node is // returned and a NOTFOUND value otherwise. sn := branch.(*sNode) if bytes.Equal(sn.Key, entry.Key) { return sn.Value, true, true } return nil, false, true default: panic("Ctrie is in an invalid state") } case main.tNode != nil: return cleanReadOnly(main.tNode, lev, parent, c, entry) case main.lNode != nil: // Hash collisions are handled using L-nodes, which are essentially // persistent linked lists. val, ok := main.lNode.lookup(entry) return val, ok, true default: panic("Ctrie is in an invalid state") } } // iremove attempts to remove the entry from the Ctrie. The first two return // values are the entry value and whether or not the entry was contained in the // Ctrie. The last bool indicates if the operation succeeded. False means it // should be retried. func (c *Ctrie) iremove(i *iNode, entry *Entry, lev uint, parent *iNode, startGen *generation) (interface{}, bool, bool) { // Linearization point. main := gcasRead(i, c) switch { case main.cNode != nil: cn := main.cNode flag, pos := flagPos(entry.hash, lev, cn.bmp) if cn.bmp&flag == 0 { // If the bitmap does not contain the relevant bit, a key with the // required hashcode prefix is not present in the trie. return nil, false, true } // Otherwise, the relevant branch at index pos is read from the array. branch := cn.array[pos] switch branch.(type) { case *iNode: // If the branch is an I-node, the iremove procedure is called // recursively at the next level. in := branch.(*iNode) if startGen == in.gen { return c.iremove(in, entry, lev+w, i, startGen) } if gcas(i, main, &mainNode{cNode: cn.renewed(startGen, c)}, c) { return c.iremove(i, entry, lev, parent, startGen) } return nil, false, false case *sNode: // If the branch is an S-node, its key is compared against the key // being removed. sn := branch.(*sNode) if !bytes.Equal(sn.Key, entry.Key) { // If the keys are not equal, the NOTFOUND value is returned. return nil, false, true } // If the keys are equal, a copy of the current node without the // S-node is created. The contraction of the copy is then created // using the toContracted procedure. A successful CAS will // substitute the old C-node with the copied C-node, thus removing // the S-node with the given key from the trie – this is the // linearization point ncn := cn.removed(pos, flag, i.gen) cntr := toContracted(ncn, lev) if gcas(i, main, cntr, c) { if parent != nil { main = gcasRead(i, c) if main.tNode != nil { cleanParent(parent, i, entry.hash, lev-w, c, startGen) } } return sn.Value, true, true } return nil, false, false default: panic("Ctrie is in an invalid state") } case main.tNode != nil: clean(parent, lev-w, c) return nil, false, false case main.lNode != nil: nln := &mainNode{lNode: main.lNode.removed(entry)} if nln.lNode.length() == 1 { nln = entomb(nln.lNode.entry()) } if gcas(i, main, nln, c) { val, ok := main.lNode.lookup(entry) return val, ok, true } return nil, false, true default: panic("Ctrie is in an invalid state") } } // toContracted ensures that every I-node except the root points to a C-node // with at least one branch. If a given C-Node has only a single S-node below // it and is not at the root level, a T-node which wraps the S-node is // returned. func toContracted(cn *cNode, lev uint) *mainNode { if lev > 0 && len(cn.array) == 1 { branch := cn.array[0] switch branch.(type) { case *sNode: return entomb(branch.(*sNode)) default: return &mainNode{cNode: cn} } } return &mainNode{cNode: cn} } // toCompressed compacts the C-node as a performance optimization. func toCompressed(cn *cNode, lev uint) *mainNode { tmpArray := make([]branch, len(cn.array)) for i, sub := range cn.array { switch sub.(type) { case *iNode: inode := sub.(*iNode) mainPtr := (*unsafe.Pointer)(unsafe.Pointer(&inode.main)) main := (*mainNode)(atomic.LoadPointer(mainPtr)) tmpArray[i] = resurrect(inode, main) case *sNode: tmpArray[i] = sub default: panic("Ctrie is in an invalid state") } } return toContracted(&cNode{bmp: cn.bmp, array: tmpArray}, lev) } func entomb(m *sNode) *mainNode { return &mainNode{tNode: &tNode{m}} } func resurrect(iNode *iNode, main *mainNode) branch { if main.tNode != nil { return main.tNode.untombed() } return iNode } func clean(i *iNode, lev uint, ctrie *Ctrie) bool { main := gcasRead(i, ctrie) if main.cNode != nil { return gcas(i, main, toCompressed(main.cNode, lev), ctrie) } return true } func cleanReadOnly(tn *tNode, lev uint, p *iNode, ctrie *Ctrie, entry *Entry) (val interface{}, exists bool, ok bool) { if !ctrie.readOnly { clean(p, lev-5, ctrie) return nil, false, false } if tn.hash == entry.hash && bytes.Equal(tn.Key, entry.Key) { return tn.Value, true, true } return nil, false, true } func cleanParent(p, i *iNode, hc uint32, lev uint, ctrie *Ctrie, startGen *generation) { var ( mainPtr = (*unsafe.Pointer)(unsafe.Pointer(&i.main)) main = (*mainNode)(atomic.LoadPointer(mainPtr)) pMainPtr = (*unsafe.Pointer)(unsafe.Pointer(&p.main)) pMain = (*mainNode)(atomic.LoadPointer(pMainPtr)) ) if pMain.cNode != nil { flag, pos := flagPos(hc, lev, pMain.cNode.bmp) if pMain.cNode.bmp&flag != 0 { sub := pMain.cNode.array[pos] if sub == i && main.tNode != nil { ncn := pMain.cNode.updated(pos, resurrect(i, main), i.gen) if !gcas(p, pMain, toContracted(ncn, lev), ctrie) && ctrie.readRoot().gen == startGen { cleanParent(p, i, hc, lev, ctrie, startGen) } } } } } func flagPos(hashcode uint32, lev uint, bmp uint32) (uint32, uint32) { idx := (hashcode >> lev) & 0x1f flag := uint32(1) << uint32(idx) mask := uint32(flag - 1) pos := bitCount(bmp & mask) return flag, pos } func bitCount(x uint32) uint32 { x -= (x >> 1) & 0x55555555 x = ((x >> 2) & 0x33333333) + (x & 0x33333333) x = ((x >> 4) + x) & 0x0f0f0f0f x *= 0x01010101 return x >> 24 } // gcas is a generation-compare-and-swap which has semantics similar to RDCSS, // but it does not create the intermediate object except in the case of // failures that occur due to the snapshot being taken. This ensures that the // write occurs only if the Ctrie root generation has remained the same in // addition to the I-node having the expected value. func gcas(in *iNode, old, n *mainNode, ct *Ctrie) bool { prevPtr := (*unsafe.Pointer)(unsafe.Pointer(&n.prev)) atomic.StorePointer(prevPtr, unsafe.Pointer(old)) if atomic.CompareAndSwapPointer( (*unsafe.Pointer)(unsafe.Pointer(&in.main)), unsafe.Pointer(old), unsafe.Pointer(n)) { gcasComplete(in, n, ct) return atomic.LoadPointer(prevPtr) == nil } return false } // gcasRead performs a GCAS-linearizable read of the I-node's main node. func gcasRead(in *iNode, ctrie *Ctrie) *mainNode { m := (*mainNode)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&in.main)))) prev := (*mainNode)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&m.prev)))) if prev == nil { return m } return gcasComplete(in, m, ctrie) } // gcasComplete commits the GCAS operation. func gcasComplete(i *iNode, m *mainNode, ctrie *Ctrie) *mainNode { for { if m == nil { return nil } prev := (*mainNode)(atomic.LoadPointer( (*unsafe.Pointer)(unsafe.Pointer(&m.prev)))) root := ctrie.rdcssReadRoot(true) if prev == nil { return m } if prev.failed != nil { // Signals GCAS failure. Swap old value back into I-node. fn := prev.failed if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&i.main)), unsafe.Pointer(m), unsafe.Pointer(fn)) { return fn } m = (*mainNode)(atomic.LoadPointer( (*unsafe.Pointer)(unsafe.Pointer(&i.main)))) continue } if root.gen == i.gen && !ctrie.readOnly { // Commit GCAS. if atomic.CompareAndSwapPointer( (*unsafe.Pointer)(unsafe.Pointer(&m.prev)), unsafe.Pointer(prev), nil) { return m } continue } // Generations did not match. Store failed node on prev to signal // I-node's main node must be set back to the previous value. atomic.CompareAndSwapPointer( (*unsafe.Pointer)(unsafe.Pointer(&m.prev)), unsafe.Pointer(prev), unsafe.Pointer(&mainNode{failed: prev})) m = (*mainNode)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&i.main)))) return gcasComplete(i, m, ctrie) } } // rdcssDescriptor is an intermediate struct which communicates the intent to // replace the value in an I-node and check that the root's generation has not // changed before committing to the new value. type rdcssDescriptor struct { old *iNode expected *mainNode nv *iNode committed int32 } // readRoot performs a linearizable read of the Ctrie root. This operation is // prioritized so that if another thread performs a GCAS on the root, a // deadlock does not occur. func (c *Ctrie) readRoot() *iNode { return c.rdcssReadRoot(false) } // rdcssReadRoot performs a RDCSS-linearizable read of the Ctrie root with the // given priority. func (c *Ctrie) rdcssReadRoot(abort bool) *iNode { r := (*iNode)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&c.root)))) if r.rdcss != nil { return c.rdcssComplete(abort) } return r } // rdcssRoot performs a RDCSS on the Ctrie root. This is used to create a // snapshot of the Ctrie by copying the root I-node and setting it to a new // generation. func (c *Ctrie) rdcssRoot(old *iNode, expected *mainNode, nv *iNode) bool { desc := &iNode{ rdcss: &rdcssDescriptor{ old: old, expected: expected, nv: nv, }, } if c.casRoot(old, desc) { c.rdcssComplete(false) return atomic.LoadInt32(&desc.rdcss.committed) == 1 } return false } // rdcssComplete commits the RDCSS operation. func (c *Ctrie) rdcssComplete(abort bool) *iNode { for { r := (*iNode)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&c.root)))) if r.rdcss == nil { return r } var ( desc = r.rdcss ov = desc.old exp = desc.expected nv = desc.nv ) if abort { if c.casRoot(r, ov) { return ov } continue } oldeMain := gcasRead(ov, c) if oldeMain == exp { // Commit the RDCSS. if c.casRoot(r, nv) { atomic.StoreInt32(&desc.committed, 1) return nv } continue } if c.casRoot(r, ov) { return ov } continue } } // casRoot performs a CAS on the Ctrie root. func (c *Ctrie) casRoot(ov, nv *iNode) bool { c.assertReadWrite() return atomic.CompareAndSwapPointer( (*unsafe.Pointer)(unsafe.Pointer(&c.root)), unsafe.Pointer(ov), unsafe.Pointer(nv)) } ================================================ FILE: trie/ctrie/ctrie_test.go ================================================ /* Copyright 2015 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package ctrie import ( "hash" "hash/fnv" "strconv" "sync" "testing" "time" "github.com/stretchr/testify/assert" ) func TestCtrie(t *testing.T) { assert := assert.New(t) ctrie := New(nil) _, ok := ctrie.Lookup([]byte("foo")) assert.False(ok) ctrie.Insert([]byte("foo"), "bar") val, ok := ctrie.Lookup([]byte("foo")) assert.True(ok) assert.Equal("bar", val) ctrie.Insert([]byte("fooooo"), "baz") val, ok = ctrie.Lookup([]byte("foo")) assert.True(ok) assert.Equal("bar", val) val, ok = ctrie.Lookup([]byte("fooooo")) assert.True(ok) assert.Equal("baz", val) for i := 0; i < 100; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), "blah") } for i := 0; i < 100; i++ { val, ok = ctrie.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal("blah", val) } val, ok = ctrie.Lookup([]byte("foo")) assert.True(ok) assert.Equal("bar", val) ctrie.Insert([]byte("foo"), "qux") val, ok = ctrie.Lookup([]byte("foo")) assert.True(ok) assert.Equal("qux", val) val, ok = ctrie.Remove([]byte("foo")) assert.True(ok) assert.Equal("qux", val) _, ok = ctrie.Remove([]byte("foo")) assert.False(ok) val, ok = ctrie.Remove([]byte("fooooo")) assert.True(ok) assert.Equal("baz", val) for i := 0; i < 100; i++ { ctrie.Remove([]byte(strconv.Itoa(i))) } } type mockHash32 struct { hash.Hash32 } func (m *mockHash32) Sum32() uint32 { return 0 } func mockHashFactory() hash.Hash32 { return &mockHash32{fnv.New32a()} } func TestInsertLNode(t *testing.T) { assert := assert.New(t) ctrie := New(mockHashFactory) for i := 0; i < 10; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } for i := 0; i < 10; i++ { val, ok := ctrie.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } _, ok := ctrie.Lookup([]byte("11")) assert.False(ok) for i := 0; i < 10; i++ { val, ok := ctrie.Remove([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } } func TestInsertTNode(t *testing.T) { assert := assert.New(t) ctrie := New(nil) for i := 0; i < 10000; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } for i := 0; i < 5000; i++ { ctrie.Remove([]byte(strconv.Itoa(i))) } for i := 0; i < 10000; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } for i := 0; i < 10000; i++ { val, ok := ctrie.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } } func TestConcurrency(t *testing.T) { assert := assert.New(t) ctrie := New(nil) var wg sync.WaitGroup wg.Add(2) go func() { for i := 0; i < 10000; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } wg.Done() }() go func() { for i := 0; i < 10000; i++ { val, ok := ctrie.Lookup([]byte(strconv.Itoa(i))) if ok { assert.Equal(i, val) } } wg.Done() }() for i := 0; i < 10000; i++ { time.Sleep(5) ctrie.Remove([]byte(strconv.Itoa(i))) } wg.Wait() } func TestConcurrency2(t *testing.T) { assert := assert.New(t) ctrie := New(nil) var wg sync.WaitGroup wg.Add(4) go func() { for i := 0; i < 10000; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } wg.Done() }() go func() { for i := 0; i < 10000; i++ { val, ok := ctrie.Lookup([]byte(strconv.Itoa(i))) if ok { assert.Equal(i, val) } } wg.Done() }() go func() { for i := 0; i < 10000; i++ { ctrie.Snapshot() } wg.Done() }() go func() { for i := 0; i < 10000; i++ { ctrie.ReadOnlySnapshot() } wg.Done() }() wg.Wait() assert.Equal(uint(10000), ctrie.Size()) } func TestSnapshot(t *testing.T) { assert := assert.New(t) ctrie := New(nil) for i := 0; i < 100; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } snapshot := ctrie.Snapshot() // Ensure snapshot contains expected keys. for i := 0; i < 100; i++ { val, ok := snapshot.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } // Now remove the values from the original. for i := 0; i < 100; i++ { ctrie.Remove([]byte(strconv.Itoa(i))) } // Ensure snapshot was unaffected by removals. for i := 0; i < 100; i++ { val, ok := snapshot.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } // New Ctrie and snapshot. ctrie = New(nil) for i := 0; i < 100; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } snapshot = ctrie.Snapshot() // Ensure snapshot is mutable. for i := 0; i < 100; i++ { snapshot.Remove([]byte(strconv.Itoa(i))) } snapshot.Insert([]byte("bat"), "man") for i := 0; i < 100; i++ { _, ok := snapshot.Lookup([]byte(strconv.Itoa(i))) assert.False(ok) } val, ok := snapshot.Lookup([]byte("bat")) assert.True(ok) assert.Equal("man", val) // Ensure original Ctrie was unaffected. for i := 0; i < 100; i++ { val, ok := ctrie.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } _, ok = ctrie.Lookup([]byte("bat")) assert.False(ok) // Ensure snapshots-of-snapshots work as expected. snapshot2 := snapshot.Snapshot() for i := 0; i < 100; i++ { _, ok := snapshot2.Lookup([]byte(strconv.Itoa(i))) assert.False(ok) } val, ok = snapshot2.Lookup([]byte("bat")) assert.True(ok) assert.Equal("man", val) snapshot2.Remove([]byte("bat")) _, ok = snapshot2.Lookup([]byte("bat")) assert.False(ok) val, ok = snapshot.Lookup([]byte("bat")) assert.True(ok) assert.Equal("man", val) } func TestReadOnlySnapshot(t *testing.T) { assert := assert.New(t) ctrie := New(nil) for i := 0; i < 100; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } snapshot := ctrie.ReadOnlySnapshot() // Ensure snapshot contains expected keys. for i := 0; i < 100; i++ { val, ok := snapshot.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } for i := 0; i < 50; i++ { ctrie.Remove([]byte(strconv.Itoa(i))) } // Ensure snapshot was unaffected by removals. for i := 0; i < 100; i++ { val, ok := snapshot.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } // Ensure read-only snapshots panic on writes. func() { defer func() { assert.NotNil(recover()) }() snapshot.Remove([]byte("blah")) }() // Ensure snapshots-of-snapshots work as expected. snapshot2 := snapshot.Snapshot() for i := 50; i < 100; i++ { ctrie.Remove([]byte(strconv.Itoa(i))) } for i := 0; i < 100; i++ { val, ok := snapshot2.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } // Ensure snapshots of read-only snapshots panic on writes. func() { defer func() { assert.NotNil(recover()) }() snapshot2.Remove([]byte("blah")) }() } func TestIterator(t *testing.T) { assert := assert.New(t) ctrie := New(nil) for i := 0; i < 10; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } expected := map[string]int{ "0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, } count := 0 for entry := range ctrie.Iterator(nil) { exp, ok := expected[string(entry.Key)] if assert.True(ok) { assert.Equal(exp, entry.Value) } count++ } assert.Equal(len(expected), count) // Closing cancel channel should close iterator channel. cancel := make(chan struct{}) iter := ctrie.Iterator(cancel) entry := <-iter exp, ok := expected[string(entry.Key)] if assert.True(ok) { assert.Equal(exp, entry.Value) } close(cancel) // Drain anything already put on the channel. Since select chooses a // pseudo-random case, we must attempt to drain for every item. for _ = range expected { <-iter } _, ok = <-iter assert.False(ok) } // TestIteratorCoversTNodes reproduces the scenario of a bug where tNodes weren't being traversed. func TestIteratorCoversTNodes(t *testing.T) { assert := assert.New(t) ctrie := New(mockHashFactory) // Add a pair of keys that collide (because we're using the mock hash). ctrie.Insert([]byte("a"), true) ctrie.Insert([]byte("b"), true) // Delete one key, leaving exactly one sNode in the cNode. This will // trigger creation of a tNode. ctrie.Remove([]byte("b")) seenKeys := map[string]bool{} for entry := range ctrie.Iterator(nil) { seenKeys[string(entry.Key)] = true } assert.Contains(seenKeys, "a", "Iterator did not return 'a'.") assert.Len(seenKeys, 1) } func TestSize(t *testing.T) { ctrie := New(nil) for i := 0; i < 10; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } assert.Equal(t, uint(10), ctrie.Size()) } func TestClear(t *testing.T) { assert := assert.New(t) ctrie := New(nil) for i := 0; i < 10; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } assert.Equal(uint(10), ctrie.Size()) snapshot := ctrie.Snapshot() ctrie.Clear() assert.Equal(uint(0), ctrie.Size()) assert.Equal(uint(10), snapshot.Size()) } type fakehash struct{} func (h *fakehash) Sum32() uint32 { return 42 } func (h *fakehash) Sum(b []byte) []byte { return nil } func (h *fakehash) Size() int { return 0 } func (h *fakehash) BlockSize() int { return 0 } func (h *fakehash) Reset() { } func (h *fakehash) Write(b []byte) (int, error) { return 0, nil } func factory() hash.Hash32 { return &fakehash{} } func TestHashCollision(t *testing.T) { trie := New(factory) trie.Insert([]byte("foobar"), 1) trie.Insert([]byte("zogzog"), 2) trie.Insert([]byte("foobar"), 3) val, exists := trie.Lookup([]byte("foobar")) assert.True(t, exists) assert.Equal(t, 3, val) trie.Remove([]byte("foobar")) _, exists = trie.Lookup([]byte("foobar")) assert.False(t, exists) } func BenchmarkInsert(b *testing.B) { ctrie := New(nil) b.ResetTimer() for i := 0; i < b.N; i++ { ctrie.Insert([]byte("foo"), 0) } } func BenchmarkLookup(b *testing.B) { numItems := 1000 ctrie := New(nil) for i := 0; i < numItems; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } key := []byte(strconv.Itoa(numItems / 2)) b.ResetTimer() for i := 0; i < b.N; i++ { ctrie.Lookup(key) } } func BenchmarkRemove(b *testing.B) { numItems := 1000 ctrie := New(nil) for i := 0; i < numItems; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } key := []byte(strconv.Itoa(numItems / 2)) b.ResetTimer() for i := 0; i < b.N; i++ { ctrie.Remove(key) } } func BenchmarkSnapshot(b *testing.B) { numItems := 1000 ctrie := New(nil) for i := 0; i < numItems; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } b.ResetTimer() for i := 0; i < b.N; i++ { ctrie.Snapshot() } } func BenchmarkReadOnlySnapshot(b *testing.B) { numItems := 1000 ctrie := New(nil) for i := 0; i < numItems; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) } b.ResetTimer() for i := 0; i < b.N; i++ { ctrie.ReadOnlySnapshot() } } ================================================ FILE: trie/dtrie/dtrie.go ================================================ /* Copyright (c) 2016, Theodore Butler All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ // Package dtrie provides an implementation of the dtrie data structure, which // is a persistent hash trie that dynamically expands or shrinks to provide // efficient memory allocation. This data structure is based on the papers // Ideal Hash Trees by Phil Bagwell and Optimizing Hash-Array Mapped Tries for // Fast and Lean Immutable JVM Collections by Michael J. Steindorfer and // Jurgen J. Vinju package dtrie // Dtrie is a persistent hash trie that dynamically expands or shrinks // to provide efficient memory allocation. type Dtrie struct { root *node hasher func(v interface{}) uint32 } type entry struct { hash uint32 key interface{} value interface{} } func (e *entry) KeyHash() uint32 { return e.hash } func (e *entry) Key() interface{} { return e.key } func (e *entry) Value() interface{} { return e.value } // New creates an empty DTrie with the given hashing function. // If nil is passed in, the default hashing function will be used. func New(hasher func(v interface{}) uint32) *Dtrie { if hasher == nil { hasher = defaultHasher } return &Dtrie{ root: emptyNode(0, 32), hasher: hasher, } } // Size returns the number of entries in the Dtrie. func (d *Dtrie) Size() (size int) { for _ = range iterate(d.root, nil) { size++ } return size } // Get returns the value for the associated key or returns nil if the // key does not exist. func (d *Dtrie) Get(key interface{}) interface{} { node := get(d.root, d.hasher(key), key) if node != nil { return node.Value() } return nil } // Insert adds a key value pair to the Dtrie, replacing the existing value if // the key already exists and returns the resulting Dtrie. func (d *Dtrie) Insert(key, value interface{}) *Dtrie { root := insert(d.root, &entry{d.hasher(key), key, value}) return &Dtrie{root, d.hasher} } // Remove deletes the value for the associated key if it exists and returns // the resulting Dtrie. func (d *Dtrie) Remove(key interface{}) *Dtrie { root := remove(d.root, d.hasher(key), key) return &Dtrie{root, d.hasher} } // Iterator returns a read-only channel of Entries from the Dtrie. If a stop // channel is provided, closing it will terminate and close the iterator // channel. Note that if a cancel channel is not used and not every entry is // read from the iterator, a goroutine will leak. func (d *Dtrie) Iterator(stop <-chan struct{}) <-chan Entry { return iterate(d.root, stop) } ================================================ FILE: trie/dtrie/dtrie_test.go ================================================ /* Copyright (c) 2016, Theodore Butler All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package dtrie import ( "testing" "github.com/stretchr/testify/assert" ) func TestDefaultHasher(t *testing.T) { assert.Equal(t, defaultHasher(map[int]string{11234: "foo"}), defaultHasher(map[int]string{11234: "foo"})) assert.NotEqual(t, defaultHasher("foo"), defaultHasher("bar")) } func collisionHash(key interface{}) uint32 { return uint32(0xffffffff) // for testing collisions } func TestInsert(t *testing.T) { insertTest(t, defaultHasher, 10000) insertTest(t, collisionHash, 1000) } func insertTest(t *testing.T, hashfunc func(interface{}) uint32, count int) *node { n := emptyNode(0, 32) for i := 0; i < count; i++ { n = insert(n, &entry{hashfunc(i), i, i}) } return n } func TestGet(t *testing.T) { getTest(t, defaultHasher, 10000) getTest(t, collisionHash, 1000) } func getTest(t *testing.T, hashfunc func(interface{}) uint32, count int) { n := insertTest(t, hashfunc, count) for i := 0; i < count; i++ { x := get(n, hashfunc(i), i) assert.Equal(t, i, x.Value()) } } func TestRemove(t *testing.T) { removeTest(t, defaultHasher, 10000) removeTest(t, collisionHash, 1000) } func removeTest(t *testing.T, hashfunc func(interface{}) uint32, count int) { n := insertTest(t, hashfunc, count) for i := 0; i < count; i++ { n = remove(n, hashfunc(i), i) } for _, e := range n.entries { if e != nil { t.Fatal("final node is not empty") } } } func TestUpdate(t *testing.T) { updateTest(t, defaultHasher, 10000) updateTest(t, collisionHash, 1000) } func updateTest(t *testing.T, hashfunc func(interface{}) uint32, count int) { n := insertTest(t, hashfunc, count) for i := 0; i < count; i++ { n = insert(n, &entry{hashfunc(i), i, -i}) } } func TestIterate(t *testing.T) { n := insertTest(t, defaultHasher, 10000) echan := iterate(n, nil) c := 0 for _ = range echan { c++ } assert.Equal(t, 10000, c) // test with stop chan c = 0 stop := make(chan struct{}) echan = iterate(n, stop) for _ = range echan { c++ if c == 100 { close(stop) break } } assert.True(t, c == 100) // test with collisions n = insertTest(t, collisionHash, 1000) c = 0 echan = iterate(n, nil) for _ = range echan { c++ } assert.Equal(t, 1000, c) } func TestSize(t *testing.T) { n := insertTest(t, defaultHasher, 10000) d := &Dtrie{n, defaultHasher} assert.Equal(t, 10000, d.Size()) } func BenchmarkInsert(b *testing.B) { b.ReportAllocs() n := emptyNode(0, 32) b.ResetTimer() for i := b.N; i > 0; i-- { n = insert(n, &entry{defaultHasher(i), i, i}) } } func BenchmarkGet(b *testing.B) { b.ReportAllocs() n := insertTest(nil, defaultHasher, b.N) b.ResetTimer() for i := b.N; i > 0; i-- { get(n, defaultHasher(i), i) } } func BenchmarkRemove(b *testing.B) { b.ReportAllocs() n := insertTest(nil, defaultHasher, b.N) b.ResetTimer() for i := b.N; i > 0; i-- { n = remove(n, defaultHasher(i), i) } } func BenchmarkUpdate(b *testing.B) { b.ReportAllocs() n := insertTest(nil, defaultHasher, b.N) b.ResetTimer() for i := b.N; i > 0; i-- { n = insert(n, &entry{defaultHasher(i), i, -i}) } } ================================================ FILE: trie/dtrie/node.go ================================================ /* Copyright (c) 2016, Theodore Butler All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package dtrie import ( "fmt" "sync" "github.com/Workiva/go-datastructures/bitarray" ) type node struct { entries []Entry nodeMap bitarray.Bitmap32 dataMap bitarray.Bitmap32 level uint8 // level starts at 0 } func (n *node) KeyHash() uint32 { return 0 } func (n *node) Key() interface{} { return nil } func (n *node) Value() interface{} { return nil } func (n *node) String() string { return fmt.Sprint(n.entries) } type collisionNode struct { entries []Entry } func (n *collisionNode) KeyHash() uint32 { return 0 } func (n *collisionNode) Key() interface{} { return nil } func (n *collisionNode) Value() interface{} { return nil } func (n *collisionNode) String() string { return fmt.Sprintf("%v", len(n.entries), n.entries) } // Entry defines anything held within the data structure type Entry interface { KeyHash() uint32 Key() interface{} Value() interface{} } func emptyNode(level uint8, capacity int) *node { return &node{entries: make([]Entry, capacity), level: level} } func insert(n *node, entry Entry) *node { index := uint(mask(entry.KeyHash(), n.level)) newNode := n if newNode.level == 6 { // handle hash collisions on 6th level if newNode.entries[index] == nil { newNode.entries[index] = entry newNode.dataMap = newNode.dataMap.SetBit(index) return newNode } if newNode.dataMap.GetBit(index) { if newNode.entries[index].Key() == entry.Key() { newNode.entries[index] = entry return newNode } cNode := &collisionNode{entries: make([]Entry, 2)} cNode.entries[0] = newNode.entries[index] cNode.entries[1] = entry newNode.entries[index] = cNode newNode.dataMap = newNode.dataMap.ClearBit(index) return newNode } cNode := newNode.entries[index].(*collisionNode) cNode.entries = append(cNode.entries, entry) return newNode } if !newNode.dataMap.GetBit(index) && !newNode.nodeMap.GetBit(index) { // insert directly newNode.entries[index] = entry newNode.dataMap = newNode.dataMap.SetBit(index) return newNode } if newNode.nodeMap.GetBit(index) { // insert into sub-node newNode.entries[index] = insert(newNode.entries[index].(*node), entry) return newNode } if newNode.entries[index].Key() == entry.Key() { newNode.entries[index] = entry return newNode } // create new node with the new and existing entries var subNode *node if newNode.level == 5 { // only 2 bits left at level 6 (4 possible indices) subNode = emptyNode(newNode.level+1, 4) } else { subNode = emptyNode(newNode.level+1, 32) } subNode = insert(subNode, newNode.entries[index]) subNode = insert(subNode, entry) newNode.dataMap = newNode.dataMap.ClearBit(index) newNode.nodeMap = newNode.nodeMap.SetBit(index) newNode.entries[index] = subNode return newNode } // returns nil if not found func get(n *node, keyHash uint32, key interface{}) Entry { index := uint(mask(keyHash, n.level)) if n.dataMap.GetBit(index) { return n.entries[index] } if n.nodeMap.GetBit(index) { return get(n.entries[index].(*node), keyHash, key) } if n.level == 6 { // get from collisionNode if n.entries[index] == nil { return nil } cNode := n.entries[index].(*collisionNode) for _, e := range cNode.entries { if e.Key() == key { return e } } } return nil } func remove(n *node, keyHash uint32, key interface{}) *node { index := uint(mask(keyHash, n.level)) newNode := n if n.dataMap.GetBit(index) { newNode.entries[index] = nil newNode.dataMap = newNode.dataMap.ClearBit(index) return newNode } if n.nodeMap.GetBit(index) { subNode := newNode.entries[index].(*node) subNode = remove(subNode, keyHash, key) // compress if only 1 entry exists in sub-node if subNode.nodeMap.PopCount() == 0 && subNode.dataMap.PopCount() == 1 { var e Entry for i := uint(0); i < 32; i++ { if subNode.dataMap.GetBit(i) { e = subNode.entries[i] break } } newNode.entries[index] = e newNode.nodeMap = newNode.nodeMap.ClearBit(index) newNode.dataMap = newNode.dataMap.SetBit(index) } newNode.entries[index] = subNode return newNode } if n.level == 6 { // delete from collisionNode cNode := newNode.entries[index].(*collisionNode) for i, e := range cNode.entries { if e.Key() == key { cNode.entries = append(cNode.entries[:i], cNode.entries[i+1:]...) break } } // compress if only 1 entry exists in collisionNode if len(cNode.entries) == 1 { newNode.entries[index] = cNode.entries[0] newNode.dataMap = newNode.dataMap.SetBit(index) } return newNode } return n } func iterate(n *node, stop <-chan struct{}) <-chan Entry { out := make(chan Entry) go func() { defer close(out) pushEntries(n, stop, out) }() return out } func pushEntries(n *node, stop <-chan struct{}, out chan Entry) { var wg sync.WaitGroup for i, e := range n.entries { select { case <-stop: return default: index := uint(i) switch { case n.dataMap.GetBit(index): out <- e case n.nodeMap.GetBit(index): wg.Add(1) go func() { defer wg.Done() pushEntries(e.(*node), stop, out) }() wg.Wait() case n.level == 6 && e != nil: for _, ce := range n.entries[index].(*collisionNode).entries { select { case <-stop: return default: out <- ce } } } } } } ================================================ FILE: trie/dtrie/util.go ================================================ /* Copyright (c) 2016, Theodore Butler All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package dtrie import ( "fmt" "hash/fnv" ) func mask(hash uint32, level uint8) uint32 { return (hash >> (5 * level)) & 0x01f } func defaultHasher(value interface{}) uint32 { switch v := value.(type) { case uint8: return uint32(v) case uint16: return uint32(v) case uint32: return v case uint64: return uint32(v) case int8: return uint32(v) case int16: return uint32(v) case int32: return uint32(v) case int64: return uint32(v) case uint: return uint32(v) case int: return uint32(v) case uintptr: return uint32(v) case float32: return uint32(v) case float64: return uint32(v) } hasher := fnv.New32a() hasher.Write([]byte(fmt.Sprintf("%#v", value))) return hasher.Sum32() } ================================================ FILE: trie/xfast/iterator.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package xfast // Entries is a typed list of Entry interfaces. type Entries []Entry // Iterator will iterate of the results of a query. type Iterator struct { n *node first bool } // Next will return a bool indicating if another value exists // in the iterator. func (iter *Iterator) Next() bool { if iter.first { iter.first = false return iter.n != nil } iter.n = iter.n.children[1] return iter.n != nil } // Value will return the Entry representing the iterator's current position. // If no Entry exists at the present condition, the iterator is // exhausted and this method will return nil. func (iter *Iterator) Value() Entry { if iter.n == nil { return nil } return iter.n.entry } // exhaust is a helper function that will exhaust this iterator // and return a list of entries. This is for internal use only. func (iter *Iterator) exhaust() Entries { entries := make(Entries, 0, 100) for it := iter; it.Next(); { entries = append(entries, it.Value()) } return entries } ================================================ FILE: trie/xfast/iterator_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package xfast import ( "testing" "github.com/stretchr/testify/assert" ) func TestIterator(t *testing.T) { iter := &Iterator{ first: true, } assert.False(t, iter.Next()) assert.Nil(t, iter.Value()) e1 := newMockEntry(5) n1 := newNode(nil, e1) iter = &Iterator{ first: true, n: n1, } assert.True(t, iter.Next()) assert.Equal(t, e1, iter.Value()) assert.False(t, iter.Next()) assert.Nil(t, iter.Value()) e2 := newMockEntry(10) n2 := newNode(nil, e2) n1.children[1] = n2 iter = &Iterator{ first: true, n: n1, } assert.True(t, iter.Next()) assert.True(t, iter.Next()) assert.Equal(t, e2, iter.Value()) assert.False(t, iter.Next()) assert.Nil(t, iter.Value()) } ================================================ FILE: trie/xfast/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package xfast import "github.com/stretchr/testify/mock" type mockEntry struct { mock.Mock } func (me *mockEntry) Key() uint64 { args := me.Called() return args.Get(0).(uint64) } func newMockEntry(key uint64) *mockEntry { me := new(mockEntry) me.On(`Key`).Return(key) return me } ================================================ FILE: trie/xfast/xfast.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package xfast provides access to a sorted tree that treats integers as if they were words of m bits, where m can be 8, 16, 32, or 64. The advantage to storing integers as a trie of words is that operations can be performed in constant time depending on the size of the universe and not on the number of items in the trie. The time complexity is as follows: Space: O(n log M) Insert: O(log M) Delete: O(log M) Search: O(log log M) Get: O(1) where n is the number of items in the trie and M is the size of the universe, ie, 2^63-1 for 64 bit ints. As you can see, for 64 bit ints, inserts and deletes can be performed in O(64), constant time which provides very predictable behavior in the best case. A get by key can be performed in O(1) time and searches can be performed in O(6) time for 64 bit integers. While x-tries have relatively slow insert, deletions, and consume a large amount of space, they form the top half of a y-fast trie which can insert and delete in O(log log M) time and consumes O(n) space. */ package xfast import "fmt" // isInternal returns a bool indicating if the provided // node is an internal node, that is, non-leaf node. func isInternal(n *node) bool { if n == nil { return false } return n.entry == nil } // hasInternal returns a bool indicating if the provided // node has a child that is an internal node. func hasInternal(n *node) bool { return isInternal(n.children[0]) || isInternal(n.children[1]) } // isLeaf returns a bool indicating if the provided node // is a leaf node, that is, has a valid entry value. func isLeaf(n *node) bool { if n == nil { return false } return !isInternal(n) } // Entry defines items that can be inserted into the x-fast // trie. type Entry interface { // Key is the key for this entry. If the trie has been // given bit size n, only the last n bits of this key // will matter. Use a bit size of 64 to enable all // 2^64-1 keys. Key() uint64 } // masks are used to determine the prefix of any given key. The masks // are stored in a [64] array where each position of the array represents // a bitmask to the ith bit. For example, if you wanted to mask the first // bit of a 64-bit int you'd and it with masks[0]. If you wanted to mask // the first bit of an 8 bit key, you'd have to shift 56 bits to the right // and perform the mask operation. This array is immutable and should not // be changed after initialization. var masks = func() [64]uint64 { // we don't technically need the last mask, this is just to be consistent masks := [64]uint64{} mask := uint64(0) for i := uint64(0); i < 64; i++ { mask = mask | 1<<(63-i) masks[i] = mask } return masks }() // positions are similar to masks and that the positions array allows // us to determine if a node should go left or right at a specific bit // position of the key. Basically, this array stores every 2^n number // where n is in [0, 64). This array is immutable and should not be // changed after initialization. var positions = func() [64]uint64 { positions := [64]uint64{} for i := uint64(0); i < 64; i++ { positions[i] = uint64(1 << (63 - i)) } return positions }() type node struct { // entry will store the entry for this node. Is nil for // every internal node and non-nil for all leaves. It is // how the internal/leaf function helpers determine the // position of this node. entry Entry // children stores the left and right child of this node. // At any time, and at any layer, it's possible for a pointer // to a child to point to a leaf due to threading. children [2]*node // i hate this, but it is really the best way // to walk up successor and predecessor threads parent *node } // newNode will allocate and initialize a newNode with the provided // parent and entry. Parent should never be nil, but entry may be // if constructing an internal node. func newNode(parent *node, entry Entry) *node { return &node{ children: [2]*node{}, entry: entry, parent: parent, } } // binarySearchHashMaps will perform a binary search of the provided // maps to return a node that matches the longest prefix of the provided // key. This will return nil if a match could not be found, which would // also return layer 0. Layer information is useful when determining the // distance from the provided node to the leaves. func binarySearchHashMaps(layers []map[uint64]*node, key uint64) (int, *node) { low, high := 0, len(layers)-1 diff := 64 - len(layers) var mid int var node *node for low <= high { mid = (low + high) / 2 n, ok := layers[mid][key&masks[diff+mid]] if ok { node = n low = mid + 1 } else { high = mid - 1 } } return low, node } // whichSide returns an int representing the side on which // the node resides in its parent. NOTE: this function will panic // if the child does not within the parent. This situation should // should be caught as early as possible as if it happens data // coming from the x-fast trie cannot be trusted. func whichSide(n, parent *node) int { if parent.children[0] == n { return 0 } if parent.children[1] == n { return 1 } panic(fmt.Sprintf(`Node: %+v, %p not a child of: %+v, %p`, n, n, parent, parent)) } // XFastTrie is a datastructure for storing integers in a known // universe, where universe size is determined by the bit size // of the desired keys. This structure should be faster than // binary search tries for very large datasets and slower for // smaller datasets. type XFastTrie struct { // layers stores the hashmaps of the individual layers of the trie. // The hashmaps store prefixes, allowing use to do a binary search // of these maps before visiting the trie for successor/predecessor // queries. layers []map[uint64]*node // root is a pointer to the first node of the trie, which actually // adds an additional layer, ie, instead of 64 layers for a // uint64, this will cause the number of layers to be 65. root *node // num is the number of items in the trie. num uint64 // bits represents the number of bits of the keys this trie // expects. Because the time complexity of operations is // dependent upon universe size, smaller sized keys will // actually cause the trie to be faster. Diff is the difference // between the desired bit size and 64 as we have to offset // in the position and mask arrays. bits, diff uint8 // min and max index the lowest and highest seen keys respectively. // this immediately allows us to check a desired key against // constraints and allows min/max operations to be performed // in O(1) time. min, max *node } // init will initialize the XFastTrie with the provided byte-size. // I'd prefer generics here, but it is what it is. We expect uints // here when ints would perform just as well, but the public methods // on the XFastTrie all expect uint64, so we expect a uint in the // constructor for consistency's sake. func (xft *XFastTrie) init(intType interface{}) { bits := uint8(0) switch intType.(type) { case uint8: bits = 8 case uint16: bits = 16 case uint32: bits = 32 case uint, uint64: bits = 64 default: // we'll panic with a bad value to the constructor. panic(`Invalid universe size provided.`) } xft.layers = make([]map[uint64]*node, bits) xft.bits = bits xft.diff = 64 - bits for i := uint8(0); i < bits; i++ { xft.layers[i] = make(map[uint64]*node, 50) // we can obviously be more intelligent about this. } xft.num = 0 xft.root = newNode(nil, nil) } // Exists returns a bool indicating if the provided // key exists in the trie. This is an O(1) operation. func (xft *XFastTrie) Exists(key uint64) bool { // the bottom hashmap of the trie has every entry // in it. _, ok := xft.layers[xft.bits-1][key] return ok } // Len returns the number of items in this trie. This is an // O(1) operation. func (xft *XFastTrie) Len() uint64 { return xft.num } // Max will return the highest keyed value in the trie. This is // an O(1) operation. func (xft *XFastTrie) Max() Entry { if xft.max == nil { return nil } return xft.max.entry } // Min will return the lowest keyed value in the trie. This is // an O(1) operation. func (xft *XFastTrie) Min() Entry { if xft.min == nil { return nil } return xft.min.entry } // insert will add the provided entry to the trie or overwrite the existing // entry if it exists. func (xft *XFastTrie) insert(entry Entry) { key := entry.Key() // so we aren't calling this interface method over and over, fucking Go n := xft.layers[xft.bits-1][key] if n != nil { n.entry = entry return } // we need to find a predecessor or successor if it exists // to help us set threads later in this method. var predecessor, successor *node if xft.min != nil && key < xft.min.entry.Key() { successor = xft.min } else { successor = xft.successor(key) } // only need to find predecessor if successor is nil as otherwise // the successor will provide us is the predecessor if it exists. if successor == nil { if xft.max != nil && key > xft.max.entry.Key() { predecessor = xft.max } else { predecessor = xft.predecessor(key) } } // find the deepest root with a matching prefix, this should // save us some time, assuming the hashmap has perfect hashing. layer, root := binarySearchHashMaps(xft.layers, key) if root == nil { n = xft.root layer = 0 } else { n = root } var leftOrRight uint64 // from the existing node, create new nodes. for i := uint8(layer); i < xft.bits; i++ { // on 0th, this will be root // find out if we need to go left or right leftOrRight = (key & positions[xft.diff+i]) >> (xft.bits - 1 - i) if n.children[leftOrRight] == nil || isLeaf(n.children[leftOrRight]) { var nn *node if i < xft.bits-1 { nn = newNode(n, nil) } else { nn = newNode(n, entry) xft.num++ } n.children[leftOrRight] = nn xft.layers[i][key&masks[xft.diff+i]] = nn // prefix for this layer } n = n.children[leftOrRight] } // we need to put the new node where it belongs in the doubly-linked // list comprised of all the leaves. if successor != nil { // we have to walk predecessor and successor threads predecessor = successor.children[0] if predecessor != nil { predecessor.children[1] = n n.children[0] = predecessor } n.children[1] = successor successor.children[0] = n } else if predecessor != nil { n.children[0] = predecessor predecessor.children[1] = n } // walk up the successor if it exists to set that branch's new // predecessor. if successor != nil { xft.walkUpSuccessor(root, n, successor) } // walk up the predecessor if it exists to set that branch's // new successor. if predecessor != nil { xft.walkUpPredecessor(root, n, predecessor) } // finally, walk up our own branch to set both successors // and predecessors. xft.walkUpNode(root, n, predecessor, successor) // and then do a final check against the min/max indicies. if xft.max == nil || key > xft.max.entry.Key() { xft.max = n } if xft.min == nil || key < xft.min.entry.Key() { xft.min = n } } // walkUpSuccessor will walk up the successor branch setting // the predecessor where possible. This breaks when a common // ancestor between successor and node is found, ie, the root. func (xft *XFastTrie) walkUpSuccessor(root, node, successor *node) { n := successor.parent for n != nil && n != root { // we don't really want to overwrite existing internal nodes, // or where the child is a leaf that is the successor if !isInternal(n.children[0]) && n.children[0] != successor { n.children[0] = node } n = n.parent } } // walkUpPredecessor will walk up the predecessor branch setting // the successor where possible. This breaks when a common // ancestor between predecessor and node is found, ie, the root. func (xft *XFastTrie) walkUpPredecessor(root, node, predecessor *node) { n := predecessor.parent for n != nil && n != root { if !isInternal(n.children[1]) && n.children[1] != predecessor { n.children[1] = node } n = n.parent } } // walkUpNode will walk up the newly created branch and set predecessor // and successor where possible. If predecessor or successor are nil, // this will set nil where possible. func (xft *XFastTrie) walkUpNode(root, node, predecessor, successor *node) { n := node.parent for n != nil && n != root { if !isInternal(n.children[1]) && n.children[1] != successor && n.children[1] != node { n.children[1] = successor } if !isInternal(n.children[0]) && n.children[0] != predecessor && n.children[0] != node { n.children[0] = predecessor } n = n.parent } } // Insert will insert the provided entries into the trie. Any entry // with an existing key will cause an overwrite. This is an O(log M) // operation, for each entry. func (xft *XFastTrie) Insert(entries ...Entry) { for _, e := range entries { xft.insert(e) } } func (xft *XFastTrie) delete(key uint64) { n := xft.layers[xft.bits-1][key] if n == nil { // there's no matching k, v pair return } successor, predecessor := n.children[1], n.children[0] i := uint8(1) delete(xft.layers[xft.bits-1], key) leftOrRight := whichSide(n, n.parent) n.parent.children[leftOrRight] = nil n.children[0], n.children[1] = nil, nil n = n.parent hasImmediateSibling := false if successor != nil && successor.parent == n { hasImmediateSibling = true } if predecessor != nil && predecessor.parent == n { hasImmediateSibling = true } // this loop will kill any nodes that no longer link to internal // nodes for n != nil && n.parent != nil { // if we have an internal node remaining we should abort // now as no further node will be removed. We should also // abort if the first parent of a leaf references the pre if hasInternal(n) || (i == 1 && hasImmediateSibling) { n = n.parent // we had one side deleted, need to set the other break } leftOrRight = whichSide(n, n.parent) n.parent.children[leftOrRight] = nil n.children[0], n.children[1] = nil, nil delete(xft.layers[xft.bits-i-1], key&masks[len(masks)-1-int(i)]) n = n.parent i++ } // we need to check now and update threads, but in the leaves // and in their branches if predecessor != nil { predecessor.children[1] = successor xft.walkUpPredecessor(n, successor, predecessor) } if successor != nil { successor.children[0] = predecessor xft.walkUpSuccessor(n, predecessor, successor) } // check max/min indices if xft.max.entry.Key() == key { xft.max = predecessor } if xft.min.entry.Key() == key { xft.min = successor } // decrement number of nodes xft.num-- } // Delete will delete the provided keys from the trie. If an entry // associated with a provided key cannot be found, that deletion is // a no-op. Each deletion is an O(log M) operation. func (xft *XFastTrie) Delete(keys ...uint64) { for _, key := range keys { xft.delete(key) } } // predecessor will find the node equal to or immediately less // than the provided key. func (xft *XFastTrie) predecessor(key uint64) *node { if xft.root == nil || xft.max == nil { // no successor if no nodes return nil } if key >= xft.max.entry.Key() { return xft.max } if key < xft.min.entry.Key() { return nil } n := xft.layers[xft.bits-1][key] if n != nil { return n } layer, n := binarySearchHashMaps(xft.layers, key) if n == nil && layer > 1 { return nil } else if n == nil { n = xft.root } if isInternal(n.children[0]) && isLeaf(n.children[1]) { return n.children[1].children[0] } return n.children[0] } // successor will find the node equal to or immediately more // than the provided key. func (xft *XFastTrie) successor(key uint64) *node { if xft.root == nil || xft.min == nil { // no successor if no nodes return nil } if key <= xft.min.entry.Key() { return xft.min } if key > xft.max.entry.Key() { return nil } n := xft.layers[xft.bits-1][key] if n != nil { return n } layer, n := binarySearchHashMaps(xft.layers, key) if n == nil && layer > 1 { return nil } else if n == nil { n = xft.root } if isInternal(n.children[1]) && isLeaf(n.children[0]) { return n.children[0].children[1] } return n.children[1] } // Successor will return an Entry which matches the provided // key or its immediate successor. Will return nil if a successor // does not exist. This is an O(log log M) operation. func (xft *XFastTrie) Successor(key uint64) Entry { n := xft.successor(key) if n == nil { return nil } return n.entry } // Predecessor will return an Entry which matches the provided // key or its immediate predecessor. Will return nil if a predecessor // does not exist. This is an O(log log M) operation. func (xft *XFastTrie) Predecessor(key uint64) Entry { n := xft.predecessor(key) if n == nil { return nil } return n.entry } // Iter will return an iterator that will iterate over all values // equal to or immediately greater than the provided key. Iterator // will iterate successor relationships. func (xft *XFastTrie) Iter(key uint64) *Iterator { return &Iterator{ n: xft.successor(key), first: true, } } // Get will return a value in the trie associated with the provided // key if it exists. Returns nil if the key does not exist. This // is expected to take O(1) time. func (xft *XFastTrie) Get(key uint64) Entry { // only have to check the last hashmap for the provided // key. n := xft.layers[xft.bits-1][key] if n == nil { return nil } return n.entry } // New will construct a new X-Fast Trie with the given "size," // that is the size of the universe of the trie. This expects // a uint of some sort, ie, uint8, uint16, etc. The size of the // universe will be 2^n-1 and will affect the speed of all operations. // IFC MUST be a uint type. func New(ifc interface{}) *XFastTrie { xft := &XFastTrie{} xft.init(ifc) return xft } ================================================ FILE: trie/xfast/xfast_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package xfast import ( "fmt" "math" "testing" "github.com/stretchr/testify/assert" "github.com/Workiva/go-datastructures/slice" ) func checkTrie(t *testing.T, xft *XFastTrie) { checkSuccessor(t, xft) checkPredecessor(t, xft) checkNodes(t, xft) } func checkSuccessor(t *testing.T, xft *XFastTrie) { n := xft.min var side int var successor *node for n != nil { successor = n.children[1] hasSuccesor := successor != nil immediateSuccessor := false if hasSuccesor { assert.Equal(t, n, successor.children[0]) if n.parent == successor.parent { immediateSuccessor = true } } for n.parent != nil { side = whichSide(n, n.parent) if isInternal(n.parent.children[1]) && isInternal(n.parent.children[0]) { break } if immediateSuccessor && n.parent == successor.parent { assert.Equal(t, successor, n.parent.children[1]) break } if side == 0 && !isInternal(n.parent.children[1]) && hasSuccesor { assert.Equal(t, successor, n.parent.children[1]) } n = n.parent } n = successor } } func checkPredecessor(t *testing.T, xft *XFastTrie) { n := xft.max var side int var predecessor *node for n != nil { predecessor = n.children[0] hasPredecessor := predecessor != nil immediatePredecessor := false if hasPredecessor { assert.Equal(t, n, predecessor.children[1]) if n.parent == predecessor.parent { immediatePredecessor = true } } for n.parent != nil { side = whichSide(n, n.parent) if isInternal(n.parent.children[0]) && isInternal(n.parent.children[1]) { break } if immediatePredecessor && n.parent == predecessor.parent { assert.Equal(t, predecessor, n.parent.children[0]) break } if side == 1 && !isInternal(n.parent.children[0]) && hasPredecessor { assert.Equal(t, predecessor, n.parent.children[0]) } n = n.parent } n = predecessor } } func checkNodes(t *testing.T, xft *XFastTrie) { count := uint64(0) n := xft.min for n != nil { count++ checkNode(t, xft, n) n = n.children[1] } assert.Equal(t, count, xft.Len()) } func checkNode(t *testing.T, xft *XFastTrie, n *node) { if n.entry == nil { assert.Fail(t, `Expected non-nil entry`) return } key := n.entry.Key() bits := make([]int, 0, xft.bits) for i := uint8(0); i < xft.bits; i++ { leftOrRight := (key & positions[xft.diff+i]) >> (xft.bits - 1 - i) bits = append(bits, int(leftOrRight)) } checkPattern(t, n, bits) } func dumpNode(t *testing.T, n *node) { for n != nil { t.Logf(`NODE: %+v, %p`, n, n) n = n.parent } } func checkPattern(t *testing.T, n *node, pattern []int) { i := len(pattern) - 1 bottomNode := n for n.parent != nil { if !assert.False(t, i < 0, fmt.Sprintf(`Too many parents. NODE: %+v, PATTERN: %+v`, bottomNode, pattern)) { dumpNode(t, bottomNode) break // so we don't panic on the next line } assert.Equal(t, pattern[i], whichSide(n, n.parent)) i-- n = n.parent } assert.Equal(t, -1, i) } func TestEmptyMinMax(t *testing.T) { xft := New(uint8(0)) assert.Nil(t, xft.Min()) assert.Nil(t, xft.Max()) } func TestMask(t *testing.T) { assert.Equal(t, uint64(math.MaxUint64), masks[63]) } func TestInsert(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(5) xft.Insert(e1) assert.True(t, xft.Exists(5)) assert.Equal(t, e1, xft.Min()) assert.Equal(t, e1, xft.Max()) checkTrie(t, xft) e2 := newMockEntry(20) xft.Insert(e2) assert.True(t, xft.Exists(20)) assert.Equal(t, uint64(2), xft.Len()) assert.Equal(t, e1, xft.Min()) assert.Equal(t, e2, xft.Max()) checkTrie(t, xft) } func TestGet(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(5) xft.Insert(e1) assert.Equal(t, e1, xft.Get(5)) assert.Nil(t, xft.Get(6)) } func TestInsertOverwrite(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(5) xft.Insert(e1) e2 := newMockEntry(5) xft.Insert(e2) checkTrie(t, xft) iter := xft.Iter(5) assert.Equal(t, Entries{e2}, iter.exhaust()) } func TestInsertBetween(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(10) xft.Insert(e1) assert.True(t, xft.Exists(10)) assert.Equal(t, e1, xft.Min()) assert.Equal(t, e1, xft.Max()) checkTrie(t, xft) e2 := newMockEntry(20) xft.Insert(e2) checkTrie(t, xft) assert.True(t, xft.Exists(20)) assert.Equal(t, uint64(2), xft.Len()) assert.Equal(t, e1, xft.Min()) assert.Equal(t, e2, xft.Max()) assert.Equal(t, e2, xft.Successor(15)) e3 := newMockEntry(15) xft.Insert(e3) assert.True(t, xft.Exists(15)) assert.Equal(t, uint64(3), xft.Len()) assert.Equal(t, e1, xft.Min()) assert.Equal(t, e2, xft.Max()) checkTrie(t, xft) iter := xft.Iter(0) entries := iter.exhaust() assert.Equal(t, Entries{e1, e3, e2}, entries) iter = xft.Iter(11) entries = iter.exhaust() assert.Equal(t, Entries{e3, e2}, entries) iter = xft.Iter(16) entries = iter.exhaust() assert.Equal(t, Entries{e2}, entries) } func TestSuccessorDoesNotExist(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(5) xft.Insert(e1) result := xft.Successor(6) assert.Nil(t, result) } func TestSuccessorIsExactValue(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(5) xft.Insert(e1) result := xft.Successor(5) assert.Equal(t, e1, result) } func TestSuccessorGreaterThanKey(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(math.MaxUint8) xft.Insert(e1) result := xft.Successor(5) assert.Equal(t, e1, result) } func TestSuccessorCloseToKey(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(10) xft.Insert(e1) result := xft.Successor(5) assert.Equal(t, e1, result) } func TestSuccessorBetweenTwoKeys(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(10) xft.Insert(e1) e2 := newMockEntry(20) xft.Insert(e2) for i := uint64(11); i < 20; i++ { result := xft.Successor(i) assert.Equal(t, e2, result) } for i := uint64(21); i < 100; i++ { result := xft.Successor(i) assert.Nil(t, result) } } func TestPredecessorDoesNotExist(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(5) xft.Insert(e1) result := xft.Predecessor(4) assert.Nil(t, result) } func TestPredecessorIsExactValue(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(5) xft.Insert(e1) result := xft.Predecessor(5) assert.Equal(t, e1, result) } func TestPredecessorLessThanKey(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(0) xft.Insert(e1) result := xft.Predecessor(math.MaxUint64) assert.Equal(t, e1, result) } func TestPredecessorCloseToKey(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(5) xft.Insert(e1) result := xft.Predecessor(10) assert.Equal(t, e1, result) } func TestPredecessorBetweenTwoKeys(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(10) xft.Insert(e1) e2 := newMockEntry(20) xft.Insert(e2) for i := uint64(11); i < 20; i++ { result := xft.Predecessor(i) assert.Equal(t, e1, result) } for i := uint64(0); i < 10; i++ { result := xft.Predecessor(i) assert.Nil(t, result) } } func TestInsertPredecessor(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(10) xft.Insert(e1) e2 := newMockEntry(5) xft.Insert(e2) checkTrie(t, xft) assert.Equal(t, e2, xft.Min()) assert.Equal(t, e1, xft.Max()) iter := xft.Iter(2) assert.Equal(t, Entries{e2, e1}, iter.exhaust()) iter = xft.Iter(5) assert.Equal(t, Entries{e2, e1}, iter.exhaust()) iter = xft.Iter(6) assert.Equal(t, Entries{e1}, iter.exhaust()) iter = xft.Iter(11) assert.Equal(t, Entries{}, iter.exhaust()) } func TestDeleteOnlyBranch(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(10) xft.Insert(e1) xft.Delete(10) checkTrie(t, xft) assert.Equal(t, uint64(0), xft.Len()) assert.Nil(t, xft.Min()) assert.Nil(t, xft.Max()) for _, hm := range xft.layers { assert.Len(t, hm, 0) } assert.NotNil(t, xft.root) assert.Nil(t, xft.root.children[0]) assert.Nil(t, xft.root.children[1]) iter := xft.Iter(0) assert.False(t, iter.Next()) } func TestDeleteLargeBranch(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(0) e2 := newMockEntry(math.MaxUint8) xft.Insert(e1, e2) checkTrie(t, xft) xft.Delete(0) assert.Equal(t, e2, xft.Min()) assert.Equal(t, e2, xft.Max()) checkTrie(t, xft) assert.Nil(t, xft.root.children[0]) n := xft.max for n != nil { assert.Nil(t, n.children[0]) n = n.parent } iter := xft.Iter(0) assert.Equal(t, Entries{e2}, iter.exhaust()) } func TestDeleteLateBranching(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(0) e2 := newMockEntry(1) xft.Insert(e1, e2) checkTrie(t, xft) xft.Delete(1) assert.Equal(t, e1, xft.Min()) assert.Equal(t, e1, xft.Max()) checkTrie(t, xft) n := xft.min for n != nil { assert.Nil(t, n.children[1]) n = n.parent } iter := xft.Iter(0) assert.Equal(t, Entries{e1}, iter.exhaust()) } func TestDeleteLateBranchingMin(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(0) e2 := newMockEntry(1) xft.Insert(e1, e2) checkTrie(t, xft) xft.Delete(0) assert.Equal(t, e2, xft.Min()) assert.Equal(t, e2, xft.Max()) checkTrie(t, xft) assert.Nil(t, xft.min.children[0]) n := xft.min.parent assert.Nil(t, n.children[0]) n = n.parent for n != nil { assert.Nil(t, n.children[1]) n = n.parent } iter := xft.Iter(0) assert.Equal(t, Entries{e2}, iter.exhaust()) } func TestDeleteMiddleBranch(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(0) e2 := newMockEntry(math.MaxUint8) e3 := newMockEntry(64) // [0, 1, 0, 0, 0, 0, 0, 0] xft.Insert(e1, e2, e3) checkTrie(t, xft) xft.Delete(64) assert.Equal(t, e1, xft.Min()) assert.Equal(t, e2, xft.Max()) checkTrie(t, xft) iter := xft.Iter(0) assert.Equal(t, Entries{e1, e2}, iter.exhaust()) } func TestDeleteMiddleBranchOtherSide(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(0) e2 := newMockEntry(math.MaxUint8) e3 := newMockEntry(128) // [1, 0, 0, 0, 0, 0, 0, 0] xft.Insert(e1, e2, e3) checkTrie(t, xft) xft.Delete(128) assert.Equal(t, e1, xft.Min()) assert.Equal(t, e2, xft.Max()) checkTrie(t, xft) iter := xft.Iter(0) assert.Equal(t, Entries{e1, e2}, iter.exhaust()) } func TestDeleteNotFound(t *testing.T) { xft := New(uint8(0)) e1 := newMockEntry(64) xft.Insert(e1) checkTrie(t, xft) xft.Delete(128) assert.Equal(t, e1, xft.Max()) assert.Equal(t, e1, xft.Min()) checkTrie(t, xft) } func BenchmarkSuccessor(b *testing.B) { numItems := 10000 xft := New(uint64(0)) for i := uint64(0); i < uint64(numItems); i++ { xft.Insert(newMockEntry(i)) } b.ResetTimer() for i := 0; i < b.N; i++ { xft.Successor(uint64(i)) } } func BenchmarkDelete(b *testing.B) { xs := make([]*XFastTrie, 0, b.N) for i := 0; i < b.N; i++ { x := New(uint8(0)) x.Insert(newMockEntry(0)) xs = append(xs, x) } // this is actually a pretty bad case for the x-fast // trie as the entire branch will have to be walked. b.ResetTimer() for i := 0; i < b.N; i++ { xs[i].Delete(0) } } func BenchmarkInsert(b *testing.B) { for i := 0; i < b.N; i++ { xft := New(uint64(0)) e := newMockEntry(uint64(i)) xft.Insert(e) } } // benchmarked against a flat list func BenchmarkListInsert(b *testing.B) { numItems := 100000 s := make(slice.Int64Slice, 0, numItems) for j := int64(0); j < int64(numItems); j++ { s = append(s, j) } b.ResetTimer() for i := 0; i < b.N; i++ { s.Insert(int64(i)) } } func BenchmarkListSearch(b *testing.B) { numItems := 1000000 s := make(slice.Int64Slice, 0, numItems) for j := int64(0); j < int64(numItems); j++ { s = append(s, j) } b.ResetTimer() for i := 0; i < b.N; i++ { s.Search(int64(i)) } } ================================================ FILE: trie/yfast/entries.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package yfast import "sort" type entriesWrapper struct { key uint64 entries Entries } // Key will return the key of the highest entry in this list. // This is required by the x-fast trie Entry interface. This // returns 0 if this list is empty. func (ew *entriesWrapper) Key() uint64 { return ew.key } // Entries is a typed list of Entry. The size of entries // will be limited to 1/2log M to 2log M where M is the size // of the universe. type Entries []Entry // search will perform a sort package search on this list // of entries and return an index indicating position. // If the returned index is >= len(entries) then a suitable // position could not be found. The index does not guarantee // equality, just indicates where the key would be inserted. func (entries Entries) search(key uint64) int { return sort.Search(len(entries), func(i int) bool { return entries[i].Key() >= key }) } // insert will insert the provided entry into this list of // entries. Returned is an entry if an entry already exists // for the provided key. If nothing is overwritten, Entry // will be nil. func (entries *Entries) insert(entry Entry) Entry { i := entries.search(entry.Key()) if i == len(*entries) { *entries = append(*entries, entry) return nil } if (*entries)[i].Key() == entry.Key() { oldEntry := (*entries)[i] (*entries)[i] = entry return oldEntry } (*entries) = append(*entries, nil) copy((*entries)[i+1:], (*entries)[i:]) (*entries)[i] = entry return nil } // delete will remove the provided key from this list of entries. // Returned is a deleted Entry. This will be nil if the key // cannot be found. func (entries *Entries) delete(key uint64) Entry { i := entries.search(key) if i == len(*entries) { // key not found return nil } if (*entries)[i].Key() != key { return nil } oldEntry := (*entries)[i] copy((*entries)[i:], (*entries)[i+1:]) (*entries)[len(*entries)-1] = nil // GC *entries = (*entries)[:len(*entries)-1] return oldEntry } // max returns the value of the highest key in this list // of entries. The bool indicates if it's a valid key, that // is if there is more than zero entries in this list. func (entries Entries) max() (uint64, bool) { if len(entries) == 0 { return 0, false } return entries[len(entries)-1].Key(), true } // get will perform a lookup over this list of entries // and return an Entry if it exists. Returns nil if the // entry does not exist. func (entries Entries) get(key uint64) Entry { i := entries.search(key) if i == len(entries) { return nil } if entries[i].Key() == key { return entries[i] } return nil } // successor will return the first entry that has a key // greater than or equal to provided key. Also returned // is the index of the find. Returns nil, -1 if a successor does // not exist. func (entries Entries) successor(key uint64) (Entry, int) { i := entries.search(key) if i == len(entries) { return nil, -1 } return entries[i], i } // predecessor will return the first entry that has a key // less than or equal to the provided key. Also returned // is the index of the find. Returns nil, -1 if a predecessor // does not exist. func (entries Entries) predecessor(key uint64) (Entry, int) { if len(entries) == 0 { return nil, -1 } i := entries.search(key) if i == len(entries) { return entries[i-1], i - 1 } if entries[i].Key() == key { return entries[i], i } i-- if i < 0 { return nil, -1 } return entries[i], i } ================================================ FILE: trie/yfast/entries_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package yfast import ( "testing" "github.com/stretchr/testify/assert" ) func TestEntriesInsert(t *testing.T) { es := Entries{} e1 := newMockEntry(5) e2 := newMockEntry(1) es.insert(e1) es.insert(e2) assert.Equal(t, Entries{e2, e1}, es) e3 := newMockEntry(3) es.insert(e3) assert.Equal(t, Entries{e2, e3, e1}, es) } func TestEntriesDelete(t *testing.T) { es := Entries{} e1 := newMockEntry(5) e2 := newMockEntry(1) es.insert(e1) es.insert(e2) es.delete(5) assert.Equal(t, Entries{e2}, es) es.delete(1) assert.Equal(t, Entries{}, es) } func TestEntriesMax(t *testing.T) { es := Entries{} max, ok := es.max() assert.Equal(t, uint64(0), max) assert.False(t, ok) e2 := newMockEntry(1) es.insert(e2) max, ok = es.max() assert.Equal(t, uint64(1), max) assert.True(t, ok) e1 := newMockEntry(5) es.insert(e1) max, ok = es.max() assert.Equal(t, uint64(5), max) assert.True(t, ok) } func TestEntriesGet(t *testing.T) { es := Entries{} e1 := newMockEntry(5) e2 := newMockEntry(1) es.insert(e1) es.insert(e2) result := es.get(5) assert.Equal(t, e1, result) result = es.get(1) assert.Equal(t, e2, result) result = es.get(10) assert.Nil(t, result) } func TestEntriesSuccessor(t *testing.T) { es := Entries{} successor, i := es.successor(5) assert.Equal(t, -1, i) assert.Nil(t, successor) e1 := newMockEntry(5) e2 := newMockEntry(1) es.insert(e1) es.insert(e2) successor, i = es.successor(0) assert.Equal(t, 0, i) assert.Equal(t, e2, successor) successor, i = es.successor(2) assert.Equal(t, 1, i) assert.Equal(t, e1, successor) successor, i = es.successor(5) assert.Equal(t, 1, i) assert.Equal(t, e1, successor) successor, i = es.successor(10) assert.Equal(t, -1, i) assert.Nil(t, successor) } func TestEntriesPredecessor(t *testing.T) { es := Entries{} predecessor, i := es.predecessor(5) assert.Equal(t, -1, i) assert.Nil(t, predecessor) e1 := newMockEntry(5) e2 := newMockEntry(1) es.insert(e1) es.insert(e2) predecessor, i = es.predecessor(0) assert.Equal(t, -1, i) assert.Nil(t, predecessor) predecessor, i = es.predecessor(2) assert.Equal(t, 0, i) assert.Equal(t, e2, predecessor) predecessor, i = es.predecessor(5) assert.Equal(t, 1, i) assert.Equal(t, e1, predecessor) predecessor, i = es.predecessor(10) assert.Equal(t, 1, i) assert.Equal(t, e1, predecessor) } ================================================ FILE: trie/yfast/interface.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package yfast // Entry defines items that can be inserted into the y-fast // trie. type Entry interface { // Key is the key for this entry. If the trie has been // given bit size n, only the last n bits of this key // will matter. Use a bit size of 64 to enable all // 2^64-1 keys. Key() uint64 } ================================================ FILE: trie/yfast/iterator.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package yfast import "github.com/Workiva/go-datastructures/trie/xfast" // iteratorExhausted is a magic value for an index to tell us // that the iterator has been exhausted. const iteratorExhausted = -2 // iterExhausted is a helper function to tell us if an iterator // has been exhausted. func iterExhausted(iter *Iterator) bool { return iter.index == iteratorExhausted } // Iterator will iterate of the results of a query. type Iterator struct { xfastIterator *xfast.Iterator index int entries *entriesWrapper } // Next will return a bool indicating if another value exists // in the iterator. func (iter *Iterator) Next() bool { if iterExhausted(iter) { return false } iter.index++ if iter.index >= len(iter.entries.entries) { next := iter.xfastIterator.Next() if !next { iter.index = iteratorExhausted return false } var ok bool iter.entries, ok = iter.xfastIterator.Value().(*entriesWrapper) if !ok { iter.index = iteratorExhausted return false } iter.index = 0 } return true } // Value will return the Entry representing the iterator's current position. // If no Entry exists at the present condition, the iterator is // exhausted and this method will return nil. func (iter *Iterator) Value() Entry { if iterExhausted(iter) { return nil } if iter.entries == nil || iter.index < 0 || iter.index >= len(iter.entries.entries) { return nil } return iter.entries.entries[iter.index] } // exhaust is a helper function that will exhaust this iterator // and return a list of entries. This is for internal use only. func (iter *Iterator) exhaust() Entries { entries := make(Entries, 0, 100) for it := iter; it.Next(); { entries = append(entries, it.Value()) } return entries } // nilIterator is an iterator that will always return false // from Next() and nil for Value(). func nilIterator() *Iterator { return &Iterator{ index: iteratorExhausted, } } ================================================ FILE: trie/yfast/mock_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package yfast type mockEntry struct { // not going to use mock here as it skews benchmarks key uint64 } func (me *mockEntry) Key() uint64 { return me.key } func newMockEntry(key uint64) *mockEntry { return &mockEntry{key} } ================================================ FILE: trie/yfast/yfast.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* Package yfast implements a y-fast trie. Instead of a red-black BBST for the leaves, this implementation uses a simple ordered list. This package should have searches that are as performant as the x-fast trie while having faster inserts/deletes and linear space consumption. Performance characteristics: Space: O(n) Get: O(log log M) Search: O(log log M) Insert: O(log log M) Delete: O(log log M) where n is the number of items in the trie and M is the size of the universe, ie, 2^m where m is the number of bits in the specified key size. This particular implementation also uses fixed bucket sizes as this should aid in multithreading these functions for performance optimization. */ package yfast import "github.com/Workiva/go-datastructures/trie/xfast" // YFastTrie implements all the methods available to the y-fast // trie datastructure. The top half is composed of an x-fast trie // while the leaves are composed of ordered lists of type Entry, // an interface found in this package. type YFastTrie struct { num uint64 xfast *xfast.XFastTrie bits uint8 } func (yfast *YFastTrie) init(intType interface{}) { switch intType.(type) { case uint8: yfast.bits = 8 case uint16: yfast.bits = 16 case uint32: yfast.bits = 32 case uint, uint64: yfast.bits = 64 default: // we'll panic with a bad value to the constructor. panic(`Invalid universe size provided.`) } yfast.xfast = xfast.New(intType) } // getBucketKey finds the largest possible value in this key's bucket. // This is the representative value for the entry in the x-fast trie. func (yfast *YFastTrie) getBucketKey(key uint64) uint64 { i := key/uint64(yfast.bits) + 1 return uint64(yfast.bits)*i - 1 } func (yfast *YFastTrie) insert(entry Entry) Entry { // first, we need to determine if we have a node in the x-trie // that already matches for the key bundleKey := yfast.getBucketKey(entry.Key()) bundle := yfast.xfast.Get(bundleKey) if bundle != nil { overwritten := bundle.(*entriesWrapper).entries.insert(entry) if overwritten == nil { yfast.num++ return nil } return overwritten } yfast.num++ entries := make(Entries, 0, yfast.bits) entries.insert(entry) ew := &entriesWrapper{ key: bundleKey, entries: entries, } yfast.xfast.Insert(ew) return nil } // Insert will insert the provided entries into the y-fast trie // and return a list of entries that were overwritten. func (yfast *YFastTrie) Insert(entries ...Entry) Entries { overwritten := make(Entries, 0, len(entries)) for _, e := range entries { overwritten = append(overwritten, yfast.insert(e)) } return overwritten } func (yfast *YFastTrie) delete(key uint64) Entry { bundleKey := yfast.getBucketKey(key) bundle := yfast.xfast.Get(bundleKey) if bundle == nil { return nil } ew := bundle.(*entriesWrapper) entry := ew.entries.delete(key) if entry == nil { return nil } yfast.num-- if len(ew.entries) == 0 { yfast.xfast.Delete(bundleKey) } return entry } // Delete will delete the provided keys from the y-fast trie // and return a list of entries that were deleted. func (yfast *YFastTrie) Delete(keys ...uint64) Entries { entries := make(Entries, 0, len(keys)) for _, key := range keys { entries = append(entries, yfast.delete(key)) } return entries } func (yfast *YFastTrie) get(key uint64) Entry { bundleKey := yfast.getBucketKey(key) bundle := yfast.xfast.Get(bundleKey) if bundle == nil { return nil } entry := bundle.(*entriesWrapper).entries.get(key) if entry == nil { // go interfaces :( return nil } return entry } // Get will look for the provided key in the y-fast trie and return // the associated value if it is found. If it is not found, this // method returns nil. func (yfast *YFastTrie) Get(key uint64) Entry { entry := yfast.get(key) if entry == nil { return nil } return entry } // Len returns the number of items in the y-fast trie. func (yfast *YFastTrie) Len() uint64 { return yfast.num } func (yfast *YFastTrie) successor(key uint64) Entry { bundle := yfast.xfast.Successor(key) if bundle == nil { return nil } entry, _ := bundle.(*entriesWrapper).entries.successor(key) if entry == nil { return nil } return entry } // Successor returns an Entry with a key equal to or immediately // greater than the provided key. If such an Entry does not exist // this returns nil. func (yfast *YFastTrie) Successor(key uint64) Entry { entry := yfast.successor(key) if entry == nil { return nil } return entry } func (yfast *YFastTrie) predecessor(key uint64) Entry { // harder case because our representative value in the // x-fast trie is the a max bundleKey := yfast.getBucketKey(key) bundle := yfast.xfast.Predecessor(bundleKey) if bundle == nil { return nil } ew := bundle.(*entriesWrapper) entry, _ := ew.entries.predecessor(key) if entry != nil { return entry } // it's possible we do exist somewhere earlier in the x-fast trie bundle = yfast.xfast.Predecessor(bundleKey - 1) if bundle == nil { return nil } ew = bundle.(*entriesWrapper) entry, _ = ew.entries.predecessor(key) if entry == nil { return nil } return entry } // Predecessor returns an Entry with a key equal to or immediately // preceding than the provided key. If such an Entry does not exist // this returns nil. func (yfast *YFastTrie) Predecessor(key uint64) Entry { entry := yfast.predecessor(key) if entry == nil { return nil } return entry } func (yfast *YFastTrie) iter(key uint64) *Iterator { xfastIter := yfast.xfast.Iter(key) xfastIter.Next() bundle := xfastIter.Value() if bundle == nil { return nilIterator() } i := bundle.(*entriesWrapper).entries.search(key) return &Iterator{ index: i - 1, xfastIterator: xfastIter, entries: bundle.(*entriesWrapper), } } // Iter will return an iterator that will iterate across all values // that start or immediately proceed the provided key. Iteration // happens in ascending order. func (yfast *YFastTrie) Iter(key uint64) *Iterator { return yfast.iter(key) } // New constructs, initializes, and returns a new y-fast trie. // Provided should be a uint type that specifies the number // of bits in the desired universe. This will affect the time // complexity of all lookup and mutate operations. func New(ifc interface{}) *YFastTrie { yfast := &YFastTrie{} yfast.init(ifc) return yfast } ================================================ FILE: trie/yfast/yfast_test.go ================================================ /* Copyright 2014 Workiva, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package yfast import ( "testing" "github.com/stretchr/testify/assert" ) func generateEntries(num int) Entries { entries := make(Entries, 0, num) for i := uint64(0); i < uint64(num); i++ { entries = append(entries, newMockEntry(i)) } return entries } func TestTrieSimpleInsert(t *testing.T) { yfast := New(uint8(0)) e1 := newMockEntry(3) e2 := newMockEntry(7) e3 := newMockEntry(8) yfast.Insert(e1, e2, e3) result := yfast.get(3) assert.Equal(t, e1, result) result = yfast.get(7) assert.Equal(t, e2, result) result = yfast.get(8) assert.Equal(t, e3, result) result = yfast.get(250) assert.Nil(t, result) assert.Equal(t, uint64(3), yfast.Len()) } func TestTrieOverwriteInsert(t *testing.T) { yfast := New(uint8(0)) e1 := newMockEntry(3) e2 := newMockEntry(3) yfast.Insert(e1) yfast.Insert(e2) assert.Equal(t, e2, yfast.Get(3)) assert.Equal(t, uint64(1), yfast.Len()) } func TestTrieDelete(t *testing.T) { yfast := New(uint8(0)) e1 := newMockEntry(3) e2 := newMockEntry(7) e3 := newMockEntry(8) yfast.Insert(e1, e2, e3) result := yfast.Delete(3) assert.Equal(t, Entries{e1}, result) assert.Nil(t, yfast.Get(3)) assert.Equal(t, uint64(2), yfast.Len()) result = yfast.Delete(7) assert.Equal(t, Entries{e2}, result) assert.Nil(t, yfast.Get(7)) assert.Equal(t, uint64(1), yfast.Len()) result = yfast.Delete(8) assert.Equal(t, Entries{e3}, result) assert.Nil(t, yfast.Get(8)) assert.Equal(t, uint64(0), yfast.Len()) result = yfast.Delete(5) assert.Equal(t, Entries{nil}, result) assert.Equal(t, uint64(0), yfast.Len()) } func TestTrieSuccessor(t *testing.T) { yfast := New(uint8(0)) e3 := newMockEntry(13) yfast.Insert(e3) successor := yfast.Successor(0) assert.Equal(t, e3, successor) e1 := newMockEntry(3) e2 := newMockEntry(7) yfast.Insert(e1, e2) successor = yfast.Successor(0) assert.Equal(t, e1, successor) successor = yfast.Successor(3) assert.Equal(t, e1, successor) successor = yfast.Successor(4) assert.Equal(t, e2, successor) successor = yfast.Successor(8) assert.Equal(t, e3, successor) successor = yfast.Successor(14) assert.Nil(t, successor) successor = yfast.Successor(100) assert.Nil(t, successor) } func TestTriePredecessor(t *testing.T) { yfast := New(uint8(0)) predecessor := yfast.Predecessor(5) assert.Nil(t, predecessor) e1 := newMockEntry(5) yfast.Insert(e1) predecessor = yfast.Predecessor(13) assert.Equal(t, e1, predecessor) e2 := newMockEntry(12) yfast.Insert(e2) predecessor = yfast.Predecessor(11) assert.Equal(t, e1, predecessor) predecessor = yfast.Predecessor(5) assert.Equal(t, e1, predecessor) predecessor = yfast.Predecessor(4) assert.Nil(t, predecessor) predecessor = yfast.Predecessor(100) assert.Equal(t, e2, predecessor) } func TestTrieIterator(t *testing.T) { yfast := New(uint8(0)) iter := yfast.Iter(5) assert.Equal(t, Entries{}, iter.exhaust()) e1 := newMockEntry(5) yfast.Insert(e1) iter = yfast.Iter(5) assert.Equal(t, Entries{e1}, iter.exhaust()) e2 := newMockEntry(12) yfast.Insert(e2) iter = yfast.Iter(5) assert.Equal(t, Entries{e1, e2}, iter.exhaust()) iter = yfast.Iter(6) assert.Equal(t, Entries{e2}, iter.exhaust()) e3 := newMockEntry(6) yfast.Insert(e3) iter = yfast.Iter(7) assert.Equal(t, Entries{e2}, iter.exhaust()) iter = yfast.Iter(0) assert.Equal(t, Entries{e1, e3, e2}, iter.exhaust()) iter = yfast.Iter(13) assert.Equal(t, Entries{}, iter.exhaust()) } func BenchmarkInsert(b *testing.B) { yfast := New(uint64(0)) entries := generateEntries(b.N) b.ResetTimer() for i := 0; i < b.N; i++ { yfast.Insert(entries[i]) } } func BenchmarkGet(b *testing.B) { numItems := 1000 entries := generateEntries(numItems) yfast := New(uint32(0)) yfast.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { yfast.Get(uint64(numItems / 2)) } } func BenchmarkDelete(b *testing.B) { entries := generateEntries(b.N) yfast := New(uint64(0)) yfast.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { yfast.Delete(uint64(i)) } } func BenchmarkSuccessor(b *testing.B) { numItems := 100000 entries := make(Entries, 0, numItems) for i := uint64(0); i < uint64(numItems); i++ { entries = append(entries, newMockEntry(i+uint64(b.N/2))) } yfast := New(uint64(0)) yfast.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { yfast.Successor(uint64(i)) } } func BenchmarkPredecessor(b *testing.B) { numItems := 100000 entries := make(Entries, 0, numItems) for i := uint64(0); i < uint64(numItems); i++ { entries = append(entries, newMockEntry(i+uint64(b.N/2))) } yfast := New(uint64(0)) yfast.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { yfast.Predecessor(uint64(i)) } } func BenchmarkIterator(b *testing.B) { numItems := 1000 entries := generateEntries(numItems) yfast := New(uint64(0)) yfast.Insert(entries...) b.ResetTimer() for i := 0; i < b.N; i++ { for iter := yfast.Iter(0); iter.Next(); { iter.Value() } } }