Repository: greensky00/skiplist Branch: master Commit: 7f4420803885 Files: 16 Total size: 164.3 KB Directory structure: gitextract_yurccgfy/ ├── LICENSE ├── Makefile ├── README.md ├── debug/ │ └── skiplist_debug.h ├── examples/ │ ├── cpp_map_example.cc │ ├── cpp_set_example.cc │ └── pure_c_example.c ├── include/ │ ├── skiplist.h │ ├── sl_map.h │ └── sl_set.h ├── src/ │ └── skiplist.cc └── tests/ ├── container_test.cc ├── mt_test.cc ├── skiplist_test.cc ├── stl_map_compare.cc └── test_common.h ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2017 Jung-Sang Ahn Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ LDFLAGS = -pthread CFLAGS = \ -g -D_GNU_SOURCE \ -I. -I./src -I./debug -I./include -I./examples -I./tests \ -fPIC \ CFLAGS += -Wall CFLAGS += -O3 #CFLAGS += -fsanitize=address -fuse-ld=gold CXXFLAGS = $(CFLAGS) \ --std=c++11 \ SKIPLIST = src/skiplist.o SHARED_LIB = libskiplist.so STATIC_LIB = libskiplist.a TEST = \ tests/skiplist_test.o \ $(STATIC_LIB) \ MT_TEST = \ tests/mt_test.o \ $(STATIC_LIB) \ STL_MAP_COMPARE = \ tests/stl_map_compare.o \ $(STATIC_LIB) \ CONTAINER_TEST = \ tests/container_test.o \ $(STATIC_LIB) \ PURE_C_EXAMPLE = \ examples/pure_c_example.o \ $(STATIC_LIB) \ CPP_MAP_EXAMPLE = \ examples/cpp_map_example.o \ $(STATIC_LIB) \ CPP_SET_EXAMPLE = \ examples/cpp_set_example.o \ $(STATIC_LIB) \ PROGRAMS = \ tests/skiplist_test \ tests/mt_test \ tests/container_test \ tests/stl_map_compare \ examples/pure_c_example \ examples/cpp_set_example \ examples/cpp_map_example \ libskiplist.so \ libskiplist.a \ all: $(PROGRAMS) libskiplist.so: $(SKIPLIST) $(CXX) $(CXXFLAGS) -shared $(LDBFALGS) -o $(SHARED_LIB) $(SKIPLIST) libskiplist.a: $(SKIPLIST) ar rcs $(STATIC_LIB) $(SKIPLIST) tests/skiplist_test: $(TEST) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) tests/mt_test: $(MT_TEST) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) tests/container_test: $(CONTAINER_TEST) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) tests/stl_map_compare: $(STL_MAP_COMPARE) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) examples/pure_c_example: $(PURE_C_EXAMPLE) $(CC) $(CFLAGS) $^ -o $@ $(LDFLAGS) examples/cpp_map_example: $(CPP_MAP_EXAMPLE) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) examples/cpp_set_example: $(CPP_SET_EXAMPLE) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) clean: rm -rf $(PROGRAMS) ./*.o ./*.so ./*/*.o ./*/*.so ================================================ FILE: README.md ================================================ Skiplist -------- A generic [Skiplist](https://en.wikipedia.org/wiki/Skip_list) container C implementation, lock-free for both multiple readers and writers. It can be used as a set or a map, containing any type of data. It basically uses STL atomic variables with C++ compiler, but they can be switched to built-in GCC atomic operations when we compile it with pure C compiler. This repository also includes STL-style lock-free `set` and `map` containers, based on Skiplist implementation. Author ------ Jung-Sang Ahn Build ----- ```sh $ make ``` How to use ---------- Copy [`skiplist.cc`](src/skiplist.cc) file and [`include`](include) files to your source repository. Or, use library file (`libskiplist.so` or `libskiplist.a`). * Pure C [examples/pure_c_example.c](examples/pure_c_example.c) * C++ (STL-style `set` and `map`) [examples/cpp_set_example.cc](examples/cpp_set_example.cc) [examples/cpp_map_example.cc](examples/cpp_map_example.cc) Benchmark results ----------------- * Skiplist vs. STL set + STL mutex * Single writer and multiple readers * Randomly insert and read 100K integers ![alt text](https://github.com/greensky00/skiplist/blob/master/docs/swmr_graph.png "Throughput") ================================================ FILE: debug/skiplist_debug.h ================================================ /** * Copyright (C) 2017-present Jung-Sang Ahn * All rights reserved. * * https://github.com/greensky00 * * Skiplist * Version: 0.2.5 * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, * copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ #pragma once #include #include #include "skiplist.h" struct dbg_node { skiplist_node snode; int value; }; #if __SL_DEBUG >= 1 #undef __SLD_ASSERT #undef __SLD_ #define __SLD_ASSERT(cond) assert(cond) #define __SLD_(b) b #endif #if __SL_DEBUG >= 2 #undef __SLD_P #define __SLD_P(args...) printf(args) #endif #if __SL_DEBUG >= 3 #undef __SLD_RT_INS #undef __SLD_NC_INS #undef __SLD_RT_RMV #undef __SLD_NC_RMV #undef __SLD_BM #define __SLD_RT_INS(e, n, t, c) __sld_rt_ins(e, n, t, c) #define __SLD_NC_INS(n, nn, t, c) __sld_nc_ins(n, nn, t, c) #define __SLD_RT_RMV(e, n, t, c) __sld_rt_rmv(e, n, t, c) #define __SLD_NC_RMV(n, nn, t, c) __sld_nc_rmv(n, nn, t, c) #define __SLD_BM(n) __sld_bm(n) #endif #if __SL_DEBUG >= 4 #error "unknown debugging level" #endif inline void __sld_rt_ins(int error_code, skiplist_node *node, int top_layer, int cur_layer) { dbg_node *ddd = _get_entry(node, dbg_node, snode); printf("[INS] retry (code %d) " "%p (top %d, cur %d) %d\n", error_code, node, top_layer, cur_layer, ddd->value); } inline void __sld_nc_ins(skiplist_node *node, skiplist_node *next_node, int top_layer, int cur_layer) { dbg_node *ddd = _get_entry(node, dbg_node, snode); dbg_node *ddd_next = _get_entry(next_node, dbg_node, snode); printf("[INS] next node changed, " "%p %p (top %d, cur %d) %d %d\n", node, next_node, top_layer, cur_layer, ddd->value, ddd_next->value); } inline void __sld_rt_rmv(int error_code, skiplist_node *node, int top_layer, int cur_layer) { dbg_node *ddd = _get_entry(node, dbg_node, snode); printf("[RMV] retry (code %d) " "%p (top %d, cur %d) %d\n", error_code, node, top_layer, cur_layer, ddd->value); } inline void __sld_nc_rmv(skiplist_node *node, skiplist_node *next_node, int top_layer, int cur_layer) { dbg_node *ddd = _get_entry(node, dbg_node, snode); dbg_node *ddd_next = _get_entry(next_node, dbg_node, snode); printf("[RMV] next node changed, " "%p %p (top %d, cur %d) %d %d\n", node, next_node, top_layer, cur_layer, ddd->value, ddd_next->value); } inline void __sld_bm(skiplist_node *node) { dbg_node *ddd = _get_entry(node, dbg_node, snode); printf("[RMV] node is being modified %d\n", ddd->value); } ================================================ FILE: examples/cpp_map_example.cc ================================================ #include "sl_map.h" #include int main() { // sl_map: Busy-waiting implementation. // erase() API may be blocked by concurrent // operations dealing with iterator on the // same key. // // sl_map_gc: Lazy reclaiming implementation. // erase() API will not be blocked by // any concurrent operations, but may // consume more memory. // sl_map slist; sl_map_gc slist; // << Insertion >> // Insert 3 KV pairs: {0, 0}, {1, 10}, {2, 20}. for (int i=0; i<3; ++i) { slist.insert(std::make_pair(i, i*10)); } // << Point lookup >> for (int i=0; i<3; ++i) { auto itr = slist.find(i); if (itr == slist.end()) continue; // Not found. printf("[point lookup] key: %d, value: %d\n", itr->first, itr->second); // Note: In `sl_map`, while `itr` is alive and holding a node // in skiplist, other thread cannot erase and free the node. // Same as `shared_ptr`, `itr` will automatically release // the node when it is not referred anymore. // But if you want to release the node before that, // you can do it as follows: // itr = slist.end(); } // << Erase >> // Erase the KV pair for key 1: {1, 10}. slist.erase(1); // << Iteration >> for (auto& entry: slist) { printf("[iteration] key: %d, value: %d\n", entry.first, entry.second); } return 0; } ================================================ FILE: examples/cpp_set_example.cc ================================================ #include "sl_set.h" #include int main() { // sl_set: Busy-waiting implementation. // erase() API may be blocked by concurrent // operations dealing with iterator on the // same key. // // sl_set_gc: Lazy reclaiming implementation. // erase() API will not be blocked by // any concurrent operations, but may // consume more memory. // sl_set slist; sl_set_gc slist; // << Insertion >> // Insert 3 integers: 0, 1, and 2. for (int i=0; i<3; ++i) { slist.insert(i); } // << Point lookup >> for (int i=0; i<3; ++i) { auto itr = slist.find(i); if (itr == slist.end()) continue; // Not found. printf("[point lookup] %d\n", *itr); // Note: In `sl_set`, while `itr` is alive and holding a node // in skiplist, other thread cannot erase and free the node. // Same as `shared_ptr`, `itr` will automatically release // the node when it is not referred anymore. // But if you want to release the node before that, // you can do it as follows: // itr = slist.end(); } // << Erase >> // Erase 1. slist.erase(1); // << Iteration >> for (auto& entry: slist) { printf("[iteration] %d\n", entry); } return 0; } ================================================ FILE: examples/pure_c_example.c ================================================ #include "skiplist.h" #include #include // Define a node that contains key and value pair. struct my_node { // Metadata for skiplist node. skiplist_node snode; // My data here: {int, int} pair. int key; int value; }; // Define a comparison function for `my_node`. static int my_cmp(skiplist_node* a, skiplist_node* b, void* aux) { // Get `my_node` from skiplist node `a` and `b`. struct my_node *aa, *bb; aa = _get_entry(a, struct my_node, snode); bb = _get_entry(b, struct my_node, snode); // aa < bb: return neg // aa == bb: return 0 // aa > bb: return pos if (aa->key < bb->key) return -1; if (aa->key > bb->key) return 1; return 0; } int main() { skiplist_raw slist; // Initialize skiplist. skiplist_init(&slist, my_cmp); // << Insertion >> // Allocate & insert 3 KV pairs: {0, 0}, {1, 10}, {2, 20}. struct my_node* nodes[3]; for (int i=0; i<3; ++i) { // Allocate memory. nodes[i] = (struct my_node*)malloc(sizeof(struct my_node)); // Initialize node. skiplist_init_node(&nodes[i]->snode); // Assign key and value. nodes[i]->key = i; nodes[i]->value = i*10; // Insert into skiplist. skiplist_insert(&slist, &nodes[i]->snode); } // << Point lookup >> for (int i=0; i<3; ++i) { // Define a query. struct my_node query; query.key = i; // Find a skiplist node `cursor`. skiplist_node* cursor = skiplist_find(&slist, &query.snode); // If `cursor` is NULL, key doesn't exist. if (!cursor) continue; // Get `my_node` from `cursor`. // Note: found->snode == *cursor struct my_node* found = _get_entry(cursor, struct my_node, snode); printf("[point lookup] key: %d, value: %d\n", found->key, found->value); // Release `cursor` (== &found->snode). // Other thread cannot free `cursor` until `cursor` is released. skiplist_release_node(cursor); } // << Erase >> // Erase the KV pair for key 1: {1, 10}. { // Define a query. struct my_node query; query.key = 1; // Find a skiplist node `cursor`. skiplist_node* cursor = skiplist_find(&slist, &query.snode); // Get `my_node` from `cursor`. // Note: found->snode == *cursor struct my_node* found = _get_entry(cursor, struct my_node, snode); printf("[erase] key: %d, value: %d\n", found->key, found->value); // Detach `found` from skiplist. skiplist_erase_node(&slist, &found->snode); // Release `found`, to free its memory. skiplist_release_node(&found->snode); // Free `found` after it becomes safe. skiplist_wait_for_free(&found->snode); skiplist_free_node(&found->snode); free(found); } // << Iteration >> { // Get the first cursor. skiplist_node* cursor = skiplist_begin(&slist); while (cursor) { // Get `entry` from `cursor`. // Note: entry->snode == *cursor struct my_node* entry = _get_entry(cursor, struct my_node, snode); printf("[iteration] key: %d, value: %d\n", entry->key, entry->value); // Get next `cursor`. cursor = skiplist_next(&slist, cursor); // Release `entry`. skiplist_release_node(&entry->snode); } } // << Destroy >> { // Iterate and free all nodes. skiplist_node* cursor = skiplist_begin(&slist); while (cursor) { struct my_node* entry = _get_entry(cursor, struct my_node, snode); printf("[destroy] key: %d, value: %d\n", entry->key, entry->value); // Get next `cursor`. cursor = skiplist_next(&slist, cursor); // Detach `entry` from skiplist. skiplist_erase_node(&slist, &entry->snode); // Release `entry`, to free its memory. skiplist_release_node(&entry->snode); // Free `entry` after it becomes safe. skiplist_wait_for_free(&entry->snode); skiplist_free_node(&entry->snode); free(entry); } } // Free skiplist. skiplist_free(&slist); return 0; } ================================================ FILE: include/skiplist.h ================================================ /** * Copyright (C) 2017-present Jung-Sang Ahn * All rights reserved. * * https://github.com/greensky00 * * Skiplist * Version: 0.2.9 * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, * copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ #ifndef _JSAHN_SKIPLIST_H #define _JSAHN_SKIPLIST_H (1) #include #include #define SKIPLIST_MAX_LAYER (64) struct _skiplist_node; //#define _STL_ATOMIC (1) #ifdef __APPLE__ #define _STL_ATOMIC (1) #endif #if defined(_STL_ATOMIC) && defined(__cplusplus) #include typedef std::atomic<_skiplist_node*> atm_node_ptr; typedef std::atomic atm_bool; typedef std::atomic atm_uint8_t; typedef std::atomic atm_uint16_t; typedef std::atomic atm_uint32_t; #else typedef struct _skiplist_node* atm_node_ptr; typedef uint8_t atm_bool; typedef uint8_t atm_uint8_t; typedef uint16_t atm_uint16_t; typedef uint32_t atm_uint32_t; #endif #ifdef __cplusplus extern "C" { #endif typedef struct _skiplist_node { atm_node_ptr *next; atm_bool is_fully_linked; atm_bool being_modified; atm_bool removed; uint8_t top_layer; // 0: bottom atm_uint16_t ref_count; atm_uint32_t accessing_next; } skiplist_node; // *a < *b : return neg // *a == *b : return 0 // *a > *b : return pos typedef int skiplist_cmp_t(skiplist_node *a, skiplist_node *b, void *aux); typedef struct { size_t fanout; size_t maxLayer; void *aux; } skiplist_raw_config; typedef struct { skiplist_node head; skiplist_node tail; skiplist_cmp_t *cmp_func; void *aux; atm_uint32_t num_entries; atm_uint32_t* layer_entries; atm_uint8_t top_layer; uint8_t fanout; uint8_t max_layer; } skiplist_raw; #ifndef _get_entry #define _get_entry(ELEM, STRUCT, MEMBER) \ ((STRUCT *) ((uint8_t *) (ELEM) - offsetof (STRUCT, MEMBER))) #endif void skiplist_init(skiplist_raw* slist, skiplist_cmp_t* cmp_func); void skiplist_free(skiplist_raw* slist); void skiplist_init_node(skiplist_node* node); void skiplist_free_node(skiplist_node* node); size_t skiplist_get_size(skiplist_raw* slist); skiplist_raw_config skiplist_get_default_config(); skiplist_raw_config skiplist_get_config(skiplist_raw* slist); void skiplist_set_config(skiplist_raw* slist, skiplist_raw_config config); int skiplist_insert(skiplist_raw* slist, skiplist_node* node); int skiplist_insert_nodup(skiplist_raw *slist, skiplist_node *node); skiplist_node* skiplist_find(skiplist_raw* slist, skiplist_node* query); skiplist_node* skiplist_find_smaller_or_equal(skiplist_raw* slist, skiplist_node* query); skiplist_node* skiplist_find_greater_or_equal(skiplist_raw* slist, skiplist_node* query); int skiplist_erase_node_passive(skiplist_raw* slist, skiplist_node* node); int skiplist_erase_node(skiplist_raw *slist, skiplist_node *node); int skiplist_erase(skiplist_raw* slist, skiplist_node* query); int skiplist_is_valid_node(skiplist_node* node); int skiplist_is_safe_to_free(skiplist_node* node); void skiplist_wait_for_free(skiplist_node* node); void skiplist_grab_node(skiplist_node* node); void skiplist_release_node(skiplist_node* node); skiplist_node* skiplist_next(skiplist_raw* slist, skiplist_node* node); skiplist_node* skiplist_prev(skiplist_raw* slist, skiplist_node* node); skiplist_node* skiplist_begin(skiplist_raw* slist); skiplist_node* skiplist_end(skiplist_raw* slist); #ifdef __cplusplus } #endif #endif // _JSAHN_SKIPLIST_H ================================================ FILE: include/sl_map.h ================================================ /** * Copyright (C) 2017-present Jung-Sang Ahn * All rights reserved. * * https://github.com/greensky00 * * Skiplist map container * Version: 0.2.0 * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, * copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ #pragma once #include "skiplist.h" #include #include #include #include #include #include template struct map_node { map_node() { skiplist_init_node(&snode); } ~map_node() { skiplist_free_node(&snode); } static int cmp(skiplist_node* a, skiplist_node* b, void* aux) { map_node *aa, *bb; aa = _get_entry(a, map_node, snode); bb = _get_entry(b, map_node, snode); if (aa->kv.first < bb->kv.first) return -1; if (aa->kv.first > bb->kv.first) return 1; return 0; } skiplist_node snode; std::pair kv; }; template class sl_map; template class sl_map_gc; template class map_iterator { friend class sl_map; friend class sl_map_gc; private: using T = std::pair; using Node = map_node; public: map_iterator() : slist(nullptr), cursor(nullptr) {} map_iterator(map_iterator&& src) : slist(src.slist), cursor(src.cursor) { // Mimic perfect forwarding. src.slist = nullptr; src.cursor = nullptr; } ~map_iterator() { if (cursor) skiplist_release_node(cursor); } void operator=(const map_iterator& src) { // This reference counting is similar to that of shared_ptr. skiplist_node* tmp = cursor; if (src.cursor) skiplist_grab_node(src.cursor); cursor = src.cursor; if (tmp) skiplist_release_node(tmp); } bool operator==(const map_iterator& src) const { return (cursor == src.cursor); } bool operator!=(const map_iterator& src) const { return !operator==(src); } T* operator->() const { Node* node = _get_entry(cursor, Node, snode); return &node->kv; } T& operator*() const { Node* node = _get_entry(cursor, Node, snode); return node->kv; } // ++A map_iterator& operator++() { if (!slist || !cursor) { cursor = nullptr; return *this; } skiplist_node* next = skiplist_next(slist, cursor); skiplist_release_node(cursor); cursor = next; return *this; } // A++ map_iterator& operator++(int) { return operator++(); } // --A map_iterator& operator--() { if (!slist || !cursor) { cursor = nullptr; return *this; } skiplist_node* prev = skiplist_prev(slist, cursor); skiplist_release_node(cursor); cursor = prev; return *this; } // A-- map_iterator operator--(int) { return operator--(); } private: map_iterator(skiplist_raw* _slist, skiplist_node* _cursor) : slist(_slist), cursor(_cursor) {} skiplist_raw* slist; skiplist_node* cursor; }; template class sl_map { private: using T = std::pair; using Node = map_node; public: using iterator = map_iterator; using reverse_iterator = map_iterator; sl_map() { skiplist_init(&slist, Node::cmp); } virtual ~sl_map() { skiplist_node* cursor = skiplist_begin(&slist); while (cursor) { Node* node = _get_entry(cursor, Node, snode); cursor = skiplist_next(&slist, cursor); // Don't need to care about release. delete node; } skiplist_free(&slist); } bool empty() { skiplist_node* cursor = skiplist_begin(&slist); if (cursor) { skiplist_release_node(cursor); return false; } return true; } size_t size() { return skiplist_get_size(&slist); } std::pair insert(const std::pair& kv) { do { Node* node = new Node(); node->kv = kv; int rc = skiplist_insert_nodup(&slist, &node->snode); if (rc == 0) { skiplist_grab_node(&node->snode); return std::pair ( iterator(&slist, &node->snode), true ); } delete node; Node query; query.kv.first = kv.first; skiplist_node* cursor = skiplist_find(&slist, &query.snode); if (cursor) { return std::pair ( iterator(&slist, cursor), false ); } } while (true); // NOTE: Should not reach here. return std::pair(iterator(), false); } iterator find(K key) { Node query; query.kv.first = key; skiplist_node* cursor = skiplist_find(&slist, &query.snode); return iterator(&slist, cursor); } virtual iterator erase(iterator& position) { skiplist_node* cursor = position.cursor; skiplist_node* next = skiplist_next(&slist, cursor); skiplist_erase_node(&slist, cursor); skiplist_release_node(cursor); skiplist_wait_for_free(cursor); Node* node = _get_entry(cursor, Node, snode); delete node; position.cursor = nullptr; return iterator(&slist, next); } virtual size_t erase(const K& key) { size_t count = 0; Node query; query.kv.first = key; skiplist_node* cursor = skiplist_find(&slist, &query.snode); while (cursor) { Node* node = _get_entry(cursor, Node, snode); if (node->kv.first != key) break; cursor = skiplist_next(&slist, cursor); skiplist_erase_node(&slist, &node->snode); skiplist_release_node(&node->snode); skiplist_wait_for_free(&node->snode); delete node; } if (cursor) skiplist_release_node(cursor); return count; } iterator begin() { skiplist_node* cursor = skiplist_begin(&slist); return iterator(&slist, cursor); } iterator end() { return iterator(); } reverse_iterator rbegin() { skiplist_node* cursor = skiplist_end(&slist); return reverse_iterator(&slist, cursor); } reverse_iterator rend() { return reverse_iterator(); } protected: skiplist_raw slist; }; template class sl_map_gc : public sl_map { private: using T = std::pair; using Node = map_node; public: using iterator = map_iterator; using reverse_iterator = map_iterator; sl_map_gc() : sl_map() , gcVector( std::max( (size_t)4, (size_t)std::thread::hardware_concurrency() ) ) { for (auto& entry: gcVector) { entry = new std::atomic(nullptr); } } ~sl_map_gc() { execGc(); for (std::atomic*& a_node: gcVector) { Node* node = a_node->load(); delete node; delete a_node; } } iterator erase(iterator& position) { skiplist_node* cursor = position.cursor; skiplist_node* next = skiplist_next(&this->slist, cursor); skiplist_erase_node(&this->slist, cursor); Node* node = _get_entry(cursor, Node, snode); gcPush(node); skiplist_release_node(cursor); execGc(); position.cursor = nullptr; return iterator(&this->slist, next); } size_t erase(const K& key) { size_t count = 0; Node query; query.kv.first = key; skiplist_node* cursor = skiplist_find(&this->slist, &query.snode); while (cursor) { Node* node = _get_entry(cursor, Node, snode); if (node->kv.first != key) break; cursor = skiplist_next(&this->slist, cursor); skiplist_erase_node(&this->slist, &node->snode); gcPush(node); skiplist_release_node(&node->snode); } if (cursor) skiplist_release_node(cursor); execGc(); return count; } private: void gcPush(Node* node) { size_t v_len = gcVector.size(); do { size_t rr = std::rand() % v_len; for (size_t ii = rr; ii < rr + v_len; ++ii) { std::atomic& a_node = *gcVector[ii % v_len]; Node* exp = nullptr; if ( a_node.compare_exchange_strong ( exp, node, std::memory_order_relaxed ) ) { return; } } std::this_thread::yield(); execGc(); } while (true); } void execGc() { std::unique_lock l(gcVectorLock, std::try_to_lock); if (!l.owns_lock()) return; size_t v_len = gcVector.size(); for (size_t ii = 0; ii < v_len; ++ii) { std::atomic& a_node = *gcVector[ii]; Node* node = a_node.load(); if (!node) continue; if (skiplist_is_safe_to_free(&node->snode)) { Node* exp = node; Node* val = nullptr; a_node.compare_exchange_strong ( exp, val, std::memory_order_relaxed ); delete node; } } } std::mutex gcVectorLock; std::vector< std::atomic* > gcVector; }; ================================================ FILE: include/sl_set.h ================================================ /** * Copyright (C) 2017-present Jung-Sang Ahn * All rights reserved. * * https://github.com/greensky00 * * Skiplist set container * Version: 0.2.0 * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, * copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ #pragma once #include "skiplist.h" #include #include #include #include #include #include template struct set_node { set_node() { skiplist_init_node(&snode); } ~set_node() { skiplist_free_node(&snode); } static int cmp(skiplist_node* a, skiplist_node* b, void* aux) { set_node *aa, *bb; aa = _get_entry(a, set_node, snode); bb = _get_entry(b, set_node, snode); if (aa->key < bb->key) return -1; if (aa->key > bb->key) return 1; return 0; } skiplist_node snode; K key; }; template class sl_set; template class sl_set_gc; template class set_iterator { friend class sl_set; friend class sl_set_gc; public: using Node = set_node; set_iterator() : slist(nullptr), cursor(nullptr) {} set_iterator(set_iterator&& src) : slist(src.slist), cursor(src.cursor) { // Mimic perfect forwarding. src.slist = nullptr; src.cursor = nullptr; } ~set_iterator() { if (cursor) skiplist_release_node(cursor); } void operator=(const set_iterator& src) { // This reference counting is similar to that of shared_ptr. skiplist_node* tmp = cursor; if (src.cursor) skiplist_grab_node(src.cursor); cursor = src.cursor; if (tmp) skiplist_release_node(tmp); } bool operator==(const set_iterator& src) const { return (cursor == src.cursor); } bool operator!=(const set_iterator& src) const { return (cursor != src.cursor); } K& operator*() const { Node* node = _get_entry(cursor, Node, snode); return node->key; } // ++A set_iterator& operator++() { if (!slist || !cursor) { cursor = nullptr; return *this; } skiplist_node* next = skiplist_next(slist, cursor); skiplist_release_node(cursor); cursor = next; return *this; } // A++ set_iterator& operator++(int) { return operator++(); } // --A set_iterator& operator--() { if (!slist || !cursor) { cursor = nullptr; return *this; } skiplist_node* prev = skiplist_prev(slist, cursor); skiplist_release_node(cursor); cursor = prev; return *this; } // A-- set_iterator operator--(int) { return operator--(); } private: set_iterator(skiplist_raw* _slist, skiplist_node* _cursor) : slist(_slist), cursor(_cursor) {} skiplist_raw* slist; skiplist_node* cursor; }; template class sl_set { private: using Node = set_node; public: using iterator = set_iterator; using reverse_iterator = set_iterator; sl_set() { skiplist_init(&slist, Node::cmp); } virtual ~sl_set() { skiplist_node* cursor = skiplist_begin(&slist); while (cursor) { Node* node = _get_entry(cursor, Node, snode); cursor = skiplist_next(&slist, cursor); // Don't need to care about release. delete node; } skiplist_free(&slist); } bool empty() { skiplist_node* cursor = skiplist_begin(&slist); if (cursor) { skiplist_release_node(cursor); return false; } return true; } size_t size() { return skiplist_get_size(&slist); } std::pair insert(const K& key) { do { Node* node = new Node(); node->key = key; int rc = skiplist_insert_nodup(&slist, &node->snode); if (rc == 0) { skiplist_grab_node(&node->snode); return std::pair ( iterator(&slist, &node->snode), true ); } delete node; Node query; query.key = key; skiplist_node* cursor = skiplist_find(&slist, &query.snode); if (cursor) { return std::pair ( iterator(&slist, cursor), false ); } } while (true); // NOTE: Should not reach here. return std::pair(iterator(), false); } iterator find(const K& key) { Node query; query.key = key; skiplist_node* cursor = skiplist_find(&slist, &query.snode); return iterator(&slist, cursor); } virtual iterator erase(iterator& position) { skiplist_node* cursor = position.cursor; skiplist_node* next = skiplist_next(&slist, cursor); skiplist_erase_node(&slist, cursor); skiplist_release_node(cursor); skiplist_wait_for_free(cursor); Node* node = _get_entry(cursor, Node, snode); delete node; position.cursor = nullptr; return iterator(&slist, next); } virtual size_t erase(const K& key) { size_t count = 0; Node query; query.key = key; skiplist_node* cursor = skiplist_find(&slist, &query.snode); while (cursor) { Node* node = _get_entry(cursor, Node, snode); if (node->key != key) break; cursor = skiplist_next(&slist, cursor); skiplist_erase_node(&slist, &node->snode); skiplist_release_node(&node->snode); skiplist_wait_for_free(&node->snode); delete node; } if (cursor) skiplist_release_node(cursor); return count; } iterator begin() { skiplist_node* cursor = skiplist_begin(&slist); return iterator(&slist, cursor); } iterator end() { return iterator(); } reverse_iterator rbegin() { skiplist_node* cursor = skiplist_end(&slist); return reverse_iterator(&slist, cursor); } reverse_iterator rend() { return reverse_iterator(); } protected: skiplist_raw slist; }; template class sl_set_gc : public sl_set { private: using Node = set_node; public: using iterator = set_iterator; using reverse_iterator = set_iterator; sl_set_gc() : sl_set() , gcVector( std::max( (size_t)16, (size_t)std::thread::hardware_concurrency() ) ) { for (auto& entry: gcVector) { entry = new std::atomic(nullptr); } } ~sl_set_gc() { execGc(); for (std::atomic*& a_node: gcVector) { Node* node = a_node->load(); delete node; delete a_node; } } iterator erase(iterator& position) { skiplist_node* cursor = position.cursor; skiplist_node* next = skiplist_next(&this->slist, cursor); skiplist_erase_node(&this->slist, cursor); Node* node = _get_entry(cursor, Node, snode); gcPush(node); skiplist_release_node(cursor); execGc(); position.cursor = nullptr; return iterator(&this->slist, next); } size_t erase(const K& key) { size_t count = 0; Node query; query.key = key; skiplist_node* cursor = skiplist_find(&this->slist, &query.snode); while (cursor) { Node* node = _get_entry(cursor, Node, snode); if (node->key != key) break; cursor = skiplist_next(&this->slist, cursor); skiplist_erase_node(&this->slist, &node->snode); gcPush(node); skiplist_release_node(&node->snode); } if (cursor) skiplist_release_node(cursor); execGc(); return count; } private: void gcPush(Node* node) { size_t v_len = gcVector.size(); do { size_t rr = std::rand() % v_len; for (size_t ii = rr; ii < rr + v_len; ++ii) { std::atomic& a_node = *gcVector[ii % v_len]; Node* exp = nullptr; if ( a_node.compare_exchange_strong ( exp, node, std::memory_order_relaxed ) ) { return; } } std::this_thread::yield(); execGc(); } while (true); } void execGc() { std::unique_lock l(gcVectorLock, std::try_to_lock); if (!l.owns_lock()) return; size_t v_len = gcVector.size(); for (size_t ii = 0; ii < v_len; ++ii) { std::atomic& a_node = *gcVector[ii]; Node* node = a_node.load(); if (!node) continue; if (skiplist_is_safe_to_free(&node->snode)) { Node* exp = node; Node* val = nullptr; a_node.compare_exchange_strong ( exp, val, std::memory_order_relaxed ); delete node; } } } std::mutex gcVectorLock; std::vector< std::atomic* > gcVector; }; ================================================ FILE: src/skiplist.cc ================================================ /** * Copyright (C) 2017-present Jung-Sang Ahn * All rights reserved. * * https://github.com/greensky00 * * Skiplist * Version: 0.2.9 * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, * copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ #include "skiplist.h" #include #define __SLD_RT_INS(e, n, t, c) #define __SLD_NC_INS(n, nn, t, c) #define __SLD_RT_RMV(e, n, t, c) #define __SLD_NC_RMV(n, nn, t, c) #define __SLD_BM(n) #define __SLD_ASSERT(cond) #define __SLD_P(args...) #define __SLD_(b) //#define __SL_DEBUG (1) #ifdef __SL_DEBUG #ifndef __cplusplus #error "Debug mode is available with C++ compiler only." #endif #include "skiplist_debug.h" #endif #define __SL_YIELD (1) #ifdef __SL_YIELD #ifdef __cplusplus #include #define YIELD() std::this_thread::yield() #else #include #define YIELD() sched_yield() #endif #else #define YIELD() #endif #if defined(_STL_ATOMIC) && defined(__cplusplus) // C++ (STL) atomic operations #define MOR std::memory_order_relaxed #define ATM_GET(var) (var).load(MOR) #define ATM_LOAD(var, val) (val) = (var).load(MOR) #define ATM_STORE(var, val) (var).store((val), MOR) #define ATM_CAS(var, exp, val) (var).compare_exchange_weak((exp), (val)) #define ATM_FETCH_ADD(var, val) (var).fetch_add(val, MOR) #define ATM_FETCH_SUB(var, val) (var).fetch_sub(val, MOR) #define ALLOC_(type, var, count) (var) = new type[count] #define FREE_(var) delete[] (var) #else // C-style atomic operations #ifndef __cplusplus typedef uint8_t bool; #ifndef true #define true 1 #endif #ifndef false #define false 0 #endif #endif #ifndef __cplusplus #define thread_local /*_Thread_local*/ #endif #define MOR __ATOMIC_RELAXED #define ATM_GET(var) (var) #define ATM_LOAD(var, val) __atomic_load(&(var), &(val), MOR) #define ATM_STORE(var, val) __atomic_store(&(var), &(val), MOR) #define ATM_CAS(var, exp, val) \ __atomic_compare_exchange(&(var), &(exp), &(val), 1, MOR, MOR) #define ATM_FETCH_ADD(var, val) __atomic_fetch_add(&(var), (val), MOR) #define ATM_FETCH_SUB(var, val) __atomic_fetch_sub(&(var), (val), MOR) #define ALLOC_(type, var, count) \ (var) = (type*)calloc(count, sizeof(type)) #define FREE_(var) free(var) #endif static inline void _sl_node_init(skiplist_node *node, size_t top_layer) { if (top_layer > UINT8_MAX) top_layer = UINT8_MAX; __SLD_ASSERT(node->is_fully_linked == false); __SLD_ASSERT(node->being_modified == false); bool bool_val = false; ATM_STORE(node->is_fully_linked, bool_val); ATM_STORE(node->being_modified, bool_val); ATM_STORE(node->removed, bool_val); if (node->top_layer != top_layer || node->next == NULL) { node->top_layer = top_layer; if (node->next) FREE_(node->next); ALLOC_(atm_node_ptr, node->next, top_layer+1); } } void skiplist_init(skiplist_raw *slist, skiplist_cmp_t *cmp_func) { slist->cmp_func = NULL; slist->aux = NULL; // fanout 4 + layer 12: 4^12 ~= upto 17M items under O(lg n) complexity. // for +17M items, complexity will grow linearly: O(k lg n). slist->fanout = 4; slist->max_layer = 12; slist->num_entries = 0; ALLOC_(atm_uint32_t, slist->layer_entries, slist->max_layer); slist->top_layer = 0; skiplist_init_node(&slist->head); skiplist_init_node(&slist->tail); _sl_node_init(&slist->head, slist->max_layer); _sl_node_init(&slist->tail, slist->max_layer); size_t layer; for (layer = 0; layer < slist->max_layer; ++layer) { slist->head.next[layer] = &slist->tail; slist->tail.next[layer] = NULL; } bool bool_val = true; ATM_STORE(slist->head.is_fully_linked, bool_val); ATM_STORE(slist->tail.is_fully_linked, bool_val); slist->cmp_func = cmp_func; } void skiplist_free(skiplist_raw *slist) { skiplist_free_node(&slist->head); skiplist_free_node(&slist->tail); FREE_(slist->layer_entries); slist->layer_entries = NULL; slist->aux = NULL; slist->cmp_func = NULL; } void skiplist_init_node(skiplist_node *node) { node->next = NULL; bool bool_false = false; ATM_STORE(node->is_fully_linked, bool_false); ATM_STORE(node->being_modified, bool_false); ATM_STORE(node->removed, bool_false); node->accessing_next = 0; node->top_layer = 0; node->ref_count = 0; } void skiplist_free_node(skiplist_node *node) { FREE_(node->next); node->next = NULL; } size_t skiplist_get_size(skiplist_raw* slist) { uint32_t val; ATM_LOAD(slist->num_entries, val); return val; } skiplist_raw_config skiplist_get_default_config() { skiplist_raw_config ret; ret.fanout = 4; ret.maxLayer = 12; ret.aux = NULL; return ret; } skiplist_raw_config skiplist_get_config(skiplist_raw *slist) { skiplist_raw_config ret; ret.fanout = slist->fanout; ret.maxLayer = slist->max_layer; ret.aux = slist->aux; return ret; } void skiplist_set_config(skiplist_raw *slist, skiplist_raw_config config) { slist->fanout = config.fanout; slist->max_layer = config.maxLayer; if (slist->layer_entries) FREE_(slist->layer_entries); ALLOC_(atm_uint32_t, slist->layer_entries, slist->max_layer); slist->aux = config.aux; } static inline int _sl_cmp(skiplist_raw *slist, skiplist_node *a, skiplist_node *b) { if (a == b) return 0; if (a == &slist->head || b == &slist->tail) return -1; if (a == &slist->tail || b == &slist->head) return 1; return slist->cmp_func(a, b, slist->aux); } static inline bool _sl_valid_node(skiplist_node *node) { bool is_fully_linked = false; ATM_LOAD(node->is_fully_linked, is_fully_linked); return is_fully_linked; } static inline void _sl_read_lock_an(skiplist_node* node) { for(;;) { // Wait for active writer to release the lock uint32_t accessing_next = 0; ATM_LOAD(node->accessing_next, accessing_next); while (accessing_next & 0xfff00000) { YIELD(); ATM_LOAD(node->accessing_next, accessing_next); } ATM_FETCH_ADD(node->accessing_next, 0x1); ATM_LOAD(node->accessing_next, accessing_next); if ((accessing_next & 0xfff00000) == 0) { return; } ATM_FETCH_SUB(node->accessing_next, 0x1); } } static inline void _sl_read_unlock_an(skiplist_node* node) { ATM_FETCH_SUB(node->accessing_next, 0x1); } static inline void _sl_write_lock_an(skiplist_node* node) { for(;;) { // Wait for active writer to release the lock uint32_t accessing_next = 0; ATM_LOAD(node->accessing_next, accessing_next); while (accessing_next & 0xfff00000) { YIELD(); ATM_LOAD(node->accessing_next, accessing_next); } ATM_FETCH_ADD(node->accessing_next, 0x100000); ATM_LOAD(node->accessing_next, accessing_next); if((accessing_next & 0xfff00000) == 0x100000) { // Wait until there's no more readers while (accessing_next & 0x000fffff) { YIELD(); ATM_LOAD(node->accessing_next, accessing_next); } return; } ATM_FETCH_SUB(node->accessing_next, 0x100000); } } static inline void _sl_write_unlock_an(skiplist_node* node) { ATM_FETCH_SUB(node->accessing_next, 0x100000); } // Note: it increases the `ref_count` of returned node. // Caller is responsible to decrease it. static inline skiplist_node* _sl_next(skiplist_raw* slist, skiplist_node* cur_node, int layer, skiplist_node* node_to_find, bool* found) { skiplist_node *next_node = NULL; // Turn on `accessing_next`: // now `cur_node` is not removable from skiplist, // which means that `cur_node->next` will be consistent // until clearing `accessing_next`. _sl_read_lock_an(cur_node); { if (!_sl_valid_node(cur_node)) { _sl_read_unlock_an(cur_node); return NULL; } ATM_LOAD(cur_node->next[layer], next_node); // Increase ref count of `next_node`: // now `next_node` is not destroyable. // << Remaining issue >> // 1) initially: A -> B // 2) T1: call _sl_next(A): // A.accessing_next := true; // next_node := B; // ----- context switch happens here ----- // 3) T2: insert C: // A -> C -> B // 4) T2: and then erase B, and free B. // A -> C B(freed) // ----- context switch back again ----- // 5) T1: try to do something with B, // but crash happens. // // ... maybe resolved using RW spinlock (Aug 21, 2017). __SLD_ASSERT(next_node); ATM_FETCH_ADD(next_node->ref_count, 1); __SLD_ASSERT(next_node->top_layer >= layer); } _sl_read_unlock_an(cur_node); size_t num_nodes = 0; skiplist_node* nodes[256]; while ( (next_node && !_sl_valid_node(next_node)) || next_node == node_to_find ) { if (found && node_to_find == next_node) *found = true; skiplist_node* temp = next_node; _sl_read_lock_an(temp); { __SLD_ASSERT(next_node); if (!_sl_valid_node(temp)) { _sl_read_unlock_an(temp); ATM_FETCH_SUB(temp->ref_count, 1); next_node = NULL; break; } ATM_LOAD(temp->next[layer], next_node); ATM_FETCH_ADD(next_node->ref_count, 1); nodes[num_nodes++] = temp; __SLD_ASSERT(next_node->top_layer >= layer); } _sl_read_unlock_an(temp); } for (size_t ii=0; iiref_count, 1); } return next_node; } static inline size_t _sl_decide_top_layer(skiplist_raw *slist) { size_t layer = 0; while (layer+1 < slist->max_layer) { // coin filp if (rand() % slist->fanout == 0) { // grow: 1/fanout probability layer++; } else { // stop: 1 - 1/fanout probability break; } } return layer; } static inline void _sl_clr_flags(skiplist_node** node_arr, int start_layer, int top_layer) { int layer; for (layer = start_layer; layer <= top_layer; ++layer) { if ( layer == top_layer || node_arr[layer] != node_arr[layer+1] ) { bool exp = true; bool bool_false = false; if (!ATM_CAS(node_arr[layer]->being_modified, exp, bool_false)) { __SLD_ASSERT(0); } } } } static inline bool _sl_valid_prev_next(skiplist_node *prev, skiplist_node *next) { return _sl_valid_node(prev) && _sl_valid_node(next); } static inline int _skiplist_insert(skiplist_raw *slist, skiplist_node *node, bool no_dup) { __SLD_( thread_local std::thread::id tid = std::this_thread::get_id(); thread_local size_t tid_hash = std::hash{}(tid) % 256; (void)tid_hash; ) int top_layer = _sl_decide_top_layer(slist); bool bool_true = true; // init node before insertion _sl_node_init(node, top_layer); _sl_write_lock_an(node); skiplist_node* prevs[SKIPLIST_MAX_LAYER]; skiplist_node* nexts[SKIPLIST_MAX_LAYER]; __SLD_P("%02x ins %p begin\n", (int)tid_hash, node); insert_retry: // in pure C, a label can only be part of a stmt. (void)top_layer; int cmp = 0, cur_layer = 0, layer; skiplist_node *cur_node = &slist->head; ATM_FETCH_ADD(cur_node->ref_count, 1); __SLD_(size_t nh = 0); __SLD_(thread_local skiplist_node* history[1024]; (void)history); int sl_top_layer = slist->top_layer; if (top_layer > sl_top_layer) sl_top_layer = top_layer; for (cur_layer = sl_top_layer; cur_layer >= 0; --cur_layer) { do { __SLD_( history[nh++] = cur_node ); skiplist_node *next_node = _sl_next(slist, cur_node, cur_layer, NULL, NULL); if (!next_node) { _sl_clr_flags(prevs, cur_layer+1, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); YIELD(); goto insert_retry; } cmp = _sl_cmp(slist, node, next_node); if (cmp > 0) { // cur_node < next_node < node // => move to next node skiplist_node* temp = cur_node; cur_node = next_node; ATM_FETCH_SUB(temp->ref_count, 1); continue; } else { // otherwise: cur_node < node <= next_node ATM_FETCH_SUB(next_node->ref_count, 1); } if (no_dup && cmp == 0) { // Duplicate key is not allowed. _sl_clr_flags(prevs, cur_layer+1, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); return -1; } if (cur_layer <= top_layer) { prevs[cur_layer] = cur_node; nexts[cur_layer] = next_node; // both 'prev' and 'next' should be fully linked before // insertion, and no other thread should not modify 'prev' // at the same time. int error_code = 0; int locked_layer = cur_layer + 1; // check if prev node is duplicated with upper layer if (cur_layer < top_layer && prevs[cur_layer] == prevs[cur_layer+1]) { // duplicate // => which means that 'being_modified' flag is already true // => do nothing } else { bool expected = false; if (ATM_CAS(prevs[cur_layer]->being_modified, expected, bool_true)) { locked_layer = cur_layer; } else { error_code = -1; } } if (error_code == 0 && !_sl_valid_prev_next(prevs[cur_layer], nexts[cur_layer])) { error_code = -2; } if (error_code != 0) { __SLD_RT_INS(error_code, node, top_layer, cur_layer); _sl_clr_flags(prevs, locked_layer, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); YIELD(); goto insert_retry; } // set current node's pointers ATM_STORE(node->next[cur_layer], nexts[cur_layer]); // check if `cur_node->next` has been changed from `next_node`. skiplist_node* next_node_again = _sl_next(slist, cur_node, cur_layer, NULL, NULL); ATM_FETCH_SUB(next_node_again->ref_count, 1); if (next_node_again != next_node) { __SLD_NC_INS(cur_node, next_node, top_layer, cur_layer); // clear including the current layer // as we already set modification flag above. _sl_clr_flags(prevs, cur_layer, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); YIELD(); goto insert_retry; } } if (cur_layer) { // non-bottom layer => go down break; } // bottom layer => insertion succeeded // change prev/next nodes' prev/next pointers from 0 ~ top_layer for (layer = 0; layer <= top_layer; ++layer) { // `accessing_next` works as a spin-lock. _sl_write_lock_an(prevs[layer]); skiplist_node* exp = nexts[layer]; if ( !ATM_CAS(prevs[layer]->next[layer], exp, node) ) { __SLD_P("%02x ASSERT ins %p[%d] -> %p (expected %p)\n", (int)tid_hash, prevs[layer], cur_layer, ATM_GET(prevs[layer]->next[layer]), nexts[layer] ); __SLD_ASSERT(0); } __SLD_P("%02x ins %p[%d] -> %p -> %p\n", (int)tid_hash, prevs[layer], layer, node, ATM_GET(node->next[layer]) ); _sl_write_unlock_an(prevs[layer]); } // now this node is fully linked ATM_STORE(node->is_fully_linked, bool_true); // allow removing next nodes _sl_write_unlock_an(node); __SLD_P("%02x ins %p done\n", (int)tid_hash, node); ATM_FETCH_ADD(slist->num_entries, 1); ATM_FETCH_ADD(slist->layer_entries[node->top_layer], 1); for (int ii=slist->max_layer-1; ii>=0; --ii) { if (slist->layer_entries[ii] > 0) { slist->top_layer = ii; break; } } // modification is done for all layers _sl_clr_flags(prevs, 0, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); return 0; } while (cur_node != &slist->tail); } return 0; } int skiplist_insert(skiplist_raw *slist, skiplist_node *node) { return _skiplist_insert(slist, node, false); } int skiplist_insert_nodup(skiplist_raw *slist, skiplist_node *node) { return _skiplist_insert(slist, node, true); } typedef enum { SM = -2, SMEQ = -1, EQ = 0, GTEQ = 1, GT = 2 } _sl_find_mode; // Note: it increases the `ref_count` of returned node. // Caller is responsible to decrease it. static inline skiplist_node* _sl_find(skiplist_raw *slist, skiplist_node *query, _sl_find_mode mode) { // mode: // SM -2: smaller // SMEQ -1: smaller or equal // EQ 0: equal // GTEQ 1: greater or equal // GT 2: greater find_retry: (void)mode; int cmp = 0; int cur_layer = 0; skiplist_node *cur_node = &slist->head; ATM_FETCH_ADD(cur_node->ref_count, 1); __SLD_(size_t nh = 0); __SLD_(thread_local skiplist_node* history[1024]; (void)history); uint8_t sl_top_layer = slist->top_layer; for (cur_layer = sl_top_layer; cur_layer >= 0; --cur_layer) { do { __SLD_(history[nh++] = cur_node); skiplist_node *next_node = _sl_next(slist, cur_node, cur_layer, NULL, NULL); if (!next_node) { ATM_FETCH_SUB(cur_node->ref_count, 1); YIELD(); goto find_retry; } cmp = _sl_cmp(slist, query, next_node); if (cmp > 0) { // cur_node < next_node < query // => move to next node skiplist_node* temp = cur_node; cur_node = next_node; ATM_FETCH_SUB(temp->ref_count, 1); continue; } else if (-1 <= mode && mode <= 1 && cmp == 0) { // cur_node < query == next_node .. return ATM_FETCH_SUB(cur_node->ref_count, 1); return next_node; } // otherwise: cur_node < query < next_node if (cur_layer) { // non-bottom layer => go down ATM_FETCH_SUB(next_node->ref_count, 1); break; } // bottom layer if (mode < 0 && cur_node != &slist->head) { // smaller mode ATM_FETCH_SUB(next_node->ref_count, 1); return cur_node; } else if (mode > 0 && next_node != &slist->tail) { // greater mode ATM_FETCH_SUB(cur_node->ref_count, 1); return next_node; } // otherwise: exact match mode OR not found ATM_FETCH_SUB(cur_node->ref_count, 1); ATM_FETCH_SUB(next_node->ref_count, 1); return NULL; } while (cur_node != &slist->tail); } return NULL; } skiplist_node* skiplist_find(skiplist_raw *slist, skiplist_node *query) { return _sl_find(slist, query, EQ); } skiplist_node* skiplist_find_smaller_or_equal(skiplist_raw *slist, skiplist_node *query) { return _sl_find(slist, query, SMEQ); } skiplist_node* skiplist_find_greater_or_equal(skiplist_raw *slist, skiplist_node *query) { return _sl_find(slist, query, GTEQ); } int skiplist_erase_node_passive(skiplist_raw *slist, skiplist_node *node) { __SLD_( thread_local std::thread::id tid = std::this_thread::get_id(); thread_local size_t tid_hash = std::hash{}(tid) % 256; (void)tid_hash; ) int top_layer = node->top_layer; bool bool_true = true, bool_false = false; bool removed = false; bool is_fully_linked = false; ATM_LOAD(node->removed, removed); if (removed) { // already removed return -1; } skiplist_node* prevs[SKIPLIST_MAX_LAYER]; skiplist_node* nexts[SKIPLIST_MAX_LAYER]; bool expected = false; if (!ATM_CAS(node->being_modified, expected, bool_true)) { // already being modified .. cannot work on this node for now. __SLD_BM(node); return -2; } // set removed flag first, so that reader cannot read this node. ATM_STORE(node->removed, bool_true); __SLD_P("%02x rmv %p begin\n", (int)tid_hash, node); erase_node_retry: ATM_LOAD(node->is_fully_linked, is_fully_linked); if (!is_fully_linked) { // already unlinked .. remove is done by other thread ATM_STORE(node->removed, bool_false); ATM_STORE(node->being_modified, bool_false); return -3; } int cmp = 0; bool found_node_to_erase = false; (void)found_node_to_erase; skiplist_node *cur_node = &slist->head; ATM_FETCH_ADD(cur_node->ref_count, 1); __SLD_(size_t nh = 0); __SLD_(thread_local skiplist_node* history[1024]; (void)history); int cur_layer = slist->top_layer; for (; cur_layer >= 0; --cur_layer) { do { __SLD_( history[nh++] = cur_node ); bool node_found = false; skiplist_node *next_node = _sl_next(slist, cur_node, cur_layer, node, &node_found); if (!next_node) { _sl_clr_flags(prevs, cur_layer+1, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); YIELD(); goto erase_node_retry; } // Note: unlike insert(), we should find exact position of `node`. cmp = _sl_cmp(slist, node, next_node); if (cmp > 0 || (cur_layer <= top_layer && !node_found) ) { // cur_node <= next_node < node // => move to next node skiplist_node* temp = cur_node; cur_node = next_node; __SLD_( if (cmp > 0) { int cmp2 = _sl_cmp(slist, cur_node, node); if (cmp2 > 0) { // node < cur_node <= next_node: not found. _sl_clr_flags(prevs, cur_layer+1, top_layer); ATM_FETCH_SUB(temp->ref_count, 1); ATM_FETCH_SUB(next_node->ref_count, 1); __SLD_ASSERT(0); } } ) ATM_FETCH_SUB(temp->ref_count, 1); continue; } else { // otherwise: cur_node <= node <= next_node ATM_FETCH_SUB(next_node->ref_count, 1); } if (cur_layer <= top_layer) { prevs[cur_layer] = cur_node; // note: 'next_node' and 'node' should not be the same, // as 'removed' flag is already set. __SLD_ASSERT(next_node != node); nexts[cur_layer] = next_node; // check if prev node duplicates with upper layer int error_code = 0; int locked_layer = cur_layer + 1; if (cur_layer < top_layer && prevs[cur_layer] == prevs[cur_layer+1]) { // duplicate with upper layer // => which means that 'being_modified' flag is already true // => do nothing. } else { expected = false; if (ATM_CAS(prevs[cur_layer]->being_modified, expected, bool_true)) { locked_layer = cur_layer; } else { error_code = -1; } } if (error_code == 0 && !_sl_valid_prev_next(prevs[cur_layer], nexts[cur_layer])) { error_code = -2; } if (error_code != 0) { __SLD_RT_RMV(error_code, node, top_layer, cur_layer); _sl_clr_flags(prevs, locked_layer, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); YIELD(); goto erase_node_retry; } skiplist_node* next_node_again = _sl_next(slist, cur_node, cur_layer, node, NULL); ATM_FETCH_SUB(next_node_again->ref_count, 1); if (next_node_again != nexts[cur_layer]) { // `next` pointer has been changed, retry. __SLD_NC_RMV(cur_node, nexts[cur_layer], top_layer, cur_layer); _sl_clr_flags(prevs, cur_layer, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); YIELD(); goto erase_node_retry; } } if (cur_layer == 0) found_node_to_erase = true; // go down break; } while (cur_node != &slist->tail); } // Not exist in the skiplist, should not happen. __SLD_ASSERT(found_node_to_erase); // bottom layer => removal succeeded. // mark this node unlinked _sl_write_lock_an(node); { ATM_STORE(node->is_fully_linked, bool_false); } _sl_write_unlock_an(node); // change prev nodes' next pointer from 0 ~ top_layer for (cur_layer = 0; cur_layer <= top_layer; ++cur_layer) { _sl_write_lock_an(prevs[cur_layer]); skiplist_node* exp = node; __SLD_ASSERT(exp != nexts[cur_layer]); __SLD_ASSERT(nexts[cur_layer]->is_fully_linked); if ( !ATM_CAS(prevs[cur_layer]->next[cur_layer], exp, nexts[cur_layer]) ) { __SLD_P("%02x ASSERT rmv %p[%d] -> %p (node %p)\n", (int)tid_hash, prevs[cur_layer], cur_layer, ATM_GET(prevs[cur_layer]->next[cur_layer]), node ); __SLD_ASSERT(0); } __SLD_ASSERT(nexts[cur_layer]->top_layer >= cur_layer); __SLD_P("%02x rmv %p[%d] -> %p (node %p)\n", (int)tid_hash, prevs[cur_layer], cur_layer, nexts[cur_layer], node); _sl_write_unlock_an(prevs[cur_layer]); } __SLD_P("%02x rmv %p done\n", (int)tid_hash, node); ATM_FETCH_SUB(slist->num_entries, 1); ATM_FETCH_SUB(slist->layer_entries[node->top_layer], 1); for (int ii=slist->max_layer-1; ii>=0; --ii) { if (slist->layer_entries[ii] > 0) { slist->top_layer = ii; break; } } // modification is done for all layers _sl_clr_flags(prevs, 0, top_layer); ATM_FETCH_SUB(cur_node->ref_count, 1); ATM_STORE(node->being_modified, bool_false); return 0; } int skiplist_erase_node(skiplist_raw *slist, skiplist_node *node) { int ret = 0; do { ret = skiplist_erase_node_passive(slist, node); // if ret == -2, other thread is accessing the same node // at the same time. try again. } while (ret == -2); return ret; } int skiplist_erase(skiplist_raw *slist, skiplist_node *query) { skiplist_node *found = skiplist_find(slist, query); if (!found) { // key not found return -4; } int ret = 0; do { ret = skiplist_erase_node_passive(slist, found); // if ret == -2, other thread is accessing the same node // at the same time. try again. } while (ret == -2); ATM_FETCH_SUB(found->ref_count, 1); return ret; } int skiplist_is_valid_node(skiplist_node* node) { return _sl_valid_node(node); } int skiplist_is_safe_to_free(skiplist_node* node) { if (node->accessing_next) return 0; if (node->being_modified) return 0; if (!node->removed) return 0; uint16_t ref_count = 0; ATM_LOAD(node->ref_count, ref_count); if (ref_count) return 0; return 1; } void skiplist_wait_for_free(skiplist_node* node) { while (!skiplist_is_safe_to_free(node)) { YIELD(); } } void skiplist_grab_node(skiplist_node* node) { ATM_FETCH_ADD(node->ref_count, 1); } void skiplist_release_node(skiplist_node* node) { __SLD_ASSERT(node->ref_count); ATM_FETCH_SUB(node->ref_count, 1); } skiplist_node* skiplist_next(skiplist_raw *slist, skiplist_node *node) { // << Issue >> // If `node` is already removed and its next node is also removed // and then released, the link update will not be applied to `node` // as it is already unrechable from skiplist. `node` still points to // the released node so that `_sl_next(node)` may return corrupted // memory region. // // 0) initial: // A -> B -> C -> D // // 1) B is `node`, which is removed but not yet released: // B --+-> C -> D // | // A --+ // // 2) remove C, and then release: // B -> !C! +-> D // | // A --------+ // // 3) skiplist_next(B): // will fetch C, which is already released so that // may contain garbage data. // // In this case, start over from the top layer, // to find valid link (same as in prev()). skiplist_node *next = _sl_next(slist, node, 0, NULL, NULL); if (!next) next = _sl_find(slist, node, GT); if (next == &slist->tail) return NULL; return next; } skiplist_node* skiplist_prev(skiplist_raw *slist, skiplist_node *node) { skiplist_node *prev = _sl_find(slist, node, SM); if (prev == &slist->head) return NULL; return prev; } skiplist_node* skiplist_begin(skiplist_raw *slist) { skiplist_node *next = NULL; while (!next) { next = _sl_next(slist, &slist->head, 0, NULL, NULL); } if (next == &slist->tail) return NULL; return next; } skiplist_node* skiplist_end(skiplist_raw *slist) { return skiplist_prev(slist, &slist->tail); } ================================================ FILE: tests/container_test.cc ================================================ #include "sl_map.h" #include "sl_set.h" #include "test_common.h" #include #include #include #include #include int _map_basic(sl_map& sl) { for (int i=0; i<10; ++i) { auto itr = sl.insert( std::make_pair(i, i*10) ); CHK_TRUE(itr.second); CHK_EQ(i, itr.first->first); CHK_EQ(i*10, itr.first->second); } // Duplicate insert should not be allowed. for (int i=0; i<10; ++i) { auto itr = sl.insert( std::make_pair(i, i*20) ); CHK_FALSE(itr.second); CHK_EQ(i, itr.first->first); CHK_EQ(i*10, itr.first->second); } for (int i=0; i<10; ++i) { auto entry = sl.find(i); if (entry != sl.end()) { CHK_EQ(i, entry->first); CHK_EQ(i*10, entry->second); } } auto ii = sl.find(5); sl.erase(ii); int count = 0; for (auto& entry: sl) { CHK_EQ(count, entry.first); CHK_EQ(count*10, entry.second); if (count == 4) count += 2; else count++; } ii = sl.begin(); while (ii != sl.end()) { if (ii->first % 2 == 0) { ii = sl.erase(ii); } else { ii++; } } count = 1; for (auto& entry: sl) { CHK_EQ(count, entry.first); CHK_EQ(count*10, entry.second); if (count == 3) count += 4; else count += 2; } CHK_EQ(4, (int)sl.size()); return 0; } int map_basic_wait() { sl_map sl_wait; return _map_basic(sl_wait); } int map_basic_gc() { sl_map_gc sl_gc; return _map_basic(sl_gc); } int _set_basic(sl_set& sl) { for (int i=0; i<10; ++i) { auto itr = sl.insert(i); CHK_TRUE(itr.second); CHK_EQ(i, *itr.first); } // Duplicate insert should not be allowed. for (int i=0; i<10; ++i) { auto itr = sl.insert(i); CHK_FALSE(itr.second); CHK_EQ(i, *itr.first); } for (int i=0; i<10; ++i) { auto entry = sl.find(i); if (entry != sl.end()) { CHK_EQ(i, *entry); } } auto ii = sl.find(5); sl.erase(ii); int count = 0; for (auto& entry: sl) { int val = entry; CHK_EQ(count, val); if (count == 4) count += 2; else count++; } ii = sl.begin(); while (ii != sl.end()) { int val = *ii; if (val % 2 == 0) { ii = sl.erase(ii); } else { ii++; } } count = 1; for (auto& entry: sl) { int val = entry; CHK_EQ(count, val); if (count == 3) count += 4; else count += 2; } CHK_EQ(4, (int)sl.size()); return 0; } int set_basic_wait() { sl_set sl_wait; return _set_basic(sl_wait); } int set_basic_gc() { sl_set_gc sl_gc; return _set_basic(sl_gc); } int map_self_refer_test() { sl_map_gc sl; for (int i=0; i<10; ++i) { auto itr = sl.insert( std::make_pair(i, i*10) ); CHK_TRUE(itr.second); CHK_EQ(i, itr.first->first); CHK_EQ(i*10, itr.first->second); } for(;;) { auto entry = sl.begin(); if (entry == sl.end()) break; sl.erase(entry->first); } return 0; } int set_self_refer_test() { sl_set_gc sl; for (int i=0; i<10; ++i) { auto itr = sl.insert(i); CHK_TRUE(itr.second); CHK_EQ(i, *itr.first); } for(;;) { auto entry = sl.begin(); if (entry == sl.end()) break; sl.erase(*entry); } return 0; } int main(int argc, char** argv) { TestSuite tt(argc, argv); tt.doTest("container map test (busy wait)", map_basic_wait); tt.doTest("container map test (lazy gc)", map_basic_gc); tt.doTest("container map self refer test", map_self_refer_test); tt.doTest("container set test (busy wait)", set_basic_wait); tt.doTest("container set test (lazy gc)", set_basic_gc); tt.doTest("container set self refer test", set_self_refer_test); return 0; } ================================================ FILE: tests/mt_test.cc ================================================ #include "skiplist.h" #include "test_common.h" #include #include #include #include #include struct TestNode { TestNode() : value(0) { skiplist_init_node(&snode); } ~TestNode() { skiplist_free_node(&snode); } static int cmp(skiplist_node* a, skiplist_node* b, void* aux) { TestNode *aa, *bb; aa = _get_entry(a, TestNode, snode); bb = _get_entry(b, TestNode, snode); if (aa->value < bb->value) return -1; if (aa->value > bb->value) return 1; return 0; } skiplist_node snode; int value; }; struct ThreadArgs { skiplist_raw* slist; int max_num; int duration_ms; int ret; }; int _itr_thread(ThreadArgs* args) { TestSuite::Timer timer(args->duration_ms); do { int num_walks = 10; int count = 0; int r = rand() % args->max_num; TestNode query; query.value = r; skiplist_node* cursor = skiplist_find(args->slist, &query.snode); while (cursor) { TestNode* node = _get_entry(cursor, TestNode, snode); cursor = skiplist_next(args->slist, cursor); usleep(10); (void)node; skiplist_release_node(&node->snode); if (++count > num_walks) break; } if (cursor) skiplist_release_node(cursor); } while (!timer.timeover()); return 0; } void itr_thread(ThreadArgs* args) { args->ret = _itr_thread(args); } int _writer_thread(ThreadArgs* args) { TestSuite::Timer timer(args->duration_ms); do { int r; TestNode* node; TestNode query; skiplist_node* cursor; r = rand() % args->max_num; query.value = r; cursor = skiplist_find(args->slist, &query.snode); if (!cursor) { node = new TestNode(); node->value = r; skiplist_insert(args->slist, &node->snode); } else { skiplist_release_node(cursor); } r = rand() % args->max_num; query.value = r; cursor = skiplist_find(args->slist, &query.snode); if (cursor) { node = _get_entry(cursor, TestNode, snode); skiplist_erase_node(args->slist, &node->snode); if (node->snode.being_modified || !node->snode.removed) printf("%d %d\n", (int)node->snode.being_modified, (int)node->snode.removed); skiplist_release_node(&node->snode); skiplist_wait_for_free(&node->snode); delete node; } } while (!timer.timeover()); uint64_t c_check = 0; skiplist_node* cursor = skiplist_begin(args->slist); while (cursor) { skiplist_node* temp_node = cursor; cursor = skiplist_next(args->slist, cursor); skiplist_release_node(temp_node); c_check++; } if (cursor) skiplist_release_node(cursor); CHK_EQ(c_check, skiplist_get_size(args->slist)); return 0; } void writer_thread(ThreadArgs* args) { args->ret = _writer_thread(args); } int itr_write_erase() { std::thread iterator; std::thread writer; int num = 30000; skiplist_raw slist; skiplist_init(&slist, TestNode::cmp); TestNode* node[num]; for (int ii=0; iivalue = ii; skiplist_insert(&slist, &node[ii]->snode); } CHK_EQ(num, (int)skiplist_get_size(&slist)); for (int ii=0; iisnode.ref_count); } for (int ii=0; iisnode); delete node[ii]; } for (int ii=1; iisnode.ref_count); } ThreadArgs args_itr; ThreadArgs args_writer; args_itr.slist = &slist; args_itr.duration_ms = 1000; args_itr.max_num = num; args_itr.ret = 0; args_writer = args_itr; iterator = std::thread(itr_thread, &args_itr); writer = std::thread(writer_thread, &args_writer); iterator.join(); writer.join(); CHK_EQ(0, args_itr.ret); CHK_EQ(0, args_writer.ret); skiplist_node* cursor = skiplist_begin(&slist); while(cursor) { TestNode* cur_node = _get_entry(cursor, TestNode, snode); if (cur_node->snode.ref_count != 1) printf("%d %d\n", (int)cur_node->value, (int)cur_node->snode.ref_count); CHK_EQ(1, cur_node->snode.ref_count); cursor = skiplist_next(&slist, cursor); skiplist_release_node(&cur_node->snode); delete cur_node; } if (cursor) skiplist_release_node(cursor); skiplist_free(&slist); return 0; } int itr_erase_deterministic() { int num = 10; skiplist_raw slist; skiplist_init(&slist, TestNode::cmp); TestNode* node[num]; for (int ii=0; iivalue = ii; skiplist_insert(&slist, &node[ii]->snode); } skiplist_node* cursor = skiplist_begin(&slist); while (cursor) { TestNode* cur_node = _get_entry(cursor, TestNode, snode); if (cur_node->value == 2) { skiplist_erase_node(&slist, &cur_node->snode); CHK_NOT(skiplist_is_safe_to_free(&cur_node->snode)); } cursor = skiplist_next(&slist, cursor); skiplist_release_node(&cur_node->snode); if (cur_node->value == 2) { CHK_OK(skiplist_is_safe_to_free(&cur_node->snode)); delete cur_node; } } if (cursor) skiplist_release_node(cursor); int count = 0; cursor = skiplist_begin(&slist); while (cursor) { TestNode* cur_node = _get_entry(cursor, TestNode, snode); cursor = skiplist_next(&slist, cursor); skiplist_release_node(&cur_node->snode); count++; delete cur_node; } if (cursor) skiplist_release_node(cursor); CHK_EQ(num-1, count); skiplist_free(&slist); return 0; } int main(int argc, char** argv) { TestSuite tt(argc, argv); tt.doTest("iterator write erase test", itr_write_erase); tt.doTest("iterator write erase deterministic test", itr_erase_deterministic); return 0; } ================================================ FILE: tests/skiplist_test.cc ================================================ /** * Copyright (C) 2017-present Jung-Sang Ahn * All rights reserved. * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, * copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ #include "skiplist.h" #include "test_common.h" #include #include #include #include #include #include #include #include struct IntNode { IntNode() { skiplist_init_node(&snode); } ~IntNode() { skiplist_free_node(&snode); } skiplist_node snode; int value; }; int _cmp_IntNode(skiplist_node *a, skiplist_node *b, void *aux) { IntNode *aa, *bb; aa = _get_entry(a, IntNode, snode); bb = _get_entry(b, IntNode, snode); if (aa->value < bb->value) { return -1; } else if (aa->value == bb->value) { return 0; } else { return 1; } } int basic_insert_and_erase() { skiplist_raw list; skiplist_init(&list, _cmp_IntNode); int i, j, temp; int n = 16; std::vector key(n); std::vector arr(n); // key assign for (i=0; ivalue); cur = skiplist_next(&list, cur); count++; } CHK_EQ(n, count); // backward iteration count = n; cur = skiplist_end(&list); while (cur) { count--; IntNode *node = _get_entry(cur, IntNode, snode); CHK_EQ(count, node->value); cur = skiplist_prev(&list, cur); } CHK_EQ(0, count); // remove even numbers IntNode query; for (i=0; ivalue); cur = skiplist_next(&list, cur); count++; } CHK_EQ(n/2, count); // backward iteration count = n/2; cur = skiplist_end(&list); while (cur) { count--; IntNode *node = _get_entry(cur, IntNode, snode); CHK_EQ(count*2 + 1, node->value); cur = skiplist_prev(&list, cur); } CHK_EQ(0, count); skiplist_free(&list); return 0; } int find_test() { TestSuite::Timer tt; TestSuite::appendResultMessage("\n"); double elapsed_sec = 1; char msg[1024]; skiplist_raw list; skiplist_init(&list, _cmp_IntNode); int i, j, temp; int n = 10000; std::vector key(n); std::vector arr(n); // assign for (i=0; ivalue); } elapsed_sec = tt.getTimeUs() / 1000000.0; sprintf(msg, "find (existing key) done: %.4f (%.1f ops/sec)\n", elapsed_sec, n / elapsed_sec); TestSuite::appendResultMessage(msg); // ==== find smaller key tt.reset(); // smaller than smallest key query.value = -5; ret = skiplist_find_smaller_or_equal(&list, &query.snode); CHK_NULL(ret); for (i=0; ivalue); } elapsed_sec = tt.getTimeUs() / 1000000.0; sprintf(msg, "find (smaller key) done: %.4f (%.1f ops/sec)\n", elapsed_sec, n / elapsed_sec); TestSuite::appendResultMessage(msg); // ==== find greater key tt.reset(); for (i=0; ivalue); } // greater than greatest key query.value = i*10; ret = skiplist_find_greater_or_equal(&list, &query.snode); CHK_NULL(ret); elapsed_sec = tt.getTimeUs() / 1000000.0; sprintf(msg, "find (greater key) done: %.4f (%.1f ops/sec)\n", elapsed_sec, n / elapsed_sec); TestSuite::appendResultMessage(msg); // ==== find non-existing key tt.reset(); for (i=0; i* stl_set; std::mutex *lock; bool use_skiplist; std::vector* key; std::vector* arr; int range_begin; int range_end; double elapsed_sec; }; struct test_args { int n_keys; int n_writers; int n_erasers; int n_readers; bool random_order; bool use_skiplist; }; int writer_thread(void *voidargs) { thread_args *args = (thread_args*)voidargs; TestSuite::Timer tt; int i; for (i=args->range_begin; i<=args->range_end; ++i) { IntNode& node = args->arr->at(i); if (args->use_skiplist) { skiplist_insert(args->list, &node.snode); } else { args->lock->lock(); args->stl_set->insert(node.value); args->lock->unlock(); } } args->elapsed_sec = tt.getTimeUs() / 1000000.0; return 0; } int eraser_thread(void *voidargs) { thread_args *args = (thread_args*)voidargs; TestSuite::Timer tt; int i; for (i=args->range_begin; i<=args->range_end; ++i) { IntNode query; if (args->use_skiplist) { query.value = args->key->at(i); int ret = skiplist_erase(args->list, &query.snode); (void)ret; } else { args->lock->lock(); args->stl_set->erase(args->key->at(i)); args->lock->unlock(); } } args->elapsed_sec = tt.getTimeUs() / 1000000.0; return 0; } int reader_thread(void *voidargs) { thread_args *args = (thread_args*)voidargs; TestSuite::Timer tt; int i; for (i=args->range_begin; i<=args->range_end; ++i) { IntNode query; if (args->use_skiplist) { query.value = args->key->at(i); skiplist_node *ret = skiplist_find(args->list, &query.snode); (void)ret; } else { args->lock->lock(); auto ret = args->stl_set->find(args->key->at(i)); (void)ret; args->lock->unlock(); } } args->elapsed_sec = tt.getTimeUs() / 1000000.0; return 0; } int concurrent_write_test(test_args t_args) { char msg[1024]; TestSuite::appendResultMessage("\n"); skiplist_raw list; std::mutex lock; std::set stl_set; skiplist_init(&list, _cmp_IntNode); int i, j, temp; int n = t_args.n_keys; std::vector key(n); std::vector arr(n); // assign for (i=0; i t_holder(n_threads); std::vector args(n_threads); for (i=0; ijoin(); delete t_holder[i]; } double elapsed_sec = tt.getTimeUs() / 1000000.0; sprintf(msg, "insert %.4f (%d threads, %.1f ops/sec)\n", elapsed_sec, n_threads, n/elapsed_sec); TestSuite::appendResultMessage(msg); if (!t_args.use_skiplist) return 0; // integrity check (forward iteration, skiplist only) tt.reset(); int count = 0; bool corruption = false; skiplist_node *cur = skiplist_begin(&list); while (cur) { IntNode *node = _get_entry(cur, IntNode, snode); if (node->value != count) { skiplist_node *missing = &arr[count].snode; sprintf(msg, "idx %d is missing, %lx\n", count, (uint64_t)missing); TestSuite::appendResultMessage(msg); skiplist_node *prev = skiplist_prev(&list, missing); skiplist_node *next = skiplist_next(&list, missing); IntNode *prev_node = _get_entry(prev, IntNode, snode); IntNode *next_node = _get_entry(next, IntNode, snode); sprintf(msg, "%d %d\n", prev_node->value, next_node->value); TestSuite::appendResultMessage(msg); count = node->value; corruption = true; } CHK_EQ(count, node->value); cur = skiplist_next(&list, cur); count++; } CHK_EQ(n, count); CHK_NOT(corruption); elapsed_sec = tt.getTimeUs() / 1000000.0; sprintf(msg, "iteration %.4f (%.1f ops/sec)\n", elapsed_sec, n/elapsed_sec); TestSuite::appendResultMessage(msg); skiplist_free(&list); return 0; } int concurrent_write_erase_test(struct test_args t_args) { char msg[1024]; TestSuite::appendResultMessage("\n"); skiplist_raw list; std::mutex lock; std::set stl_set; skiplist_init(&list, _cmp_IntNode); int i, j, temp; int n = t_args.n_keys; std::vector key_add(n); std::vector key_del(n); std::vector arr_add(n); std::vector arr_del(n); std::vector arr_add_dbgref(n); // initial list state: 0, 10, 20, ... // => writer threads: adding 5, 15, 25, ... // => eraser threads: erasing 0, 10, 20, ... // final list state: 5, 15, 25, ... // initial insert for (i=0; i t_holder_add(n_threads_add); std::vector args_add(n_threads_add); for (i=0; i t_holder_del(n_threads_del); std::vector args_del(n_threads_del); for (i=0; ijoin(); delete t_holder_add[i]; } for (i=0; ijoin(); delete t_holder_del[i]; } double max_seconds_add = 0; double max_seconds_del = 0; for (i=0; i max_seconds_add) { max_seconds_add = args_add[i].elapsed_sec; } } for (i=0; i max_seconds_del) { max_seconds_del = args_del[i].elapsed_sec; } } sprintf(msg, "insertion %.4f (%d threads, %.1f ops/sec)\n", max_seconds_add, n_threads_add, n / max_seconds_add); TestSuite::appendResultMessage(msg); sprintf(msg, "deletion %.4f (%d threads, %.1f ops/sec)\n", max_seconds_del, n_threads_del, n / max_seconds_del); TestSuite::appendResultMessage(msg); sprintf(msg, "mutation total %.4f (%d threads, %.1f ops/sec)\n", std::max(max_seconds_add, max_seconds_del), n_threads_add + n_threads_del, (n*2) / std::max(max_seconds_add, max_seconds_del)); TestSuite::appendResultMessage(msg); if (!t_args.use_skiplist) { return 0; } // integrity check (forward iteration, skiplist only) std::chrono::time_point start, end; std::chrono::duration elapsed_seconds; start = std::chrono::system_clock::now(); int count = 0; bool corruption = false; skiplist_node *cur = skiplist_begin(&list); std::vector dbg_node(n); std::vector dbg_int(n); while (cur) { IntNode *node = _get_entry(cur, IntNode, snode); dbg_node[count] = cur; dbg_int[count] = node; // 5, 15, 25, 35 ... int idx = count * 10 + 5; if (node->value != idx) { skiplist_node *missing = &arr_add_dbgref[count]->snode; sprintf(msg, "count %d, idx %d is missing %lx\n", count, idx, (uint64_t)missing); TestSuite::appendResultMessage(msg); skiplist_node *prev = skiplist_prev(&list, missing); skiplist_node *next = skiplist_next(&list, missing); IntNode *prev_node = _get_entry(prev, IntNode, snode); IntNode *next_node = _get_entry(next, IntNode, snode); sprintf(msg, "%d %d\n", prev_node->value, next_node->value); TestSuite::appendResultMessage(msg); corruption = true; } CHK_EQ(idx, node->value); cur = skiplist_next(&list, cur); count++; } CHK_EQ(n, count); CHK_NOT(corruption); end = std::chrono::system_clock::now(); elapsed_seconds = end-start; sprintf(msg, "iteration %.4f (%.1f ops/sec)\n", elapsed_seconds.count(), n/elapsed_seconds.count()); TestSuite::appendResultMessage(msg); skiplist_free(&list); return 0; } int concurrent_write_read_test(struct test_args t_args) { char msg[1024]; TestSuite::appendResultMessage("\n"); skiplist_raw list; std::mutex lock; std::set stl_set; skiplist_init(&list, _cmp_IntNode); int i, j, temp; int n = t_args.n_keys; std::vector key_add(n); std::vector key_read(n); std::vector arr_add(n); std::vector arr_find(n); std::vector arr_add_dbgref(n); // initial list state: 0, 10, 20, ... // => writer threads: adding 5, 15, 25, ... // => reader threads: reading 0, 10, 20, ... // final list state: 0, 5, 10, 15, 20, ... // initial insert for (i=0; i t_holder_add(n_threads_add); std::vector t_holder_find(n_threads_find); std::vector args_add(n_threads_add); std::vector args_find(n_threads_find); if (n_threads_add) { int n_keys_per_thread_add = n / n_threads_add; for (i=0; ijoin(); delete t_holder_add[i]; } for (i=0; ijoin(); delete t_holder_find[i]; } if (n_threads_add) { double max_seconds_add = 0; for (i=0; i max_seconds_add) { max_seconds_add = args_add[i].elapsed_sec; } } sprintf(msg, "insertion %.4f (%d threads, %.1f ops/sec)\n", max_seconds_add, n_threads_add, n / max_seconds_add); TestSuite::appendResultMessage(msg); } if (n_threads_find) { double max_seconds_find = 0; for (i=0; i max_seconds_find) { max_seconds_find = args_find[i].elapsed_sec; } } sprintf(msg, "retrieval %.4f (%d threads, %.1f ops/sec)\n", max_seconds_find, n_threads_find, n / max_seconds_find); TestSuite::appendResultMessage(msg); } skiplist_free(&list); return 0; } int main(int argc, char** argv) { TestSuite ts(argc, argv); srand(0xabcd); struct test_args args; args.n_keys = 40000; args.random_order = true; args.use_skiplist = true; //ts.options.printTestMessage = true; ts.doTest("basic insert and erase", basic_insert_and_erase); ts.doTest("find test", find_test); args.n_writers = 8; ts.doTest("concurrent write test", concurrent_write_test, args); args.n_writers = 4; args.n_erasers = 4; ts.doTest("concurrent write erase test", concurrent_write_erase_test, args); args.n_writers = 1; args.n_readers = 7; ts.doTest("concurrent write read test", concurrent_write_read_test, args); return 0; } ================================================ FILE: tests/stl_map_compare.cc ================================================ #include "sl_map.h" #include "test_common.h" #include #include #include #include #include struct thread_args { thread_args() : mode(SKIPLIST), num(0), id(0), modulo(0), duration_ms(0), op_count(0), temp(0), sl(nullptr), stdmap(nullptr), lock(nullptr) {} enum Mode { SKIPLIST = 0, MAP_MUTEX = 1, MAP_ONLY = 2 } mode; int num; int id; int modulo; int duration_ms; int op_count; volatile uint64_t temp; sl_map* sl; std::map* stdmap; std::mutex* lock; }; size_t num_primes(uint64_t number, size_t max_prime) { size_t ret = 0; for (size_t ii=2; ii<=max_prime; ++ii) { if (number % ii == 0) { number /= ii; ret++; } } return ret; } void reader(thread_args* args) { TestSuite::Timer timer(args->duration_ms); while (!timer.timeover()) { int r = rand() % args->num; int max_walks = 3; int walks = 0; if (args->mode == thread_args::SKIPLIST) { auto itr = args->sl->find(r); while (itr != args->sl->end()) { uint64_t number = itr->second; itr++; args->temp += num_primes(number, 10000); if (++walks >= max_walks) break; } } else if (args->mode == thread_args::MAP_MUTEX) { std::lock_guard l(*args->lock); auto itr = args->stdmap->find(r); while (itr != args->stdmap->end()) { uint64_t number = itr->second; itr++; args->temp += num_primes(number, 10000); if (++walks >= max_walks) break; } } else { auto itr = args->stdmap->find(r); while (itr != args->stdmap->end()) { uint64_t number = itr->second; itr++; args->temp += num_primes(number, 10000); if (++walks >= max_walks) break; } } args->op_count += max_walks; } } void writer(thread_args* args) { TestSuite::Timer timer(args->duration_ms); while (!timer.timeover()) { int r = rand() % (args->num / args->modulo); r *= args->modulo; r += args->id; if (args->mode == thread_args::SKIPLIST) { auto itr = args->sl->find(r); if (itr == args->sl->end()) { args->sl->insert(std::make_pair(r, r)); } else { args->sl->erase(itr); } } else if (args->mode == thread_args::MAP_MUTEX) { std::lock_guard l(*args->lock); auto itr = args->stdmap->find(r); if (itr == args->stdmap->end()) { args->stdmap->insert(std::make_pair(r, r)); } else { args->stdmap->erase(itr); } } else { auto itr = args->stdmap->find(r); if (itr == args->stdmap->end()) { args->stdmap->insert(std::make_pair(r, r)); } else { args->stdmap->erase(itr); } } args->op_count++; } } int concurrent_test(int mode) { sl_map_gc sl; std::map stdmap; std::mutex lock; int num = 10000000; int duration_ms = 5000; int num_readers = 4; thread_args r_args[num_readers]; std::thread readers[num_readers]; for (int i=0; i(mode); r_args[i].num = num; r_args[i].duration_ms = duration_ms; r_args[i].sl = &sl; r_args[i].stdmap = &stdmap; r_args[i].lock = &lock; readers[i] = std::thread(reader, &r_args[i]); } int num_writers = 4; thread_args w_args[num_writers]; std::thread writers[num_writers]; for (int i=0; i(mode); w_args[i].num = num; w_args[i].id = i; w_args[i].modulo = num_writers; w_args[i].duration_ms = duration_ms; w_args[i].sl = &sl; w_args[i].stdmap = &stdmap; w_args[i].lock = &lock; writers[i] = std::thread(writer, &w_args[i]); } int r_total = 0, w_total = 0; for (int i=0; i params = {0, 1/*, 2*/}; tt.options.printTestMessage = true; tt.doTest("concurrent access comparison test", concurrent_test, TestRange(params)); return 0; } ================================================ FILE: tests/test_common.h ================================================ /** * Copyright (C) 2017-present Jung-Sang Ahn * All rights reserved. * * https://github.com/greensky00 * * Test Suite * Version: 0.1.65 * * Permission is hereby granted, free of charge, to any person * obtaining a copy of this software and associated documentation * files (the "Software"), to deal in the Software without * restriction, including without limitation the rights to use, * copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following * conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifndef _CLM_DEFINED #define _CLM_DEFINED (1) #ifdef TESTSUITE_NO_COLOR #define _CLM_D_GRAY "" #define _CLM_GREEN "" #define _CLM_B_GREEN "" #define _CLM_RED "" #define _CLM_B_RED "" #define _CLM_BROWN "" #define _CLM_B_BROWN "" #define _CLM_BLUE "" #define _CLM_B_BLUE "" #define _CLM_MAGENTA "" #define _CLM_B_MAGENTA "" #define _CLM_CYAN "" #define _CLM_END "" #define _CLM_WHITE_FG_RED_BG "" #else #define _CLM_D_GRAY "\033[1;30m" #define _CLM_GREEN "\033[32m" #define _CLM_B_GREEN "\033[1;32m" #define _CLM_RED "\033[31m" #define _CLM_B_RED "\033[1;31m" #define _CLM_BROWN "\033[33m" #define _CLM_B_BROWN "\033[1;33m" #define _CLM_BLUE "\033[34m" #define _CLM_B_BLUE "\033[1;34m" #define _CLM_MAGENTA "\033[35m" #define _CLM_B_MAGENTA "\033[1;35m" #define _CLM_CYAN "\033[36m" #define _CLM_B_GREY "\033[1;37m" #define _CLM_END "\033[0m" #define _CLM_WHITE_FG_RED_BG "\033[37;41m" #endif #define _CL_D_GRAY(str) _CLM_D_GRAY str _CLM_END #define _CL_GREEN(str) _CLM_GREEN str _CLM_END #define _CL_RED(str) _CLM_RED str _CLM_END #define _CL_B_RED(str) _CLM_B_RED str _CLM_END #define _CL_MAGENTA(str) _CLM_MAGENTA str _CLM_END #define _CL_BROWN(str) _CLM_BROWN str _CLM_END #define _CL_B_BROWN(str) _CLM_B_BROWN str _CLM_END #define _CL_B_BLUE(str) _CLM_B_BLUE str _CLM_END #define _CL_B_MAGENTA(str) _CLM_B_MAGENTA str _CLM_END #define _CL_CYAN(str) _CLM_CYAN str _CLM_END #define _CL_B_GRAY(str) _CLM_B_GREY str _CLM_END #define _CL_WHITE_FG_RED_BG(str) _CLM_WHITE_FG_RED_BG str _CLM_END #endif #define __COUT_STACK_INFO__ \ std::endl \ << " time: " << _CLM_D_GRAY << \ TestSuite::getTimeString() << _CLM_END << "\n" \ << " thread: " << _CLM_BROWN \ << std::hex << std::setw(4) << std::setfill('0') << \ (std::hash{}( std::this_thread::get_id() ) & 0xffff) \ << std::dec << _CLM_END << "\n" \ << " in: " << _CLM_CYAN << __func__ << "()" _CLM_END << "\n" \ << " at: " << _CLM_GREEN << __FILE__ << _CLM_END ":" \ << _CLM_B_MAGENTA << __LINE__ << _CLM_END << "\n" \ // exp_value == value #define CHK_EQ(exp_value, value) \ { \ auto _ev = (exp_value); \ decltype(_ev) _v = (decltype(_ev))(value); \ if (_ev != _v) { \ std::cout \ << __COUT_STACK_INFO__ \ << " value of: " _CLM_B_BLUE #value _CLM_END "\n" \ << " expected: " _CLM_B_GREEN << _ev << _CLM_END "\n" \ << " actual: " _CLM_B_RED << _v << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } \ } // exp_value != value #define CHK_NEQ(exp_value, value) \ { \ auto _ev = (exp_value); \ decltype(_ev) _v = (decltype(_ev))(value); \ if (_ev == _v) { \ std::cout \ << __COUT_STACK_INFO__ \ << " value of: " _CLM_B_BLUE #value _CLM_END "\n" \ << " expected: not " _CLM_B_GREEN << _ev << _CLM_END "\n" \ << " actual: " _CLM_B_RED << _v << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } \ } // value == true #define CHK_OK(value) \ if (!(value)) { \ std::cout \ << __COUT_STACK_INFO__ \ << " value of: " _CLM_B_BLUE #value _CLM_END "\n" \ << " expected: " _CLM_B_GREEN << "true" << _CLM_END "\n" \ << " actual: " _CLM_B_RED << "false" << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } #define CHK_TRUE(value) CHK_OK(value) // value == false #define CHK_NOT(value) \ if (value) { \ std::cout \ << __COUT_STACK_INFO__ \ << " value of: " _CLM_B_BLUE #value _CLM_END "\n" \ << " expected: " _CLM_B_GREEN << "false" << _CLM_END "\n" \ << " actual: " _CLM_B_RED << "true" << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } #define CHK_FALSE(value) CHK_NOT(value) // value == NULL #define CHK_NULL(value) \ { \ auto _v = (value); \ if (_v) { \ std::cout \ << __COUT_STACK_INFO__ \ << " value of: " _CLM_B_BLUE #value _CLM_END "\n" \ << " expected: " _CLM_B_GREEN << "NULL" << _CLM_END "\n"; \ printf(" actual: " _CLM_B_RED "%p" _CLM_END "\n", _v); \ TestSuite::failHandler(); \ return -1; \ } \ } // value != NULL #define CHK_NONNULL(value) \ if (!(value)) { \ std::cout \ << __COUT_STACK_INFO__ \ << " value of: " _CLM_B_BLUE #value _CLM_END "\n" \ << " expected: " _CLM_B_GREEN << "non-NULL" << _CLM_END "\n" \ << " actual: " _CLM_B_RED << "NULL" << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } // value == 0 #define CHK_Z(value) \ { \ auto _v = (value); \ if ((0) != _v) { \ std::cout \ << __COUT_STACK_INFO__ \ << " value of: " _CLM_B_BLUE #value _CLM_END "\n" \ << " expected: " _CLM_B_GREEN << "0" << _CLM_END "\n" \ << " actual: " _CLM_B_RED << _v << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } \ } // smaller < greater #define CHK_SM(smaller, greater) \ { \ auto _sm = (smaller); \ decltype(_sm) _gt = (decltype(_sm))(greater); \ if (!(_sm < _gt)) { \ std::cout \ << __COUT_STACK_INFO__ \ << " expected: " \ << _CLM_B_BLUE #smaller " < " #greater _CLM_END "\n" \ << " value of " \ << _CLM_B_GREEN #smaller _CLM_END ": " \ << _CLM_B_RED << _sm << _CLM_END "\n" \ << " value of " \ << _CLM_B_GREEN #greater _CLM_END ": " \ << _CLM_B_RED << _gt << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } \ } // smaller <= greater #define CHK_SMEQ(smaller , greater) \ { \ auto _sm = (smaller); \ decltype(_sm) _gt = (decltype(_sm))(greater); \ if (!(_sm <= _gt)) { \ std::cout \ << __COUT_STACK_INFO__ \ << " expected: " \ << _CLM_B_BLUE #smaller " <= " #greater _CLM_END "\n" \ << " value of " \ << _CLM_B_GREEN #smaller _CLM_END ": " \ << _CLM_B_RED << _sm << _CLM_END "\n" \ << " value of " \ << _CLM_B_GREEN #greater _CLM_END ": " \ << _CLM_B_RED << _gt << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } \ } // greater > smaller #define CHK_GT(greater, smaller) \ { \ auto _sm = (smaller); \ decltype(_sm) _gt = (decltype(_sm))(greater); \ if (!(_gt > _sm)) { \ std::cout \ << __COUT_STACK_INFO__ \ << " expected: " \ << _CLM_B_BLUE #greater " > " #smaller _CLM_END "\n" \ << " value of " \ << _CLM_B_GREEN #greater _CLM_END ": " \ << _CLM_B_RED << _gt << _CLM_END "\n" \ << " value of " \ << _CLM_B_GREEN #smaller _CLM_END ": " \ << _CLM_B_RED << _sm << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } \ } // greater >= smaller #define CHK_GTEQ(greater, smaller) \ { \ auto _sm = (smaller); \ decltype(_sm) _gt = (decltype(_sm))(greater); \ if (!(_gt >= _sm)) { \ std::cout \ << __COUT_STACK_INFO__ \ << " expected: " \ << _CLM_B_BLUE #greater " >= " #smaller _CLM_END "\n" \ << " value of " \ << _CLM_B_GREEN #greater _CLM_END ": " \ << _CLM_B_RED << _gt << _CLM_END "\n" \ << " value of " \ << _CLM_B_GREEN #smaller _CLM_END ": " \ << _CLM_B_RED << _sm << _CLM_END "\n"; \ TestSuite::failHandler(); \ return -1; \ } \ } using test_func = std::function; class TestArgsBase; using test_func_args = std::function; class TestSuite; class TestArgsBase { public: virtual ~TestArgsBase() { } void setCallback(std::string test_name, test_func_args func, TestSuite* test_instance) { testName = test_name; testFunction = func; testInstance = test_instance; } void testAll() { testAllInternal(0); } virtual void setParam(size_t param_no, size_t param_idx) = 0; virtual size_t getNumSteps(size_t param_no) = 0; virtual size_t getNumParams() = 0; virtual std::string toString() = 0; private: inline void testAllInternal(size_t depth); std::string testName; test_func_args testFunction; TestSuite* testInstance; }; class TestArgsWrapper { public: TestArgsWrapper(TestArgsBase* _test_args) : test_args(_test_args) {} ~TestArgsWrapper() { delete test_args; } TestArgsBase* getArgs() const { return test_args; } operator TestArgsBase*() const { return getArgs(); } private: TestArgsBase* test_args; }; enum class StepType { LINEAR, EXPONENTIAL }; template class TestRange { public: TestRange() : type(RangeType::NONE), begin(), end(), step() {} // Constructor for given values TestRange(const std::vector& _array) : type(RangeType::ARRAY), array(_array) , begin(), end(), step() {} // Constructor for regular steps TestRange(T _begin, T _end, T _step, StepType _type) : begin(_begin), end(_end), step(_step) { if (_type == StepType::LINEAR) { type = RangeType::LINEAR; } else { type = RangeType::EXPONENTIAL; } } T getEntry(size_t idx) { if (type == RangeType::ARRAY) { return array[idx]; } else if (type == RangeType::LINEAR) { return begin + step * idx; } else if (type == RangeType::EXPONENTIAL) { ssize_t _begin = begin; ssize_t _step = step; ssize_t _ret = _begin * std::pow(_step, idx); return (T)(_ret); } return begin; } size_t getSteps() { if (type == RangeType::ARRAY) { return array.size(); } else if (type == RangeType::LINEAR) { return ((end - begin) / step) + 1; } else if (type == RangeType::EXPONENTIAL) { ssize_t coe = ((ssize_t)end) / ((ssize_t)begin); double steps_double = (double)std::log(coe) / std::log(step); return steps_double + 1; } return 0; } private: enum class RangeType { NONE, ARRAY, LINEAR, EXPONENTIAL }; RangeType type; std::vector array; T begin; T end; T step; }; struct TestOptions { TestOptions() : printTestMessage(false) , abortOnFailure(false) , preserveTestFiles(false) {} bool printTestMessage; bool abortOnFailure; bool preserveTestFiles; }; class TestSuite { friend TestArgsBase; private: static std::mutex& getResMsgLock() { static std::mutex res_msg_lock; return res_msg_lock; } static std::string& getResMsg() { static std::string res_msg; return res_msg; } static std::string& getInfoMsg() { thread_local std::string info_msg; return info_msg; } static std::string& getTestName() { static std::string test_name; return test_name; } static TestSuite*& getCurTest() { static TestSuite* cur_test; return cur_test; } public: static bool& globalMsgFlag() { static bool global_msg_flag = false; return global_msg_flag; } static std::string getCurrentTestName() { return getTestName(); } static bool isMsgAllowed() { TestSuite* cur_test = TestSuite::getCurTest(); if ( cur_test && (cur_test->options.printTestMessage || cur_test->displayMsg) && !cur_test->suppressMsg ) { return true; } if (globalMsgFlag()) return true; return false; } static void setInfo(const char* format, ...) { thread_local char info_buf[4096]; size_t len = 0; va_list args; va_start(args, format); len += vsnprintf(info_buf + len, 4096 - len, format, args); va_end(args); getInfoMsg() = info_buf; } static void clearInfo() { getInfoMsg().clear(); } static void failHandler() { if (!getInfoMsg().empty()) { std::cout << " info: " << getInfoMsg() << std::endl; } } static void usage(int argc, char** argv) { printf("\n"); printf("Usage: %s [-f ] [-r ] [-p]\n", argv[0]); printf("\n"); printf(" -f, --filter\n"); printf(" Run specific tests matching the given keyword.\n"); printf(" -r, --range\n"); printf(" Run TestRange-based tests using given parameter value.\n"); printf(" -p, --preserve\n"); printf(" Do not clean up test files.\n"); printf(" --abort-on-failure\n"); printf(" Immediately abort the test if failure happens.\n"); printf(" --suppress-msg\n"); printf(" Suppress test messages.\n"); printf(" --display-msg\n"); printf(" Display test messages.\n"); printf("\n"); } static std::string usToString(uint64_t us) { std::stringstream ss; if (us < 1000) { // us ss << std::fixed << std::setprecision(0) << us << " us"; } else if (us < 1000000) { // ms double tmp = static_cast(us / 1000.0); ss << std::fixed << std::setprecision(1) << tmp << " ms"; } else if (us < (uint64_t)600 * 1000000) { // second: 1 s -- 600 s (10 mins) double tmp = static_cast(us / 1000000.0); ss << std::fixed << std::setprecision(1) << tmp << " s"; } else { // minute double tmp = static_cast(us / 60.0 / 1000000.0); ss << std::fixed << std::setprecision(0) << tmp << " m"; } return ss.str(); } static std::string countToString(uint64_t count) { std::stringstream ss; if (count < 1000) { ss << count; } else if (count < 1000000) { // K double tmp = static_cast(count / 1000.0); ss << std::fixed << std::setprecision(1) << tmp << "K"; } else if (count < (uint64_t)1000000000) { // M double tmp = static_cast(count / 1000000.0); ss << std::fixed << std::setprecision(1) << tmp << "M"; } else { // B double tmp = static_cast(count / 1000000000.0); ss << std::fixed << std::setprecision(1) << tmp << "B"; } return ss.str(); } static std::string sizeToString(uint64_t size) { std::stringstream ss; if (size < 1024) { ss << size << " B"; } else if (size < 1024*1024) { // K double tmp = static_cast(size / 1024.0); ss << std::fixed << std::setprecision(1) << tmp << " KiB"; } else if (size < (uint64_t)1024*1024*1024) { // M double tmp = static_cast(size / 1024.0 / 1024.0); ss << std::fixed << std::setprecision(1) << tmp << " MiB"; } else { // B double tmp = static_cast(size / 1024.0 / 1024.0 / 1024.0); ss << std::fixed << std::setprecision(1) << tmp << " GiB"; } return ss.str(); } private: struct TimeInfo { TimeInfo(std::tm* src) : year(src->tm_year + 1900) , month(src->tm_mon + 1) , day(src->tm_mday) , hour(src->tm_hour) , min(src->tm_min) , sec(src->tm_sec) , msec(0) , usec(0) {} TimeInfo(std::chrono::system_clock::time_point now) { std::time_t raw_time = std::chrono::system_clock::to_time_t(now); std::tm* lt_tm = std::localtime(&raw_time); year = lt_tm->tm_year + 1900; month = lt_tm->tm_mon + 1; day = lt_tm->tm_mday; hour = lt_tm->tm_hour; min = lt_tm->tm_min; sec = lt_tm->tm_sec; size_t us_epoch = std::chrono::duration_cast < std::chrono::microseconds > ( now.time_since_epoch() ).count(); msec = (us_epoch / 1000) % 1000; usec = us_epoch % 1000; } int year; int month; int day; int hour; int min; int sec; int msec; int usec; }; public: TestSuite(int argc = 0, char **argv = nullptr) : cntPass(0) , cntFail(0) , useGivenRange(false) , preserveTestFiles(false) , forceAbortOnFailure(false) , suppressMsg(false) , displayMsg(false) , givenRange(0) , startTimeGlobal(std::chrono::system_clock::now()) { for (int ii=1; ii cur_time = std::chrono::system_clock::now();; std::chrono::duration elapsed = cur_time - startTimeGlobal; std::string time_str = usToString(elapsed.count() * 1000000); printf(_CL_GREEN("%zu") " tests passed", cntPass); if (cntFail) { printf(", " _CL_RED("%zu") " tests failed", cntFail); } printf(" out of " _CL_CYAN("%zu") " (" _CL_BROWN("%s") ")\n", cntPass+cntFail, time_str.c_str()); } // === Helper functions ==================================== static std::string getTestFileName(const std::string& prefix) { TimeInfo lt(std::chrono::system_clock::now()); (void)lt; char time_char[64]; sprintf(time_char, "%04d%02d%02d_%02d%02d%02d", lt.year, lt.month, lt.day, lt.hour, lt.min, lt.sec); std::string ret = prefix; ret += "_"; ret += time_char; return ret; } static std::string getTimeString() { TimeInfo lt(std::chrono::system_clock::now()); char time_char[64]; sprintf(time_char, "%04d-%02d-%02d %02d:%02d:%02d.%03d%03d", lt.year, lt.month, lt.day, lt.hour, lt.min, lt.sec, lt.msec, lt.usec); return time_char; } static std::string getTimeStringShort() { TimeInfo lt(std::chrono::system_clock::now()); char time_char[64]; sprintf(time_char, "%02d:%02d.%03d %03d", lt.min, lt.sec, lt.msec, lt.usec); return time_char; } static std::string getTimeStringPlain() { TimeInfo lt(std::chrono::system_clock::now()); char time_char[64]; sprintf(time_char, "%02d%02d_%02d%02d%02d", lt.month, lt.day, lt.hour, lt.min, lt.sec); return time_char; } static int mkdir(const std::string& path) { struct stat st; if (stat(path.c_str(), &st) != 0) { return ::mkdir(path.c_str(), 0755); } return 0; } static int copyfile(const std::string& src, const std::string& dst) { std::string cmd = "cp -R " + src + " " + dst; int rc = ::system(cmd.c_str()); return rc; } static int remove(const std::string& path) { int rc = ::remove(path.c_str()); return rc; } enum TestPosition { BEGINNING_OF_TEST = 0, MIDDLE_OF_TEST = 1, END_OF_TEST = 2, }; static void clearTestFile( const std::string& prefix, TestPosition test_pos = MIDDLE_OF_TEST ) { TestSuite*& cur_test = TestSuite::getCurTest(); if ( test_pos == END_OF_TEST && ( cur_test->preserveTestFiles || cur_test->options.preserveTestFiles ) ) return; int r; std::string command = "rm -rf "; command += prefix; command += "*"; r = system(command.c_str()); (void)r; } static void setResultMessage(const std::string& msg) { TestSuite::getResMsg() = msg; } static void appendResultMessage(const std::string& msg) { std::lock_guard l(TestSuite::getResMsgLock()); TestSuite::getResMsg() += msg; } static size_t _msg(const char* format, ...) { size_t cur_len = 0; TestSuite* cur_test = TestSuite::getCurTest(); if ( ( cur_test && (cur_test->options.printTestMessage || cur_test->displayMsg) && !cur_test->suppressMsg ) || globalMsgFlag() ) { va_list args; va_start(args, format); cur_len += vprintf(format, args); va_end(args); } return cur_len; } static size_t _msgt(const char* format, ...) { size_t cur_len = 0; TestSuite* cur_test = TestSuite::getCurTest(); if ( ( cur_test && (cur_test->options.printTestMessage || cur_test->displayMsg) && !cur_test->suppressMsg ) || globalMsgFlag() ) { std::cout << _CLM_D_GRAY << getTimeStringShort() << _CLM_END << "] "; va_list args; va_start(args, format); cur_len += vprintf(format, args); va_end(args); } return cur_len; } class Msg { public: Msg() {} template inline Msg& operator<<(const T& data) { if (TestSuite::isMsgAllowed()) { std::cout << data; } return *this; } using MyCout = std::basic_ostream< char, std::char_traits >; typedef MyCout& (*EndlFunc)(MyCout&); Msg& operator<<(EndlFunc func) { if (TestSuite::isMsgAllowed()) { func(std::cout); } return *this; } }; static void sleep_us(size_t us, const std::string& msg = std::string()) { if (!msg.empty()) TestSuite::_msg("%s (%zu us)\n", msg.c_str(), us); std::this_thread::sleep_for(std::chrono::microseconds(us)); } static void sleep_ms(size_t ms, const std::string& msg = std::string()) { if (!msg.empty()) TestSuite::_msg("%s (%zu ms)\n", msg.c_str(), ms); std::this_thread::sleep_for(std::chrono::milliseconds(ms)); } static void sleep_sec(size_t sec, const std::string& msg = std::string()) { if (!msg.empty()) TestSuite::_msg("%s (%zu s)\n", msg.c_str(), sec); std::this_thread::sleep_for(std::chrono::seconds(sec)); } static std::string lzStr(size_t digit, uint64_t num) { std::stringstream ss; ss << std::setw(digit) << std::setfill('0') << std::to_string(num); return ss.str(); } static double calcThroughput(uint64_t ops, uint64_t elapsed_us) { return ops * 1000000.0 / elapsed_us; } static std::string throughputStr(uint64_t ops, uint64_t elapsed_us) { return countToString(ops * 1000000.0 / elapsed_us); } static std::string sizeThroughputStr(uint64_t size_byte, uint64_t elapsed_us) { return sizeToString(size_byte * 1000000.0 / elapsed_us); } // === Timer things ==================================== class Timer { public: Timer() : duration_ms(0) { reset(); } Timer(size_t _duration_ms) : duration_ms(_duration_ms) { reset(); } inline bool timeout() { return timeover(); } bool timeover() { auto cur = std::chrono::system_clock::now(); std::chrono::duration elapsed = cur - start; if (duration_ms < elapsed.count() * 1000) return true; return false; } uint64_t getTimeSec() { auto cur = std::chrono::system_clock::now(); std::chrono::duration elapsed = cur - start; return (uint64_t)(elapsed.count()); } uint64_t getTimeMs() { auto cur = std::chrono::system_clock::now(); std::chrono::duration elapsed = cur - start; return (uint64_t)(elapsed.count() * 1000); } uint64_t getTimeUs() { auto cur = std::chrono::system_clock::now(); std::chrono::duration elapsed = cur - start; return (uint64_t)(elapsed.count() * 1000000); } void reset() { start = std::chrono::system_clock::now(); } void resetSec(size_t _duration_sec) { duration_ms = _duration_sec * 1000; reset(); } void resetMs(size_t _duration_ms) { duration_ms = _duration_ms; reset(); } private: std::chrono::time_point start; size_t duration_ms; }; // === Workload generator things ==================================== class WorkloadGenerator { public: WorkloadGenerator(double ops_per_sec = 0.0, uint64_t max_ops_per_batch = 0) : opsPerSec(ops_per_sec) , maxOpsPerBatch(max_ops_per_batch) , numOpsDone(0) { reset(); } void reset() { start = std::chrono::system_clock::now(); numOpsDone = 0; } size_t getNumOpsToDo() { if (opsPerSec <= 0) return 0; auto cur = std::chrono::system_clock::now(); std::chrono::duration elapsed = cur - start; double exp = opsPerSec * elapsed.count(); if (numOpsDone < exp) { if (maxOpsPerBatch) { return std::min(maxOpsPerBatch, (uint64_t)exp - numOpsDone); } return (uint64_t)exp - numOpsDone; } return 0; } void addNumOpsDone(size_t num) { numOpsDone += num; } private: std::chrono::time_point start; double opsPerSec; uint64_t maxOpsPerBatch; uint64_t numOpsDone; }; // === Progress things ================================== // Progress that knows the maximum value. class Progress { public: Progress(uint64_t _num, const std::string& _comment = std::string(), const std::string& _unit = std::string()) : curValue(0) , num(_num) , timer(0) , lastPrintTimeUs(timer.getTimeUs()) , comment(_comment) , unit(_unit) {} void update(uint64_t cur) { curValue = cur; uint64_t curTimeUs = timer.getTimeUs(); if (curTimeUs - lastPrintTimeUs > 50000 || cur == 0 || curValue >= num) { // Print every 0.05 sec (20 Hz). lastPrintTimeUs = curTimeUs; std::string _comment = (comment.empty()) ? "" : comment + ": "; std::string _unit = (unit.empty()) ? "" : unit + " "; _msg("\r%s%ld/%ld %s(%.1f%%)", _comment.c_str(), curValue, num, _unit.c_str(), (double)curValue*100/num); fflush(stdout); } if (curValue >= num) { _msg("\n"); fflush(stdout); } } void done() { if (curValue < num) update(num); } private: uint64_t curValue; uint64_t num; Timer timer; uint64_t lastPrintTimeUs; std::string comment; std::string unit; }; // Progress that doesn't know the maximum value. class UnknownProgress { public: UnknownProgress(const std::string& _comment = std::string(), const std::string& _unit = std::string()) : curValue(0) , timer(0) , lastPrintTimeUs(timer.getTimeUs()) , comment(_comment) , unit(_unit) {} void update(uint64_t cur) { curValue = cur; uint64_t curTimeUs = timer.getTimeUs(); if ( curTimeUs - lastPrintTimeUs > 50000 || cur == 0 ) { // Print every 0.05 sec (20 Hz). lastPrintTimeUs = curTimeUs; std::string _comment = (comment.empty()) ? "" : comment + ": "; std::string _unit = (unit.empty()) ? "" : unit + " "; _msg("\r%s%ld %s", _comment.c_str(), curValue, _unit.c_str()); fflush(stdout); } } void done() { _msg("\n"); fflush(stdout); } private: uint64_t curValue; Timer timer; uint64_t lastPrintTimeUs; std::string comment; std::string unit; }; // === Displayer things ================================== class Displayer { public: Displayer(size_t num_raws, size_t num_cols) : numRaws(num_raws) , numCols(num_cols) , colWidth(num_cols, 20) , context(num_raws, std::vector(num_cols)) {} void init() { for (size_t ii=0; ii& src) { size_t num_src = src.size(); if (!num_src) return; for (size_t ii=0; ii= numRaws || col_idx >= numCols) return; thread_local char info_buf[32]; size_t len = 0; va_list args; va_start(args, format); len += vsnprintf(info_buf + len, 20 - len, format, args); va_end(args); context[raw_idx][col_idx] = info_buf; } void print() { _msg("\033[%zuA", numRaws); for (size_t ii=0; ii colWidth; std::vector< std::vector< std::string > > context; }; // === Gc things ==================================== template class GcVar { public: GcVar(T& _src, T2 _to) : src(_src), to(_to) {} ~GcVar() { // GC by value. src = to; } private: T& src; T2 to; }; class GcFunc { public: GcFunc(std::function _func) : func(_func) {} ~GcFunc() { // GC by function. func(); } private: std::function func; }; // === Thread things ==================================== struct ThreadArgs { /* Opaque. */ }; using ThreadFunc = std::function< int(ThreadArgs*) >; using ThreadExitHandler = std::function< void(ThreadArgs*) >; private: struct ThreadInternalArgs { ThreadInternalArgs() : userArgs(nullptr), func(nullptr), rc(0) {} ThreadArgs* userArgs; ThreadFunc func; int rc; }; public: struct ThreadHolder { ThreadHolder() : tid(nullptr), handler(nullptr) {} ThreadHolder(std::thread* _tid, ThreadExitHandler _handler) : tid(_tid), handler(_handler) {} ThreadHolder(ThreadArgs* u_args, ThreadFunc t_func, ThreadExitHandler t_handler) : tid(nullptr), handler(nullptr) { spawn(u_args, t_func, t_handler); } ~ThreadHolder() { join(true); } void spawn(ThreadArgs* u_args, ThreadFunc t_func, ThreadExitHandler t_handler) { if (tid) return; handler = t_handler; args.userArgs = u_args; args.func = t_func; tid = new std::thread(spawnThread, &args); } void join(bool force = false) { if (!tid) return; if (tid->joinable()) { if (force) { // Force kill. handler(args.userArgs); } tid->join(); } delete tid; tid = nullptr; } int getResult() const { return args.rc; } std::thread* tid; ThreadExitHandler handler; ThreadInternalArgs args; }; // === doTest things ==================================== // 1) Without parameter. void doTest( const std::string& test_name, test_func func ) { if (!matchFilter(test_name)) return; readyTest(test_name); TestSuite::getResMsg() = ""; TestSuite::getInfoMsg() = ""; TestSuite::getCurTest() = this; int ret = func(); reportTestResult(test_name, ret); } // 2) Ranged parameter. template void doTest( std::string test_name, F func, TestRange range ) { if (!matchFilter(test_name)) return; size_t n = (useGivenRange) ? 1 : range.getSteps(); size_t i; for (i=0; i void doTest( const std::string& test_name, F func, T1 arg1, T2... args ) { if (!matchFilter(test_name)) return; readyTest(test_name); TestSuite::getResMsg() = ""; TestSuite::getInfoMsg() = ""; TestSuite::getCurTest() = this; int ret = func(arg1, args...); reportTestResult(test_name, ret); } // 4) Multi composite parameters. template void doTest( const std::string& test_name, F func, TestArgsWrapper& args_wrapper ) { if (!matchFilter(test_name)) return; TestArgsBase* args = args_wrapper.getArgs(); args->setCallback(test_name, func, this); args->testAll(); } TestOptions options; private: void doTestCB( const std::string& test_name, test_func_args func, TestArgsBase* args ) { readyTest(test_name); TestSuite::getResMsg() = ""; TestSuite::getInfoMsg() = ""; TestSuite::getCurTest() = this; int ret = func(args); reportTestResult(test_name, ret); } static void spawnThread(ThreadInternalArgs* args) { args->rc = args->func(args->userArgs); } bool matchFilter(const std::string& test_name) { if (!filter.empty() && test_name.find(filter) == std::string::npos) { // Doesn't match with the given filter. return false; } return true; } void readyTest(const std::string& test_name) { printf("[ " "...." " ] %s\n", test_name.c_str()); if ( (options.printTestMessage || displayMsg) && !suppressMsg ) { printf(_CL_D_GRAY(" === TEST MESSAGE (BEGIN) ===\n")); } fflush(stdout); getTestName() = test_name; startTimeLocal = std::chrono::system_clock::now(); } void reportTestResult(const std::string& test_name, int result) { std::chrono::time_point cur_time = std::chrono::system_clock::now();; std::chrono::duration elapsed = cur_time - startTimeLocal; std::string time_str = usToString(elapsed.count() * 1000000); char msg_buf[1024]; std::string res_msg = TestSuite::getResMsg(); sprintf(msg_buf, "%s (" _CL_BROWN("%s") ")%s%s", test_name.c_str(), time_str.c_str(), (res_msg.empty() ? "" : ": "), res_msg.c_str() ); if (result < 0) { printf("[ " _CL_RED("FAIL") " ] %s\n", msg_buf); cntFail++; } else { if ( (options.printTestMessage || displayMsg) && !suppressMsg ) { printf(_CL_D_GRAY(" === TEST MESSAGE (END) ===\n")); } else { // Move a line up. printf("\033[1A"); // Clear current line. printf("\r"); // And then overwrite. } printf("[ " _CL_GREEN("PASS") " ] %s\n", msg_buf); cntPass++; } if ( result != 0 && (options.abortOnFailure || forceAbortOnFailure) ) { abort(); } getTestName().clear(); } size_t cntPass; size_t cntFail; std::string filter; bool useGivenRange; bool preserveTestFiles; bool forceAbortOnFailure; bool suppressMsg; bool displayMsg; int64_t givenRange; // Start time of each test. std::chrono::time_point startTimeLocal; // Start time of the entire test suite. std::chrono::time_point startTimeGlobal; }; // ===== Functor ===== struct TestArgsSetParamFunctor { template void operator()(T* t, TestRange& r, size_t param_idx) const { *t = r.getEntry(param_idx); } }; template inline typename std::enable_if::type TestArgsSetParamScan(int, std::tuple &, std::tuple...> &, FuncT, size_t) { } template inline typename std::enable_if::type TestArgsSetParamScan(int index, std::tuple& t, std::tuple...>& r, FuncT f, size_t param_idx) { if (index == 0) f(std::get(t), std::get(r), param_idx); TestArgsSetParamScan(index-1, t, r, f, param_idx); } struct TestArgsGetNumStepsFunctor { template void operator()(T* t, TestRange& r, size_t& steps_ret) const { (void)t; steps_ret = r.getSteps(); } }; template inline typename std::enable_if::type TestArgsGetStepsScan(int, std::tuple &, std::tuple...> &, FuncT, size_t) { } template inline typename std::enable_if::type TestArgsGetStepsScan(int index, std::tuple& t, std::tuple...>& r, FuncT f, size_t& steps_ret) { if (index == 0) f(std::get(t), std::get(r), steps_ret); TestArgsGetStepsScan(index-1, t, r, f, steps_ret); } #define TEST_ARGS_CONTENTS() \ void setParam(size_t param_no, size_t param_idx) { \ TestArgsSetParamScan(param_no, args, ranges, \ TestArgsSetParamFunctor(), \ param_idx); } \ size_t getNumSteps(size_t param_no) { \ size_t ret = 0; \ TestArgsGetStepsScan(param_no, args, ranges, \ TestArgsGetNumStepsFunctor(), \ ret); \ return ret; } \ size_t getNumParams() { \ return std::tuple_size::value; \ } // ===== TestArgsBase ===== void TestArgsBase::testAllInternal(size_t depth) { size_t i; size_t n_params = getNumParams(); size_t n_steps = getNumSteps(depth); for (i=0; idoTestCB(test_name, testFunction, this); } } } // ===== Parameter macros ===== #define DEFINE_PARAMS_2(name, \ type1, param1, range1, \ type2, param2, range2) \ class name ## _class : public TestArgsBase { \ public: \ name ## _class() { \ args = std::make_tuple(¶m1, ¶m2); \ ranges = std::make_tuple( \ TestRangerange1, \ TestRangerange2 ); \ } \ std::string toString() { \ std::stringstream ss; \ ss << param1 << ", " << param2; \ return ss.str(); \ } \ TEST_ARGS_CONTENTS() \ type1 param1; \ type2 param2; \ private: \ std::tuple args; \ std::tuple, TestRange> ranges; \ }; #define DEFINE_PARAMS_3(name, \ type1, param1, range1, \ type2, param2, range2, \ type3, param3, range3) \ class name ## _class : public TestArgsBase { \ public: \ name ## _class() { \ args = std::make_tuple(¶m1, ¶m2, ¶m3); \ ranges = std::make_tuple( \ TestRangerange1, \ TestRangerange2, \ TestRangerange3 ); \ } \ std::string toString() { \ std::stringstream ss; \ ss << param1 << ", " << param2 << ", " << param3; \ return ss.str(); \ } \ TEST_ARGS_CONTENTS() \ type1 param1; \ type2 param2; \ type3 param3; \ private: \ std::tuple args; \ std::tuple, \ TestRange, \ TestRange> ranges; \ }; #define SET_PARAMS(name) \ TestArgsWrapper name(new name ## _class()) #define GET_PARAMS(name) \ name ## _class* name = static_cast(TEST_args_base__) #define PARAM_BASE TestArgsBase* TEST_args_base__ #define TEST_SUITE_AUTO_PREFIX __func__ #define TEST_SUITE_PREPARE_PATH(path) \ const std::string _ts_auto_prefiix_ = TEST_SUITE_AUTO_PREFIX; \ TestSuite::clearTestFile(_ts_auto_prefiix_); \ path = TestSuite::getTestFileName(_ts_auto_prefiix_); #define TEST_SUITE_CLEANUP_PATH() \ TestSuite::clearTestFile( _ts_auto_prefiix_, \ TestSuite::END_OF_TEST );