Repository: ShisoftResearch/bifrost Branch: develop Commit: b826201b0482 Files: 48 Total size: 628.3 KB Directory structure: gitextract_g24w9kwe/ ├── .gitignore ├── .travis.yml ├── Cargo.toml ├── LICENSE ├── MEMBERSHIP_GUIDE.md ├── README.md ├── RECOVERY_IMPROVEMENTS.md ├── SNAPSHOT_GUIDE.md ├── TYPE2_LAZY_LOADING_CONTRACT.md ├── examples/ │ └── graceful_shutdown.rs ├── src/ │ ├── conshash/ │ │ ├── mod.rs │ │ └── weights.rs │ ├── hasher/ │ │ ├── Cargo.toml │ │ └── src/ │ │ └── lib.rs │ ├── lib.rs │ ├── membership/ │ │ ├── client.rs │ │ ├── member.rs │ │ ├── mod.rs │ │ └── server.rs │ ├── plugins/ │ │ ├── Cargo.toml │ │ └── src/ │ │ └── lib.rs │ ├── proc_macro/ │ │ ├── Cargo.toml │ │ └── src/ │ │ └── lib.rs │ ├── raft/ │ │ ├── client.rs │ │ ├── disk.rs │ │ ├── mod.rs │ │ └── state_machine/ │ │ ├── callback/ │ │ │ ├── client.rs │ │ │ ├── mod.rs │ │ │ └── server.rs │ │ ├── configs.rs │ │ ├── macros.rs │ │ ├── master.rs │ │ └── mod.rs │ ├── rpc/ │ │ ├── cluster.rs │ │ ├── mod.rs │ │ └── proto.rs │ ├── tcp/ │ │ ├── client.rs │ │ ├── mod.rs │ │ ├── server.rs │ │ └── shortcut.rs │ ├── utils/ │ │ ├── bindings.rs │ │ ├── math.rs │ │ ├── mod.rs │ │ ├── serde.rs │ │ └── time.rs │ └── vector_clock/ │ └── mod.rs └── tests/ ├── graceful_shutdown_tests.rs └── single_node_recovery_test.rs ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ target Cargo.lock .idea/ bifrost.iml ================================================ FILE: .travis.yml ================================================ language: rust rust: - nightly ================================================ FILE: Cargo.toml ================================================ [package] name = "bifrost" version = "0.1.0" authors = ["Hao Shi "] edition = "2018" [lib] name = "bifrost" [dependencies] serde_cbor = "0.11.1" serde_json = "1.0.51" byteorder = "1" log = "*" serde = { version = "1.0", features = ["derive"] } bifrost_plugins = { path = "src/plugins" } bifrost_hasher = { path = "src/hasher" } bifrost_proc_macro = { path = "src/proc_macro" } rand = "*" lazy_static = "*" threadpool = "1" num_cpus = "1" parking_lot = {version = "*", features = ["nightly"]} thread-id = "5" tokio = { version = "1", features = ["full"] } tokio-util = {version = "0.7", features = ["full"]} tokio-stream = "0.1" bytes = "1" crc32fast = "*" futures = {version = "0.3", features = ["executor", "thread-pool"] } futures-timer = "3" async-std = "1" lightning-containers = { git = "ssh://git@192.168.10.134/shisoft-x/Lightning.git", branch = "develop" } [dev-dependencies] env_logger = "*" ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 Hao Shi Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MEMBERSHIP_GUIDE.md ================================================ # Membership Guide This document explains how membership works in Bifrost and the difference between **Raft Cluster Membership** and the **Membership Service**. ## Two Types of Membership Bifrost has two distinct membership systems that serve different purposes: ### 1. Raft Cluster Membership (PERSISTED ✅) **Location**: `src/raft/state_machine/configs.rs` **Purpose**: Tracks which servers are part of the Raft consensus cluster **Persistence**: **YES** - Fully persisted to disk via: - Write-Ahead Log (WAL) - Snapshots **Members**: Raft servers that participate in consensus (leader election, log replication) **Operations**: - `new_member_(address)` - Add a Raft server to the cluster - `del_member_(address)` - Remove a Raft server from the cluster - `member_address()` - Query all Raft cluster members **Recovery**: On restart, Raft cluster membership is recovered from: 1. Latest snapshot on disk 2. WAL log replay **Why Persisted?**: Critical for Raft consensus. The cluster must know its membership to: - Calculate quorum (majority) - Elect leaders - Replicate logs correctly **Code Example**: ```rust // These members are persisted and recovered on restart service.join(&vec!["node1:5000".to_string()]).await; ``` ### 2. Membership Service (NOT PERSISTED ❌) **Location**: `src/membership/server.rs` **Purpose**: Tracks member groups, heartbeat status, and online/offline state **Persistence**: **NO** - Intentionally ephemeral **Members**: Applications or clients using the membership service for: - Group membership - Leader election within groups - Liveness tracking - Membership change notifications **Operations**: - `join(address)` - Join as a member - `leave(id)` - Leave the membership service - `join_group(group_name, id)` - Join a group - `leave_group(group, id)` - Leave a group - `ping(id)` - Send heartbeat **Recovery**: On restart, starts with **empty state** and rebuilds through: 1. Members calling `join()` again 2. Heartbeat `ping()` messages 3. Group operations **Why NOT Persisted?**: - Membership should reflect **current network reality** - Stale disk state would be misleading after crashes - Members must actively rejoin to prove they're alive - Groups are transient application-level constructs **Code Example**: ```rust // After restart, this state is gone - members must rejoin let client = MemberClient::new(...).await; client.join().await; // Must be called again after restart client.join_group("workers".to_string()).await; ``` ## Comparison Table | Feature | Raft Cluster Membership | Membership Service | |---------|------------------------|-------------------| | **Persisted** | ✅ Yes (WAL + Snapshot) | ❌ No (Always fresh) | | **Purpose** | Raft consensus | Application groups/heartbeats | | **Scope** | Cluster-wide | Per-service | | **Recovery** | From disk | From network rediscovery | | **State Machine ID** | `CONFIG_SM_ID` (1) | `DEFAULT_SERVICE_ID` | | **Critical for Raft** | ✅ Yes | ❌ No | | **Survives Restart** | ✅ Yes | ❌ No | ## How They Work Together ``` ┌─────────────────────────────────────────────────────────┐ │ Bifrost Cluster │ ├─────────────────────────────────────────────────────────┤ │ │ │ Raft Cluster Membership (Persisted) │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │ │ Server1 │ │ Server2 │ │ Server3 │ │ │ │ :5000 │ │ :5001 │ │ :5002 │ │ │ └────┬────┘ └────┬────┘ └────┬────┘ │ │ │ │ │ │ │ ├────────────┴────────────┤ │ │ │ Raft Consensus │ │ │ │ (Leader Election, │ │ │ │ Log Replication) │ │ │ └─────────────────────────┘ │ │ │ │ Membership Service (NOT Persisted - Fresh on restart) │ │ ┌─────────────────────────────────────────────┐ │ │ │ Members: { │ │ │ │ App1 -> online, groups: [workers] │ │ │ │ App2 -> online, groups: [workers] │ │ │ │ App3 -> offline, groups: [storage] │ │ │ │ } │ │ │ │ Groups: { │ │ │ │ workers -> leader: App1, members: [1,2] │ │ │ │ storage -> leader: None, members: [3] │ │ │ │ } │ │ │ └─────────────────────────────────────────────┘ │ │ ↑ This is cleared on restart │ └─────────────────────────────────────────────────────────┘ ``` ## Startup Sequence After Restart ### Raft Cluster Membership (Automatic) ```rust // Server restarts let raft_service = RaftService::new(options); RaftService::start(&raft_service).await; // ✅ Cluster membership automatically recovered from disk // ✅ Knows about Server1, Server2, Server3 // ✅ Can participate in consensus immediately ``` ### Membership Service (Manual Rejoin Required) ```rust // Server restarts let membership_client = MemberClient::new(...).await; // ❌ Membership service starts EMPTY // ❌ Previous groups/members are forgotten // ✅ Applications must rejoin membership_client.join().await; // Rejoin as member membership_client.join_group("workers").await; // Rejoin group membership_client.start_heartbeat(); // Start sending pings // ✅ Membership service rebuilds state from these actions ``` ## Why This Design? ### Raft Cluster Membership: Persisted - **Safety**: Raft consensus requires consistent membership for quorum - **Correctness**: Must survive crashes to maintain cluster integrity - **Availability**: Cluster can restart without manual intervention ### Membership Service: NOT Persisted - **Freshness**: Ensures membership reflects current reality - **Simplicity**: No stale data to reconcile - **Self-Healing**: Dead members naturally disappear (no heartbeat = offline) - **Flexibility**: Applications control their own membership lifecycle ## Code Examples ### Example 1: Raft Member Survives Restart ```rust // Initial setup let raft = RaftService::new(Options { storage: Storage::DISK(disk_opts), address: "node1:5000".to_string(), service_id: DEFAULT_SERVICE_ID, }); raft.join(&vec!["node2:5000".to_string()]).await; // --- CRASH AND RESTART --- // After restart let raft = RaftService::new(Options { storage: Storage::DISK(disk_opts), // Same disk path address: "node1:5000".to_string(), service_id: DEFAULT_SERVICE_ID, }); RaftService::start(&raft).await; // ✅ Still knows about node2:5000 (recovered from disk) // ✅ Can participate in cluster immediately ``` ### Example 2: Membership Service Starts Fresh ```rust // Initial setup let member = MemberClient::new(...).await; member.join().await; member.join_group("workers").await; // --- CRASH AND RESTART --- // After restart let member = MemberClient::new(...).await; // ❌ Not in any groups // ❌ Not registered as a member // Must rejoin explicitly member.join().await; // Required member.join_group("workers").await; // Required member.start_heartbeat(); // Required ``` ## Best Practices ### For Raft Cluster Members 1. Use `Storage::DISK` for production deployments 2. Membership changes are committed via Raft consensus 3. No need to rejoin after restart ### For Membership Service Users 1. **Always rejoin after restart**: ```rust async fn on_startup() { member_client.join().await; for group in my_groups { member_client.join_group(group).await; } member_client.start_heartbeat(); } ``` 2. **Handle disconnections gracefully** - May need to rejoin 3. **Monitor membership changes** via subscriptions: ```rust client.on_any_member_joined(|member, version| { println!("New member joined: {:?}", member); }).await; ``` ## Summary - **Raft Cluster Membership**: Persisted for consensus correctness ✅ - **Membership Service**: NOT persisted, learns from network ❌ - Both serve different purposes and have different persistence requirements - This design ensures both safety (for Raft) and freshness (for membership) ================================================ FILE: README.md ================================================ # bifrost [![Build Status](https://travis-ci.org/ShisoftResearch/bifrost.svg?branch=master)](https://travis-ci.org/ShisoftResearch/bifrost) Pure rust building block for distributed systems ### Objective The objective of bifrost is to build a solid foundation for distributed systems in rust. It is similar to one of my Clojure project [cluster-connecter](https://github.com/shisoft/cluster-connector), but no longer require any third-party software like Zookeeper or etcd. Bifrost will ship with it's own reliable data store based on [raft consensus algorithm](https://raft.github.io/) state machines. Users are also able to build their own reliable data structures by implementing state machine commands. **Bifrost is still in very early stage of development and it is not suggested to be used in any kinds of projects until it is stabilized and fully tested** ### Progress Check List - [ ] RPC - [x] TCP Server - [x] Protocol - [x] Event driven server - [x] Sync client - [x] Async client - [X] Multiplexing pluggable services - [X] Shortcut (for both TCP and RPC APIs) - [ ] Raft (data replication) - [x] Leader election - [x] Log replication - [x] Master/subs state machine framework - [ ] State machine client - [x] Sync - [x] PubSub - [ ] Master state machine snapshot - [x] Generate - [x] Install - [ ] Generate in chunks - [ ] Install in chunks - [ ] Automation - [ ] Persistent to disk - [ ] Recover from disk - [ ] Incremental snapshot - [ ] Membership changes - [x] State machine - [x] New Member - [x] Delete Member - [x] Snapshot - [x] Recover - [X] Interfaces - [X] Update procedures - [x] Cluster bootstrap - [x] Client - [x] Command - [x] Query - [x] Concurrency - [x] Failover - [x] Membership changes - [x] Subscription - [ ] Raft Group - [ ] Tests - [x] State machine framework - [x] Leader selection - [x] Log replication - [ ] Snapshot - [ ] Membership changes - [x] New member - [x] Delete member - [ ] Safety - [ ] Stress and benchmark - [ ] Stress + Safety - [ ] Sharding - [x] Consistent hash - [ ] Reliable data store - [x] Client group membership - [x] Client group leader election - [x] Map - [ ] Set - [ ] Array - [ ] Queue - [x] Value - [x] Number - [ ] Lock - [ ] Integration (API) - [ ] gPRC - [ ] Utility - [x] [Global bindings](https://clojuredocs.org/clojure.core/binding) - [x] Consistent hashing - [x] Vector clock ================================================ FILE: RECOVERY_IMPROVEMENTS.md ================================================ # Node Recovery and Temporary Failure Handling Improvements ## Overview This document describes the improvements made to handle nodes that temporarily miss heartbeats due to being under load, ensuring they can properly recover and rejoin the cluster either as a leader or follower. ## Problem Statement When running under heavy load, nodes may temporarily fail to respond to heartbeats, leading to: 1. **Premature offline marking**: Nodes marked offline after a single timeout 2. **Leadership churn**: Rapid leadership changes causing instability 3. **Flapping**: Nodes bouncing between online/offline states 4. **Panic on errors**: Unwrap() calls causing crashes during transient failures ## Solutions Implemented ### 1. Grace Period with Consecutive Failure Tracking **New Configuration Constants:** ```rust static MAX_TIMEOUT: i64 = 10_000; // 10 seconds before considering potentially offline static OFFLINE_GRACE_CHECKS: u8 = 3; // Require 3 consecutive failures before marking offline static ONLINE_GRACE_CHECKS: u8 = 2; // Require 2 consecutive successes before marking online static MIN_STATE_CHANGE_INTERVAL: i64 = 5_000; // Minimum 5 seconds between state changes ``` **Benefits:** - **Resilience**: Tolerates temporary hiccups (up to 3 timeout checks × 500ms = ~1.5 seconds grace) - **Anti-flapping**: Minimum 5 second interval prevents rapid state oscillation - **Smooth recovery**: Requires 2 consecutive successful heartbeats before marking node back online ### 2. Enhanced HeartbeatStatus Tracking **New HBStatus Fields:** ```rust struct HBStatus { last_updated: i64, online: bool, consecutive_failures: u8, // Count of consecutive timeout checks consecutive_successes: u8, // Count of consecutive successful checks last_state_change: i64, // Timestamp of last state transition } ``` **Behavior:** - **Online → Offline**: Tracks consecutive timeouts, only transitions after reaching threshold AND minimum interval - **Offline → Online**: Tracks consecutive successful heartbeats, transitions after reaching threshold AND minimum interval - **Stable states**: Resets counters when nodes are consistently responsive ### 3. Improved Error Handling All `.unwrap()` calls replaced with proper error handling: **Fixed Functions:** - `compose_client_member`: Now returns `Option` instead of panicking - `group_leader_candidate_available`: Logs errors instead of panicking - `group_leader_candidate_unavailable`: Handles all failure cases gracefully - `notify_for_member_*`: Early returns with error logging on failures - Mutex lock failures: Gracefully handled with error logging **Result:** - No more panics during transient failures - Clear error logs for debugging - System continues operating even when individual operations fail ### 4. Leadership Transfer Grace Period When leadership transfers (e.g., during reelection): ```rust async fn transfer_leadership(&self) { // Give all online members fresh timestamps // Reset all failure/success counters // Prevents immediate timeout after leadership change } ``` **Benefits:** - New leader gets time to stabilize before checking heartbeats - Prevents cascading failures during leadership transitions - All members get a "fresh start" under new leadership ## Recovery Scenarios ### Scenario 1: Node Under Temporary Load **Timeline:** 1. Node A is leader and becomes overloaded 2. Misses heartbeat at T+10s (consecutive_failures = 1) 3. Misses heartbeat at T+10.5s (consecutive_failures = 2) 4. Misses heartbeat at T+11s (consecutive_failures = 3) 5. **Now marked offline** (after 3 consecutive failures) 6. Leadership election: Node B becomes leader 7. Node A recovers, starts sending heartbeats again 8. Receives heartbeat at T+15s (consecutive_successes = 1) 9. Receives heartbeat at T+15.5s (consecutive_successes = 2) 10. **Marked back online** (after 2 consecutive successes AND 5s minimum interval) 11. Node A becomes follower of Node B **Key Points:** - ~1.5 second tolerance before marking offline (3 × 500ms checks) - Minimum 5 second offline period (anti-flapping protection) - Node A does NOT automatically reclaim leadership (stability) - Node A properly syncs as follower under Node B ### Scenario 2: Brief Network Hiccup **Timeline:** 1. Node experiences single timeout (consecutive_failures = 1) 2. Next heartbeat succeeds (consecutive_failures reset to 0) 3. **Node remains online** - no state change **Key Points:** - Single hiccups don't trigger state changes - Prevents unnecessary leadership elections - Maintains cluster stability ### Scenario 3: Persistent Failure **Timeline:** 1. Node genuinely fails (hardware/crash) 2. Consecutive failures accumulate: 1, 2, 3 3. Marked offline after 3 checks 4. Leadership transfers to healthy node 5. Eventually removed from cluster if doesn't recover **Key Points:** - Real failures still detected quickly (~1.5 seconds) - System continues with remaining healthy nodes - No false positives from temporary load ## Monitoring and Observability ### New Log Messages **During failure detection:** ``` DEBUG: Member 12345 timeout check 1/3 (10500ms since last update, 2000ms since last state change) DEBUG: Member 12345 timeout check 2/3 (11000ms since last update, 2500ms since last state change) WARN: Marking member 12345 as offline after 3 consecutive timeout checks (11500ms since last update) ``` **During recovery:** ``` DEBUG: Member 12345 recovery check 1/2 (3000ms since last state change) INFO: Marking member 12345 as back online after 2 consecutive successful checks ``` **Error scenarios:** ``` ERROR: Failed to change leader for group 789 to member 12345 ERROR: Failed to find online member for group 789 after member 12345 became unavailable ERROR: Failed to compose client member 12345 for online notification ``` ## Configuration Tuning You can adjust these constants based on your needs: - **Increase `OFFLINE_GRACE_CHECKS`**: More tolerance for slow responses (longer detection time) - **Decrease `OFFLINE_GRACE_CHECKS`**: Faster failure detection (less tolerance) - **Increase `MIN_STATE_CHANGE_INTERVAL`**: More aggressive anti-flapping (longer recovery time) - **Decrease `MIN_STATE_CHANGE_INTERVAL`**: Faster recovery (more risk of flapping) - **Increase `MAX_TIMEOUT`**: More lenient heartbeat requirements - **Decrease `MAX_TIMEOUT`**: Stricter heartbeat requirements ## Testing Recommendations 1. **Load testing**: Verify nodes can recover under realistic load 2. **Network partition**: Test with simulated network splits 3. **Chaos testing**: Randomly kill/restart nodes to test recovery paths 4. **Long-running stability**: Monitor for log growth and state flapping ## Backward Compatibility All changes are backward compatible: - Wire protocol unchanged - State machine behavior unchanged (only timing/resilience improved) - Existing clusters will benefit immediately upon upgrade ## Performance Impact - **Minimal CPU overhead**: Simple counter increments - **Minimal memory overhead**: 3 extra bytes per member (2 u8 counters + i64 timestamp) - **Reduced network churn**: Fewer unnecessary state changes = less Raft log entries - **Improved stability**: Less leadership churn = better overall performance ## Future Enhancements Potential future improvements: 1. **Configurable parameters**: Make timeouts/thresholds runtime-configurable 2. **Adaptive timeouts**: Adjust based on observed network latency 3. **Priority-based leader election**: Prefer certain nodes as leaders 4. **Health scoring**: Multi-factor health beyond just heartbeats 5. **Metrics export**: Prometheus/OpenTelemetry integration for monitoring ================================================ FILE: SNAPSHOT_GUIDE.md ================================================ # Snapshot, Checkpointing, and Recovery Guide ## Overview Bifrost's Raft implementation now includes production-ready snapshot, checkpointing, and recovery functionality. This prevents unbounded memory growth and enables fast recovery after restarts. ## Features ✅ **Automatic Snapshot Creation**: Triggered by configurable log count thresholds ✅ **Persistent Storage**: Atomic writes with CRC32 corruption detection ✅ **Crash Recovery**: Automatically loads snapshots on restart ✅ **Log Compaction**: Removes old logs from memory after snapshots ✅ **Follower Catch-up**: Automatically sends snapshots to lagging nodes ✅ **Corruption Handling**: Graceful fallback when snapshot files are corrupted ## Quick Start ### 1. Basic Setup with Snapshots ```rust use bifrost::raft::{RaftService, Options, Storage, DEFAULT_SERVICE_ID}; use bifrost::raft::disk::DiskOptions; use bifrost::rpc::Server; #[tokio::main] async fn main() { // Create Raft service with disk storage let service = RaftService::new(Options { storage: Storage::DISK(DiskOptions::new("/var/lib/myapp/raft".to_string())), address: "127.0.0.1:5000".to_string(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&"127.0.0.1:5000"); server.register_service(&service).await; Server::listen_and_resume(&server).await; // Start automatically recovers from snapshot if it exists! RaftService::start(&service).await; // Bootstrap or join cluster service.bootstrap().await; // OR: service.join(&vec!["existing-node:5000".to_string()]).await; } ``` ### 2. Custom Configuration ```rust use bifrost::raft::disk::DiskOptions; let custom_opts = DiskOptions { path: "/data/raft".to_string(), take_snapshots: true, // Enable snapshots append_logs: true, // Enable log persistence trim_logs: true, // Enable log trimming snapshot_log_threshold: 5000, // Snapshot every 5000 applied logs log_compaction_threshold: 10000, // Compact when > 10000 logs }; let service = RaftService::new(Options { storage: Storage::DISK(custom_opts), address: "127.0.0.1:5000".to_string(), service_id: DEFAULT_SERVICE_ID, }); ``` ## How It Works ### Automatic Snapshot Creation **When**: After the leader applies `snapshot_log_threshold` logs since the last snapshot **What happens**: 1. Leader generates snapshot from all state machines 2. Persists snapshot to disk with CRC32 checksum 3. Updates snapshot metadata (index, term) 4. Compacts old logs from memory (if > compaction_threshold) ```rust // Triggered automatically in try_sync_log_to_followers() // After successfully committing logs to followers if should_take_snapshot() { take_snapshot(); // Generates, persists, and compacts } ``` ### Startup Recovery **When**: Every time `RaftService::start()` is called **What happens**: 1. Checks if snapshot file exists on disk 2. Validates CRC32 checksum 3. Deserializes snapshot data 4. Calls `state_machine.recover(snapshot_data)` 5. Updates indices and metadata 6. Compacts logs already covered by snapshot 7. Continues normal operation ```rust // Automatically called in RaftService::start() load_snapshot_on_startup(); ``` ### Follower Catch-up with Snapshots **When**: A follower needs logs that the leader has already compacted **Scenarios**: - New node joining the cluster - Node was offline during log compaction - Node is too slow and fell far behind **What happens**: 1. Leader detects: `follower.next_index <= leader.last_snapshot_index` 2. Leader generates snapshot from state machines 3. Leader sends via `install_snapshot` RPC 4. Follower receives snapshot 5. Follower recovers state machine 6. Follower persists snapshot to disk 7. Follower compacts old logs 8. Follower continues with normal log replication ```rust // In send_follower_heartbeat() if follower.next_index <= last_snapshot_index { // Follower needs compacted logs - send snapshot let snapshot = master_sm.snapshot().unwrap(); rpc.install_snapshot( term, leader_id, last_snapshot_index, last_snapshot_term, snapshot ).await; } ``` ## Implementing Snapshotable State Machines Your state machines must implement `snapshot()` and `recover()`: ```rust use bifrost::raft::state_machine::StateMachineCtl; use futures::future::BoxFuture; use serde::{Serialize, Deserialize}; #[derive(Serialize, Deserialize)] struct MyState { counter: i64, data: HashMap, } struct MyStateMachine { state: MyState, } impl StateMachineCtl for MyStateMachine { fn id(&self) -> u64 { 42 } fn snapshot(&self) -> Option> { // Serialize your entire state let data = bincode::serialize(&self.state).ok()?; Some(data) } fn recover(&mut self, data: Vec) -> BoxFuture<()> { // Deserialize and restore state if !data.is_empty() { if let Ok(state) = bincode::deserialize(&data) { self.state = state; println!("State machine recovered: counter={}", self.state.counter); } } Box::pin(async {}) } // ... command and query handlers ... } ``` ## File Layout When using disk storage, the following files are created: ``` /var/lib/myapp/raft/ ├── log.dat # Persisted Raft log entries ├── snapshot.dat # Latest snapshot with CRC32 └── snapshot.dat.tmp # Temporary file during writes (atomic) ``` ### Snapshot File Format ``` [4 bytes] CRC32 checksum [8 bytes] Data length [N bytes] Serialized SnapshotEntity: { last_included_index: u64, last_included_term: u64, snapshot: Vec // Serialized state machine data } ``` ## Monitoring ### Check Snapshot Status ```rust let meta = service.read_meta().await; println!("Last snapshot: index={}, term={}", meta.last_snapshot_index, meta.last_snapshot_term); let num_logs = service.num_logs().await; println!("Logs in memory: {}", num_logs); ``` ### Manually Trigger Snapshot (Advanced) ```rust // Normally automatic, but you can manually trigger: let mut meta = service.write_meta().await; service.take_snapshot(&mut meta).await; ``` ## Safety Guarantees ### 1. **Crash Safety** - Atomic writes using temp file + rename pattern - If process crashes during snapshot write, old snapshot remains intact ### 2. **Corruption Detection** - CRC32 checksum on all snapshots - Corrupted snapshots are detected and ignored - System falls back to log-based recovery ### 3. **Raft Correctness** - Snapshots track correct term and index - No safety violations from compaction - Follows Raft paper specifications ### 4. **Consistency** - Followers always get consistent state via snapshots - State machine recovery is deterministic - All nodes eventually converge to same state ## Configuration Recommendations ### Small Applications (< 1000 ops/sec) ```rust snapshot_log_threshold: 1000, log_compaction_threshold: 2000, ``` ### Medium Applications (1000-10000 ops/sec) ```rust snapshot_log_threshold: 5000, log_compaction_threshold: 10000, ``` ### Large Applications (> 10000 ops/sec) ```rust snapshot_log_threshold: 10000, log_compaction_threshold: 20000, ``` ### Memory-Constrained Systems ```rust snapshot_log_threshold: 500, // Snapshot more frequently log_compaction_threshold: 1000, // Compact aggressively ``` ## Troubleshooting ### Issue: Logs keep growing **Solution**: Check that `take_snapshots: true` and thresholds are set appropriately ### Issue: Snapshot file not created **Solution**: - Verify disk permissions on path - Ensure state machines implement `snapshot()` correctly - Check logs for error messages ### Issue: Follower doesn't catch up **Solution**: - Check network connectivity - Verify `install_snapshot` RPC is working - Check follower logs for error messages ### Issue: Corrupted snapshot detected **Solution**: - Delete corrupted file, server will recover from logs - Investigate disk issues - Check for process crashes during snapshot writes ## Performance Considerations ### Snapshot Creation Cost - **Time**: O(state_size) to serialize state - **Disk I/O**: One sequential write - **Memory**: Temporary copy of state during serialization ### Log Compaction Cost - **Time**: O(logs_to_remove) to filter BTreeMap - **Memory**: Immediate reduction after compaction ### Recovery Cost - **Time**: O(state_size) to deserialize snapshot + O(remaining_logs) - **Disk I/O**: One sequential read ## Testing Run all snapshot tests: ```bash cargo test --lib test_snapshot test_log_compaction test_state_machine_snapshot test_install ``` Individual tests: - `test_snapshot_write_and_read` - I/O functionality - `test_snapshot_corruption_detection` - CRC validation - `test_log_compaction_removes_old_logs` - Memory reduction - `test_snapshot_threshold_configuration` - Threshold logic - `test_state_machine_snapshot_and_recovery` - SM serialization - `test_install_snapshot_compacts_logs` - Follower catch-up - `snapshot_disk_persistence` - End-to-end persistence - `snapshot_persistence_and_recovery` - Full recovery cycle ## Example: Multi-Server Deployment ```rust // server1.rs (Leader) let service = RaftService::new(Options { storage: Storage::DISK(DiskOptions::new("/data/node1".to_string())), address: "10.0.1.10:5000".to_string(), service_id: DEFAULT_SERVICE_ID, }); // ... setup ... service.bootstrap().await; // server2.rs (Follower) let service = RaftService::new(Options { storage: Storage::DISK(DiskOptions::new("/data/node2".to_string())), address: "10.0.1.11:5000".to_string(), service_id: DEFAULT_SERVICE_ID, }); // ... setup ... service.join(&vec!["10.0.1.10:5000".to_string()]).await; // server3.rs (New node joining later - will get snapshot automatically!) let service = RaftService::new(Options { storage: Storage::DISK(DiskOptions::new("/data/node3".to_string())), address: "10.0.1.12:5000".to_string(), service_id: DEFAULT_SERVICE_ID, }); // ... setup ... service.join(&vec!["10.0.1.10:5000".to_string()]).await; // ✅ Will automatically receive snapshot if logs are compacted! ``` ## Summary The Raft framework now has **industrial-grade snapshot and recovery capabilities**: - ✅ Automatic snapshot creation based on thresholds - ✅ Crash-safe atomic writes to disk - ✅ Automatic recovery on restart - ✅ Log compaction to prevent memory leaks - ✅ Automatic snapshot transfer to lagging/new nodes - ✅ Corruption detection and handling - ✅ Fully tested with comprehensive test suite New nodes joining the cluster automatically receive snapshots if they're too far behind - no manual intervention needed! ================================================ FILE: TYPE2_LAZY_LOADING_CONTRACT.md ================================================ # Type-2 Lazy Loading Contract This document defines the contract when Bifrost Type-1 does not own any Type-2 catalog or inventory metadata. ## Scope In this model: - Type-1 is only the raft group for the Type-1 plane. - Type-1 does not store which Type-2 planes should exist. - Type-2 planes are materialized only when a caller explicitly addresses a `plane_id`. - Persisted Type-2 state is recovered locally from each host's per-plane storage directory. ## What Bifrost Guarantees - `RaftService::plane(plane_id)` and `RaftService::ensure_plane(plane_id)` can lazily materialize a persisted or explicitly created Type-2 runtime. - A missing Type-2 plane does not fall back to Type-1 handlers. - `RaftService::loaded_type2_planes()` reports only the Type-2 runtimes currently materialized in memory on the local host. ## What Bifrost Does Not Guarantee - Bifrost cannot enumerate the authoritative set of Type-2 planes for a cluster. - Bifrost cannot derive the logical mapping from database, tenant, shard, or partition to `plane_id`. - Bifrost cannot infer the desired Type-2 member set unless that information is supplied during bootstrap or recovered from that plane's own persisted state. ## Required Upper-Layer Inputs If Type-1 owns no Type-2 metadata, the upper layer must provide all of the following: 1. A stable `plane_id` for every logical Type-2 plane. 2. The logical mapping from upper-layer objects to `plane_id`. 3. The initial Type-2 member addresses when a plane is first created. 4. The decision of when a plane should be opened, created, retried, or forgotten. 5. Any lifecycle versioning or generation rules if a logical plane can be recreated. ## Consequences For Nebuchadnezzar And Morpheus Nebuchadnezzar or Morpheus must act as the control plane for Type-2 discovery. In practice, that means they must: 1. Resolve the correct `plane_id` before calling Bifrost. 2. Provide the intended Type-2 member set during first-plane bootstrap. 3. Re-open Type-2 planes by `plane_id` during restart or recovery. 4. Treat `loaded_type2_planes()` as a local observability API, not as cluster inventory. ## Design Rule If authoritative Type-2 discovery is needed inside Bifrost, that is a different design: it requires a Type-1 plane catalog. Without that catalog, lazy loading is valid, but discovery must remain outside Type-1. ================================================ FILE: examples/graceful_shutdown.rs ================================================ /// Example demonstrating graceful shutdown of Bifrost services /// /// This example shows how to: /// 1. Start a Raft service with an RPC server /// 2. Handle shutdown signals (Ctrl+C) /// 3. Gracefully shutdown all services /// /// Run with: cargo run --example graceful_shutdown use bifrost::raft::{Options, RaftService, Storage, DEFAULT_SERVICE_ID}; use bifrost::rpc::Server; use std::sync::Arc; use tokio::signal; #[tokio::main] async fn main() { env_logger::init(); let address = "127.0.0.1:9000".to_string(); println!("Starting Bifrost services on {}...", address); // Create Raft service let raft_service = RaftService::new(Options { storage: Storage::MEMORY, address: address.clone(), service_id: DEFAULT_SERVICE_ID, }); // Create and start RPC server let server = Server::new(&address); Server::listen_and_resume(&server).await; server.register_service(&raft_service).await; // Start Raft service if RaftService::start(&raft_service, false).await { println!("Raft service started successfully"); raft_service.bootstrap().await; println!("Raft cluster bootstrapped"); } else { eprintln!("Failed to start Raft service"); return; } println!("\nServices running. Press Ctrl+C to trigger graceful shutdown...\n"); // Wait for Ctrl+C signal match signal::ctrl_c().await { Ok(()) => { println!("\n\nReceived Ctrl+C, initiating graceful shutdown...\n"); } Err(err) => { eprintln!("Unable to listen for shutdown signal: {}", err); return; } } // Gracefully shutdown all services println!("1. Shutting down Raft service..."); raft_service.shutdown().await; println!(" ✓ Raft service shut down"); println!("2. Shutting down RPC server..."); server.shutdown().await; println!(" ✓ RPC server shut down"); println!("\n✓ All services shut down gracefully\n"); // Give a moment for any final log messages tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; } ================================================ FILE: src/conshash/mod.rs ================================================ use futures::prelude::*; use std::collections::{BTreeMap, HashMap}; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use crate::conshash::weights::client::SMClient as WeightSMClient; use crate::conshash::weights::DEFAULT_SERVICE_ID; use crate::membership::client::{Member, ObserverClient as MembershipClient}; use crate::raft::client::{RaftClient, SubscriptionError, SubscriptionReceipt}; use crate::raft::state_machine::master::ExecError; use crate::utils::serde::serialize; use bifrost_hasher::{hash_bytes, hash_str}; use parking_lot::*; pub mod weights; #[derive(Debug)] pub enum Action { Joined, Left, } #[derive(Debug)] pub enum InitTableError { GroupNotExisted, NoWeightService(ExecError), NoWeightGroup, NoWeightInfo, Unknown, } #[derive(Debug)] pub enum CHError { WatchError(Result, ExecError>), InitTableError(InitTableError), } struct LookupTables { nodes: Vec, addrs: HashMap, } pub struct ConsistentHashing { tables: RwLock, membership: Arc, weight_sm_client: WeightSMClient, group_name: String, watchers: RwLock, &Vec) + Send + Sync>>>, update_lock: async_std::sync::Mutex<()>, version: AtomicU64, num_addrs: AtomicUsize, } impl ConsistentHashing { pub async fn new_with_id( id: u64, group: &str, raft_client: &Arc, membership_client: &Arc, ) -> Result, CHError> { let ch = Arc::new(ConsistentHashing { tables: RwLock::new(LookupTables { nodes: Vec::new(), addrs: HashMap::new(), }), membership: membership_client.clone(), weight_sm_client: WeightSMClient::new(id, &raft_client), group_name: group.to_string(), watchers: RwLock::new(Vec::new()), version: AtomicU64::new(0), update_lock: async_std::sync::Mutex::new(()), num_addrs: AtomicUsize::new(0), }); { let ch = ch.clone(); let res = membership_client .on_group_member_joined( move |(member, version)| { let ch = ch.clone(); server_joined(ch, member, version).boxed() }, group, ) .await; if let Ok(Ok(_)) = res { } else { return Err(CHError::WatchError(res)); } } { let ch = ch.clone(); let res = membership_client .on_group_member_online( move |(member, version)| { let ch = ch.clone(); server_joined(ch, member, version).boxed() }, group, ) .await; if let Ok(Ok(_)) = res { } else { return Err(CHError::WatchError(res)); } } { let ch = ch.clone(); let res = membership_client .on_group_member_left( move |(member, version)| { let ch = ch.clone(); server_left(ch, member, version).boxed() }, group, ) .await; if let Ok(Ok(_)) = res { } else { return Err(CHError::WatchError(res)); } } { let ch = ch.clone(); let res = membership_client .on_group_member_offline( move |(member, version)| { debug!( "Dected server member {:?} offline at version {}", member, version ); let ch = ch.clone(); server_left(ch, member, version).boxed() }, group, ) .await; if let Ok(Ok(_)) = res { } else { return Err(CHError::WatchError(res)); } } Ok(ch) } pub async fn new( group: &str, raft_client: &Arc, membership_client: &Arc, ) -> Result, CHError> { Self::new_with_id(DEFAULT_SERVICE_ID, group, raft_client, membership_client).await } pub async fn new_client( group: &str, raft_client: &Arc, membership_client: &Arc, ) -> Result, CHError> { Self::new_client_with_id(DEFAULT_SERVICE_ID, group, raft_client, membership_client).await } pub async fn new_client_with_id( id: u64, group: &str, raft_client: &Arc, membership_client: &Arc, ) -> Result, CHError> { match ConsistentHashing::new_with_id(id, group, raft_client, membership_client).await { Err(e) => Err(e), Ok(ch) => match ch.init_table().await { Err(e) => Err(CHError::InitTableError(e)), Ok(_) => Ok(ch.clone()), }, } } pub fn to_server_name(&self, server_id: u64) -> String { let lookup_table = self.tables.read(); trace!("Lookup table has {:?}", lookup_table.addrs); if let Some(name) = lookup_table.addrs.get(&server_id) { name.to_owned() } else { panic!("Cannot find server name for server id {}", server_id); } } pub fn to_server_name_option(&self, server_id: Option) -> Option { if let Some(sid) = server_id { let lookup_table = self.tables.read(); lookup_table.addrs.get(&sid).cloned() } else { None } } pub fn get_server_id(&self, hash: u64) -> Option { let lookup_table = self.tables.read(); let nodes = &lookup_table.nodes; let slot_count = nodes.len(); if slot_count == 0 { return None; } let result = nodes.get(self.jump_hash(slot_count, hash)); // trace!("Hash {} have been point to {:?}", hash, result); result.cloned() } pub fn jump_hash(&self, slot_count: usize, hash: u64) -> usize { let mut b: i64 = -1; let mut j: i64 = 0; let mut h = hash; while j < (slot_count as i64) { b = j; h = h.wrapping_mul(2862933555777941757).wrapping_add(1); j = (((b.wrapping_add(1)) as f64) * ((1i64 << 31) as f64) / (((h >> 33).wrapping_add(1)) as f64)) as i64; } // trace!( // "Jump hash point to index {} for {}, with slots {}", // b, // hash, // slot_count // ); b as usize } pub fn get_server(&self, hash: u64) -> Option { self.to_server_name_option(self.get_server_id(hash)) } pub fn get_server_by_string(&self, string: &String) -> Option { self.get_server(hash_str(string)) } pub fn get_server_by(&self, obj: &T) -> Option where T: serde::Serialize, { self.get_server(hash_bytes(serialize(obj).as_slice())) } pub fn get_server_id_by_string(&self, string: &String) -> Option { self.get_server_id(hash_str(string)) } pub fn get_server_id_by(&self, obj: &T) -> Option where T: serde::Serialize, { self.get_server_id(hash_bytes(serialize(obj).as_slice())) } pub fn rand_server(&self) -> Option { let rand = rand::random::(); self.get_server(rand) } pub fn rand_server_id(&self) -> Option { let rand = rand::random::(); self.get_server_id(rand) } pub fn nodes_count(&self) -> usize { let lookup_table = self.tables.read(); return lookup_table.nodes.len(); } pub fn server_count(&self) -> usize { return self.num_addrs.load(Ordering::Relaxed); } pub async fn set_weight(&self, server_name: &String, weight: u64) -> Result<(), ExecError> { let group_id = hash_str(&self.group_name); let server_id = hash_str(server_name); self.weight_sm_client .set_weight(&group_id, &server_id, &weight) .await } fn watch_all_actions(&self, f: F) where F: Fn(&Member, &Action, &Vec, &Vec) + 'static + Send + Sync, { let mut watchers = self.watchers.write(); watchers.push(Box::new(f)); } pub fn watch_server_nodes_range_changed(&self, server: &String, f: F) // return ranges [...,...) where F: Fn((usize, u32)) + 'static + Send + Sync, { let server_id = hash_str(server); let wrapper = move |_: &Member, _: &Action, nodes: &Vec, _: &Vec| { let node_len = nodes.len(); let mut weight = 0; let mut start = None; for ni in 0..node_len { let node = nodes[ni]; if node == server_id { weight += 1; if start.is_none() { start = Some(ni) } } } if let Some(start_node) = start { f((start_node, weight)); } else { warn!("No node exists for watch"); } }; self.watch_all_actions(wrapper); } pub async fn init_table(&self) -> Result<(), InitTableError> { let group_name = &self.group_name; let _lock = self.update_lock.lock().await; debug!( "Initializing table from membership group members for {}", group_name ); debug!("Get group members for {}", group_name); let group_members = self.membership.group_members(group_name, true).await; debug!("Got group members for {}", group_name); if let Ok(Some((members, version))) = group_members { let group_id = hash_str(group_name); debug!("Getting weights for {}", group_name); match self.weight_sm_client.get_weights(&group_id).await { Ok(Some(weights)) => { debug!("Group {} have {} weights", group_name, weights.len()); if let Some(min_weight) = weights.values().min() { let mut factors: BTreeMap = BTreeMap::new(); let min_weight = *min_weight as f64; for member in members.iter() { let k = member.id; let w = match weights.get(&k) { Some(w) => *w as f64, None => min_weight, }; factors.insert(k, (w / min_weight) as u32); } let factor_sum: u32 = factors.values().sum(); let mut lookup_table = self.tables.write(); lookup_table.nodes = Vec::with_capacity(factor_sum as usize); for member in members.iter() { lookup_table.addrs.insert(member.id, member.address.clone()); } for (server_id, weight) in factors.into_iter() { for _ in 0..weight { lookup_table.nodes.push(server_id); } } self.num_addrs.store(members.len(), Ordering::Relaxed); self.version.store(version, Ordering::Relaxed); Ok(()) } else { Err(InitTableError::NoWeightInfo) } } Err(e) => { error!("No weright service for group {}", group_name); Err(InitTableError::NoWeightService(e)) } Ok(None) => { error!("No weight group for group {}", group_name); Err(InitTableError::NoWeightGroup) } } } else { error!( "No group {} existed in table, groups {:?}", group_name, self.membership.all_groups().await ); Err(InitTableError::GroupNotExisted) } } #[inline] pub fn membership(&self) -> &Arc { &self.membership } } async fn server_joined(ch: Arc, member: Member, version: u64) { server_changed(ch, member, Action::Joined, version).await; } async fn server_left(ch: Arc, member: Member, version: u64) { server_changed(ch, member, Action::Left, version).await; } async fn server_changed(ch: Arc, member: Member, action: Action, version: u64) { warn!( "Detected server membership change, member {:?}, action {:?}, version {}", member, action, version ); let ch_version = ch.version.load(Ordering::Relaxed); if ch_version <= version { { debug!("Obtaining conshash table write lock"); let old_nodes = (&*ch.tables.read()).nodes.clone(); debug!("Reinit conshash table"); let reinit_res = ch.init_table().await; if let Err(e) = &reinit_res { error!( "Cannot reinit table {:?}, member {:?}, action {:?}, version {} -> {}", e, member, action, ch_version, version ); } debug!("Triggering conshash watchers"); let new_nodes = (&*ch.tables.read()).nodes.clone(); for watch in ch.watchers.read().iter() { watch(&member, &action, &new_nodes, &old_nodes); } } debug!("Server change processing completed"); } else { warn!("Server membership change too old to follow, member {:?}, action {:?}, version {}, expect {}", member, action, version, ch_version); } } #[cfg(test)] mod test { use crate::conshash::weights::Weights; use crate::conshash::ConsistentHashing; use crate::membership::client::ObserverClient; use crate::membership::member::MemberService; use crate::membership::server::Membership; use crate::raft::client::RaftClient; use crate::raft::{Options, RaftService, Storage}; use crate::rpc::Server; use crate::utils::time::async_wait_secs; use std::collections::HashMap; use std::sync::atomic::*; use std::sync::Arc; #[tokio::test(flavor = "multi_thread")] async fn primary() { let _ = env_logger::try_init(); info!("Creating raft service"); let addr = String::from("127.0.0.1:2200"); let raft_service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: 0, }); info!("Creating server"); let server = Server::new(&addr); info!("Creating membership service"); let _membership = Membership::new(&server, &raft_service).await; server.register_service_with_id(0, &raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.bootstrap().await; let group_1 = String::from("test_group_1"); let group_2 = String::from("test_group_2"); let group_3 = String::from("test_group_3"); let server_1 = String::from("server1"); let server_2 = String::from("server2"); let server_3 = String::from("server3"); info!("Create raft client"); let wild_raft_client = RaftClient::new(&vec![addr.clone()], 0).await.unwrap(); info!("Create observer"); let observer_client = Arc::new(ObserverClient::new(&wild_raft_client)); info!("Create subscription"); RaftClient::prepare_subscription(&server).await; info!("New group 1"); observer_client.new_group(&group_1).await.unwrap().unwrap(); info!("New group 2"); observer_client.new_group(&group_2).await.unwrap().unwrap(); info!("New group 3"); observer_client.new_group(&group_3).await.unwrap().unwrap(); info!("New raft client for member 1"); let member1_raft_client = RaftClient::new(&vec![addr.clone()], 0).await.unwrap(); info!("New member service 1"); let member1_svr = MemberService::new(&server_1, &member1_raft_client, &raft_service).await; info!("New raft client for member 2"); let member2_raft_client = RaftClient::new(&vec![addr.clone()], 0).await.unwrap(); info!("New member service 2"); let member2_svr = MemberService::new(&server_2, &member2_raft_client, &raft_service).await; info!("New raft client for member 3"); let member3_raft_client = RaftClient::new(&vec![addr.clone()], 0).await.unwrap(); info!("New member service 3"); let member3_svr = MemberService::new(&server_3, &member3_raft_client, &raft_service).await; info!("Member 1 join group 1"); member1_svr.join_group(&group_1).await.unwrap(); info!("Member 1 join group 2"); member2_svr.join_group(&group_1).await.unwrap(); info!("Member 1 join group 3"); member3_svr.join_group(&group_1).await.unwrap(); info!("Member 1 join group 2"); member1_svr.join_group(&group_2).await.unwrap(); info!("Member 2 join group 2"); member2_svr.join_group(&group_2).await.unwrap(); info!("Member 1 join group 3"); member1_svr.join_group(&group_3).await.unwrap(); info!("New weight service"); Weights::new(&raft_service).await; info!("New conshash for group 1"); let ch1 = ConsistentHashing::new(&group_1, &wild_raft_client, &observer_client) .await .unwrap(); info!("New conshash for group 2"); let ch2 = ConsistentHashing::new(&group_2, &wild_raft_client, &observer_client) .await .unwrap(); info!("New conshash for group 3"); let ch3 = ConsistentHashing::new(&group_3, &wild_raft_client, &observer_client) .await .unwrap(); info!("Set server 1 in group 1 to 1"); ch1.set_weight(&server_1, 1).await.unwrap(); info!("Set server 2 in group 1 to 2"); ch1.set_weight(&server_2, 2).await.unwrap(); info!("Set server 3 in group 1 to 3"); ch1.set_weight(&server_3, 3).await.unwrap(); info!("Set server 1 in group 2 to 1"); ch2.set_weight(&server_1, 1).await.unwrap(); info!("Set server 2 in group 2 to 1"); ch2.set_weight(&server_2, 1).await.unwrap(); info!("Set server 1 in group 3 to 2"); ch3.set_weight(&server_1, 2).await.unwrap(); info!("Init table for conshash 1"); ch1.init_table().await.unwrap(); info!("Init table for conshash 2"); ch2.init_table().await.unwrap(); info!("Init table for conshash 3"); ch3.init_table().await.unwrap(); let ch1_server_node_changes_count = Arc::new(AtomicUsize::new(0)); let ch1_server_node_changes_count_clone = ch1_server_node_changes_count.clone(); info!("Watch node change from conshash 1"); ch1.watch_server_nodes_range_changed(&server_2, move |_| { ch1_server_node_changes_count_clone.fetch_add(1, Ordering::Relaxed); }); let ch2_server_node_changes_count = Arc::new(AtomicUsize::new(0)); let ch2_server_node_changes_count_clone = ch2_server_node_changes_count.clone(); info!("Watch node change from conshash 2"); ch2.watch_server_nodes_range_changed(&server_2, move |_| { ch2_server_node_changes_count_clone.fetch_add(1, Ordering::Relaxed); }); let ch3_server_node_changes_count = Arc::new(AtomicUsize::new(0)); let ch3_server_node_changes_count_clone = ch3_server_node_changes_count.clone(); info!("Watch node change from conshash 3"); ch3.watch_server_nodes_range_changed(&server_2, move |_| { ch3_server_node_changes_count_clone.fetch_add(1, Ordering::Relaxed); }); info!("Counting nodes for conshash 1"); assert_eq!(ch1.nodes_count(), 6); info!("Counting nodes for conshash 2"); assert_eq!(ch2.nodes_count(), 2); info!("Counting nodes for conshash 3"); assert_eq!(ch3.nodes_count(), 1); info!("Batch get server by string from conshash 1"); let mut ch_1_mapping: HashMap = HashMap::new(); let data_set_size: usize = 30000; for i in 0..data_set_size { let k = format!("k - {}", i); let server = ch1.get_server_by_string(&k).unwrap(); *ch_1_mapping.entry(server.clone()).or_insert(0) += 1; } info!("Counting distribution for conshash 1"); assert_eq!(ch_1_mapping.get(&server_1).unwrap(), &4936); assert_eq!(ch_1_mapping.get(&server_2).unwrap(), &9923); assert_eq!(ch_1_mapping.get(&server_3).unwrap(), &15141); // hard coded due to constant info!("Batch get server by string from conshash 2"); let mut ch_2_mapping: HashMap = HashMap::new(); for i in 0..data_set_size { let k = format!("k - {}", i); let server = ch2.get_server_by_string(&k).unwrap(); *ch_2_mapping.entry(server.clone()).or_insert(0) += 1; } info!("Counting distribution for conshash 2"); assert_eq!(ch_2_mapping.get(&server_1).unwrap(), &14967); assert_eq!(ch_2_mapping.get(&server_2).unwrap(), &15033); info!("Batch get server by string from conshash 3"); let mut ch_3_mapping: HashMap = HashMap::new(); for i in 0..data_set_size { let k = format!("k - {}", i); let server = ch3.get_server_by_string(&k).unwrap(); *ch_3_mapping.entry(server.clone()).or_insert(0) += 1; } info!("Counting distribution for conshash 3"); assert_eq!(ch_3_mapping.get(&server_1).unwrap(), &30000); info!("Close member 1"); member1_svr.close(); info!("Waiting"); for i in 0..10 { async_wait_secs().await; } let mut ch_1_mapping: HashMap = HashMap::new(); info!("Recheck get server by string for conshash 1"); for i in 0..data_set_size { let k = format!("k - {}", i); let server = ch1.get_server_by_string(&k).unwrap(); *ch_1_mapping.entry(server.clone()).or_insert(0) += 1; } info!("Recount distribution for conshash 1"); assert_eq!( ch_1_mapping.get(&server_2).unwrap() + ch_1_mapping.get(&server_3).unwrap(), data_set_size as u64 ); assert_eq!(ch_1_mapping.get(&server_2).unwrap(), &11932); assert_eq!(ch_1_mapping.get(&server_3).unwrap(), &18068); let mut ch_2_mapping: HashMap = HashMap::new(); info!("Recheck get server by string for conshash 2"); for i in 0..data_set_size { let k = format!("k - {}", i); let server = ch2.get_server_by_string(&k).unwrap(); *ch_2_mapping.entry(server.clone()).or_insert(0) += 1; } info!("Recount distribution for conshash 2"); assert_eq!( ch_2_mapping.get(&server_2).unwrap(), &(data_set_size as u64) ); info!("Cheching conshash 3 with no members"); for i in 0..data_set_size { let k = format!("k - {}", i); assert!(ch3.get_server_by_string(&k).is_none()); // no member } info!("Waiting"); async_wait_secs().await; async_wait_secs().await; info!("Testing callback counter"); assert_eq!(ch1_server_node_changes_count.load(Ordering::Relaxed), 1); assert_eq!(ch2_server_node_changes_count.load(Ordering::Relaxed), 1); assert_eq!(ch3_server_node_changes_count.load(Ordering::Relaxed), 0); info!("Membership tests all done !"); } } ================================================ FILE: src/conshash/weights.rs ================================================ use crate::raft::state_machine::StateMachineCtl; use crate::raft::RaftService; use bifrost_plugins::hash_ident; use futures::FutureExt; use std::collections::HashMap; use std::sync::Arc; pub static DEFAULT_SERVICE_ID: u64 = hash_ident!(BIFROST_DHT_WEIGHTS) as u64; raft_state_machine! { def cmd set_weight(group: u64, id: u64, weight: u64); def qry get_weights(group: u64) -> Option>; def qry get_weight(group: u64, id: u64) -> Option; } pub struct Weights { pub groups: HashMap>, pub id: u64, } impl StateMachineCmds for Weights { fn set_weight(&mut self, group: u64, id: u64, weight: u64) -> BoxFuture<()> { *self .groups .entry(group) .or_insert_with(|| HashMap::new()) .entry(id) .or_insert_with(|| 0) = weight; future::ready(()).boxed() } fn get_weights(&self, group: u64) -> BoxFuture>> { future::ready(match self.groups.get(&group) { Some(m) => Some(m.clone()), None => None, }) .boxed() } fn get_weight(&self, group: u64, id: u64) -> BoxFuture> { future::ready(match self.groups.get(&group) { Some(m) => match m.get(&id) { Some(w) => Some(*w), None => None, }, None => None, }) .boxed() } } impl StateMachineCtl for Weights { raft_sm_complete!(); fn id(&self) -> u64 { self.id } fn snapshot(&self) -> Vec { crate::utils::serde::serialize(&self.groups) } fn recover(&mut self, data: Vec) -> BoxFuture<()> { match crate::utils::serde::deserialize::>>(data.as_slice()) { Some(groups) => self.groups = groups, None => { error!("Failed to deserialize weights state machine snapshot. Starting with empty groups."); self.groups.clear(); } } future::ready(()).boxed() } fn recoverable(&self) -> bool { true } } impl Weights { pub async fn new_with_id(id: u64, raft_service: &Arc) { raft_service .register_state_machine(Box::new(Weights { groups: HashMap::new(), id, })) .await } pub async fn new(raft_service: &Arc) { Self::new_with_id(DEFAULT_SERVICE_ID, raft_service).await } } ================================================ FILE: src/hasher/Cargo.toml ================================================ [package] name = "bifrost_hasher" version = "0.1.0" authors = ["Hao Shi "] [lib] name = "bifrost_hasher" [dependencies] twox-hash = "1" ================================================ FILE: src/hasher/src/lib.rs ================================================ use std::collections::hash_map::DefaultHasher; use std::hash::Hasher; extern crate twox_hash; pub fn hash_bytes(bytes: &[u8]) -> u64 { let mut hasher = twox_hash::XxHash::default(); hasher.write(bytes); hasher.finish() } pub fn hash_str<'a>(text: &'a str) -> u64 { // the same as the one in utils hash let text_bytes = text.as_bytes(); hash_bytes(text_bytes) } pub fn hash_bytes_secondary(bytes: &[u8]) -> u64 { let mut hasher = DefaultHasher::default(); hasher.write(bytes); hasher.finish() } ================================================ FILE: src/lib.rs ================================================ #![crate_type = "lib"] #![feature(proc_macro_hygiene)] #![feature(trait_alias)] #[cfg(disable_shortcut)] pub static DISABLE_SHORTCUT: bool = true; #[cfg(not(disable_shortcut))] pub static DISABLE_SHORTCUT: bool = false; #[macro_use] pub mod utils; pub mod tcp; #[macro_use] pub mod rpc; #[macro_use] pub mod raft; pub mod conshash; pub mod membership; pub mod vector_clock; #[macro_use] extern crate log; #[macro_use] extern crate lazy_static; pub extern crate bytes; ================================================ FILE: src/membership/client.rs ================================================ use crate::membership::raft::client::SMClient; use crate::membership::DEFAULT_SERVICE_ID; use crate::raft::client::{RaftClient, SubscriptionError, SubscriptionReceipt}; use crate::raft::state_machine::master::ExecError; use bifrost_hasher::hash_str; use futures::future::BoxFuture; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::sync::Arc; use super::server::MemberGroup; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Member { pub id: u64, pub address: String, pub online: bool, } #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Group { pub id: u64, pub name: String, pub members: u64, } pub struct MemberClient { pub id: u64, pub sm_client: Arc, } impl MemberClient { pub async fn join_group(&self, group: &String) -> Result { self.sm_client.join_group(group, &self.id).await } pub async fn leave_group(&self, group: &String) -> Result { self.sm_client.leave_group(&hash_str(group), &self.id).await } } pub struct ObserverClient { pub sm_client: Arc, } impl ObserverClient { pub fn new(raft_client: &Arc) -> ObserverClient { ObserverClient { sm_client: Arc::new(SMClient::new(DEFAULT_SERVICE_ID, &raft_client)), } } pub fn new_from_sm(sm_client: &Arc) -> ObserverClient { ObserverClient { sm_client: sm_client.clone(), } } pub async fn new_group(&self, name: &String) -> Result, ExecError> { self.sm_client.new_group(name).await } pub async fn del_group(&self, name: &String) -> Result { self.sm_client.del_group(&hash_str(name)).await } pub async fn group_leader( &self, group: &String, ) -> Result, u64)>, ExecError> { self.sm_client.group_leader(&hash_str(group)).await } pub async fn group_members( &self, group: &String, online_only: bool, ) -> Result, u64)>, ExecError> { self.sm_client .group_members(&hash_str(group), &online_only) .await } pub async fn all_members(&self, online_only: bool) -> Result<(Vec, u64), ExecError> { self.sm_client.all_members(&online_only).await } pub async fn on_group_member_offline( &self, f: F, group: &str, ) -> Result, ExecError> where F: Fn((Member, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client .on_group_member_offline(f, &hash_str(group)) .await } pub async fn on_any_member_offline( &self, f: F, ) -> Result, ExecError> where F: Fn((Member, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client.on_any_member_offline(f).await } pub async fn on_group_member_online( &self, f: F, group: &str, ) -> Result, ExecError> where F: Fn((Member, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client .on_group_member_online(f, &hash_str(group)) .await } pub async fn on_any_member_online( &self, f: F, ) -> Result, ExecError> where F: Fn((Member, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client.on_any_member_online(f).await } pub async fn on_group_member_joined( &self, f: F, group: &str, ) -> Result, ExecError> where F: Fn((Member, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client .on_group_member_joined(f, &hash_str(group)) .await } pub async fn on_any_member_joined( &self, f: F, ) -> Result, ExecError> where F: Fn((Member, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client.on_any_member_joined(f).await } pub async fn on_group_member_left( &self, f: F, group: &str, ) -> Result, ExecError> where F: Fn((Member, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client .on_group_member_left(f, &hash_str(group)) .await } pub async fn on_any_member_left( &self, f: F, ) -> Result, ExecError> where F: Fn((Member, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client.on_any_member_left(f).await } pub async fn on_group_leader_changed( &self, f: F, group: &String, ) -> Result, ExecError> where F: Fn((Option, Option, u64)) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.sm_client .on_group_leader_changed(f, &hash_str(group)) .await } pub async fn all_groups(&self) -> Result, ExecError> { self.sm_client.all_groups().await } } ================================================ FILE: src/membership/member.rs ================================================ use super::client::{MemberClient, ObserverClient}; use super::heartbeat_rpc::*; use super::raft::client::SMClient; use bifrost_hasher::hash_str; use futures::prelude::*; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::{runtime, time}; use crate::membership::DEFAULT_SERVICE_ID; use crate::raft::client::RaftClient; use crate::raft::state_machine::master::ExecError; use crate::raft::RaftService; use crate::utils::time::get_time; static PING_INTERVAL: u64 = 500; pub struct MemberService { member_client: MemberClient, sm_client: Arc, raft_client: Arc, closed: AtomicBool, id: u64, } impl MemberService { pub async fn new( server_address: &String, raft_client: &Arc, raft_service: &Arc, ) -> Arc { let server_id = hash_str(server_address); let sm_client = Arc::new(SMClient::new(DEFAULT_SERVICE_ID, &raft_client)); let service = Arc::new(MemberService { sm_client: sm_client.clone(), member_client: MemberClient { id: server_id, sm_client: sm_client.clone(), }, raft_client: raft_client.clone(), closed: AtomicBool::new(false), id: server_id, }); let _join_res = sm_client.join(&server_address).await; let service_clone = service.clone(); raft_service.rt.spawn(async move { while !service_clone.closed.load(Ordering::Relaxed) { let start_time = get_time(); let rpc_client = service_clone.raft_client.current_leader_rpc_client().await; if let Ok(rpc_client) = rpc_client { let _ping_res = ImmeServiceClient::ping(DEFAULT_SERVICE_ID, &rpc_client, service_clone.id) .await; } else { error!("Cannot find RPC client for membership heartbeat to leader"); } let time_now = get_time(); let elapsed_time = time_now - start_time; trace!( "Membership ping at time {}, elapsed {}ms", time_now, elapsed_time ); if (elapsed_time as u64) < PING_INTERVAL { let wait_time = PING_INTERVAL - elapsed_time as u64; trace!("Waiting membership heartbeat for {}ms", wait_time); time::sleep(time::Duration::from_millis(wait_time)).await; } } debug!("Member service closed"); }); return service; } pub fn close(&self) { self.closed.store(true, Ordering::Relaxed); } pub async fn leave(&self) -> Result { self.close(); self.sm_client.leave(&self.id).await } pub async fn join_group(&self, group: &String) -> Result { self.member_client.join_group(group).await } pub async fn leave_group(&self, group: &String) -> Result { self.member_client.leave_group(group).await } pub fn client(&self) -> ObserverClient { ObserverClient::new_from_sm(&self.sm_client) } pub fn get_server_id(&self) -> u64 { self.id } } impl Drop for MemberService { fn drop(&mut self) { let sm_client = self.sm_client.clone(); let self_id = self.id; tokio::spawn(async move { sm_client.leave(&self_id).await }.boxed()); } } ================================================ FILE: src/membership/mod.rs ================================================ // Group membership manager regardless actual raft members pub mod client; pub mod member; pub mod server; use crate::membership::client::Member as ClientMember; use bifrost_plugins::hash_ident; pub static DEFAULT_SERVICE_ID: u64 = hash_ident!(BIFROST_MEMBERSHIP_SERVICE) as u64; pub mod raft { use super::server::MemberGroup; use super::*; use std::collections::BTreeMap; raft_state_machine! { def cmd hb_online_changed(online: Vec, offline: Vec); def cmd join(address: String) -> Option; def cmd leave(id: u64) -> bool; def cmd join_group(group_name: String, id: u64) -> bool; def cmd leave_group(group: u64, id: u64) -> bool; def cmd new_group(name: String) -> Result; def cmd del_group(id: u64) -> bool; def qry group_leader(group: u64) -> Option<(Option, u64)>; def qry group_members (group: u64, online_only: bool) -> Option<(Vec, u64)>; def qry all_members (online_only: bool) -> (Vec, u64); def qry all_groups() -> BTreeMap; def sub on_group_member_offline(group: u64) -> (ClientMember, u64); // def sub on_any_member_offline() -> (ClientMember, u64); // def sub on_group_member_online(group: u64) -> (ClientMember, u64); // def sub on_any_member_online() -> (ClientMember, u64); // def sub on_group_member_joined(group: u64) -> (ClientMember, u64); // def sub on_any_member_joined() -> (ClientMember, u64); // def sub on_group_member_left(group: u64) -> (ClientMember, u64); // def sub on_any_member_left() -> (ClientMember, u64); // def sub on_group_leader_changed(group: u64) -> (Option, Option, u64); } } // The service only responsible for receiving heartbeat and // Updating last updated time // Expired update time will trigger timeout in the raft state machine mod heartbeat_rpc { service! { rpc ping(id: u64); } } #[cfg(test)] mod test { use crate::membership::client::ObserverClient; use crate::membership::member::MemberService; use crate::membership::server::Membership; use crate::raft::client::RaftClient; use crate::raft::{Options, RaftService, Storage, DEFAULT_SERVICE_ID}; use crate::rpc::Server; use crate::utils::time::async_wait_secs; use futures::prelude::*; use std::sync::atomic::*; use std::sync::Arc; #[tokio::test(flavor = "multi_thread")] async fn primary() { let _ = env_logger::builder().format_timestamp(None).try_init(); let addr = String::from("127.0.0.1:2100"); let raft_service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); info!("Creating server"); let server = Server::new(&addr); info!("Register service"); server.register_service(&raft_service).await; info!("Server listen and resume"); Server::listen_and_resume(&server).await; info!("Start raft service"); RaftService::start(&raft_service, false).await; info!("Bootstrap raft service"); raft_service.bootstrap().await; info!("Creating membership service"); Membership::new(&server, &raft_service).await; let group_1 = String::from("test_group_1"); let group_2 = String::from("test_group_2"); let group_3 = String::from("test_group_3"); info!("Creating raft client"); let wild_raft_client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); info!("Create observer"); let client = ObserverClient::new(&wild_raft_client); info!("Prepare subscription"); RaftClient::prepare_subscription(&server).await; info!("Creating new group: {}", group_1); client.new_group(&group_1).await.unwrap().unwrap(); info!("Creating new group {}", group_2); client.new_group(&group_2).await.unwrap().unwrap(); info!("Creating new group {}", group_3); client.new_group(&group_3).await.unwrap().unwrap(); let any_member_joined_count = Arc::new(AtomicUsize::new(0)); let any_member_left_count = Arc::new(AtomicUsize::new(0)); let any_member_offline_count = Arc::new(AtomicUsize::new(0)); let any_member_online_count = Arc::new(AtomicUsize::new(0)); let group_leader_changed_count = Arc::new(AtomicUsize::new(0)); let group_member_joined_count = Arc::new(AtomicUsize::new(0)); let group_member_left_count = Arc::new(AtomicUsize::new(0)); let group_member_online_count = Arc::new(AtomicUsize::new(0)); let group_member_offline_count = Arc::new(AtomicUsize::new(0)); let any_member_joined_count_clone = any_member_joined_count.clone(); let any_member_left_count_clone = any_member_left_count.clone(); let any_member_offline_count_clone = any_member_offline_count.clone(); let any_member_online_count_clone = any_member_online_count.clone(); let group_leader_changed_count_clone = group_leader_changed_count.clone(); let group_member_joined_count_clone = group_member_joined_count.clone(); let group_member_left_count_clone = group_member_left_count.clone(); let group_member_online_count_clone = group_member_online_count.clone(); let group_member_offline_count_clone = group_member_offline_count.clone(); info!("Subscribe on_any_member_joined"); client .on_any_member_joined(move |_| { any_member_joined_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }) .await .unwrap() .unwrap(); info!("Subscribe on_any_member_left"); client .on_any_member_left(move |_| { any_member_left_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }) .await .unwrap() .unwrap(); info!("Subscribe on_any_member_offline"); client .on_any_member_offline(move |_| { any_member_offline_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }) .await .unwrap() .unwrap(); info!("Subscribe on_any_member_online"); client .on_any_member_online(move |_| { any_member_online_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }) .await .unwrap() .unwrap(); info!("Subscribe on_group_leader_changed"); client .on_group_leader_changed( move |_| { group_leader_changed_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }, &group_1, ) .await .unwrap() .unwrap(); info!("Subscribe on_group_member_joined"); client .on_group_member_joined( move |_| { group_member_joined_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }, &group_1, ) .await .unwrap() .unwrap(); info!("Subscribe on_group_member_left"); client .on_group_member_left( move |_| { group_member_left_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }, &group_1, ) .await .unwrap() .unwrap(); info!("Subscribe on_group_member_online"); client .on_group_member_online( move |_| { group_member_online_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }, &group_1, ) .await .unwrap() .unwrap(); info!("Subscribe on_group_member_offline"); client .on_group_member_offline( move |_| { group_member_offline_count_clone.fetch_add(1, Ordering::Relaxed); future::ready(()).boxed() }, &group_1, ) .await .unwrap() .unwrap(); info!("New member1_raft_client"); let member1_raft_client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let member1_addr = String::from("server1"); info!("New member service {}", member1_addr); let member1_svr = MemberService::new(&member1_addr, &member1_raft_client, &raft_service).await; info!("New member2_raft_client"); let member2_raft_client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let member2_addr = String::from("server2"); info!("New member service {}", member2_addr); let member2_svr = MemberService::new(&member2_addr, &member2_raft_client, &raft_service).await; info!("New member3_raft_client"); let member3_raft_client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let member3_addr = String::from("server3"); info!("New member service {}", member3_addr); let member3_svr = MemberService::new(&member3_addr, &member3_raft_client, &raft_service).await; info!("Member 1 join group 1"); member1_svr.join_group(&group_1).await.unwrap(); info!("Member 2 join group 1"); member2_svr.join_group(&group_1).await.unwrap(); info!("Member 3 join group 1"); member3_svr.join_group(&group_1).await.unwrap(); info!("Member 1 join group 2"); member1_svr.join_group(&group_2).await.unwrap(); info!("Member 2 join group 2"); member2_svr.join_group(&group_2).await.unwrap(); info!("Member 1 join group 3"); member1_svr.join_group(&group_3).await.unwrap(); info!("Checking group members after join"); assert_eq!( member1_svr .client() .all_members(false) .await .unwrap() .0 .len(), 3 ); assert_eq!( member1_svr .client() .all_members(true) .await .unwrap() .0 .len(), 3 ); assert_eq!( member1_svr .client() .group_members(&group_1, false) .await .unwrap() .unwrap() .0 .len(), 3 ); assert_eq!( member1_svr .client() .group_members(&group_1, true) .await .unwrap() .unwrap() .0 .len(), 3 ); assert_eq!( member1_svr .client() .group_members(&group_2, false) .await .unwrap() .unwrap() .0 .len(), 2 ); assert_eq!( member1_svr .client() .group_members(&group_2, true) .await .unwrap() .unwrap() .0 .len(), 2 ); assert_eq!( member1_svr .client() .group_members(&group_3, false) .await .unwrap() .unwrap() .0 .len(), 1 ); assert_eq!( member1_svr .client() .group_members(&group_3, true) .await .unwrap() .unwrap() .0 .len(), 1 ); member1_svr.close(); // close only end the heartbeat thread info!("############### Waiting for membership changes ###############"); for i in 0..10 { async_wait_secs().await; } info!("*************** Checking members ***************"); assert_eq!( member1_svr .client() .all_members(false) .await .unwrap() .0 .len(), 3 ); assert_eq!( member1_svr .client() .all_members(true) .await .unwrap() .0 .len(), 2 ); assert_eq!( member1_svr .client() .group_members(&group_1, false) .await .unwrap() .unwrap() .0 .len(), 3 ); assert_eq!( member1_svr .client() .group_members(&group_1, true) .await .unwrap() .unwrap() .0 .len(), 2 ); assert_eq!( member1_svr .client() .group_members(&group_2, false) .await .unwrap() .unwrap() .0 .len(), 2 ); assert_eq!( member1_svr .client() .group_members(&group_2, true) .await .unwrap() .unwrap() .0 .len(), 1 ); assert_eq!( member1_svr .client() .group_members(&group_3, false) .await .unwrap() .unwrap() .0 .len(), 1 ); assert_eq!( member1_svr .client() .group_members(&group_3, true) .await .unwrap() .unwrap() .0 .len(), 0 ); member2_svr.leave().await.unwrap(); // leave will report to the raft servers to remove it from the list assert_eq!( member1_svr .client() .all_members(false) .await .unwrap() .0 .len(), 2 ); assert_eq!( member1_svr .client() .all_members(true) .await .unwrap() .0 .len(), 1 ); assert_eq!( member1_svr .client() .group_members(&group_1, false) .await .unwrap() .unwrap() .0 .len(), 2 ); assert_eq!( member1_svr .client() .group_members(&group_1, true) .await .unwrap() .unwrap() .0 .len(), 1 ); assert_eq!( member1_svr .client() .group_members(&group_2, false) .await .unwrap() .unwrap() .0 .len(), 1 ); assert_eq!( member1_svr .client() .group_members(&group_2, true) .await .unwrap() .unwrap() .0 .len(), 0 ); assert_eq!( member1_svr .client() .group_members(&group_3, false) .await .unwrap() .unwrap() .0 .len(), 1 ); assert_eq!( member1_svr .client() .group_members(&group_3, true) .await .unwrap() .unwrap() .0 .len(), 0 ); async_wait_secs().await; info!("=========== Checking event trigger ==========="); assert_eq!(any_member_joined_count.load(Ordering::Relaxed), 3); assert_eq!(any_member_left_count.load(Ordering::Relaxed), 1); assert_eq!(any_member_offline_count.load(Ordering::Relaxed), 1); assert_eq!(any_member_online_count.load(Ordering::Relaxed), 0); // no server online from offline assert!(group_leader_changed_count.load(Ordering::Relaxed) > 0); // Number depends on hashing assert_eq!(group_member_joined_count.load(Ordering::Relaxed), 3); // assert_eq!(group_member_left_count.load(Ordering::Relaxed), 2); // this test case is unstable assert_eq!(group_member_online_count.load(Ordering::Relaxed), 0); assert_eq!(group_member_offline_count.load(Ordering::Relaxed), 1); } } ================================================ FILE: src/membership/server.rs ================================================ use super::heartbeat_rpc::*; use super::raft::*; use super::*; use crate::membership::client::Member as ClientMember; use crate::raft::state_machine::callback::server::{notify as cb_notify, SMCallback}; use crate::raft::state_machine::StateMachineCtl; use crate::raft::{LogEntry, PlaneId, RaftMsg, RaftService, Service as raft_svr_trait}; use crate::rpc::Server; use crate::utils::time; use crate::utils::time::get_time; use bifrost_hasher::hash_str; use futures::prelude::future::*; use futures::prelude::*; use futures::stream::FuturesUnordered; use lightning::map::Map; use lightning::map::PtrHashMap; use serde::Deserialize; use serde::Serialize; use std::collections::BTreeMap; use std::collections::{BTreeSet, HashSet}; use std::future::Future; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time as std_time; use tokio::time as async_time; static MAX_TIMEOUT: i64 = 10_000; // 10 secs timeout before considering a member potentially offline static OFFLINE_GRACE_CHECKS: u8 = 3; // Number of consecutive timeout checks before marking offline static ONLINE_GRACE_CHECKS: u8 = 2; // Number of consecutive successful checks before marking back online static MIN_STATE_CHANGE_INTERVAL: i64 = 5_000; // Minimum 5 seconds between state changes (anti-flapping) #[derive(Clone, Copy)] struct HBStatus { last_updated: i64, online: bool, consecutive_failures: u8, // Count of consecutive timeout checks while supposedly online consecutive_successes: u8, // Count of consecutive successful checks while supposedly offline last_state_change: i64, // Timestamp of last online/offline state change } pub struct HeartbeatService { status: PtrHashMap, raft_service: Arc, closed: AtomicBool, was_leader: AtomicBool, watcher_handle: std::sync::Mutex>>, } impl Service for HeartbeatService { fn ping(&self, id: u64) -> BoxFuture<()> { async move { let current_time = time::get_time(); // Update existing status or create new one let old_status = self.status.get(&id); let new_status = if let Some(mut status) = old_status { let elapsed = current_time - status.last_updated; status.last_updated = current_time; // Reset failure counter on successful ping, but keep other fields status.consecutive_failures = 0; if !status.online { // Member is recovering, increment success counter status.consecutive_successes += 1; } trace!("Updated heartbeat for member {}, elapsed {}ms", id, elapsed); status } else { // First time seeing this member trace!("First heartbeat from member {}", id); HBStatus { online: true, last_updated: current_time, consecutive_failures: 0, consecutive_successes: 0, last_state_change: current_time, } }; self.status.insert(id, new_status); } .boxed() } } impl HeartbeatService { async fn update_raft(&self, online: &Vec, offline: &Vec) { let log = commands::hb_online_changed::new(online, offline); // Encode to state machine command let (fn_id, _, data) = log.encode(); self.raft_service .c_command( PlaneId::type1(), LogEntry { id: 0, term: 0, sm_id: DEFAULT_SERVICE_ID, fn_id, data, }, ) .await; } async fn transfer_leadership(&self) { //update timestamp for every alive server to give them a grace period let all_entries = self.status.entries(); let current_time = get_time(); let mut online_count = 0; for (id, mut stat) in all_entries { if stat.online { stat.last_updated = current_time; // Reset counters to give all members a fresh start under new leader stat.consecutive_failures = 0; stat.consecutive_successes = 0; self.status.insert(id, stat); online_count += 1; } } info!( "Leadership transferred, reset heartbeat status for {} online members", online_count ); } pub async fn shutdown(&self) { info!("Shutting down heartbeat service"); self.closed.store(true, Ordering::Relaxed); // Wait for the watcher task to complete match self.watcher_handle.lock() { Ok(mut guard) => { if let Some(handle) = guard.take() { let _ = handle.await; } } Err(e) => { error!( "Failed to acquire watcher handle lock during shutdown: {}", e ); } } } } dispatch_rpc_service_functions!(HeartbeatService); service_with_id!(HeartbeatService, DEFAULT_SERVICE_ID); #[derive(Debug)] struct Member { pub address: String, pub groups: HashSet, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MemberGroup { members: BTreeSet, leader: Option, name: String, } /// Membership service manages member groups and heartbeat status. /// /// IMPORTANT: This service does NOT persist its state to disk. On each restart, /// it starts with empty state and rebuilds membership through: /// 1. Members sending join() commands /// 2. Heartbeat ping() messages updating online/offline status /// 3. Group membership operations (join_group, leave_group, etc.) /// /// This design ensures membership always reflects current network reality, /// not stale persisted state that might be outdated after crashes. pub struct Membership { heartbeat: Arc, groups: BTreeMap, members: BTreeMap, callback: Option, version: u64, } impl Drop for Membership { fn drop(&mut self) { self.heartbeat.closed.store(true, Ordering::Relaxed) } } impl Membership { /// Creates a new Membership service with fresh, empty state. /// /// The service will discover members through: /// - join() commands from members joining the cluster /// - ping() heartbeats indicating member liveness /// - join_group/leave_group commands for group management /// /// No state is recovered from disk - all membership is learned from the network. pub async fn new(server: &Arc, raft_service: &Arc) { let service = Arc::new(HeartbeatService { status: PtrHashMap::with_capacity(32), closed: AtomicBool::new(false), raft_service: raft_service.clone(), was_leader: AtomicBool::new(false), watcher_handle: std::sync::Mutex::new(None), }); let service_clone = service.clone(); let service_for_task = service.clone(); let handle = raft_service.rt.spawn(async move { info!("Starting membership heartbeat watcher (fresh state, learning from network)"); while !service_for_task.closed.load(Ordering::Relaxed) { let service = &service_for_task; let start_time = get_time(); let is_leader = service.raft_service.is_leader(); let was_leader = service.was_leader.load(Ordering::Relaxed); if !was_leader && is_leader { // Transferred leader will skip checking all member timeout for once service.transfer_leadership().await } if was_leader != is_leader { service.was_leader.store(is_leader, Ordering::Relaxed); } if is_leader { trace!("Resync Membership as leader id {}", service.raft_service.id); let mut outdated_members: Vec = Vec::new(); let mut back_in_members: Vec = Vec::new(); { let all_entries = service.status.entries(); let mut members_to_update = vec![]; for (id, mut status) in all_entries { let last_updated = status.last_updated; let alive = (start_time < last_updated) || ((start_time - last_updated) < MAX_TIMEOUT); let time_since_last_change = start_time - status.last_state_change; // Finding new offline servers (with grace period) if status.online && !alive { status.consecutive_failures += 1; status.consecutive_successes = 0; // Only mark offline after multiple consecutive failures AND minimum interval if status.consecutive_failures >= OFFLINE_GRACE_CHECKS && time_since_last_change >= MIN_STATE_CHANGE_INTERVAL { warn!( "Marking member {} as offline after {} consecutive timeout checks ({}ms since last update)", id, status.consecutive_failures, start_time - last_updated ); status.online = false; status.last_state_change = start_time; status.consecutive_failures = 0; outdated_members.push(id); } else { debug!( "Member {} timeout check {}/{} ({}ms since last update, {}ms since last state change)", id, status.consecutive_failures, OFFLINE_GRACE_CHECKS, start_time - last_updated, time_since_last_change ); } members_to_update.push((id, status)); } // Finding new online servers (with grace period) else if !status.online && alive { status.consecutive_successes += 1; status.consecutive_failures = 0; // Only mark online after multiple consecutive successes AND minimum interval if status.consecutive_successes >= ONLINE_GRACE_CHECKS && time_since_last_change >= MIN_STATE_CHANGE_INTERVAL { info!( "Marking member {} as back online after {} consecutive successful checks", id, status.consecutive_successes ); status.online = true; status.last_state_change = start_time; status.consecutive_successes = 0; back_in_members.push(id); } else { debug!( "Member {} recovery check {}/{} ({}ms since last state change)", id, status.consecutive_successes, ONLINE_GRACE_CHECKS, time_since_last_change ); } members_to_update.push((id, status)); } // Member is consistently online or offline else if alive { // Member is online and responsive - reset counters if status.consecutive_failures > 0 || status.consecutive_successes > 0 { status.consecutive_failures = 0; status.consecutive_successes = 0; members_to_update.push((id, status)); } } } for (id, s) in members_to_update { service.status.insert(id, s); } } if back_in_members.len() + outdated_members.len() > 0 { debug!( "Update member state machine for {} online, {} offline", back_in_members.len(), outdated_members.len() ); service .update_raft(&back_in_members, &outdated_members) .await; } } let end_time = get_time(); let time_took = end_time - start_time; let interval = 500; // in ms if time_took < interval { let time_to_wait = interval - time_took; trace!( "Membership resync completed, waiting for {}ms for next resync", time_to_wait ); async_time::sleep(std_time::Duration::from_millis(time_to_wait as u64)).await } else { trace!( "Membership resync completed, left behine {}ms for next resync", time_took - interval ); } } info!("Membership heartbeat watcher stopped gracefully"); }); // Store the handle for graceful shutdown match service.watcher_handle.lock() { Ok(mut guard) => { *guard = Some(handle); } Err(e) => { error!( "Failed to acquire watcher handle lock during initialization: {}", e ); } } // Create membership service with EMPTY state. // It will learn all membership from the network through: // 1. join() commands as members join // 2. ping() heartbeats for liveness tracking // 3. Group operations (join_group, leave_group, etc.) let mut membership_service = Membership { heartbeat: service_clone.clone(), groups: BTreeMap::new(), // Empty groups - will be populated as groups are created members: BTreeMap::new(), // Empty members - will be populated as members join callback: None, version: 0, // Version starts at 0 }; membership_service.init_callback(raft_service).await; raft_service .register_state_machine(Box::new(membership_service)) .await; server.register_service(&service_clone).await; } async fn compose_client_member(&self, id: u64) -> Option { let member = self.members.get(&id)?; let status = self.heartbeat.status.get(&id)?; Some(ClientMember { id, address: member.address.clone(), online: status.online, }) } async fn init_callback(&mut self, raft_service: &Arc) { self.callback = Some(SMCallback::new(self.id(), raft_service.clone()).await); } async fn notify_for_member_online(&self, id: u64) { debug!("Notifying member {} online", id); let client_member = match self.compose_client_member(id).await { Some(member) => member, None => { error!( "Failed to compose client member {} for online notification", id ); return; } }; let version = self.version; cb_notify( &self.callback, commands::on_any_member_online::new(), || (client_member.clone(), version), ) .await; if let Some(ref member) = self.members.get(&id) { for group in &member.groups { cb_notify( &self.callback, commands::on_group_member_online::new(group), || (client_member.clone(), version), ) .await; } } } async fn notify_for_member_offline(&self, id: u64) { debug!("Notifying member {} offline", id); let client_member = match self.compose_client_member(id).await { Some(member) => member, None => { error!( "Failed to compose client member {} for offline notification", id ); return; } }; let version = self.version; cb_notify( &self.callback, commands::on_any_member_offline::new(), || (client_member.clone(), version), ) .await; if let Some(ref member) = self.members.get(&id) { for group in &member.groups { cb_notify( &self.callback, commands::on_group_member_offline::new(group), || (client_member.clone(), version), ) .await; } } } async fn notify_for_member_left(&self, id: u64) { debug!("Notifying member {} left", id); let client_member = match self.compose_client_member(id).await { Some(member) => member, None => { error!( "Failed to compose client member {} for left notification", id ); return; } }; let version = self.version; cb_notify(&self.callback, commands::on_any_member_left::new(), || { (client_member.clone(), version) }) .await; if let Some(ref member) = self.members.get(&id) { for group in &member.groups { self.notify_for_group_member_left(*group, &client_member) .await } } } async fn notify_for_group_member_left(&self, group: u64, member: &ClientMember) { debug!("Notifying member {:?} left group {}", member, group); cb_notify( &self.callback, commands::on_group_member_left::new(&group), || (member.clone(), self.version), ) .await; } async fn leave_group_(&mut self, group_id: u64, id: u64, need_notify: bool) -> bool { let mut success = false; if let Some(ref mut group) = self.groups.get_mut(&group_id) { if let Some(ref mut member) = self.members.get_mut(&id) { group.members.remove(&id); member.groups.remove(&group_id); success = true; } } if success { if need_notify { if let Some(client_member) = self.compose_client_member(id).await { self.notify_for_group_member_left(group_id, &client_member) .await; } else { error!( "Failed to compose client member {} for group {} leave notification", id, group_id ); } } self.group_leader_candidate_unavailable(group_id, id).await; true } else { false } } fn member_groups(&self, member: u64) -> Option> { if let Some(member) = self.members.get(&member) { Some(member.groups.clone()) } else { None } } async fn group_first_online_member_id(&self, group: u64) -> Result, ()> { if let Some(group) = self.groups.get(&group) { for member in group.members.iter() { if let Some(member_stat) = self.heartbeat.status.get(&member) { if member_stat.online { return Ok(Some(*member)); } } } Ok(None) } else { Err(()) } } async fn change_leader(&mut self, group_id: u64, new: Option) -> Result<(), ()> { let mut old: Option = None; let mut changed = false; if let Some(group) = self.groups.get_mut(&group_id) { old = group.leader; if old != new { group.leader = new; changed = true; } } if changed { let version = self.version; let old_leader = if let Some(id_opt) = old { self.compose_client_member(id_opt).await } else { None }; let new_leader = if let Some(id_opt) = new { self.compose_client_member(id_opt).await } else { None }; cb_notify( &self.callback, commands::on_group_leader_changed::new(&group_id), move || (old_leader, new_leader, version), ) .await; Ok(()) } else { Err(()) } } async fn group_leader_candidate_available(&mut self, group_id: u64, member: u64) { // if the group does not have a leader, assign the available member let mut leader_changed = false; if let Some(group) = self.groups.get_mut(&group_id) { if group.leader == None { leader_changed = true; } } if leader_changed { if let Err(_) = self.change_leader(group_id, Some(member)).await { error!( "Failed to change leader for group {} to member {}", group_id, member ); } } } async fn group_leader_candidate_unavailable(&mut self, group_id: u64, member: u64) { // if the group have a leader that is the same as the member, reelect let mut reelected = false; if let Some(group) = self.groups.get_mut(&group_id) { if group.leader == Some(member) { reelected = true; } } if reelected { match self.group_first_online_member_id(group_id).await { Ok(online_id) => { if let Err(_) = self.change_leader(group_id, online_id).await { error!("Failed to change leader for group {} after member {} became unavailable", group_id, member); } } Err(_) => { error!("Failed to find online member for group {} after member {} became unavailable", group_id, member); } } } } async fn leader_candidate_available(&mut self, member: u64) { if let Some(groups) = self.member_groups(member) { for group in groups { self.group_leader_candidate_available(group, member).await } } } async fn leader_candidate_unavailable(&mut self, member: u64) { if let Some(groups) = self.member_groups(member) { for group in groups { self.group_leader_candidate_unavailable(group, member).await } } } } impl StateMachineCmds for Membership { fn hb_online_changed(&mut self, online: Vec, offline: Vec) -> BoxFuture<()> { debug!( "Member status changed, back online {}, gone offline {}", online.len(), offline.len() ); async move { self.version += 1; let current_time = time::get_time(); { for id in &online { if let Some(mut stat) = self.heartbeat.status.get(&id) { stat.online = true; stat.last_state_change = current_time; // Reset counters after state change is confirmed stat.consecutive_failures = 0; stat.consecutive_successes = 0; self.heartbeat.status.insert(*id, stat); } } for id in &offline { if let Some(mut stat) = self.heartbeat.status.get(&id) { stat.online = false; stat.last_state_change = current_time; // Reset counters after state change is confirmed stat.consecutive_failures = 0; stat.consecutive_successes = 0; self.heartbeat.status.insert(*id, stat); } } } for id in online { self.notify_for_member_online(id).await; self.leader_candidate_available(id).await; } for id in offline { self.notify_for_member_offline(id).await; self.leader_candidate_unavailable(id).await; } } .boxed() } fn join(&mut self, address: String) -> BoxFuture> { async move { self.version += 1; let id = hash_str(&address); let mut joined = false; { let current_time = time::get_time(); self.members.entry(id).or_insert_with(|| { joined = true; Member { address: address.clone(), groups: HashSet::new(), } }); self.heartbeat.status.insert( id, HBStatus { last_updated: current_time, online: true, consecutive_failures: 0, consecutive_successes: 0, last_state_change: current_time, }, ); } if joined { match self.compose_client_member(id).await { Some(composed_client_member) => { cb_notify( &self.callback, commands::on_any_member_joined::new(), || (composed_client_member, self.version), ) .await; Some(id) } None => { error!("Failed to compose client member {} after join", id); None } } } else { None } } .boxed() } fn leave(&mut self, id: u64) -> BoxFuture { async move { if !self.members.contains_key(&id) { return false; }; self.version += 1; let mut groups: Vec = Vec::new(); if let Some(member) = self.members.get(&id) { for group in &member.groups { groups.push(*group); } } self.notify_for_member_left(id).await; for group_id in groups { self.leave_group_(group_id, id, false).await; } // in this part we will not do leader_candidate_unavailable // because it have already been triggered by leave_group_ // in the loop above self.heartbeat.status.remove(&id); self.members.remove(&id); true } .boxed() } fn join_group(&mut self, group_name: String, id: u64) -> BoxFuture { async move { let group_id = hash_str(&group_name); self.version += 1; let mut success = false; if !self.groups.contains_key(&group_id) { if let Err(existing_id) = self.new_group(group_name.clone()).await { debug!( "Group {} already exists with id {}", group_name, existing_id ); } } // create group if not exists if let Some(ref mut group) = self.groups.get_mut(&group_id) { if let Some(ref mut member) = self.members.get_mut(&id) { group.members.insert(id); member.groups.insert(group_id); success = true; } } if success { match self.compose_client_member(id).await { Some(composed_member) => { cb_notify( &self.callback, commands::on_group_member_joined::new(&group_id), || (composed_member, self.version), ) .await; self.group_leader_candidate_available(group_id, id).await; true } None => { error!( "Failed to compose client member {} for group {} join notification", id, group_id ); false } } } else { false } } .boxed() } fn leave_group(&mut self, group_id: u64, id: u64) -> BoxFuture { async move { self.version += 1; self.leave_group_(group_id, id, true).await } .boxed() } fn new_group(&mut self, name: String) -> BoxFuture> { async move { self.version += 1; let id = hash_str(&name); let mut inserted = false; self.groups.entry(id).or_insert_with(|| { inserted = true; MemberGroup { members: BTreeSet::new(), leader: None, name: name.clone(), } }); if inserted { Ok(id) } else { Err(id) } } .boxed() } fn del_group(&mut self, id: u64) -> BoxFuture { async move { self.version += 1; let mut members: Option> = None; if let Some(group) = self.groups.get(&id) { members = Some(group.members.clone()); } if let Some(members) = members { for member_id in members { if let Some(ref mut member) = self.members.get_mut(&member_id) { member.groups.remove(&id); } } self.groups.remove(&id); true } else { false } } .boxed() } fn group_leader(&self, group_id: u64) -> BoxFuture, u64)>> { async move { if let Some(group) = self.groups.get(&group_id) { Some(( match group.leader { Some(id) => self.compose_client_member(id).await, None => None, }, self.version, )) } else { None } } .boxed() } fn group_members( &self, group: u64, online_only: bool, ) -> BoxFuture, u64)>> { async move { if let Some(group) = self.groups.get(&group) { let futs: FuturesUnordered<_> = group .members .iter() .map(|id| self.compose_client_member(*id)) .collect(); let members: Vec<_> = futs.collect().await; Some(( members .into_iter() .filter_map(|member| member) .filter(|member| !online_only || member.online) .collect(), self.version, )) } else { None } } .boxed() } fn all_members(&self, online_only: bool) -> BoxFuture<(Vec, u64)> { async move { let futs: FuturesUnordered<_> = self .members .iter() .map(|(id, _)| self.compose_client_member(*id)) .collect(); let members: Vec<_> = futs.collect().await; ( members .into_iter() .filter_map(|member| member) .filter(|member| !online_only || member.online) .collect(), self.version, ) } .boxed() } fn all_groups(&self) -> BoxFuture> { future::ready(self.groups.clone()).boxed() } } impl StateMachineCtl for Membership { raft_sm_complete!(); fn id(&self) -> u64 { DEFAULT_SERVICE_ID } fn snapshot(&self) -> Vec { // Membership service intentionally does NOT persist its state. // It starts fresh on each restart and learns membership from the network // via heartbeats and join/leave commands. // This ensures membership reflects current network reality, not stale disk state. unreachable!() } fn recover(&mut self, _: Vec) -> BoxFuture<()> { // Membership service does not recover from snapshots. // It rebuilds its state from network discovery and heartbeats. future::ready(()).boxed() } fn recoverable(&self) -> bool { false } } ================================================ FILE: src/plugins/Cargo.toml ================================================ [package] name = "bifrost_plugins" version = "0.1.0" authors = ["Hao Shi "] [lib] proc-macro = true [dependencies] bifrost_hasher = { path = "../hasher" } syn = "2" quote = "1" ================================================ FILE: src/plugins/src/lib.rs ================================================ extern crate bifrost_hasher; extern crate proc_macro; extern crate syn; use bifrost_hasher::hash_str; use proc_macro::TokenStream; use proc_macro::TokenTree; use syn::{parse_macro_input, LitStr}; #[proc_macro] pub fn hash_ident(item: TokenStream) -> TokenStream { let item_clone = item.clone(); let tokens: Vec<_> = item.into_iter().collect(); if tokens.len() != 1 { panic!( "argument should be a single identifier, but got {} arguments {:?}", tokens.len(), tokens ); } let text = match tokens[0] { TokenTree::Ident(ref ident) => ident.to_string(), _ => parse_macro_input!(item_clone as LitStr).value(), // _ => panic!("argument only support ident or string literal, found '{:?}', parsing {:?}", tokens, tokens[0]) }; let text = &*text; let str = String::from(text); format!("{}", hash_str(&str)).parse().unwrap() } ================================================ FILE: src/proc_macro/Cargo.toml ================================================ [package] name = "bifrost_proc_macro" version = "0.1.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] proc-macro = true [dependencies] syn = { version = "*", features = ["extra-traits"] } quote = "*" proc-macro2 = "*" ================================================ FILE: src/proc_macro/src/lib.rs ================================================ use proc_macro::TokenStream; use quote::quote; use syn::{ parse::{Parse, ParseBuffer, ParseStream}, parse_macro_input, punctuated::Punctuated, FnArg, Ident, ItemTrait, Lifetime, Pat, PatType, Result, Token, TraitItem, TraitItemFn, Type, TypeReference, TypeTuple, }; struct Args { args: Punctuated, } impl Parse for Args { fn parse(input: ParseStream) -> Result { let args = Punctuated::parse_terminated(input)?; Ok(Args { args }) } } #[proc_macro] pub fn adjust_caller_identifiers(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as Args); let output = input .args .into_iter() .map(|arg| match arg { FnArg::Typed(pat_type) => { let pat = &*pat_type.pat; let ty = &*pat_type.ty; match (&pat, ty) { (Pat::Ident(pat_ident), Type::Reference(_)) => { let ident = &pat_ident.ident; quote! { ref #ident } } (Pat::Ident(pat_ident), Type::Group(group)) => { let ident = &pat_ident.ident; if let Type::Reference(_) = &*group.elem { quote! { ref #ident } } else { quote! { #ident } } } (Pat::Ident(pat_ident), _) => { let ident = &pat_ident.ident; quote! { #ident } } _ => panic!("Unsupported pattern!"), } } _ => panic!("Variadic arguments are not supported!"), }) .collect::>(); quote! { ( #(#output),* ) } .into() } #[proc_macro] pub fn adjust_function_signature(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as TraitItemFn); let mut output_trait_fn = input.clone(); let sig = &mut output_trait_fn.sig; // eprintln!("Adjust {:?}", sig); for input in &mut sig.inputs { match input { FnArg::Typed(pat_type) => { // eprintln!("Checking lifetime {:?}", pat_type); match *pat_type.ty { Type::Reference(ref mut ref_type) => { if ref_type.lifetime.is_none() { ref_type.lifetime = Some(Lifetime::new("'a", proc_macro2::Span::call_site())); //eprintln!("Assigning lifetime {:?}", ref_type); } } Type::Group(ref mut group) => { if let Type::Reference(ref mut ref_type) = &mut *group.elem { if ref_type.lifetime.is_none() { ref_type.lifetime = Some(Lifetime::new("'a", proc_macro2::Span::call_site())); //eprintln!("Assigning lifetime {:?}", ref_type); } } } _ => {} } } FnArg::Receiver(ref mut receiver) => { if let &mut Some((_, ref mut lifetime)) = &mut receiver.reference { if lifetime.is_none() { *lifetime = Some(Lifetime::new("'a", proc_macro2::Span::call_site())); } } } _ => {} } } quote!(#output_trait_fn).into() } #[proc_macro] pub fn deref_tuple_types(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as TypeTuple); let transformed_types: Vec<_> = input .elems .into_iter() .map(|ty| match ty { Type::Reference(TypeReference { elem, .. }) => *elem, Type::Group(group) => { if let Type::Reference(TypeReference { elem, .. }) = *group.elem { *elem } else { Type::Group(group) } } other => other, }) .collect(); let tokens = quote! { (#(#transformed_types),*) }; tokens.into() } ================================================ FILE: src/raft/client.rs ================================================ use super::*; use crate::raft::state_machine::callback::client::SubscriptionService; use crate::raft::state_machine::callback::SubKey; use crate::raft::state_machine::configs::commands::{ del_member_ as conf_del_member, member_address as conf_member_address, new_member_ as conf_new_member, subscribe as conf_subscribe, unsubscribe as conf_unsubscribe, }; use crate::raft::state_machine::master::ExecError; use crate::raft::state_machine::StateMachineClient; use crate::rpc; use bifrost_hasher::{hash_bytes, hash_str}; use futures::future::BoxFuture; use std::clone::Clone; use std::cmp::max; use std::collections::{BTreeMap, HashMap, HashSet}; use std::iter::FromIterator; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; use tokio::time::sleep; const ORDERING: Ordering = Ordering::Relaxed; pub type Client = Arc; pub type SubscriptionReceipt = (SubKey, u64); lazy_static! { pub static ref CALLBACK: RwLock>> = RwLock::new(None); } #[derive(Debug)] pub enum ClientError { LeaderIdValid, ServerUnreachable, } impl std::fmt::Display for ClientError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ClientError::LeaderIdValid => write!(f, "leader id is invalid"), ClientError::ServerUnreachable => write!(f, "seed nodes are unreachable"), } } } impl std::error::Error for ClientError {} #[derive(Debug)] pub enum SubscriptionError { RemoteError, SubServiceNotSet, CannotFindSubId, } struct PlaneClientState { pos: AtomicU64, leader_id: AtomicU64, last_log_id: AtomicU64, last_log_term: AtomicU64, } struct Members { clients: BTreeMap, id_map: HashMap, } pub trait AsRaftPlaneClient: Send + Sync { fn as_raft_plane_client(self: &Arc) -> Arc; } #[derive(Clone)] pub struct RaftPlaneClient { client: Arc, plane_id: PlaneId, } impl RaftPlaneClient { pub fn plane_id(&self) -> PlaneId { self.plane_id } pub async fn execute(&self, sm_id: u64, msg: M) -> Result where R: 'static, M: RaftMsg + 'static, { self.client .execute_on_plane(self.plane_id, sm_id, msg) .await } pub async fn subscribe( &self, sm_id: u64, msg: M, f: F, ) -> Result, ExecError> where M: RaftMsg + 'static, R: 'static + Send, F: Fn(R) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.client .subscribe_on_plane(self.plane_id, sm_id, msg, f) .await } pub async fn unsubscribe( &self, receipt: SubscriptionReceipt, ) -> Result, ExecError> { self.client.unsubscribe(receipt).await } pub async fn cluster_info(&self) -> Result { self.client.cluster_info_on_plane(self.plane_id).await } pub async fn have_state_machine(&self, sm_id: u64) -> Result { self.client .have_state_machine_on_plane(self.plane_id, sm_id) .await } } pub struct RaftClient { members: RwLock, type1_state: Arc, plane_states: RwLock>>, service_id: u64, } impl AsRaftPlaneClient for RaftClient { fn as_raft_plane_client(self: &Arc) -> Arc { self.plane(PlaneId::type1()) } } impl AsRaftPlaneClient for RaftPlaneClient { fn as_raft_plane_client(self: &Arc) -> Arc { self.clone() } } impl RaftClient { fn new_plane_state() -> Arc { Arc::new(PlaneClientState { pos: AtomicU64::new(rand::random::()), leader_id: AtomicU64::new(0), last_log_id: AtomicU64::new(0), last_log_term: AtomicU64::new(0), }) } pub async fn new(servers: &Vec, service_id: u64) -> Result, ClientError> { let client = RaftClient { members: RwLock::new(Members { clients: BTreeMap::new(), id_map: HashMap::new(), }), type1_state: Self::new_plane_state(), plane_states: RwLock::new(HashMap::new()), service_id, }; client.update_info(servers).await?; Ok(Arc::new(client)) } async fn plane_state(&self, plane_id: PlaneId) -> Arc { if plane_id.is_type1() { return self.type1_state.clone(); } { let states = self.plane_states.read().await; if let Some(state) = states.get(&plane_id) { return state.clone(); } } let mut states = self.plane_states.write().await; states .entry(plane_id) .or_insert_with(Self::new_plane_state) .clone() } pub fn plane(self: &Arc, plane_id: PlaneId) -> Arc { Arc::new(RaftPlaneClient { client: self.clone(), plane_id, }) } pub fn type1(self: &Arc) -> Arc { self.plane(PlaneId::type1()) } pub async fn add_root_member(&self, address: &String) -> Result { self.execute(CONFIG_SM_ID, conf_new_member::new(address)) .await } pub async fn remove_root_member(&self, address: &String) -> Result<(), ExecError> { self.execute(CONFIG_SM_ID, conf_del_member::new(address)) .await } pub async fn root_member_addresses(&self) -> Result, ExecError> { self.execute(CONFIG_SM_ID, conf_member_address::new()).await } pub async fn prepare_subscription(server: &Arc) -> Option<()> { let mut callback = CALLBACK.write().await; return if callback.is_none() { let sub_service = SubscriptionService::initialize(&server).await; *callback = Some(sub_service.clone()); Some(()) } else { None }; } async fn cluster_info<'a>( &'a self, plane_id: PlaneId, servers: &Vec, ) -> Option { debug!( "Getting server info for plane {} from {:?}", plane_id.raw(), servers ); let mut attempt_remains: i32 = 10; loop { debug!( "Trying to get cluster info for plane {}, attempt from {:?}...{}", plane_id.raw(), servers, attempt_remains ); let mut futs: FuturesUnordered<_> = servers .iter() .map(|server_addr| { let id = hash_str(server_addr); let server_addr = server_addr.clone(); async move { let mut members = self.members.write().await; debug!( "Checking server info for plane {} on {}", plane_id.raw(), server_addr ); if !members.clients.contains_key(&id) { debug!( "Connecting to node {} for plane {}", server_addr, plane_id.raw() ); match rpc::DEFAULT_CLIENT_POOL.get(&server_addr).await { Ok(client) => { debug!( "Added server info on {} to members for plane {}", server_addr, plane_id.raw() ); members.clients.insert( id, AsyncServiceClient::new_with_service_id( self.service_id, &client, ), ); debug!( "Member {} added for plane {}", server_addr, plane_id.raw() ); } Err(e) => { warn!( "Cannot find server info for plane {} from {}, {}", plane_id.raw(), server_addr, e ); return None; } } } debug!( "Getting server info for plane {} from {}, id {}", plane_id.raw(), server_addr, id ); let member_client = match members.clients.get(&id) { Some(client) => client, None => { debug!( "Server not found for plane {}, skip {}, id {}", plane_id.raw(), server_addr, id ); return None; } }; debug!( "Invoking server_cluster_info for plane {} on {}, id {}", plane_id.raw(), server_addr, id ); let info_res = member_client.c_server_cluster_info(plane_id).await; debug!( "Checking cluster info response for plane {} from {}", plane_id.raw(), server_addr ); return match info_res { Ok(info) => { if info.leader_id != 0 { debug!( "Found server info for plane {} with leader id {}", plane_id.raw(), info.leader_id ); Some(info) } else { debug!( "Discovered zero leader id for plane {} from {}", plane_id.raw(), server_addr ); None } } Err(e) => { debug!( "Error on getting cluster info for plane {} from {}, {:?}", plane_id.raw(), server_addr, e ); None } }; } }) .collect(); while let Some(res) = futs.next().await { if let Some(info) = res { return Some(info); } } if attempt_remains > 0 { // We found an uninitialized node, should try again // Random sleep debug!( "Plane {} fail attempt had zero leader id, retry...{}", plane_id.raw(), attempt_remains ); let delay_sec = 1 + (rand::random::() % 9); sleep(Duration::from_secs(delay_sec)).await; attempt_remains -= 1; continue; } else { debug!( "Continuously getting zero leader id for plane {}, give up", plane_id.raw() ); break; } } warn!( "Cannot find anything useful for plane {} from list: {:?}", plane_id.raw(), servers ); return None; } async fn update_info(&self, servers: &Vec) -> Result<(), ClientError> { debug!( "Updating cluster info for plane {} from servers: {:?}", PlaneId::type1().raw(), servers ); let cluster_info = self.cluster_info(PlaneId::type1(), servers).await; match cluster_info { Some(info) => { let mut members = self.members.write().await; let remote_members = info.members; let mut remote_ids = HashSet::with_capacity(remote_members.len()); members.id_map.clear(); for (id, addr) in remote_members { members.id_map.insert(id, addr); remote_ids.insert(id); } let mut connected_ids = HashSet::with_capacity(members.clients.len()); for id in members.clients.keys() { connected_ids.insert(*id); } let ids_to_remove = connected_ids.difference(&remote_ids); for id in ids_to_remove { warn!( "Removed server with id {} while refreshing plane {}", id, PlaneId::type1().raw() ); members.clients.remove(id); } for id in remote_ids.difference(&connected_ids) { let addr = match members.id_map.get(id) { Some(addr) => addr.clone(), None => { error!( "Cannot find address for server id {} while refreshing plane {}", id, PlaneId::type1().raw() ); continue; } }; if !members.clients.contains_key(id) { if let Ok(client) = rpc::DEFAULT_CLIENT_POOL.get(&addr).await { info!( "Having new server addr {} id {} for plane {}", addr, id, PlaneId::type1().raw() ); members.clients.insert( *id, AsyncServiceClient::new_with_service_id(self.service_id, &client), ); } else { error!( "Cannot connect to new server addr {}, id {} for plane {}", addr, id, PlaneId::type1().raw() ); } } } info!( "UPDATE_INFO Setting plane {} leader to {}, was {}", PlaneId::type1().raw(), info.leader_id, self.type1_state.leader_id.load(Relaxed) ); self.type1_state.leader_id.store(info.leader_id, ORDERING); swap_when_greater(&self.type1_state.last_log_id, info.last_log_id); swap_when_greater(&self.type1_state.last_log_term, info.last_log_term); Ok(()) } None => { error!( "Cannot update info for plane {}, cannot get cluster info", PlaneId::type1().raw() ); Err(ClientError::ServerUnreachable) } } } async fn update_plane_info( &self, plane_id: PlaneId, servers: &Vec, ) -> Result<(), ClientError> { if plane_id.is_type1() { return self.update_info(servers).await; } let cluster_info = self.cluster_info(plane_id, servers).await; match cluster_info { Some(info) => { let state = self.plane_state(plane_id).await; info!( "UPDATE_INFO Setting plane {} leader to {}, was {}", plane_id.raw(), info.leader_id, state.leader_id.load(Relaxed) ); state.leader_id.store(info.leader_id, ORDERING); swap_when_greater(&state.last_log_id, info.last_log_id); swap_when_greater(&state.last_log_term, info.last_log_term); Ok(()) } None => { error!( "Cannot update info for plane {}, cannot get cluster info", plane_id.raw() ); Err(ClientError::ServerUnreachable) } } } pub async fn probe_servers( servers: &Vec, server_address: &String, service_id: u64, ) -> bool { servers .iter() .map(|peer_addr| { timeout(Duration::from_secs(2), async move { if peer_addr == server_address { // Should not include the server we are running return false; } match rpc::DEFAULT_CLIENT_POOL.get(peer_addr).await { Ok(client) => ImmeServiceClient::c_ping(service_id, &client).await.is_ok(), Err(_) => false, } }) }) .collect::>() .collect::>() .await .into_iter() .any(|r| match r { Ok(true) => true, _ => false, }) } pub async fn execute(&self, sm_id: u64, msg: M) -> Result where R: 'static, M: RaftMsg + 'static, { self.execute_on_plane(PlaneId::type1(), sm_id, msg).await } pub async fn execute_on_plane( &self, plane_id: PlaneId, sm_id: u64, msg: M, ) -> Result where R: 'static, M: RaftMsg + 'static, { let (fn_id, op, req_data) = msg.encode(); let response = match op { OpType::QUERY => self.query_on_plane(plane_id, sm_id, fn_id, req_data).await, OpType::COMMAND | OpType::SUBSCRIBE => { self.command_on_plane(plane_id, sm_id, fn_id, req_data) .await } }; match response { Ok(data) => match data { Ok(data) => Ok(M::decode_return(&data)), Err(e) => Err(e), }, Err(e) => Err(e), } } pub async fn can_callback() -> bool { CALLBACK.read().await.is_some() } fn get_sub_key(&self, plane_id: PlaneId, sm_id: u64, msg: M) -> SubKey where M: RaftMsg + 'static, R: 'static, { let raft_sid = self.service_id; let (fn_id, pattern_id) = { let (fn_id, _, pattern_data) = msg.encode(); (fn_id, hash_bytes(pattern_data.as_slice())) }; SubKey::new(raft_sid, plane_id, sm_id, fn_id, pattern_id) } pub async fn get_callback(&self) -> Result, SubscriptionError> { match CALLBACK.read().await.clone() { None => { debug!("Subscription service not set"); Err(SubscriptionError::SubServiceNotSet) } Some(c) => Ok(c), } } pub async fn subscribe( &self, sm_id: u64, msg: M, f: F, ) -> Result, ExecError> where M: RaftMsg + 'static, R: 'static + Send, F: Fn(R) -> BoxFuture<'static, ()> + 'static + Send + Sync, { self.subscribe_on_plane(PlaneId::type1(), sm_id, msg, f) .await } pub async fn subscribe_on_plane( &self, plane_id: PlaneId, sm_id: u64, msg: M, f: F, ) -> Result, ExecError> where M: RaftMsg + 'static, R: 'static + Send, F: Fn(R) -> BoxFuture<'static, ()> + 'static + Send + Sync, { let callback = match self.get_callback().await { Ok(c) => c, Err(e) => return Ok(Err(e)), }; let key = self.get_sub_key(plane_id, sm_id, msg); let wrapper_fn = move |data: Vec| -> BoxFuture<'static, ()> { f(M::decode_return(&data)).boxed() }; let cluster_subs = self .execute_on_plane( plane_id, CONFIG_SM_ID, conf_subscribe::new(&key, &callback.server_address, &callback.session_id), ) .await; match cluster_subs { Ok(Ok(sub_id)) => { let mut subs_map = callback.subs.write().await; let subs_lst = subs_map.entry(key).or_insert_with(|| Vec::new()); let boxed_fn = Box::new(wrapper_fn); subs_lst.push((boxed_fn, sub_id)); Ok(Ok((key, sub_id))) } Ok(Err(_)) => Ok(Err(SubscriptionError::RemoteError)), Err(e) => Err(e), } } pub async fn unsubscribe( &self, receipt: SubscriptionReceipt, ) -> Result, ExecError> { match self.get_callback().await { Ok(callback) => { let (key, sub_id) = receipt; let unsub = self .execute_on_plane(key.plane_id, CONFIG_SM_ID, conf_unsubscribe::new(&sub_id)) .await; match unsub { Ok(_) => { let mut subs_map = callback.subs.write().await; let subs_lst = subs_map.entry(key).or_insert_with(|| Vec::new()); let mut sub_index = 0; for i in 0..subs_lst.len() { if subs_lst[i].1 == sub_id { sub_index = i; break; } } if subs_lst.len() > 0 && subs_lst[sub_index].1 == sub_id { let _ = subs_lst.remove(sub_index); Ok(Ok(())) } else { Ok(Err(SubscriptionError::CannotFindSubId)) } } Err(e) => Err(e), } } Err(e) => { debug!("Subscription service not set"); return Ok(Err(e)); } } } async fn query_on_plane( &self, plane_id: PlaneId, sm_id: u64, fn_id: u64, data: Vec, ) -> Result { let state = self.plane_state(plane_id).await; let mut depth = 0; loop { if depth == 0 { trace!( "Raft client query plane_id={} sm_id {}, fn_id {}", plane_id.raw(), sm_id, fn_id ); } else { warn!( "Retry client query plane_id={} sm_id {}, fn_id {}", plane_id.raw(), sm_id, fn_id ); } let pos = state.pos.fetch_add(1, ORDERING); let members = self.members.read().await; let num_members = members.clients.len(); if num_members >= 1 { let node_index = pos as usize % num_members; let rpc_client = match members.clients.values().nth(node_index) { Some(client) => client, None => { error!( "Cannot find client for plane {} at index {} (total: {})", plane_id.raw(), node_index, num_members ); return Err(ExecError::ServersUnreachable); } }; trace!( "Query for plane {} from node {} for sm_id {}, fn_id {}", plane_id.raw(), node_index, sm_id, fn_id ); let res = rpc_client .c_query(plane_id, &self.gen_log_entry(&state, sm_id, fn_id, &data)) .await; trace!( "Query for plane {} from node {} for sm_id {}, fn_id {} completed", plane_id.raw(), node_index, sm_id, fn_id ); match res { Ok(res) => match res { ClientQryResponse::LeftBehind { last_log_term, last_log_id, } => { debug!("Found left behind record on plane {}...{}, updating client state: server has log_id={}, term={}", plane_id.raw(), depth, last_log_id, last_log_term); // Update client state from server to avoid retry loop swap_when_greater(&state.last_log_id, last_log_id); swap_when_greater(&state.last_log_term, last_log_term); // Add a small delay to allow server to catch up if under stress if depth > 0 { sleep(Duration::from_millis(50)).await; } if depth >= num_members { error!("Too many retry on query for plane {}, num_members {}, due to left behind record {}", plane_id.raw(), num_members, depth); return Err(ExecError::TooManyRetry); } else { depth += 1; continue; } } ClientQryResponse::Success { data, last_log_term, last_log_id, } => { swap_when_greater(&state.last_log_id, last_log_id); swap_when_greater(&state.last_log_term, last_log_term); if depth > 0 { warn!("Retry successful on plane {}...{}", plane_id.raw(), depth); } trace!("Query for plane {} from node {} for sm_id {}, fn_id {}, successful at log id {}, term {}", plane_id.raw(), node_index, sm_id, fn_id, last_log_id, last_log_term); return Ok(data); } }, Err(e) => { error!( "Got unknown error on query for plane {}: {:?}, server {}", plane_id.raw(), e, rpc_client.client.address ); if depth >= num_members { return Err(ExecError::Unknown); } else { debug!("Retry query on plane {}...{}", plane_id.raw(), depth); depth += 1; continue; } } } } else { return Err(ExecError::ServersUnreachable); } } } async fn command_on_plane( &self, plane_id: PlaneId, sm_id: u64, fn_id: u64, data: Vec, ) -> Result { const NOT_COMMITTED_RETRY_DELAY_MS: u64 = 10; const UPDATE_INFO_RETRY_DELAY_MS: u64 = 10; let not_committed_retry_limit = std::cmp::max( 64, ((HEARTBEAT_MS * 2) / NOT_COMMITTED_RETRY_DELAY_MS as i64) as i32, ); let update_info_retry_limit = std::cmp::max( 64, ((HEARTBEAT_MS * 2) / UPDATE_INFO_RETRY_DELAY_MS as i64) as i32, ); enum FailureAction { SwitchLeader, NotCommitted, UpdateInfo, NotLeader, ShuttingDown, } let state = self.plane_state(plane_id).await; let mut leader_retry_depth = 0; let mut not_committed_depth = 0; let mut update_info_depth = 0; loop { let failure = { if leader_retry_depth > 0 { let members = self.members.read().await; let num_members = members.clients.len(); if leader_retry_depth >= max(num_members + 1, 5) { error!( "Too many retry on command for plane {}, num_members {}, due to leader retry attempts {}", plane_id.raw(), num_members, leader_retry_depth ); return Err(ExecError::TooManyRetry); }; } match self.preferred_client_on_plane(&state).await { Some((leader_id, client)) => { let cmd_res = client .c_command(plane_id, self.gen_log_entry(&state, sm_id, fn_id, &data)) .await; match cmd_res { Ok(ClientCmdResponse::Success { data, last_log_term, last_log_id, }) => { swap_when_greater(&state.last_log_id, last_log_id); swap_when_greater(&state.last_log_term, last_log_term); return Ok(data); } Ok(ClientCmdResponse::NotLeader(new_leader_id)) => { if new_leader_id == 0 || new_leader_id == leader_id { warn!( "RAFTDBG_V2 client plane_id={} notleader-zero_or_same leader_id={} suggested={} depth={} update_info_depth={} not_committed_depth={}", plane_id.raw(), leader_id, new_leader_id, leader_retry_depth, update_info_depth, not_committed_depth ); debug!( "CLIENT plane_id={}: NOT LEADER, SUGGESTION NOT USEFUL, REFRESH INFO. GOT: {}", plane_id.raw(), new_leader_id ); FailureAction::UpdateInfo } else { warn!( "RAFTDBG_V3 client plane_id={} notleader-redirect current_leader={} suggested_leader={} depth={}", plane_id.raw(), leader_id, new_leader_id, leader_retry_depth ); debug!( "CLIENT plane_id={}: NOT LEADER, REMOTE SUGGEST SWITCH TO {}", plane_id.raw(), new_leader_id ); info!( "CMD Setting plane {} leader to {}, was {}", plane_id.raw(), new_leader_id, state.leader_id.load(Relaxed) ); state.leader_id.store(new_leader_id, ORDERING); FailureAction::NotLeader } } Ok(ClientCmdResponse::NotCommitted { last_log_term, last_log_id, }) => { debug!( "CLIENT plane_id={}: NOT COMMITTED at leader {}, refreshing client log cursor to term {}, id {}", plane_id.raw(), leader_id, last_log_term, last_log_id ); swap_when_greater(&state.last_log_id, last_log_id); swap_when_greater(&state.last_log_term, last_log_term); FailureAction::NotCommitted } Ok(ClientCmdResponse::ShuttingDown) => FailureAction::ShuttingDown, Err(e) => { warn!( "RAFTDBG_V3 client plane_id={} transport_or_rpc_error leader_id={} depth={} error={:?}", plane_id.raw(), leader_id, leader_retry_depth, e ); debug!( "CLIENT plane_id={}: ERROR - {} - {:?}", plane_id.raw(), leader_id, e ); FailureAction::SwitchLeader // need switch server for leader } } } None => { warn!("Need update members for plane {}", plane_id.raw()); FailureAction::UpdateInfo } } }; // match failure { FailureAction::NotCommitted => { not_committed_depth += 1; update_info_depth = 0; if not_committed_depth >= not_committed_retry_limit { error!( "Too many retry on command for plane {} due to NotCommitted responses {}", plane_id.raw(), not_committed_depth ); return Err(ExecError::TooManyRetry); } debug!( "Retrying command for plane {} after NotCommitted response {}/{}", plane_id.raw(), not_committed_depth, not_committed_retry_limit ); sleep(Duration::from_millis(NOT_COMMITTED_RETRY_DELAY_MS)).await; continue; } FailureAction::UpdateInfo => { update_info_depth += 1; not_committed_depth = 0; warn!( "RAFTDBG_V2 client plane_id={} update_info_retry count={} depth={}", plane_id.raw(), update_info_depth, leader_retry_depth ); if update_info_depth >= update_info_retry_limit { error!( "Too many retry on command for plane {} due to cluster-info refresh attempts {}", plane_id.raw(), update_info_depth ); return Err(ExecError::TooManyRetry); } let servers = { let members = self.members.read().await; Vec::from_iter(members.id_map.values().cloned()) }; if servers.is_empty() { warn!( "Cannot refresh cluster info for plane {}: no known servers", plane_id.raw() ); return Err(ExecError::ServersUnreachable); } debug!( "Refreshing cluster info for plane {} after transient NotLeader/leader-miss {}/{} from {:?}", plane_id.raw(), update_info_depth, update_info_retry_limit, servers ); if let Err(e) = self.update_plane_info(plane_id, &servers).await { warn!( "Failed to refresh cluster info for plane {} during command retry: {:?}", plane_id.raw(), e ); } sleep(Duration::from_millis(UPDATE_INFO_RETRY_DELAY_MS)).await; continue; } FailureAction::SwitchLeader => { not_committed_depth = 0; update_info_depth = 0; leader_retry_depth += 1; warn!( "RAFTDBG_V3 client plane_id={} switch_leader depth={}", plane_id.raw(), leader_retry_depth ); debug!("Switch leader for plane {} by probing", plane_id.raw()); let members = self.members.read().await; let num_members = members.clients.len(); let leader_id = state.leader_id.load(ORDERING); let new_leader_id = match members .clients .keys() .nth(leader_retry_depth as usize % num_members) { Some(id) => *id, None => { error!( "Cannot find new leader for plane {} at index {} (total: {})", plane_id.raw(), leader_retry_depth as usize % num_members, num_members ); return Err(ExecError::ServersUnreachable); } }; let leadder_switch = state.leader_id.compare_exchange( leader_id, new_leader_id, ORDERING, Relaxed, ); info!( "SWITCH plane {} exchange leader to {}, was {:?}", plane_id.raw(), new_leader_id, leadder_switch ); debug!( "CLIENT plane_id={}: Switch leader {}", plane_id.raw(), new_leader_id ); } FailureAction::NotLeader => { leader_retry_depth += 1; not_committed_depth = 0; update_info_depth = 0; continue; } FailureAction::ShuttingDown => { return Err(ExecError::ShuttingDown); } } } } fn gen_log_entry( &self, state: &PlaneClientState, sm_id: u64, fn_id: u64, data: &Vec, ) -> LogEntry { LogEntry { id: state.last_log_id.load(ORDERING), term: state.last_log_term.load(ORDERING), sm_id, fn_id, data: data.clone(), } } pub fn leader_id(&self) -> u64 { self.type1_state.leader_id.load(ORDERING) } pub async fn leader_client(&self) -> Option<(u64, Client)> { self.current_leader_client_on_plane(PlaneId::type1(), &self.type1_state) .await } async fn any_known_client(&self) -> Option { let members = self.members.read().await; members.clients.values().next().cloned() } async fn preferred_client_on_plane(&self, state: &PlaneClientState) -> Option<(u64, Client)> { if let Some((leader_id, client)) = self.leader_client_on_plane(state).await { return Some((leader_id, client)); } self.any_known_client().await.map(|client| (0, client)) } async fn leader_client_on_plane(&self, state: &PlaneClientState) -> Option<(u64, Client)> { let members = self.members.read().await; let leader_id = state.leader_id.load(ORDERING); if let Some(client) = members.clients.get(&leader_id) { Some((leader_id, client.clone())) } else { None } } async fn current_leader_client_on_plane( &self, plane_id: PlaneId, state: &PlaneClientState, ) -> Option<(u64, Client)> { { let leader_client = self.leader_client_on_plane(state).await; if leader_client.is_some() { return leader_client; } } debug!( "Obtaining leader client for plane {} by updating cluster info", plane_id.raw() ); { let servers = { let members = self.members.read().await; Vec::from_iter(members.id_map.values().cloned()) }; if let Err(e) = self.update_plane_info(plane_id, &servers).await { error!( "Failed to update cluster info for plane {}: {:?}", plane_id.raw(), e ); return None; } let leader_id = state.leader_id.load(ORDERING); let members = self.members.read().await; if let Some(client) = members.clients.get(&leader_id) { debug!( "Obtained leader client for plane {} with id: {}", plane_id.raw(), leader_id ); Some((leader_id, client.clone())) } else { warn!( "Cannot obtain leader client for plane {} with id {}. Having {:?}", plane_id.raw(), leader_id, members.clients.keys().collect::>() ); None } } } pub async fn current_leader_rpc_client(&self) -> Result, ()> { let (_, client) = self .current_leader_client_on_plane(PlaneId::type1(), &self.type1_state) .await .ok_or_else(|| ())?; Ok(client.client.clone()) } pub async fn cluster_info_on_plane( &self, plane_id: PlaneId, ) -> Result { let state = self.plane_state(plane_id).await; let client = match self.leader_client_on_plane(&state).await { Some((_, client)) => client, None => self .any_known_client() .await .ok_or(ExecError::ServersUnreachable)?, }; client .c_server_cluster_info(plane_id) .await .map_err(|_| ExecError::ServersUnreachable) } pub async fn have_state_machine_on_plane( &self, plane_id: PlaneId, sm_id: u64, ) -> Result { let state = self.plane_state(plane_id).await; let client = match self.leader_client_on_plane(&state).await { Some((_, client)) => client, None => self .any_known_client() .await .ok_or(ExecError::ServersUnreachable)?, }; client .c_have_state_machine(plane_id, sm_id) .await .map_err(|_| ExecError::ServersUnreachable) } } fn swap_when_greater(atomic: &AtomicU64, value: u64) { let mut orig_num = atomic.load(ORDERING); loop { if orig_num >= value { return; } match atomic.compare_exchange(orig_num, value, ORDERING, Relaxed) { Ok(_) => { return; } Err(actual) => { orig_num = actual; } } } } pub struct CachedStateMachine { server_list: Vec, raft_service_id: u64, plane_id: PlaneId, state_machine_id: u64, cache: RwLock>>, } impl CachedStateMachine { pub fn new( server_list: &Vec, raft_service_id: u64, plane_id: PlaneId, state_machine_id: u64, ) -> Self { debug!( "Construct cached state machine for list {:?}, service id {}, plane {}, state machine {}", server_list, raft_service_id, plane_id.raw(), state_machine_id ); Self { server_list: server_list.clone(), raft_service_id, plane_id, state_machine_id, cache: RwLock::new(None), } } pub async fn get(&self) -> Arc { loop { { let client = self.cache.read().await; if let Some(cache) = &*client { return (*cache).clone(); } } { let mut place_holder = self.cache.write().await; if place_holder.is_none() { debug!( "Creating state machine client instance, service {}, state machine id {}", self.raft_service_id, self.state_machine_id ); let raft_client = match RaftClient::new(&self.server_list, self.raft_service_id).await { Ok(client) => client, Err(e) => { error!( "Failed to create RaftClient for service {} and sm {}: {:?}", self.raft_service_id, self.state_machine_id, e ); // Drop the lock and retry after a delay drop(place_holder); sleep(Duration::from_millis(100)).await; continue; } }; let plane_client = raft_client.plane(self.plane_id); // Create a client for the state machine on the raft service *place_holder = Some(Arc::new(T::new_instance( self.state_machine_id, &plane_client, ))) } } } } } ================================================ FILE: src/raft/disk.rs ================================================ // Now only offers log persistent use crate::raft::{LogEntry, LogsMap, Options, PlaneId, RaftMeta, SnapshotEntity, Storage}; use async_std::sync::*; use std::convert::TryInto; use std::fs::OpenOptions; use std::io; use std::io::{Read, Seek, SeekFrom}; use std::ops::Bound::*; use std::path::{Path, PathBuf}; use tokio::fs::*; use tokio::io::*; // const MAX_LOG_CAPACITY: usize = 10; #[derive(Clone)] pub struct DiskOptions { pub path: String, pub take_snapshots: bool, pub append_logs: bool, pub trim_logs: bool, // Snapshot configuration pub snapshot_log_threshold: u64, // Trigger snapshot after N logs pub log_compaction_threshold: u64, // Compact when logs exceed this } impl DiskOptions { pub fn new(path: String) -> Self { Self { path, take_snapshots: true, append_logs: true, trim_logs: true, snapshot_log_threshold: 1000, log_compaction_threshold: 2000, } } } pub struct StorageEntity { pub logs: Option, pub snapshot: Option, pub last_term: u64, pub base_path: PathBuf, pub plane_id: PlaneId, } pub struct DiskLogEntry { pub term: u64, pub commit_index: u64, pub last_applied: u64, pub log: LogEntry, } impl DiskLogEntry { /// Encode to deterministic binary format with CRC32 checksum. /// /// On-disk record layout (written by `append_logs`): /// [8 bytes] record length (= 4 + payload length, does NOT include these 8 bytes) /// [4 bytes] CRC32 of payload /// [N bytes] payload: /// [8 bytes] term /// [8 bytes] commit_index /// [8 bytes] last_applied /// [8 bytes] log.id /// [8 bytes] log.term /// [8 bytes] log.sm_id /// [8 bytes] log.fn_id /// [8 bytes] log.data.len() /// [M bytes] log.data /// /// `encode()` returns only the payload (the CRC and length prefix are added by the caller). pub fn encode(&self) -> Vec { let data_len = self.log.data.len(); let total_size = 8 * 8 + data_len; let mut buf = Vec::with_capacity(total_size); buf.extend_from_slice(&self.term.to_le_bytes()); buf.extend_from_slice(&self.commit_index.to_le_bytes()); buf.extend_from_slice(&self.last_applied.to_le_bytes()); buf.extend_from_slice(&self.log.id.to_le_bytes()); buf.extend_from_slice(&self.log.term.to_le_bytes()); buf.extend_from_slice(&self.log.sm_id.to_le_bytes()); buf.extend_from_slice(&self.log.fn_id.to_le_bytes()); buf.extend_from_slice(&(data_len as u64).to_le_bytes()); buf.extend_from_slice(&self.log.data); buf } /// Decode payload bytes (without the length prefix or CRC — the caller strips those). pub fn decode(data: &[u8]) -> io::Result { if data.len() < 64 { return Err(io::Error::new( io::ErrorKind::InvalidData, format!("DiskLogEntry too short: {} bytes", data.len()), )); } let term = u64::from_le_bytes(data[0..8].try_into().unwrap()); let commit_index = u64::from_le_bytes(data[8..16].try_into().unwrap()); let last_applied = u64::from_le_bytes(data[16..24].try_into().unwrap()); let log_id = u64::from_le_bytes(data[24..32].try_into().unwrap()); let log_term = u64::from_le_bytes(data[32..40].try_into().unwrap()); let log_sm_id = u64::from_le_bytes(data[40..48].try_into().unwrap()); let log_fn_id = u64::from_le_bytes(data[48..56].try_into().unwrap()); let data_len = u64::from_le_bytes(data[56..64].try_into().unwrap()) as usize; if data.len() < 64 + data_len { return Err(io::Error::new( io::ErrorKind::InvalidData, format!("DiskLogEntry data truncated: expected {}, got {}", 64 + data_len, data.len()), )); } let log_data = data[64..64 + data_len].to_vec(); Ok(DiskLogEntry { term, commit_index, last_applied, log: LogEntry { id: log_id, term: log_term, sm_id: log_sm_id, fn_id: log_fn_id, data: log_data, }, }) } } impl StorageEntity { pub fn new_with_options( opts: &Options, term: &mut u64, commit_index: &mut u64, last_applied: &mut u64, logs: &mut LogsMap, ) -> io::Result> { Self::new_with_options_on_plane( PlaneId::type1(), opts, term, commit_index, last_applied, logs, ) } pub fn new_with_options_on_plane( plane_id: PlaneId, opts: &Options, term: &mut u64, commit_index: &mut u64, last_applied: &mut u64, logs: &mut LogsMap, ) -> io::Result> { Ok(match &opts.storage { &Storage::DISK(ref options) => { let base_path = Path::new(&options.path); let _ = std::fs::create_dir_all(base_path); let log_path = base_path.join("log.dat"); let snapshot_path = base_path.join("snapshot.dat"); let mut open_opts = OpenOptions::new(); open_opts .write(true) .create(true) .read(true) .truncate(false); let mut last_log_id = 0; let mut storage = Self { logs: if options.append_logs { let mut log_file = open_opts.open(log_path.as_path())?; let mut len_buf = [0u8; 8]; let mut crc_buf = [0u8; 4]; let mut counter = 0; let mut last_valid_pos: u64 = 0; loop { let pos_before = log_file.seek(SeekFrom::Current(0)) .unwrap_or(last_valid_pos); if log_file.read_exact(&mut len_buf).is_err() { break; } let record_len = u64::from_le_bytes(len_buf); // record_len = 4 (CRC) + payload_len if record_len < 4 { warn!("WAL corrupt: invalid record length {} at pos {}, truncating", record_len, pos_before); break; } let payload_len = record_len - 4; if log_file.read_exact(&mut crc_buf).is_err() { warn!("WAL truncated: missing CRC at pos {}, truncating", pos_before); break; } let expected_crc = u32::from_le_bytes(crc_buf); let mut data_buf = vec![0u8; payload_len as usize]; if log_file.read_exact(&mut data_buf).is_err() { warn!("WAL truncated: missing payload at pos {}, truncating", pos_before); break; } let actual_crc = crc32fast::hash(&data_buf); if actual_crc != expected_crc { warn!( "WAL CRC mismatch at pos {}: expected {:#010x}, got {:#010x}, truncating", pos_before, expected_crc, actual_crc ); break; } match DiskLogEntry::decode(&data_buf) { Ok(entry) => { *term = entry.term; // Do not trust commit/last_applied embedded in WAL for SM reconstruction // We'll derive commit_index from commit.idx and force replay from last_applied=0 last_log_id = entry.log.id; logs.insert(entry.log.id, entry.log); counter += 1; last_valid_pos = log_file.seek(SeekFrom::Current(0)) .unwrap_or(last_valid_pos); } Err(e) => { warn!("WAL decode error at pos {}: {:?}, truncating", pos_before, e); break; } } } // Truncate WAL at last valid entry to remove any corrupt tail, // then seek to end so appends start at the correct position let current_len = log_file.seek(SeekFrom::End(0)).unwrap_or(last_valid_pos); if current_len > last_valid_pos { info!("WAL has corrupt tail ({} extra bytes), truncating to {}", current_len - last_valid_pos, last_valid_pos); if let Err(e) = log_file.set_len(last_valid_pos) { warn!("Failed to truncate WAL to {} bytes: {:?}", last_valid_pos, e); } } let _ = log_file.seek(SeekFrom::End(0)); debug!( "Recovered {} raft logs for plane {}", counter, plane_id.raw() ); Some(File::from_std(log_file)) } else { None }, snapshot: if options.take_snapshots { Some(File::from_std(open_opts.open(snapshot_path.as_path())?)) } else { None }, last_term: 0, base_path: base_path.to_path_buf(), plane_id, }; storage.last_term = last_log_id; // If commit progress side file exists, load it to ensure accurate indices // Force full replay by resetting last_applied to 0 on startup *last_applied = 0; if let Ok(Some((ci, _la))) = futures::executor::block_on(storage.read_commit_progress()) { *commit_index = ci; debug!( "Recovered commit progress for plane {}: commit_index={} (will replay to rebuild state)", plane_id.raw(), ci ); } else { // If no commit progress found, default to 0 to avoid partial state *commit_index = 0; } Some(storage) } _ => None, }) } pub async fn append_logs<'a>( &mut self, meta: &'a RwLockWriteGuard<'a, RaftMeta>, logs: &'a RwLockWriteGuard<'a, LogsMap>, ) -> io::Result<()> { if let Some(f) = &mut self.logs { let was_last_term = self.last_term; let mut counter = 0; let mut terms_appended = vec![]; let master = meta.state_machine.read().await; for (term, log) in logs.range((Excluded(self.last_term), Unbounded)) { // Skip non-recoverable state machines if !master.is_recoverable(log.sm_id) { continue; } let entry = DiskLogEntry { term: *term, commit_index: meta.commit_index, last_applied: meta.last_applied, log: log.clone(), }; let entry_data = entry.encode(); let checksum = crc32fast::hash(&entry_data); // Write: [8 bytes record_len = 4+payload_len][4 bytes CRC32][N bytes payload] let record_len = 4u64 + entry_data.len() as u64; f.write_all(&record_len.to_le_bytes()).await?; f.write_all(&checksum.to_le_bytes()).await?; f.write_all(entry_data.as_slice()).await?; self.last_term = *term; terms_appended.push(self.last_term); counter += 1; } if counter > 0 { f.sync_all().await?; debug!( "Appended and persisted {} logs for plane {}, was {}, appended {:?}", counter, self.plane_id.raw(), was_last_term, terms_appended ); } } Ok(()) } pub async fn post_processing<'a>( &mut self, meta: &RwLockWriteGuard<'a, RaftMeta>, logs: RwLockWriteGuard<'a, LogsMap>, ) -> io::Result<()> { // TODO: trim logs in memory // TODO: trim logs on disk self.append_logs(meta, &logs).await?; Ok(()) // let (last_log_id, _) = get_last_log_info!(self, logs); // let expecting_oldest_log = if last_log_id > MAX_LOG_CAPACITY as u64 { // last_log_id - MAX_LOG_CAPACITY as u64 // } else { // 0 // }; // let double_cap = MAX_LOG_CAPACITY << 1; // if logs.len() > double_cap && meta.last_applied > expecting_oldest_log { // debug!("trim logs"); // while logs.len() > MAX_LOG_CAPACITY { // let first_key = *logs.iter().next().unwrap().0; // logs.remove(&first_key).unwrap(); // } // if let Some(ref storage) = meta.storage { // let mut storage = storage.write().await; // let snapshot = SnapshotEntity { // term: meta.term, // commit_index: meta.commit_index, // last_applied: meta.last_applied, // snapshot: meta.state_machine.read().await.snapshot().unwrap(), // }; // storage // .snapshot // .write_all(crate::utils::serde::serialize(&snapshot).as_slice())?; // storage.snapshot.sync_all().unwrap(); // } // } // if let Some(ref storage) = meta.storage { // let mut storage = storage.write().await; // let logs_data = crate::utils::serde::serialize(&*meta.logs.read().await); // // TODO: async file system calls // storage.logs.write_all(logs_data.as_slice())?; // storage.logs.sync_all().unwrap(); // } } /// Ensure WAL file is fully synced to disk. pub async fn flush_wal(&mut self) -> io::Result<()> { if let Some(f) = &mut self.logs { info!( "WAL fsync for plane {}: syncing log.dat to disk", self.plane_id.raw() ); f.sync_all().await?; info!("WAL fsync for plane {}: completed", self.plane_id.raw()); } Ok(()) } /// Persist commit progress atomically to a side file (commit.idx) pub async fn write_commit_progress( &mut self, commit_index: u64, last_applied: u64, ) -> io::Result<()> { let commit_path = self.base_path.join("commit.idx"); let temp_path = self.base_path.join("commit.idx.tmp"); let mut f = File::create(&temp_path).await?; f.write_all(&commit_index.to_le_bytes()).await?; f.write_all(&last_applied.to_le_bytes()).await?; f.sync_all().await?; drop(f); std::fs::rename(&temp_path, &commit_path)?; Ok(()) } /// Read commit progress if available pub async fn read_commit_progress(&self) -> io::Result> { let commit_path = self.base_path.join("commit.idx"); if !commit_path.exists() { return Ok(None); } let mut f = File::open(&commit_path).await?; let mut buf = [0u8; 16]; if f.read_exact(&mut buf).await.is_err() { return Ok(None); } let commit_index = u64::from_le_bytes(buf[0..8].try_into().unwrap()); let last_applied = u64::from_le_bytes(buf[8..16].try_into().unwrap()); Ok(Some((commit_index, last_applied))) } /// Write snapshot to disk using atomic write pattern (temp file + rename) pub async fn write_snapshot(&mut self, snapshot: &SnapshotEntity) -> io::Result<()> { let snapshot_path = self.base_path.join("snapshot.dat"); let temp_path = self.base_path.join("snapshot.dat.tmp"); // Serialize snapshot let snapshot_data = crate::utils::serde::serialize(snapshot); // Calculate CRC32 checksum let checksum = crc32fast::hash(&snapshot_data); // Write to temp file let mut temp_file = File::create(&temp_path).await?; // Write checksum first temp_file.write_all(&checksum.to_le_bytes()).await?; // Write length temp_file .write_all(&(snapshot_data.len() as u64).to_le_bytes()) .await?; // Write data temp_file.write_all(&snapshot_data).await?; // Sync to disk temp_file.sync_all().await?; drop(temp_file); // Atomic rename std::fs::rename(&temp_path, &snapshot_path)?; info!( "Snapshot persisted to disk for plane {}: index={}, term={}, size={} bytes", self.plane_id.raw(), snapshot.last_included_index, snapshot.last_included_term, snapshot_data.len() ); Ok(()) } /// Read and validate snapshot from disk pub async fn read_snapshot(&self) -> io::Result> { let snapshot_path = self.base_path.join("snapshot.dat"); // Check if snapshot file exists if !snapshot_path.exists() { debug!( "No snapshot file found for plane {} at {:?}", self.plane_id.raw(), snapshot_path ); return Ok(None); } let mut file = File::open(&snapshot_path).await?; // Read checksum let mut checksum_buf = [0u8; 4]; if file.read_exact(&mut checksum_buf).await.is_err() { warn!( "Failed to read snapshot checksum for plane {}, file may be corrupted", self.plane_id.raw() ); return Ok(None); } let expected_checksum = u32::from_le_bytes(checksum_buf); // Read length let mut len_buf = [0u8; 8]; if file.read_exact(&mut len_buf).await.is_err() { warn!( "Failed to read snapshot length for plane {}, file may be corrupted", self.plane_id.raw() ); return Ok(None); } let len = u64::from_le_bytes(len_buf); // Read data let mut data_buf = vec![0u8; len as usize]; if file.read_exact(&mut data_buf).await.is_err() { warn!( "Failed to read snapshot data for plane {}, file may be corrupted", self.plane_id.raw() ); return Ok(None); } // Verify checksum let actual_checksum = crc32fast::hash(&data_buf); if actual_checksum != expected_checksum { error!( "Snapshot checksum mismatch on plane {}! Expected: {}, Got: {}. File is corrupted.", self.plane_id.raw(), expected_checksum, actual_checksum ); return Ok(None); } // Deserialize let snapshot = crate::utils::serde::deserialize::(&data_buf).unwrap(); info!( "Snapshot loaded from disk for plane {}: index={}, term={}, size={} bytes", self.plane_id.raw(), snapshot.last_included_index, snapshot.last_included_term, data_buf.len() ); Ok(Some(snapshot)) } } ================================================ FILE: src/raft/mod.rs ================================================ use self::state_machine::callback::server::Subscriptions; use self::state_machine::callback::SMCallback; use self::state_machine::configs::commands::new_member_; use self::state_machine::configs::{RaftMember, CONFIG_SM_ID}; use self::state_machine::master::{ExecError, ExecResult, MasterStateMachine, SubStateMachine}; use self::state_machine::OpType; use crate::raft::client::{ClientError, RaftClient}; use crate::raft::disk::*; use crate::raft::state_machine::StateMachineCtl; use crate::rpc; use crate::utils::time::get_time; use async_std::sync::*; use bifrost_hasher::hash_str; use bifrost_plugins::hash_ident; use futures::future::BoxFuture; use futures::prelude::*; use futures::stream::FuturesUnordered; use serde::{Deserialize, Serialize}; use std::cmp::{max, min}; use std::collections::Bound::{Included, Unbounded}; use std::collections::{BTreeMap, HashMap}; use std::fmt::{Display, Formatter}; use std::io; use std::path::Path; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::Relaxed; use std::time::Duration; use tokio::runtime; use tokio::sync::{watch, Mutex as TokioMutex}; use tokio::time::*; #[macro_use] pub mod state_machine; pub mod client; pub mod disk; pub static DEFAULT_SERVICE_ID: u64 = hash_ident!(BIFROST_RAFT_DEFAULT_SERVICE) as u64; #[derive( Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, PartialOrd, Ord, Serialize, )] pub struct PlaneId(u64); #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum PlaneIdError { Type2PlaneMustBePositive, } impl PlaneId { pub const fn type1() -> Self { Self(0) } pub fn type2(raw: u64) -> Result { if raw == 0 { Err(PlaneIdError::Type2PlaneMustBePositive) } else { Ok(Self(raw)) } } pub const fn raw(self) -> u64 { self.0 } pub const fn is_type1(self) -> bool { self.0 == 0 } pub const fn is_type2(self) -> bool { self.0 > 0 } } impl Display for PlaneIdError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { PlaneIdError::Type2PlaneMustBePositive => { write!(f, "type-2 plane ids must be greater than zero") } } } } impl std::error::Error for PlaneIdError {} impl From for u64 { fn from(value: PlaneId) -> Self { value.raw() } } pub trait RaftMsg: Send + Sync { fn encode(self) -> (u64, OpType, Vec); fn decode_return(data: &Vec) -> R; } const CHECKER_MS: i64 = 200; const HEARTBEAT_MS: i64 = 1000; // Timeout for heartbeat task - increased to prevent timeouts under stress const HEARTBEAT_TASK_TIMEOUT_MS: i64 = 5000; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct LogEntry { pub id: u64, pub term: u64, pub sm_id: u64, pub fn_id: u64, pub data: Vec, } #[derive(Serialize, Deserialize, Debug, Clone)] pub enum ClientCmdResponse { Success { data: ExecResult, last_log_term: u64, last_log_id: u64, }, NotLeader(u64), NotCommitted { last_log_term: u64, last_log_id: u64, }, ShuttingDown, } #[derive(Serialize, Deserialize, Debug, Clone)] pub enum ClientQryResponse { Success { data: ExecResult, last_log_term: u64, last_log_id: u64, }, LeftBehind { last_log_term: u64, last_log_id: u64, }, } #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ClientClusterInfo { members: Vec<(u64, String)>, last_log_id: u64, last_log_term: u64, leader_id: u64, } #[derive(Serialize, Deserialize, Debug, Clone)] pub enum AppendEntriesResult { Ok, TermOut(u64), LogMismatch, } #[derive(Serialize, Deserialize, Clone, Debug)] pub struct SnapshotEntity { pub last_included_index: u64, pub last_included_term: u64, pub snapshot: Vec, } type LogEntries = Vec; type LogsMap = BTreeMap; service! { rpc append_entries(plane_id: PlaneId, term: u64, leader_id: u64, prev_log_id: u64, prev_log_term: u64, entries: &Option, leader_commit: u64) -> (u64, AppendEntriesResult); rpc request_vote(plane_id: PlaneId, term: u64, candidate_id: u64, last_log_id: u64, last_log_term: u64) -> ((u64, u64), bool); // term, voteGranted rpc install_snapshot(plane_id: PlaneId, term: u64, leader_id: u64, last_included_index: u64, last_included_term: u64, data: Vec) -> u64; rpc reelect(plane_id: PlaneId) -> bool; rpc c_command(plane_id: PlaneId, entry: LogEntry) -> ClientCmdResponse; rpc c_query(plane_id: PlaneId, entry: &LogEntry) -> ClientQryResponse; rpc c_server_cluster_info(plane_id: PlaneId) -> ClientClusterInfo; rpc c_put_offline() -> bool; rpc c_have_state_machine(plane_id: PlaneId, id: u64) -> bool; rpc c_ping(); } service_with_id!(RaftService, DEFAULT_SERVICE_ID); fn gen_rand(lower: i64, higher: i64) -> i64 { let span = (higher - lower).max(1) as u64; lower + (rand::random::() % span) as i64 } fn gen_timeout() -> i64 { gen_rand(10_000, 30_000) } struct FollowerStatus { next_index: u64, match_index: u64, } pub struct LeaderMeta { last_updated: i64, followers: HashMap>>, } impl LeaderMeta { fn new() -> LeaderMeta { LeaderMeta { last_updated: get_time(), followers: HashMap::new(), } } } pub enum Membership { Leader(RwLock), Follower, Candidate, Offline, Undefined, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] enum LifecycleState { Running, Stopping, Stopped, } pub struct RaftMeta { term: u64, vote_for: Option, timeout: i64, last_checked: i64, membership: Membership, logs: Arc>, state_machine: Arc>, commit_index: u64, last_applied: u64, leader_id: u64, storage: Option>>, last_snapshot_index: u64, last_snapshot_term: u64, lifecycle: LifecycleState, } #[derive(Clone)] pub enum Storage { MEMORY, DISK(DiskOptions), } impl Storage { pub fn default() -> Storage { Storage::MEMORY } } #[derive(Clone)] pub struct Options { pub storage: Storage, pub address: String, pub service_id: u64, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct PlaneSpec { pub plane_id: PlaneId, } #[derive(Clone, Debug, Eq, PartialEq)] pub struct PlaneBootstrap { pub plane_id: PlaneId, pub seed_nodes: Vec, } #[derive(Debug)] pub enum PlaneError { PlaneNotFound(PlaneId), StorageInit(io::Error), InitializationFailed(PlaneId), } impl Display for PlaneError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { PlaneError::PlaneNotFound(plane_id) => { write!(f, "plane {} is not registered on this host", plane_id.raw()) } PlaneError::StorageInit(err) => write!(f, "failed to initialize plane storage: {err}"), PlaneError::InitializationFailed(plane_id) => { write!(f, "failed to initialize plane {}", plane_id.raw()) } } } } impl std::error::Error for PlaneError {} #[derive(Debug)] pub enum PlaneBootstrapError { Type1PlaneUnsupported, EmptySeedNodes, NoType1MembersDiscovered, LocalMemberMissing { local_address: String, }, MembershipConflict { plane_id: PlaneId, current_members: Vec, requested_members: Vec, }, MemberRegistrationRejected { address: String, }, NotLeader { plane_id: PlaneId, leader_id: u64, }, Client(ClientError), Plane(PlaneError), Exec(ExecError), } impl Display for PlaneBootstrapError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { PlaneBootstrapError::Type1PlaneUnsupported => { write!( f, "type-2 bootstrap via seed nodes is only supported for type-2 planes" ) } PlaneBootstrapError::EmptySeedNodes => { write!(f, "plane bootstrap requires at least one type-1 seed node") } PlaneBootstrapError::NoType1MembersDiscovered => { write!(f, "type-1 seed discovery returned no available servers") } PlaneBootstrapError::LocalMemberMissing { local_address } => { write!( f, "type-1 discovered members must include the local address {}", local_address ) } PlaneBootstrapError::MembershipConflict { plane_id, current_members, requested_members, } => { write!( f, "plane {} membership conflict: current={:?}, requested={:?}", plane_id.raw(), current_members, requested_members ) } PlaneBootstrapError::MemberRegistrationRejected { address } => { write!(f, "plane bootstrap rejected member {}", address) } PlaneBootstrapError::NotLeader { plane_id, leader_id, } => { write!( f, "node is not the leader for plane {} (leader_id={})", plane_id.raw(), leader_id ) } PlaneBootstrapError::Client(err) => { write!(f, "type-1 seed discovery failed: {err}") } PlaneBootstrapError::Plane(err) => write!(f, "{err}"), PlaneBootstrapError::Exec(err) => write!(f, "plane bootstrap command failed: {err}"), } } } impl std::error::Error for PlaneBootstrapError {} impl From for PlaneBootstrapError { fn from(value: PlaneError) -> Self { Self::Plane(value) } } impl From for PlaneBootstrapError { fn from(value: ExecError) -> Self { Self::Exec(value) } } impl From for PlaneBootstrapError { fn from(value: ClientError) -> Self { Self::Client(value) } } struct RaftPlaneRuntime { plane_id: PlaneId, meta: RwLock, is_leader: AtomicBool, checker_task: TokioMutex>>, shutdown_tx: watch::Sender, } impl RaftPlaneRuntime { fn new(opts: &Options, plane_id: PlaneId) -> Result { let mut term = 0; let mut logs = BTreeMap::new(); let mut commit_index = 0; let mut last_applied = 0; let plane_opts = options_for_plane(opts, plane_id); let storage_entity = StorageEntity::new_with_options_on_plane( plane_id, &plane_opts, &mut term, &mut commit_index, &mut last_applied, &mut logs, ) .map_err(PlaneError::StorageInit)?; let master_sm = MasterStateMachine::new_on_plane(plane_opts.service_id, plane_id); let (shutdown_tx, _shutdown_rx) = watch::channel(LifecycleState::Running); Ok(Self { plane_id, meta: RwLock::new(RaftMeta { term, vote_for: None, timeout: gen_timeout(), last_checked: get_time(), membership: Membership::Undefined, logs: Arc::new(RwLock::new(logs)), state_machine: Arc::new(RwLock::new(master_sm)), commit_index, last_applied, leader_id: 0, storage: storage_entity.map(|entity| Arc::new(Mutex::new(entity))), last_snapshot_index: 0, last_snapshot_term: 0, lifecycle: LifecycleState::Running, }), is_leader: AtomicBool::new(false), checker_task: TokioMutex::new(None), shutdown_tx, }) } } #[derive(Clone)] pub struct PlaneHandle { service: Arc, plane_id: PlaneId, } impl PlaneHandle { pub const fn id(&self) -> PlaneId { self.plane_id } pub async fn callback(&self, state_machine_id: u64) -> Result { SMCallback::new_on_plane(state_machine_id, self.plane_id, self.service.clone()).await } pub async fn register_state_machine( &self, state_machine: SubStateMachine, ) -> Result<(), PlaneError> { self.service .register_state_machine_on_plane(self.plane_id, state_machine) .await } pub async fn recover_after_register(&self) -> Result<(), PlaneError> { self.service .recover_after_register_on_plane(self.plane_id) .await } pub async fn cluster_info(&self) -> Result { self.service .cluster_info_on_plane_local(self.plane_id) .await } pub async fn have_state_machine(&self, sm_id: u64) -> Result { self.service .have_state_machine_on_plane_local(self.plane_id, sm_id) .await } pub async fn is_leader(&self) -> Result { self.service.is_leader_on_plane(self.plane_id).await } pub async fn flush_persistence(&self) -> Result<(), PlaneError> { self.service.flush_persistence_on_plane(self.plane_id).await } pub async fn shutdown(&self) -> Result<(), PlaneError> { self.service.shutdown_plane(self.plane_id).await } } fn options_for_plane(opts: &Options, plane_id: PlaneId) -> Options { let storage = match &opts.storage { Storage::MEMORY => Storage::MEMORY, Storage::DISK(disk_opts) => { let mut plane_disk_opts = disk_opts.clone(); if plane_id.is_type2() { plane_disk_opts.path = format!("{}/planes/{}", disk_opts.path, plane_id.raw()); } Storage::DISK(plane_disk_opts) } }; Options { storage, address: opts.address.clone(), service_id: opts.service_id, } } pub struct RaftService { meta: RwLock, planes: RwLock>>, pub id: u64, pub options: Options, pub rt: runtime::Runtime, _is_leader: AtomicBool, checker_task: TokioMutex>>, shutdown_tx: watch::Sender, } dispatch_rpc_service_functions!(RaftService); #[derive(Debug)] enum CheckerAction { SendHeartbeat, BecomeCandidate, ExitLoop, None, } #[derive(Clone)] enum RequestVoteResponse { Granted, TermOut(u64, u64), NotGranted, } enum HeartbeatReplicationResult { Matched(u64), TermOut { term: u64, leader_id: u64 }, } macro_rules! get_last_log_info { ($s: expr, $logs: expr) => {{ let last_log = $logs.iter().next_back(); $s.get_log_info_(last_log) }}; } async fn check_commit(meta: &mut RwLockWriteGuard<'_, RaftMeta>) { while meta.commit_index > meta.last_applied { let next_log_id = meta.last_applied + 1; let entry = { // Clone the next entry so we can drop the log read lock before mutating apply state. let logs = meta.logs.read().await; logs.get(&next_log_id).cloned() }; let Some(entry) = entry else { warn!( "Committed log entry {} is missing during apply (commit_index={}, last_applied={}); deferring replay", next_log_id, meta.commit_index, meta.last_applied ); break; }; match apply_committed_entry(meta, &entry).await { Ok(_) => { meta.last_applied = next_log_id; } Err(ExecError::SmNotFound(sm_id)) => { warn!( "Deferring log entry {} until state machine {} is registered", next_log_id, sm_id ); break; } Err(e) => { error!( "Failed to commit command for log entry {}: {:?}", next_log_id, e ); // Preserve prior behavior for non-recoverable apply failures. meta.last_applied = next_log_id; } } } } impl RaftService { pub const fn plane_id(&self) -> PlaneId { PlaneId::type1() } fn check_plane(&self, plane_id: PlaneId) { debug_assert_eq!(plane_id, self.plane_id()); if plane_id != self.plane_id() { warn!( "received raft request for plane {} on single-plane service {}", plane_id.raw(), self.id ); } } fn lifecycle_is_stopping(state: LifecycleState) -> bool { matches!(state, LifecycleState::Stopping | LifecycleState::Stopped) } async fn wait_for_apply_drain_for_meta( meta_lock: &RwLock, timeout_duration: Duration, ) -> bool { let deadline = Instant::now() + timeout_duration; loop { { let meta = meta_lock.read().await; if meta.commit_index == meta.last_applied { return true; } } if Instant::now() >= deadline { return false; } sleep(Duration::from_millis(50)).await; } } async fn start_managed_runtime(self: &Arc, runtime: Option>) { if let Some(runtime_ref) = runtime.as_ref() { let guard = runtime_ref.checker_task.lock().await; if guard.is_some() { return; } } else { let guard = self.checker_task.lock().await; if guard.is_some() { return; } } let server = self.clone(); let runtime_ref = runtime.clone(); let plane_id = runtime_ref .as_ref() .map(|runtime| runtime.plane_id) .unwrap_or_else(PlaneId::type1); let mut shutdown_rx = if let Some(runtime) = runtime_ref.as_ref() { runtime.shutdown_tx.subscribe() } else { server.shutdown_tx.subscribe() }; let handle = self.rt.spawn(async move { info!( "Starting Raft checker/heartbeat task for plane {}", plane_id.raw() ); loop { if Self::lifecycle_is_stopping(*shutdown_rx.borrow()) { break; } let start_time = get_time(); let expected_ends = start_time + CHECKER_MS; let heartbeat_task_continue = async { let mut meta = if let Some(runtime) = runtime_ref.as_ref() { runtime.meta.write().await } else { server.meta.write().await }; if Self::lifecycle_is_stopping(meta.lifecycle) { return false; } let current_time = get_time(); let mut is_leader = false; let action = match meta.membership { Membership::Leader(_) => { is_leader = true; if current_time >= meta.last_checked + HEARTBEAT_MS { CheckerAction::SendHeartbeat } else { CheckerAction::None } } Membership::Follower | Membership::Candidate => { debug_assert!(meta.timeout > 100); let timeout_time = meta.last_checked + meta.timeout; let time_remains = timeout_time - current_time; if meta.vote_for.is_none() && time_remains < 0 { CheckerAction::BecomeCandidate } else { CheckerAction::None } } Membership::Offline => CheckerAction::ExitLoop, Membership::Undefined => CheckerAction::None, }; if let Some(runtime) = runtime_ref.as_ref() { runtime.is_leader.store(is_leader, Relaxed); } else { server._is_leader.store(is_leader, Relaxed); } match action { CheckerAction::SendHeartbeat => { server .send_followers_heartbeat_on_plane( plane_id, &mut meta, None, false, ) .await; meta.last_checked = get_time(); } CheckerAction::BecomeCandidate => { let leader_flag = runtime_ref .as_ref() .map(|runtime| &runtime.is_leader) .unwrap_or(&server._is_leader); server .become_candidate_on_plane( plane_id, leader_flag, &mut meta, ) .await; } CheckerAction::ExitLoop => return false, CheckerAction::None => {} } true }; let timed_heartbeat = tokio::select! { changed = shutdown_rx.changed() => { match changed { Ok(_) if Self::lifecycle_is_stopping(*shutdown_rx.borrow()) => break, Ok(_) => continue, Err(_) => break, } } result = timeout( Duration::from_millis(HEARTBEAT_TASK_TIMEOUT_MS as u64), heartbeat_task_continue, ) => result, }; let end_time = get_time(); let time_to_sleep = expected_ends - end_time - 1; match timed_heartbeat { Err(_) => { error!( "Heartbeat cannot finish in time for {}ms on plane {}", HEARTBEAT_MS, plane_id.raw() ); } Ok(false) => break, Ok(true) => {} } if time_to_sleep > 0 { tokio::select! { changed = shutdown_rx.changed() => { match changed { Ok(_) if Self::lifecycle_is_stopping(*shutdown_rx.borrow()) => break, Ok(_) => {} Err(_) => break, } } _ = sleep(Duration::from_millis(time_to_sleep as u64)) => {} } } } if let Some(runtime) = runtime_ref.as_ref() { runtime.is_leader.store(false, Relaxed); } else { server._is_leader.store(false, Relaxed); } info!( "Raft checker/heartbeat task stopped gracefully for plane {}", plane_id.raw() ); }); if let Some(runtime_ref) = runtime.as_ref() { let mut guard = runtime_ref.checker_task.lock().await; *guard = Some(handle); } else { let mut guard = self.checker_task.lock().await; *guard = Some(handle); } } async fn shutdown_managed_runtime(&self, runtime: Option>) { let plane_id = runtime .as_ref() .map(|runtime| runtime.plane_id) .unwrap_or_else(PlaneId::type1); let already_stopping = { let mut meta = if let Some(runtime) = runtime.as_ref() { runtime.meta.write().await } else { self.meta.write().await }; if meta.lifecycle != LifecycleState::Running { true } else { meta.lifecycle = LifecycleState::Stopping; meta.membership = Membership::Offline; false } }; if !already_stopping { if let Some(runtime) = runtime.as_ref() { let _ = runtime.shutdown_tx.send(LifecycleState::Stopping); } else { let _ = self.shutdown_tx.send(LifecycleState::Stopping); } } let _ = if let Some(runtime) = runtime.as_ref() { Self::wait_for_apply_drain_for_meta(&runtime.meta, Duration::from_secs(5)).await } else { Self::wait_for_apply_drain_for_meta(&self.meta, Duration::from_secs(5)).await }; let handle = if let Some(runtime) = runtime.as_ref() { let mut guard = runtime.checker_task.lock().await; guard.take() } else { let mut guard = self.checker_task.lock().await; guard.take() }; if let Some(handle) = handle { let _ = handle.await; } let _ = self.flush_persistence_on_plane(plane_id).await; { let mut meta = if let Some(runtime) = runtime.as_ref() { runtime.meta.write().await } else { self.meta.write().await }; meta.lifecycle = LifecycleState::Stopped; } if let Some(runtime) = runtime.as_ref() { runtime.is_leader.store(false, Relaxed); let _ = runtime.shutdown_tx.send(LifecycleState::Stopped); } else { self._is_leader.store(false, Relaxed); let _ = self.shutdown_tx.send(LifecycleState::Stopped); } } /// Public helper for applications to trigger commit replay after registering /// their state machines. This ensures snapshot replay and log apply happen /// only after SMs are ready. pub async fn recover_after_register(&self) { let mut meta = self.meta.write().await; { let mut master_sm = meta.state_machine.write().await; master_sm.recover_registered_snapshots().await; } info!( "Manual apply on plane {}: applying committed logs (commit_index={}, last_applied={})", PlaneId::type1().raw(), meta.commit_index, meta.last_applied ); check_commit(&mut meta).await; info!( "Manual apply on plane {}: applied logs up to last_applied={}", PlaneId::type1().raw(), meta.last_applied ); } } impl RaftService { fn plane_storage_path(&self, plane_id: PlaneId) -> Option { match &self.options.storage { Storage::MEMORY => None, Storage::DISK(disk_opts) => { Some(format!("{}/planes/{}", disk_opts.path, plane_id.raw())) } } } fn plane_has_persisted_state(&self, plane_id: PlaneId) -> bool { self.plane_storage_path(plane_id) .map(|path| { let base = Path::new(&path); base.join("commit.idx").exists() || base.join("log.dat").exists() || base.join("snapshot.dat").exists() }) .unwrap_or(false) } async fn load_snapshot_into_meta( plane_id: PlaneId, meta: &mut RwLockWriteGuard<'_, RaftMeta>, ) -> bool { let storage = meta.storage.clone(); if let Some(storage) = storage { let storage = storage.lock().await; match storage.read_snapshot().await { Ok(Some(snapshot)) => { info!( "Found snapshot on plane {}: index={}, term={}. Recovering state machine...", plane_id.raw(), snapshot.last_included_index, snapshot.last_included_term ); meta.state_machine .write() .await .recover(snapshot.snapshot.clone()) .await; meta.last_snapshot_index = snapshot.last_included_index; meta.last_snapshot_term = snapshot.last_included_term; // Snapshot restore must reset the apply cursor to the snapshot index so // post-snapshot WAL entries are replayed deterministically on startup. meta.last_applied = snapshot.last_included_index; if meta.commit_index < snapshot.last_included_index { meta.commit_index = snapshot.last_included_index; } { let mut logs = meta.logs.write().await; let before_count = logs.len(); logs.retain(|&id, _| id > snapshot.last_included_index); let after_count = logs.len(); info!( "Compacted logs on plane {} during startup: removed {} logs, {} remaining", plane_id.raw(), before_count - after_count, after_count ); } { let mut master_sm = meta.state_machine.write().await; master_sm.recover_registered_snapshots().await; } info!( "Snapshot recovery completed successfully for plane {}", plane_id.raw() ); true } Ok(None) => { debug!("No snapshot found on disk for plane {}", plane_id.raw()); false } Err(e) => { warn!( "Failed to load snapshot from disk for plane {}: {:?}. Starting without snapshot recovery.", plane_id.raw(), e ); false } } } else { debug!( "No storage configured, skipping snapshot recovery for plane {}", plane_id.raw() ); false } } async fn recover_config_state_from_logs( plane_id: PlaneId, meta: &mut RwLockWriteGuard<'_, RaftMeta>, ) { let committed_entries = { let logs = meta.logs.read().await; logs.range((Unbounded, Included(&meta.commit_index))) .filter_map(|(_, entry)| { if entry.sm_id == CONFIG_SM_ID { Some(entry.clone()) } else { None } }) .collect::>() }; if committed_entries.is_empty() { return; } let mut master_sm = meta.state_machine.write().await; for entry in committed_entries { if let Err(err) = master_sm.commit_cmd(&entry).await { warn!( "Failed to recover config log {} on plane {} during runtime initialization: {:?}", entry.id, plane_id.raw(), err ); } } } async fn initialize_runtime_meta( &self, plane_id: PlaneId, leader_flag: &AtomicBool, meta: &mut RwLockWriteGuard<'_, RaftMeta>, bootstrap_if_fresh: bool, ) -> Result<(), PlaneError> { leader_flag.store(false, Relaxed); meta.last_checked = get_time() + (CHECKER_MS * 10); let recovered_from_disk = Self::load_snapshot_into_meta(plane_id, meta).await; Self::recover_config_state_from_logs(plane_id, meta).await; let server_address = self.options.address.clone(); { let mut sm = meta.state_machine.write().await; let start_time = get_time(); while get_time() < start_time + 5000 { if sm.configs.new_member(server_address.clone()).await || sm.configs.member_existed(self.id) { break; } sleep(Duration::from_millis(50)).await; } if !sm.configs.member_existed(self.id) { return Err(PlaneError::InitializationFailed(plane_id)); } let num_members = sm.configs.members.len(); let has_logs = !meta.logs.read().await.is_empty(); let has_term = meta.term > 0; let should_promote = num_members == 1 && ((recovered_from_disk || has_logs || has_term) || bootstrap_if_fresh); debug!( "Plane {} initialization: recovered_from_disk={}, num_members={}, has_logs={}, has_term={}, bootstrap_if_fresh={}, membership={:?}", plane_id.raw(), recovered_from_disk, num_members, has_logs, has_term, bootstrap_if_fresh, match &meta.membership { Membership::Leader(_) => "Leader", Membership::Follower => "Follower", Membership::Candidate => "Candidate", Membership::Offline => "Offline", Membership::Undefined => "Undefined", } ); if should_promote { info!( "Single-node plane {} detected during initialization (term={}, logs={}, members={}). Becoming leader immediately.", plane_id.raw(), meta.term, has_logs, num_members ); let (last_log_id, _) = { let logs = meta.logs.read().await; get_last_log_info!(self, logs) }; drop(sm); ensure_direct_leader_term(meta); self.become_leader_on_plane(leader_flag, meta, last_log_id) .await; info!( "Plane {} successfully transitioned to Leader state", plane_id.raw() ); } } Ok(()) } async fn resolve_plane_runtime( &self, plane_id: PlaneId, allow_create_if_missing: bool, bootstrap_if_fresh: bool, ) -> Result<(Option>, bool), PlaneError> { if plane_id.is_type1() { return Ok((None, false)); } { let planes = self.planes.read().await; if let Some(runtime) = planes.get(&plane_id).cloned() { return Ok((Some(runtime), false)); } } let should_materialize = allow_create_if_missing || self.plane_has_persisted_state(plane_id); if !should_materialize { return Err(PlaneError::PlaneNotFound(plane_id)); } let runtime = Arc::new(RaftPlaneRuntime::new(&self.options, plane_id)?); { let mut meta = runtime.meta.write().await; self.initialize_runtime_meta( plane_id, &runtime.is_leader, &mut meta, bootstrap_if_fresh, ) .await?; } let mut planes = self.planes.write().await; if let Some(existing) = planes.get(&plane_id).cloned() { Ok((Some(existing), false)) } else { planes.insert(plane_id, runtime.clone()); Ok((Some(runtime), true)) } } async fn runtime_for_plane( &self, plane_id: PlaneId, ) -> Result>, PlaneError> { self.resolve_plane_runtime(plane_id, false, false) .await .map(|(runtime, _)| runtime) } fn canonicalize_member_addresses(mut members: Vec) -> Vec { members.sort(); members.dedup(); members } fn validate_plane_bootstrap( &self, bootstrap: PlaneBootstrap, ) -> Result<(PlaneId, Vec), PlaneBootstrapError> { if bootstrap.plane_id.is_type1() { return Err(PlaneBootstrapError::Type1PlaneUnsupported); } let seed_nodes = Self::canonicalize_member_addresses(bootstrap.seed_nodes); if seed_nodes.is_empty() { return Err(PlaneBootstrapError::EmptySeedNodes); } Ok((bootstrap.plane_id, seed_nodes)) } async fn plane_members_from_seed_nodes( &self, seed_nodes: Vec, ) -> Result, PlaneBootstrapError> { let client = RaftClient::new(&seed_nodes, self.options.service_id).await?; let members = Self::canonicalize_member_addresses(client.root_member_addresses().await?); if members.is_empty() { return Err(PlaneBootstrapError::NoType1MembersDiscovered); } if !members.iter().any(|member| member == &self.options.address) { return Err(PlaneBootstrapError::LocalMemberMissing { local_address: self.options.address.clone(), }); } Ok(members) } async fn plane_member_addresses(&self, plane_id: PlaneId) -> Result, PlaneError> { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { let meta = runtime.meta.read().await; let master_sm = meta.state_machine.read().await; return Ok(Self::canonicalize_member_addresses( master_sm .configs .members .values() .map(|member| member.address.clone()) .collect(), )); } let meta = self.meta.read().await; let master_sm = meta.state_machine.read().await; Ok(Self::canonicalize_member_addresses( master_sm .configs .members .values() .map(|member| member.address.clone()) .collect(), )) } async fn add_plane_member_via_log( &self, plane_id: PlaneId, address: String, ) -> Result { let (fn_id, _, data) = new_member_::new(&address).encode(); let entry = LogEntry { id: 0, term: 0, sm_id: CONFIG_SM_ID, fn_id, data, }; match Service::c_command(self, plane_id, entry).await { ClientCmdResponse::Success { data: Ok(data), .. } => { Ok(new_member_::decode_return(&data)) } ClientCmdResponse::Success { data: Err(err), .. } => Err(err.into()), ClientCmdResponse::NotLeader(leader_id) => Err(PlaneBootstrapError::NotLeader { plane_id, leader_id, }), ClientCmdResponse::NotCommitted { .. } => Err(ExecError::NotCommitted.into()), ClientCmdResponse::ShuttingDown => Err(ExecError::ShuttingDown.into()), } } pub async fn ensure_plane( self: &Arc, spec: PlaneSpec, ) -> Result { if let (Some(runtime), _) = self .resolve_plane_runtime(spec.plane_id, true, true) .await? { self.start_managed_runtime(Some(runtime)).await; } Ok(PlaneHandle { service: self.clone(), plane_id: spec.plane_id, }) } pub async fn plane(self: &Arc, plane_id: PlaneId) -> Result { if let (Some(runtime), _) = self.resolve_plane_runtime(plane_id, false, false).await? { self.start_managed_runtime(Some(runtime)).await; } Ok(PlaneHandle { service: self.clone(), plane_id, }) } async fn ensure_plane_membership( self: &Arc, plane_id: PlaneId, requested_members: Vec, ) -> Result { let plane = self.ensure_plane(PlaneSpec { plane_id }).await?; let current_members = self.plane_member_addresses(plane_id).await?; if current_members == requested_members { return Ok(plane); } if current_members.iter().any(|member| { !requested_members .iter() .any(|requested| requested == member) }) { return Err(PlaneBootstrapError::MembershipConflict { plane_id, current_members, requested_members, }); } for member in &requested_members { let added = self .add_plane_member_via_log(plane_id, member.clone()) .await?; if !added && !current_members.iter().any(|current| current == member) { return Err(PlaneBootstrapError::MemberRegistrationRejected { address: member.clone(), }); } } Ok(plane) } /// Materialize a type-2 plane using the current type-1 membership discovered /// from one or more root seed nodes. pub async fn ensure_plane_from_seeds( self: &Arc, bootstrap: PlaneBootstrap, ) -> Result { let (plane_id, seed_nodes) = self.validate_plane_bootstrap(bootstrap)?; let requested_members = self.plane_members_from_seed_nodes(seed_nodes).await?; self.ensure_plane_membership(plane_id, requested_members) .await } /// Returns only the type-2 plane runtimes that are currently materialized on this host. /// /// This is a local runtime-cache view, not an authoritative plane inventory. /// Type-1 is intentionally excluded. pub async fn loaded_type2_planes(&self) -> Vec { let planes = self.planes.read().await; planes.keys().copied().collect() } pub async fn recover_after_register_on_plane( &self, plane_id: PlaneId, ) -> Result<(), PlaneError> { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { let mut meta = runtime.meta.write().await; { let mut master_sm = meta.state_machine.write().await; master_sm.recover_registered_snapshots().await; } info!( "Manual apply on plane {}: applying committed logs (commit_index={}, last_applied={})", plane_id.raw(), meta.commit_index, meta.last_applied ); check_commit(&mut meta).await; return Ok(()); } self.recover_after_register().await; Ok(()) } pub async fn register_state_machine_on_plane( &self, plane_id: PlaneId, state_machine: SubStateMachine, ) -> Result<(), PlaneError> { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { let meta = runtime.meta.read().await; let mut master_sm = meta.state_machine.write().await; master_sm.register(state_machine); return Ok(()); } self.register_state_machine(state_machine).await; Ok(()) } pub(crate) async fn subscriptions_on_plane( &self, plane_id: PlaneId, ) -> Result>, PlaneError> { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { let meta = runtime.meta.read().await; let master_sm = meta.state_machine.read().await; return Ok(master_sm.configs.subscriptions.clone()); } let meta = self.meta.read().await; let master_sm = meta.state_machine.read().await; Ok(master_sm.configs.subscriptions.clone()) } pub async fn cluster_info_on_plane_local( &self, plane_id: PlaneId, ) -> Result { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { let meta = runtime.meta.read().await; let logs = meta.logs.read().await; let sm = meta.state_machine.read().await; let members = sm .members() .iter() .map(|(id, member)| (*id, member.address.clone())) .collect::>(); let last_log = logs.iter().next_back(); let (last_log_id, last_log_term) = match last_log { Some((last_log_id, last_log_item)) => (*last_log_id, last_log_item.term), None => (0, 0), }; return Ok(ClientClusterInfo { members, last_log_id, last_log_term, leader_id: meta.leader_id, }); } Ok(self.cluster_info().await) } pub async fn have_state_machine_on_plane_local( &self, plane_id: PlaneId, id: u64, ) -> Result { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { let meta = runtime.meta.read().await; let sm = meta.state_machine.read().await; return Ok(sm.has_sub(&id)); } let meta = self.meta.read().await; let sm = meta.state_machine.read().await; Ok(sm.has_sub(&id)) } pub async fn is_leader_on_plane(&self, plane_id: PlaneId) -> Result { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { return Ok(runtime.is_leader.load(Relaxed)); } Ok(self.is_leader()) } pub async fn flush_persistence_on_plane(&self, plane_id: PlaneId) -> Result<(), PlaneError> { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { let (storage_opt, commit_index, last_applied) = { let meta = runtime.meta.read().await; (meta.storage.clone(), meta.commit_index, meta.last_applied) }; if let Some(storage_mutex) = storage_opt { let mut storage = storage_mutex.lock().await; let _ = storage.flush_wal().await; let _ = storage .write_commit_progress(commit_index, last_applied) .await; } return Ok(()); } self.flush_persistence().await; Ok(()) } pub async fn shutdown_plane(&self, plane_id: PlaneId) -> Result<(), PlaneError> { if let Some(runtime) = self.runtime_for_plane(plane_id).await? { self.shutdown_managed_runtime(Some(runtime)).await; return Ok(()); } self.shutdown().await; Ok(()) } } fn is_majority(members: u64, granted: u64) -> bool { let required = members / 2 + 1; let majority = granted >= (required); debug!( "Members {} granted {}, is majority: {}", members, granted, majority ); majority } async fn apply_committed_entry<'a>( meta: &'a RwLockWriteGuard<'a, RaftMeta>, entry: &'a LogEntry, ) -> ExecResult { meta.state_machine.write().await.commit_cmd(&entry).await } fn is_leader(meta: &RwLockWriteGuard) -> bool { match meta.membership { Membership::Leader(_) => true, _ => false, } } fn alter_term(meta: &mut RwLockWriteGuard, term: u64) { if meta.term != term { meta.term = term; meta.vote_for = None; } } fn ensure_direct_leader_term(meta: &mut RwLockWriteGuard<'_, RaftMeta>) { if meta.term == 0 { meta.term = 1; } } impl RaftService { pub fn new(opts: Options) -> Arc { let server_address = opts.address.clone(); let server_id = hash_str(&server_address); let mut term = 0; let mut logs = BTreeMap::new(); let mut commit_index = 0; let mut last_applied = 0; let storage_entity = match StorageEntity::new_with_options( &opts, &mut term, &mut commit_index, &mut last_applied, &mut logs, ) { Ok(entity) => entity, Err(e) => { panic!( "Failed to initialize storage entity: {:?}. Cannot proceed without storage.", e ); } }; let master_sm = MasterStateMachine::new_on_plane(opts.service_id, PlaneId::type1()); let (shutdown_tx, _shutdown_rx) = watch::channel(LifecycleState::Running); let server_obj = RaftService { meta: RwLock::new(RaftMeta { term, vote_for: None, timeout: gen_timeout(), last_checked: get_time(), membership: Membership::Undefined, logs: Arc::new(RwLock::new(logs)), state_machine: Arc::new(RwLock::new(master_sm)), commit_index, last_applied, leader_id: 0, storage: storage_entity.map(|e| Arc::new(Mutex::new(e))), last_snapshot_index: 0, last_snapshot_term: 0, lifecycle: LifecycleState::Running, }), planes: RwLock::new(BTreeMap::new()), id: server_id, options: opts, rt: runtime::Builder::new_multi_thread() .enable_all() .thread_name("raft-server") .worker_threads(12) .max_blocking_threads(num_cpus::get()) .event_interval(31) .build() .expect("Failed to build tokio runtime for Raft service"), _is_leader: AtomicBool::new(false), checker_task: TokioMutex::new(None), shutdown_tx, }; Arc::new(server_obj) } /// Load snapshot from disk and recover state machine if snapshot exists async fn load_snapshot_on_startup(&self) -> bool { // IMPORTANT: read the snapshot from disk while holding only a read lock, then // drop the read lock before acquiring the write lock. If we tried to acquire // meta.write() while still inside the `if let … = meta.read().await.storage {` // block the borrow of `storage` would keep the read guard alive and we would // deadlock waiting for our own read guard to be released. let maybe_snapshot = { let meta = self.meta.read().await; match meta.storage { Some(ref storage) => { let storage = storage.lock().await; match storage.read_snapshot().await { Ok(v) => v, Err(e) => { warn!("Failed to load snapshot from disk: {:?}. Starting without snapshot recovery.", e); None } } } None => { debug!("No storage configured, skipping snapshot recovery"); None } } }; // ← read guard dropped here before the write lock below if let Some(snapshot) = maybe_snapshot { info!( "Found snapshot on disk: index={}, term={}. Recovering state machine...", snapshot.last_included_index, snapshot.last_included_term ); let mut meta = self.meta.write().await; // safe: read guard already released // Recover state machine (stores sub-SM snapshots; applied when they register) meta.state_machine .write() .await .recover(snapshot.snapshot.clone()) .await; // Update snapshot metadata meta.last_snapshot_index = snapshot.last_included_index; meta.last_snapshot_term = snapshot.last_included_term; // Restore applies from the snapshot boundary every time so any // committed WAL entries after the snapshot are replayed exactly once. meta.last_applied = snapshot.last_included_index; if meta.commit_index < snapshot.last_included_index { meta.commit_index = snapshot.last_included_index; } // Compact logs: remove entries already covered by the snapshot { let mut logs = meta.logs.write().await; let before_count = logs.len(); logs.retain(|&id, _| id > snapshot.last_included_index); let after_count = logs.len(); info!( "Compacted logs on startup: removed {} logs, {} remaining", before_count - after_count, after_count ); } info!("Snapshot recovery completed successfully"); true } else { debug!("No snapshot found on disk, starting fresh"); false } } pub async fn start(server: &Arc, recover_registered: bool) -> bool { info!( "Waiting for raft server to be initialized on plane {}", PlaneId::type1().raw() ); { let mut meta = server.meta.write().await; if server .initialize_runtime_meta(PlaneId::type1(), &server._is_leader, &mut meta, false) .await .is_err() { return false; } } if recover_registered { server.recover_after_register().await; } server.start_managed_runtime(None).await; return true; } /// New server without recovery from registered state machine pub async fn new_server(opts: Options) -> (bool, Arc, Arc) { let address = opts.address.clone(); let svr_id = opts.service_id; let service = RaftService::new(opts); let server = Server::new(&address); Server::listen_and_resume(&server).await; server.register_service_with_id(svr_id, &service).await; (RaftService::start(&service, false).await, service, server) } pub async fn probe_and_join(&self, servers: &Vec) -> Result { debug!( "Probing and try to join servers for plane {}: {:?}", self.plane_id().raw(), servers ); let is_first_node = !RaftClient::probe_servers(servers, &self.options.address, self.options.service_id) .await; if is_first_node { debug!( "There is no live node in the server list for plane {}, will bootstrap", self.plane_id().raw() ); self.bootstrap().await; Ok(false) } else { debug!( "There are some live nodes for plane {}, will join them", self.plane_id().raw() ); self.join(servers).await } } pub async fn bootstrap(&self) { let mut meta = self.write_meta().await; let (last_log_id, _) = { let logs = meta.logs.read().await; get_last_log_info!(self, logs) }; self.become_leader(&mut meta, last_log_id).await; } pub async fn conservative_bootstrap(&self, servers: &Vec) { let meta = self.meta.read().await; debug!( "Conservative bootstrap for plane {}, checking storage", self.plane_id().raw() ); if let Some(storage) = &meta.storage { debug!( "There is storage for plane {}, checking last term", self.plane_id().raw() ); if storage.lock().await.last_term > 0 { debug!( "Plane {} has logged term, will probe and join or bootstrap", self.plane_id().raw() ); drop(meta); if let Err(e) = self.probe_and_join(servers).await { error!( "Failed to probe and join cluster during conservative bootstrap on plane {}: {:?}", self.plane_id().raw(), e ); } } else { debug!( "Log is empty for plane {}, bootstrap", self.plane_id().raw() ); drop(meta); self.bootstrap().await; } } else { debug!( "No storage for plane {}, will probe and join or bootstrap", self.plane_id().raw() ); drop(meta); if let Err(e) = self.probe_and_join(servers).await { error!( "Failed to probe and join cluster during conservative bootstrap on plane {}: {:?}", self.plane_id().raw(), e ); } } } pub async fn join(&self, servers: &Vec) -> Result { debug!( "Trying to join plane {} cluster with id {}", self.plane_id().raw(), self.id ); let client = RaftClient::new(servers, self.options.service_id).await; if let Ok(client) = client { debug!( "Executing in SM to create new member on plane {}: {}, {}", self.plane_id().raw(), &self.options.address, self.id ); let result = client.add_root_member(&self.options.address).await; debug!( "Getting member address for plane {}: {}", self.plane_id().raw(), self.id ); let members = client.root_member_addresses().await; debug!( "Updating local meta for plane {} by acquiring lock: {}", self.plane_id().raw(), self.id ); let mut meta = self.write_meta().await; debug!( "Local meta lock acquired for plane {}: {}", self.plane_id().raw(), self.id ); if let Ok(members) = members { debug!( "We have following members for plane {} node {}: {:?}", self.plane_id().raw(), self.id, members ); for member in members { meta.state_machine .write() .await .configs .new_member(member) .await; } } debug!( "Become follower because of join on plane {}: {}", self.plane_id().raw(), self.id ); self.become_follower(&mut meta, 0, client.leader_id()); debug!( "Resetting last checked for join on plane {}: {}", self.plane_id().raw(), self.id ); self.reset_last_checked(&mut meta); match &result { Ok(joined) => debug!( "Completed join for plane {} node {}, result {}", self.plane_id().raw(), self.id, joined ), Err(e) => debug!( "Join failed for plane {} node {}, error: {:?}", self.plane_id().raw(), self.id, e ), } result } else { Err(ExecError::CannotConstructClient) } } pub async fn leave(&self) -> bool { let members = self.cluster_info().await.members; let servers: Vec<_> = members .iter() .map(|&(_, ref address)| address.clone()) .collect(); debug!( "Leaving from plane {} cluster, server id {} with {} members {:?}", self.plane_id().raw(), self.id, servers.len(), servers ); if let Ok(client) = RaftClient::new(&servers, self.options.service_id).await { debug!( "Temporary client for plane {} leaving, leader: {}. Sending removal message.", self.plane_id().raw(), client.leader_id() ); match client.remove_root_member(&self.options.address).await { Ok(_) => info!( "Successfully removed member {} from plane {} cluster", self.options.address, self.plane_id().raw() ), Err(e) => { error!( "Failed to remove member {} from plane {} cluster: {:?}", self.options.address, self.plane_id().raw(), e ); return false; } } } else { error!( "Cannot obtain temporary client for leaving plane {}", self.plane_id().raw() ); return false; } let mut meta = self.write_meta().await; if is_leader(&meta) { info!( "Leader step down on plane {}: {}", self.plane_id().raw(), self.options.address ); if !self.send_followers_heartbeat(&mut meta, None, true).await { error!("Leader cannot step down on plane {}", self.plane_id().raw()); return false; } info!( "Step down heartbeat sent to followers on plane {}", self.plane_id().raw() ); let mut reelected = false; for (_id, addr) in members { if addr != self.options.address { info!( "Calling reelect on plane {} to {}", self.plane_id().raw(), addr ); match rpc::DEFAULT_CLIENT_POOL.get(&addr).await { Ok(client) => { let service = AsyncServiceClient::new(&client); match service.reelect(PlaneId::type1()).await { Ok(true) => { info!( "New leader has been elected on plane {}", self.plane_id().raw() ); reelected = true; break; // Only need one successful reelection } Ok(false) => { warn!( "Server {} cannot be elected on plane {}", addr, self.plane_id().raw() ); } Err(e) => { error!( "Server {} cannot be elected on plane {} due to comm error {:?}", addr, self.plane_id().raw(), e ); } } } Err(e) => { error!( "Cannot call reelect on plane {} to {}, error {:?}", self.plane_id().raw(), addr, e ) } } } } if !reelected { warn!( "No new leader has been elected on plane {}", self.plane_id().raw() ); } } meta.membership = Membership::Offline; let mut sm = meta.state_machine.write().await; sm.clear_subs(); return true; } pub async fn cluster_info(&self) -> ClientClusterInfo { let meta = self.meta.read().await; let logs = meta.logs.read().await; let sm = &meta.state_machine.read().await; let sm_members = sm.members(); let mut members = Vec::new(); for (id, member) in sm_members.iter() { members.push((*id, member.address.clone())) } let (last_log_id, last_log_term) = get_last_log_info!(self, logs); ClientClusterInfo { members, last_log_id, last_log_term, leader_id: meta.leader_id, } } pub async fn num_members(&self) -> usize { let meta = self.meta.read().await; let member_sm = meta.state_machine.read().await; let ref members = member_sm.configs.members; members.len() } pub async fn num_logs(&self) -> usize { let meta = self.meta.read().await; let logs = meta.logs.read().await; logs.len() } pub async fn last_log_id(&self) -> Option { let meta = self.meta.read().await; let logs = meta.logs.read().await; logs.keys().cloned().last() } pub async fn leader_id(&self) -> u64 { let meta = self.meta.read().await; meta.leader_id } pub async fn is_leader_for_real(&self) -> bool { let meta = self.meta.read().await; match meta.membership { Membership::Leader(_) => true, _ => false, } } pub fn is_leader(&self) -> bool { self._is_leader.load(Relaxed) } pub fn get_server_id(&self) -> u64 { self.id } /// Force a write-back and fsync of WAL and persist current commit progress. pub async fn flush_persistence(&self) { let (storage_opt, commit_index, last_applied) = { let meta = self.meta.read().await; (meta.storage.clone(), meta.commit_index, meta.last_applied) }; if let Some(storage_mutex) = storage_opt { let mut storage = storage_mutex.lock().await; let _ = storage.flush_wal().await; let _ = storage .write_commit_progress(commit_index, last_applied) .await; info!( "Flushed WAL and wrote commit progress: commit_index={}, last_applied={}", commit_index, last_applied ); } } pub async fn shutdown(&self) { info!( "Shutting down RaftService on plane {} at {}", self.plane_id().raw(), self.options.address ); let plane_runtimes = { let planes = self.planes.read().await; planes.values().cloned().collect::>() }; for runtime in plane_runtimes { self.shutdown_managed_runtime(Some(runtime)).await; } self.shutdown_managed_runtime(None).await; info!( "RaftService shutdown complete for plane {}", self.plane_id().raw() ); } pub async fn register_state_machine(&self, state_machine: SubStateMachine) { let meta = self.meta.read().await; let mut master_sm = meta.state_machine.write().await; master_sm.register(state_machine); } fn switch_membership(&self, meta: &mut RwLockWriteGuard, membership: Membership) { self.reset_last_checked(meta); meta.membership = membership; } fn get_log_info_(&self, log: Option<(&u64, &LogEntry)>) -> (u64, u64) { match log { Some((last_log_id, last_log_item)) => (*last_log_id, last_log_item.term), None => (0, 0), } } fn insert_leader_follower_meta( &self, leader_meta: &mut RwLockWriteGuard, last_log_id: u64, member_id: u64, ) { // the leader itself will not be consider as a follower when sending heartbeat if member_id == self.id { return; } leader_meta.followers.entry(member_id).or_insert_with(|| { Arc::new(Mutex::new(FollowerStatus { next_index: last_log_id + 1, match_index: 0, })) }); } fn reload_leader_meta( &self, member_map: &HashMap, leader_meta: &mut RwLockWriteGuard, last_log_id: u64, ) { for member in member_map.values() { self.insert_leader_follower_meta(leader_meta, last_log_id, member.id); } } async fn write_meta<'a>(&'a self) -> RwLockWriteGuard<'a, RaftMeta> { self.meta.write().await } pub async fn read_meta(&self) -> RwLockReadGuard<'_, RaftMeta> { self.meta.read().await } async fn become_candidate_on_plane<'a>( &'a self, plane_id: PlaneId, leader_flag: &'a AtomicBool, meta: &'a mut RwLockWriteGuard<'_, RaftMeta>, ) { let server_id = self.id; debug!( "Plane {} server {} become candidate", plane_id.raw(), server_id ); self.reset_last_checked(meta); leader_flag.store(false, Relaxed); let term = meta.term; alter_term(meta, term + 1); meta.vote_for = Some(server_id); self.switch_membership(meta, Membership::Candidate); let term = meta.term; let (last_log_id, last_log_term) = { let logs = meta.logs.read().await; get_last_log_info!(self, logs) }; let (mut members_vote_response_stream, num_members) = { let members: Vec<_> = { let member_sm = meta.state_machine.read().await; let ref members = member_sm.configs.members; members .values() .map(|member| (member.rpc.clone(), member.id)) .collect() }; let len = members.len(); let futs: FuturesUnordered<_> = members .into_iter() .map(|(rpc, member_id)| { let vote_fut = async move { if member_id == server_id { debug!("Plane {} member {} vote for itself", plane_id.raw(), member_id); RequestVoteResponse::Granted } else { if let Ok(((remote_term, remote_leader_id), vote_granted)) = rpc .request_vote(plane_id, term, server_id, last_log_id, last_log_term) .await { if vote_granted { debug!( "Plane {} member {} received one vote from {}", plane_id.raw(), server_id, member_id ); RequestVoteResponse::Granted } else if remote_term > term { debug!( "Plane {} member {} is term out, by {}. Now leader is {}, term {}", plane_id.raw(), server_id, member_id, remote_leader_id, remote_term ); RequestVoteResponse::TermOut(remote_term, remote_leader_id) } else { debug!( "Plane {} member {} did not get vote from {}", plane_id.raw(), server_id, member_id ); RequestVoteResponse::NotGranted } } else { debug!( "Plane {} member {} request vote failed from {}", plane_id.raw(), server_id, member_id ); RequestVoteResponse::NotGranted // default for request failure } } }; timeout(Duration::from_millis(1500), self.rt.spawn(vote_fut)) }) .collect(); (futs, len) }; let mut granted = 0; while let Some(vote_response) = members_vote_response_stream.next().await { if let Ok(res) = vote_response { if meta.term != term { break; } match res { Ok(RequestVoteResponse::TermOut(remote_term, remote_leader_id)) => { self.become_follower_on_plane( leader_flag, meta, remote_term, remote_leader_id, ); break; } Ok(RequestVoteResponse::Granted) => { granted += 1; debug!( "Plane {} member {} received {} votes for now", plane_id.raw(), server_id, granted ); if is_majority(num_members as u64, granted) { debug!( "Plane {} member {} become leader after receiving majority votes", plane_id.raw(), server_id ); self.become_leader_on_plane(leader_flag, meta, last_log_id) .await; break; } } _ => {} } } } debug!( "Plane {} granted votes for {}: {}/{}", plane_id.raw(), self.id, granted, num_members ); return; } async fn become_candidate<'a>(&'a self, meta: &'a mut RwLockWriteGuard<'_, RaftMeta>) { self.become_candidate_on_plane(PlaneId::type1(), &self._is_leader, meta) .await; } fn become_follower_on_plane( &self, leader_flag: &AtomicBool, meta: &mut RwLockWriteGuard, term: u64, leader_id: u64, ) { alter_term(meta, term); meta.leader_id = leader_id; self.switch_membership(meta, Membership::Follower); leader_flag.store(false, Relaxed); } fn become_follower(&self, meta: &mut RwLockWriteGuard, term: u64, leader_id: u64) { self.become_follower_on_plane(&self._is_leader, meta, term, leader_id); } async fn become_leader_on_plane( &self, leader_flag: &AtomicBool, meta: &mut RwLockWriteGuard<'_, RaftMeta>, last_log_id: u64, ) { debug!( "Plane {} server {} become leader, term {}", self.plane_id().raw(), self.id, meta.term ); let leader_meta = RwLock::new(LeaderMeta::new()); { let mut guard = leader_meta.write().await; let member_sm = meta.state_machine.read().await; let ref members = member_sm.configs.members; self.reload_leader_meta(members, &mut guard, last_log_id); guard.last_updated = get_time(); } meta.leader_id = self.id; self.switch_membership(meta, Membership::Leader(leader_meta)); leader_flag.store(true, Relaxed); } async fn become_leader(&self, meta: &mut RwLockWriteGuard<'_, RaftMeta>, last_log_id: u64) { self.become_leader_on_plane(&self._is_leader, meta, last_log_id) .await; } async fn send_followers_heartbeat_on_plane<'a>( &self, plane_id: PlaneId, meta: &mut RwLockWriteGuard<'a, RaftMeta>, log_id: Option, no_delay: bool, ) -> bool { let now = get_time(); if meta.last_checked + HEARTBEAT_MS > now { if no_delay { debug!("Issuing delayed heartbeat on plane {}", plane_id.raw()); } else { debug!("Block throttled heartbeat on plane {}", plane_id.raw()); return false; } } trace!("Sending followers heartbeat on plane {}", plane_id.raw()); if let Membership::Leader(ref leader_meta) = meta.membership { let leader_id = meta.leader_id; debug_assert_eq!(self.id, leader_id); let mut heartbeat_futs = FuturesUnordered::new(); // Send out heartbeats { let leader_meta = leader_meta.read().await; let member_sm = meta.state_machine.read().await; let ref members = member_sm.configs.members; for member in members.values() { let member_id = member.id; if member_id == self.id { continue; } let follower = if let Some(follower) = leader_meta.followers.get(&member_id) { follower } else { debug!( "Plane {} follower not found, {}, {}", plane_id.raw(), member_id, leader_meta.followers.len() ); //TODO: remove after debug continue; }; // get a send follower task without await let hb_fut = Self::send_follower_heartbeat( plane_id, meta.commit_index, meta.term, meta.leader_id, meta.last_applied, meta.last_snapshot_index, meta.last_snapshot_term, meta.state_machine.clone(), meta.logs.clone(), follower.clone(), member.rpc.clone(), member_id, ); let heartbeat_fut = async move { (member_id, hb_fut.await) }.boxed(); let task_spawned = self.rt.spawn(heartbeat_fut); let timeout_interval = 1000; let task_with_timeout = timeout(Duration::from_millis(timeout_interval), task_spawned); heartbeat_futs.push(task_with_timeout); } } let followers = heartbeat_futs.len(); if followers <= 0 { // Early quit if no followers return true; } if let (Some(log_id), &Membership::Leader(ref leader_meta)) = (log_id, &meta.membership) { let mut updated_followers = 0; let mut higher_term = None; { let mut leader_meta = leader_meta.write().await; while let Some(heartbeat_res) = heartbeat_futs.next().await { match heartbeat_res { Ok(Ok((member_id, heartbeat_result))) => { match heartbeat_result { HeartbeatReplicationResult::Matched(last_matched_id) => { debug!( "Heartbeat response on plane {} from {} is {:?}", plane_id.raw(), member_id, last_matched_id ); if last_matched_id >= log_id { updated_followers += 1; if is_majority(followers as u64, updated_followers) { return true; } } } HeartbeatReplicationResult::TermOut { term: remote_term, leader_id: remote_leader_id, } => { higher_term = Some((remote_term, remote_leader_id)); break; } } } Ok(Err(err)) => { warn!( "Heartbeat task failed on plane {} while replicating log {}: {:?}", plane_id.raw(), log_id, err ); } Err(_) => { warn!( "Heartbeat task timed out on plane {} while replicating log {}", plane_id.raw(), log_id ); } } } leader_meta.last_updated = get_time(); } if let Some((remote_term, remote_leader_id)) = higher_term { warn!( "Plane {} stepping down after follower reported higher term {} (leader_id={})", plane_id.raw(), remote_term, remote_leader_id ); alter_term(meta, remote_term); meta.leader_id = remote_leader_id; self.switch_membership(meta, Membership::Follower); return false; } debug!( "Plane {} replicated log {} to {} of {} followers", plane_id.raw(), log_id, updated_followers, followers ); // is_majority(members, updated_followers) false } else { !log_id.is_some() } } else { unreachable!() } } async fn send_followers_heartbeat<'a>( &'a self, meta: &mut RwLockWriteGuard<'a, RaftMeta>, log_id: Option, no_delay: bool, ) -> bool { self.send_followers_heartbeat_on_plane(PlaneId::type1(), meta, log_id, no_delay) .await } async fn send_follower_heartbeat( plane_id: PlaneId, commit_index: u64, term: u64, leader_id: u64, last_applied: u64, last_snapshot_index: u64, last_snapshot_term: u64, master_sm: Arc>, logs: Arc>, follower: Arc>, rpc: Arc, member_id: u64, ) -> HeartbeatReplicationResult { // let commit_index = meta.commit_index; // let term = meta.term; // let leader_id = meta.leader_id; // let meta_term = meta.term; // let meta_last_applied = meta.last_applied; // let master_sm = &meta.state_machine; // let logs = &meta.logs; trace!( "Sending follower heartbeat on plane {} to {}", plane_id.raw(), member_id ); let mut follower = follower.lock().await; let logs = logs.read().await; let mut is_retry = false; loop { let entries: Option = { // extract logs to send to follower let list: LogEntries = logs .range((Included(&follower.next_index), Unbounded)) .map(|(_, entry)| entry.clone()) .collect(); //TODO: avoid clone entry if list.is_empty() { None } else { Some(list) } }; if is_retry && entries.is_none() { // break when retry and there is no entry trace!( "Stop retry on plane {} when entry is empty, {}, member id {}", plane_id.raw(), follower.next_index, member_id ); return HeartbeatReplicationResult::Matched(follower.match_index); } let last_entries_id = match &entries { // get last entry id &Some(ref entries) => { // Safe: entries is Some, so it's not empty (checked above) entries.iter().last().map(|entry| entry.id) } &None => None, }; // Check if follower needs logs that have been compacted (Issue 5) // If so, send snapshot instead if follower.next_index <= last_snapshot_index { debug!( "Follower {} on plane {} needs compacted logs (next_index: {} <= snapshot_index: {}), sending snapshot", member_id, plane_id.raw(), follower.next_index, last_snapshot_index ); let master_sm = master_sm.read().await; let snapshot = master_sm.snapshot(); // Use the correct last_included_term from snapshot metadata (Issue 2) if let Ok(_) = rpc .install_snapshot( plane_id, term, leader_id, last_snapshot_index, last_snapshot_term, snapshot, ) .await { follower.next_index = last_snapshot_index + 1; follower.match_index = last_snapshot_index; } return HeartbeatReplicationResult::Matched(follower.match_index); } let (follower_last_log_id, follower_last_log_term) = { // extract follower last log info // assumed log ids are sequence of integers let follower_last_log_id = if follower.next_index == 0 { 0 } else { follower.next_index - 1 }; if follower_last_log_id == 0 || logs.is_empty() { (0, 0) // 0 represents there is no logs in the leader } else { // detect cleaned logs (shouldn't happen now with snapshot check above) let first_log_id = match logs.iter().next() { Some((first_log_id, _)) => *first_log_id, None => { error!("Logs map is not empty on plane {} but iter().next() returned None - this should not happen", plane_id.raw()); return HeartbeatReplicationResult::Matched(follower.match_index); } }; if first_log_id > follower_last_log_id { debug!( "Taking snapshot on plane {} for follower {} (first_log: {} > follower_last: {})", plane_id.raw(), member_id, first_log_id, follower_last_log_id ); let master_sm = master_sm.read().await; let snapshot = master_sm.snapshot(); // Use last_applied as snapshot index, get term from the log at that index let snapshot_term = logs .get(&last_applied) .map(|e| e.term) .unwrap_or(last_snapshot_term); if let Ok(_) = rpc .install_snapshot( plane_id, term, leader_id, last_applied, snapshot_term, snapshot, ) .await { follower.next_index = last_applied + 1; follower.match_index = last_applied; } return HeartbeatReplicationResult::Matched(follower.match_index); } let follower_last_entry = logs.get(&follower_last_log_id); match follower_last_entry { Some(entry) => (entry.id, entry.term), None => { panic!("Cannot find old logs for follower, first_id: {}, follower_last: {}", first_log_id, follower_last_log_id); } } } }; let append_result = rpc .append_entries( plane_id, term, leader_id, follower_last_log_id, follower_last_log_term, &entries, commit_index, ) .await; match append_result { Ok((follower_term, result)) => match result { AppendEntriesResult::Ok => { trace!( "Log updated on plane {} to follower {}", plane_id.raw(), member_id ); if let Some(last_entries_id) = last_entries_id { follower.next_index = last_entries_id + 1; follower.match_index = last_entries_id; } } AppendEntriesResult::LogMismatch => { debug!( "Log mismatch on plane {} in follower {}, index {}", plane_id.raw(), member_id, follower.next_index ); if follower.next_index > 0 { follower.next_index -= 1; } else { debug!("Log mismatching index is zero on plane {}", plane_id.raw()); } } AppendEntriesResult::TermOut(actual_leader_id) => { debug!( "Follower {} rejected append on plane {} because follower_term={} leader_term={} actual_leader_id={} while leader {} was replicating from next_index {}", member_id, plane_id.raw(), follower_term, term, actual_leader_id, leader_id, follower.next_index ); return HeartbeatReplicationResult::TermOut { term: follower_term, leader_id: actual_leader_id, }; } }, Err(err) => { debug!( "Follower {} RPC append failed on plane {} from next_index {}: {:?}", member_id, plane_id.raw(), follower.next_index, err ); break; } // retry will happened in next heartbeat } is_retry = true; } HeartbeatReplicationResult::Matched(follower.match_index) } //check term number, return reject = false if server term is stale fn check_term_on_plane( &self, leader_flag: &AtomicBool, meta: &mut RwLockWriteGuard, remote_term: u64, leader_id: u64, ) -> bool { if remote_term > meta.term { self.become_follower_on_plane(leader_flag, meta, remote_term, leader_id) } else if remote_term < meta.term { return false; } return true; } fn reset_last_checked(&self, meta: &mut RwLockWriteGuard) { trace!( "Reset last checked. Elapsed: {}, id: {}, term: {}", get_time() - meta.last_checked, self.id, meta.term ); meta.last_checked = get_time(); meta.timeout = gen_timeout(); } async fn handle_append_entries_on_meta<'a>( &'a self, leader_flag: &'a AtomicBool, mut meta: RwLockWriteGuard<'a, RaftMeta>, term: u64, leader_id: u64, prev_log_id: u64, prev_log_term: u64, entries: &'a Option, leader_commit: u64, ) -> (u64, AppendEntriesResult) { self.reset_last_checked(&mut meta); let term_ok = self.check_term_on_plane(leader_flag, &mut meta, term, leader_id); let result = if term_ok { if let Membership::Candidate = meta.membership { debug!("SWITCH FROM CANDIDATE BACK TO FOLLOWER {}", self.id); self.become_follower_on_plane(leader_flag, &mut meta, term, leader_id); } if prev_log_id > 0 { check_commit(&mut meta).await; let mut logs = meta.logs.write().await; let contains_prev_log = logs.contains_key(&prev_log_id); let log_mismatch; if contains_prev_log { let entry = match logs.get(&prev_log_id) { Some(entry) => entry, None => { error!("Log key {} exists in contains_key but not in get() - data inconsistency", prev_log_id); return (meta.term, AppendEntriesResult::LogMismatch); } }; log_mismatch = entry.term != prev_log_term; } else { return (meta.term, AppendEntriesResult::LogMismatch); } if log_mismatch { let ids_to_del: Vec = logs .range((Included(prev_log_id), Unbounded)) .map(|(id, _)| *id) .collect(); for id in ids_to_del { logs.remove(&id); } return (meta.term, AppendEntriesResult::LogMismatch); } } let mut last_new_entry = std::u64::MAX; { let mut logs = meta.logs.write().await; if let Some(ref entries) = entries { for entry in entries { let entry_id = entry.id; logs.entry(entry_id).or_insert(entry.clone()); last_new_entry = max(last_new_entry, entry_id); } } else if !logs.is_empty() { last_new_entry = match logs.values().last() { Some(entry) => entry.id, None => { error!("Logs map is not empty but values().last() returned None - this should not happen"); std::u64::MAX } }; } if let Err(e) = self.logs_post_processing(&meta, logs).await { error!("Failed to persist logs during append_entries: {:?}", e); } } if leader_commit > meta.commit_index { meta.commit_index = min(leader_commit, last_new_entry); check_commit(&mut meta).await; } (meta.term, AppendEntriesResult::Ok) } else { (meta.term, AppendEntriesResult::TermOut(meta.leader_id)) }; self.reset_last_checked(&mut meta); result } async fn handle_request_vote_on_meta<'a>( &'a self, mut meta: RwLockWriteGuard<'a, RaftMeta>, term: u64, candidate_id: u64, last_log_id: u64, last_log_term: u64, ) -> ((u64, u64), bool) { let vote_for = meta.vote_for; let mut vote_granted = false; if term > meta.term { check_commit(&mut meta).await; let logs = meta.logs.read().await; let conf_sm = &meta.state_machine.read().await.configs; let candidate_valid = conf_sm.member_existed(candidate_id); let can_vote = vote_for.map_or(true, |voted_for| voted_for == candidate_id); if can_vote && candidate_valid { let (last_id, last_term) = get_last_log_info!(self, logs); if last_log_id >= last_id && last_log_term >= last_term { vote_granted = true; } } } if vote_granted { meta.vote_for = Some(candidate_id); } ((meta.term, meta.leader_id), vote_granted) } async fn handle_install_snapshot_on_meta<'a>( &'a self, leader_flag: &'a AtomicBool, mut meta: RwLockWriteGuard<'a, RaftMeta>, term: u64, leader_id: u64, last_included_index: u64, last_included_term: u64, data: Vec, ) -> u64 { let term_ok = self.check_term_on_plane(leader_flag, &mut meta, term, leader_id); if term_ok { check_commit(&mut meta).await; } meta.state_machine.write().await.recover(data.clone()).await; meta.last_snapshot_index = last_included_index; meta.last_snapshot_term = last_included_term; meta.commit_index = last_included_index; meta.last_applied = last_included_index; { let mut logs = meta.logs.write().await; logs.retain(|&id, _| id > last_included_index); } if let Some(ref storage) = meta.storage { let snapshot_entity = SnapshotEntity { last_included_index, last_included_term, snapshot: data, }; let mut storage = storage.lock().await; if let Err(e) = storage.write_snapshot(&snapshot_entity).await { error!("Failed to persist snapshot to disk: {:?}", e); } } self.reset_last_checked(&mut meta); meta.term } async fn handle_client_command_on_meta<'a>( &'a self, plane_id: PlaneId, leader_flag: &'a AtomicBool, mut meta: RwLockWriteGuard<'a, RaftMeta>, mut entry: LogEntry, ) -> ClientCmdResponse { if Self::lifecycle_is_stopping(meta.lifecycle) { return ClientCmdResponse::ShuttingDown; } if !is_leader(&meta) { let member_count = { let member_sm = meta.state_machine.read().await; member_sm.configs.members.len() }; if member_count == 1 && meta.leader_id == self.id { let last_log_id = { let logs = meta.logs.read().await; let (last_log_id, _last_log_term) = get_last_log_info!(self, logs); last_log_id }; ensure_direct_leader_term(&mut meta); self.become_leader_on_plane(leader_flag, &mut meta, last_log_id) .await; } } if !is_leader(&meta) { return if meta.leader_id == self.id { ClientCmdResponse::NotLeader(0) } else { ClientCmdResponse::NotLeader(meta.leader_id) }; } let existing_pending_entry = if entry.id > meta.commit_index { let logs = meta.logs.read().await; match logs.get(&entry.id) { Some(existing) if existing.term == entry.term && existing.sm_id == entry.sm_id && existing.fn_id == entry.fn_id && existing.data == entry.data => { Some((existing.id, existing.term)) } _ => None, } } else { None }; let (new_log_id, new_log_term) = if let Some((existing_log_id, existing_log_term)) = existing_pending_entry { (existing_log_id, existing_log_term) } else { self.leader_append_log(&meta, &mut entry).await }; entry.id = new_log_id; entry.term = new_log_term; let data = match entry.sm_id { CONFIG_SM_ID => Some( self.try_sync_config_to_followers_on_plane(plane_id, meta, &entry, new_log_id) .await, ), _ => { self.try_sync_log_to_followers_on_plane(plane_id, meta, &entry, new_log_id) .await } }; if let Some(data) = data { ClientCmdResponse::Success { data, last_log_id: new_log_id, last_log_term: new_log_term, } } else { ClientCmdResponse::NotCommitted { last_log_id: new_log_id, last_log_term: new_log_term, } } } async fn handle_client_query_on_meta<'a>( &'a self, meta: RwLockReadGuard<'a, RaftMeta>, entry: &'a LogEntry, ) -> ClientQryResponse { let logs = meta.logs.read().await; let (last_log_id, last_log_term) = get_last_log_info!(self, logs); if entry.term > last_log_term || entry.id > last_log_id { ClientQryResponse::LeftBehind { last_log_term, last_log_id, } } else { let qry_res = meta.state_machine.read().await.exec_qry(entry).await; ClientQryResponse::Success { data: qry_res, last_log_id, last_log_term, } } } async fn leader_append_log<'a>( &'a self, meta: &'a RwLockWriteGuard<'a, RaftMeta>, entry: &mut LogEntry, ) -> (u64, u64) { let mut logs = meta.logs.write().await; let (last_log_id, _last_log_term) = get_last_log_info!(self, logs); let new_log_id = last_log_id + 1; let new_log_term = meta.term; entry.term = new_log_term; entry.id = new_log_id; logs.insert(entry.id, entry.clone()); // Strict write-ahead: persist to WAL before any application can observe/commit if let Err(e) = self.logs_post_processing(meta, logs).await { error!( "Failed to persist log entry {} to storage on plane {}: {:?}", new_log_id, self.plane_id().raw(), e ); // Note: We still return the log ID/term even if persistence failed // The caller should handle this appropriately } (new_log_id, new_log_term) } async fn logs_post_processing<'a>( &'a self, meta: &'a RwLockWriteGuard<'a, RaftMeta>, logs: RwLockWriteGuard<'a, LogsMap>, ) -> io::Result<()> { if let Some(storage_mutex) = &meta.storage { let mut storage = storage_mutex.lock().await; storage.post_processing(meta, logs).await?; } Ok(()) } async fn try_sync_log_to_followers_on_plane<'a>( &'a self, plane_id: PlaneId, mut meta: RwLockWriteGuard<'a, RaftMeta>, entry: &LogEntry, new_log_id: u64, ) -> Option { debug!("Sync logs to followers on plane {}", plane_id.raw()); if self .send_followers_heartbeat_on_plane(plane_id, &mut meta, Some(new_log_id), true) .await { // Strict write-ahead: ensure persistence reflects this index before applying if let Some(storage_mutex) = &meta.storage { let mut storage = storage_mutex.lock().await; info!( "Strict WA plane {}: flushing WAL before commit at log_id={} (term={})", plane_id.raw(), new_log_id, entry.term ); let _ = storage.flush_wal().await; info!( "Strict WA plane {}: WAL fsync completed before commit at log_id={}", plane_id.raw(), new_log_id ); } meta.commit_index = new_log_id; info!( "Strict WA plane {}: applying entry at log_id={} (commit_index={})", plane_id.raw(), new_log_id, meta.commit_index ); let result = apply_committed_entry(&mut meta, entry).await; info!( "Strict WA plane {}: apply completed at log_id={} (result={:?})", plane_id.raw(), new_log_id, result ); // Mark applied and persist commit progress atomically after apply meta.last_applied = new_log_id; if let Some(storage_mutex) = &meta.storage { let mut storage = storage_mutex.lock().await; info!( "Strict WA plane {}: writing commit progress (commit_index={}, last_applied={})", plane_id.raw(), meta.commit_index, meta.last_applied ); let _ = storage .write_commit_progress(meta.commit_index, meta.last_applied) .await; info!( "Strict WA plane {}: commit progress persisted", plane_id.raw() ); } // Check if we should take a snapshot after committing let num_logs = meta.logs.read().await.len(); if self.should_take_snapshot(&meta, num_logs) { self.take_snapshot(&mut meta).await; } Some(result) } else { None } } async fn try_sync_config_to_followers_on_plane<'a>( &'a self, plane_id: PlaneId, mut meta: RwLockWriteGuard<'a, RaftMeta>, entry: &LogEntry, new_log_id: u64, ) -> ExecResult { // this will force followers to commit the changes debug!("Sync config to followers on plane {}", plane_id.raw()); meta.commit_index = new_log_id; let data = apply_committed_entry(&meta, &entry).await; if let Membership::Leader(ref leader_meta) = meta.membership { let mut leader_meta = leader_meta.write().await; let member_sm = meta.state_machine.read().await; let ref members = member_sm.configs.members; self.reload_leader_meta(members, &mut leader_meta, new_log_id); } self.send_followers_heartbeat_on_plane(plane_id, &mut meta, Some(new_log_id), true) .await; data } /// Check if we should take a snapshot based on configuration thresholds fn should_take_snapshot( &self, meta: &RwLockWriteGuard<'_, RaftMeta>, _num_logs: usize, ) -> bool { // Only leaders should automatically create snapshots if !is_leader(meta) { return false; } // Check if storage is configured for snapshots if meta.storage.is_none() { return false; } // Get snapshot threshold from storage options if let Storage::DISK(ref opts) = self.options.storage { let logs_since_snapshot = if meta.last_snapshot_index > 0 { meta.last_applied.saturating_sub(meta.last_snapshot_index) } else { meta.last_applied }; // Trigger snapshot if we've applied enough logs since last snapshot if logs_since_snapshot >= opts.snapshot_log_threshold { debug!( "Snapshot threshold reached on plane {}: {} logs since last snapshot (threshold: {})", self.plane_id().raw(), logs_since_snapshot, opts.snapshot_log_threshold ); return true; } } false } /// Create and persist a snapshot async fn take_snapshot(&self, meta: &mut RwLockWriteGuard<'_, RaftMeta>) { info!( "Taking snapshot on plane {} at index={}, term={}", self.plane_id().raw(), meta.last_applied, meta.term ); // Generate snapshot from state machine (Master SM decides which subs are recoverable) let snapshot_data = { let sm = meta.state_machine.read().await; sm.snapshot() }; // Get the term of the log at last_applied index let last_included_term = { let logs = meta.logs.read().await; logs.get(&meta.last_applied) .map(|e| e.term) .unwrap_or(meta.last_snapshot_term) }; // Create snapshot entity let snapshot_entity = SnapshotEntity { last_included_index: meta.last_applied, last_included_term, snapshot: snapshot_data, }; // Persist snapshot to disk if let Some(ref storage) = meta.storage { let storage_clone = storage.clone(); let mut storage_guard = storage_clone.lock().await; match storage_guard.write_snapshot(&snapshot_entity).await { Ok(_) => { // Update snapshot metadata FIRST so compaction can use it meta.last_snapshot_index = snapshot_entity.last_included_index; meta.last_snapshot_term = snapshot_entity.last_included_term; info!( "Snapshot created successfully on plane {} at index={}, term={}", self.plane_id().raw(), meta.last_snapshot_index, meta.last_snapshot_term ); // Now compact logs (reads meta.last_snapshot_index) self.compact_logs_after_snapshot(meta, storage_guard).await; } Err(e) => { error!( "Failed to persist snapshot on plane {}: {:?}", self.plane_id().raw(), e ); } } } } /// Compact logs after a snapshot has been created async fn compact_logs_after_snapshot( &self, meta: &RwLockWriteGuard<'_, RaftMeta>, mut _storage: async_std::sync::MutexGuard<'_, disk::StorageEntity>, ) { if let Storage::DISK(ref opts) = self.options.storage { let snapshot_index = meta.last_snapshot_index; let compaction_threshold = opts.log_compaction_threshold; let mut logs = meta.logs.write().await; let before_count = logs.len(); debug!( "Compaction check on plane {}: {} logs, threshold: {}, snapshot_index: {}", self.plane_id().raw(), before_count, compaction_threshold, snapshot_index ); // Only compact if we exceed the compaction threshold if before_count as u64 > compaction_threshold { // Keep logs after last_snapshot_index logs.retain(|&id, _| id > snapshot_index); let after_count = logs.len(); info!( "Compacted {} logs on plane {} (from {} to {}), keeping logs after index {}", before_count - after_count, self.plane_id().raw(), before_count, after_count, snapshot_index ); } else { info!( "Skipping log compaction on plane {}: {} logs <= threshold {}", self.plane_id().raw(), before_count, compaction_threshold ); } } else { debug!( "Not using disk storage on plane {}, skipping compaction", self.plane_id().raw() ); } } } impl Service for RaftService { fn append_entries<'a>( &'a self, plane_id: PlaneId, term: u64, leader_id: u64, prev_log_id: u64, prev_log_term: u64, entries: &'a Option, leader_commit: u64, ) -> BoxFuture<'a, (u64, AppendEntriesResult)> { async move { match self.resolve_plane_runtime(plane_id, true, false).await { Ok((Some(runtime), _)) => { let meta = runtime.meta.write().await; self.handle_append_entries_on_meta( &runtime.is_leader, meta, term, leader_id, prev_log_id, prev_log_term, entries, leader_commit, ) .await } Ok((None, _)) => { let meta = self.write_meta().await; self.handle_append_entries_on_meta( &self._is_leader, meta, term, leader_id, prev_log_id, prev_log_term, entries, leader_commit, ) .await } Err(err) => { warn!( "Rejecting append_entries for plane {}: {}", plane_id.raw(), err ); (0, AppendEntriesResult::LogMismatch) } } } .boxed() } fn request_vote( &self, plane_id: PlaneId, term: u64, candidate_id: u64, last_log_id: u64, last_log_term: u64, ) -> BoxFuture<((u64, u64), bool)> { async move { match self.resolve_plane_runtime(plane_id, true, false).await { Ok((Some(runtime), _)) => { let meta = runtime.meta.write().await; self.handle_request_vote_on_meta( meta, term, candidate_id, last_log_id, last_log_term, ) .await } Ok((None, _)) => { let meta = self.write_meta().await; self.handle_request_vote_on_meta( meta, term, candidate_id, last_log_id, last_log_term, ) .await } Err(err) => { warn!( "Rejecting request_vote for plane {}: {}", plane_id.raw(), err ); ((0, self.get_server_id()), false) } } } .boxed() } fn install_snapshot( &self, plane_id: PlaneId, term: u64, leader_id: u64, last_included_index: u64, last_included_term: u64, data: Vec, ) -> BoxFuture { async move { match self.resolve_plane_runtime(plane_id, true, false).await { Ok((Some(runtime), _)) => { let meta = runtime.meta.write().await; self.handle_install_snapshot_on_meta( &runtime.is_leader, meta, term, leader_id, last_included_index, last_included_term, data, ) .await } Ok((None, _)) => { let meta = self.write_meta().await; self.handle_install_snapshot_on_meta( &self._is_leader, meta, term, leader_id, last_included_index, last_included_term, data, ) .await } Err(err) => { warn!( "Rejecting install_snapshot for plane {}: {}", plane_id.raw(), err ); 0 } } } .boxed() } fn c_command<'a>( &'a self, plane_id: PlaneId, entry: LogEntry, ) -> BoxFuture<'a, ClientCmdResponse> { async move { match self.resolve_plane_runtime(plane_id, false, false).await { Ok((Some(runtime), _)) => { let meta = runtime.meta.write().await; self.handle_client_command_on_meta(plane_id, &runtime.is_leader, meta, entry) .await } Ok((None, _)) => { let meta = self.write_meta().await; self.handle_client_command_on_meta(plane_id, &self._is_leader, meta, entry) .await } Err(err) => { warn!( "Rejecting client command for plane {}: {}", plane_id.raw(), err ); ClientCmdResponse::ShuttingDown } } } .boxed() } fn c_query<'a>( &'a self, plane_id: PlaneId, entry: &'a LogEntry, ) -> BoxFuture<'a, ClientQryResponse> { async move { match self.resolve_plane_runtime(plane_id, false, false).await { Ok((Some(runtime), _)) => { let meta = runtime.meta.read().await; self.handle_client_query_on_meta(meta, entry).await } Ok((None, _)) => { let meta = self.meta.read().await; self.handle_client_query_on_meta(meta, entry).await } Err(err) => { warn!( "Rejecting client query for plane {}: {}", plane_id.raw(), err ); ClientQryResponse::LeftBehind { last_log_term: 0, last_log_id: 0, } } } } .boxed() } fn c_server_cluster_info(&self, plane_id: PlaneId) -> BoxFuture { async move { match self.cluster_info_on_plane_local(plane_id).await { Ok(info) => info, Err(err) => { warn!( "Rejecting cluster_info for plane {}: {}", plane_id.raw(), err ); ClientClusterInfo { members: Vec::new(), last_log_id: 0, last_log_term: 0, leader_id: 0, } } } } .boxed() } fn c_put_offline(&self) -> BoxFuture { self.leave().boxed() } fn c_have_state_machine(&self, plane_id: PlaneId, id: u64) -> BoxFuture { async move { match self.have_state_machine_on_plane_local(plane_id, id).await { Ok(result) => result, Err(err) => { warn!( "Rejecting have_state_machine for plane {}: {}", plane_id.raw(), err ); false } } } .boxed() } fn c_ping(&self) -> BoxFuture<()> { future::ready(()).boxed() } fn reelect<'a>(&'a self, plane_id: PlaneId) -> futures::future::BoxFuture { async move { match self.resolve_plane_runtime(plane_id, true, false).await { Ok((Some(runtime), _)) => { let mut meta = runtime.meta.write().await; info!( "Been asked to reelect on plane {}, become candidate. Server id {}", plane_id.raw(), self.get_server_id() ); self.become_candidate_on_plane(plane_id, &runtime.is_leader, &mut meta) .await; runtime.is_leader.load(Relaxed) } Ok((None, _)) => { let mut meta = self.meta.write().await; info!( "Been asked to reelect on plane {}, become candidate. Server id {}", plane_id.raw(), self.get_server_id() ); self.become_candidate(&mut meta).await; let is_leader = self.is_leader(); info!( "Reelect result for plane {} server {}, is leader {}", plane_id.raw(), self.get_server_id(), is_leader ); is_leader } Err(err) => { warn!("Rejecting reelect for plane {}: {}", plane_id.raw(), err); false } } } .boxed() } } pub struct RaftStateMachine { pub id: u64, pub name: String, } impl RaftStateMachine { pub fn new(name: &String) -> RaftStateMachine { RaftStateMachine { id: hash_str(name), name: name.clone(), } } } #[cfg(test)] mod test { use self::client::SMClient; use self::commands::{add, get}; use crate::raft::client::RaftClient; use crate::raft::disk; use crate::raft::state_machine::master::ExecError; use crate::raft::state_machine::StateMachineCtl; use crate::raft::{ ClientCmdResponse, ClientQryResponse, LogEntry, Options, PlaneBootstrap, PlaneBootstrapError, PlaneHandle, PlaneId, PlaneSpec, RaftMsg, RaftService, Service as RaftRpcService, Storage, DEFAULT_SERVICE_ID, }; use crate::rpc::Server; use crate::utils::time::async_wait_secs; use futures::FutureExt; use std::sync::Arc; struct CounterStateMachine { value: u64, } raft_state_machine! { def cmd add(value: u64) -> u64; def qry get() -> u64; } impl StateMachineCmds for CounterStateMachine { fn add(&mut self, value: u64) -> BoxFuture { self.value += value; futures::future::ready(self.value).boxed() } fn get(&self) -> BoxFuture { futures::future::ready(self.value).boxed() } } impl StateMachineCtl for CounterStateMachine { raft_sm_complete!(); fn id(&self) -> u64 { 77 } fn snapshot(&self) -> Vec { Vec::new() } fn recover(&mut self, _: Vec) -> BoxFuture<()> { futures::future::ready(()).boxed() } fn recoverable(&self) -> bool { false } } struct PersistentCounterStateMachine { value: u64, } impl StateMachineCmds for PersistentCounterStateMachine { fn add(&mut self, value: u64) -> BoxFuture { self.value += value; futures::future::ready(self.value).boxed() } fn get(&self) -> BoxFuture { futures::future::ready(self.value).boxed() } } impl StateMachineCtl for PersistentCounterStateMachine { raft_sm_complete!(); fn id(&self) -> u64 { 88 } fn snapshot(&self) -> Vec { crate::utils::serde::serialize(&self.value) } fn recover(&mut self, data: Vec) -> BoxFuture<()> { if let Some(value) = crate::utils::serde::deserialize(&data) { self.value = value; } futures::future::ready(()).boxed() } fn recoverable(&self) -> bool { true } } async fn wait_for_plane(service: &Arc, plane_id: PlaneId) -> PlaneHandle { for _ in 0..5 { if let Ok(plane) = service.plane(plane_id).await { return plane; } async_wait_secs().await; } panic!("plane {} was not materialized in time", plane_id.raw()); } async fn query_counter_locally( service: &Arc, plane_id: PlaneId, sm_id: u64, ) -> u64 { let (fn_id, _, data) = get::new().encode(); let entry = LogEntry { id: 0, term: 0, sm_id, fn_id, data, }; match RaftRpcService::c_query(service.as_ref(), plane_id, &entry).await { ClientQryResponse::Success { data: Ok(data), .. } => get::decode_return(&data), other => panic!("unexpected local query response: {:?}", other), } } async fn add_counter_locally( service: &Arc, plane_id: PlaneId, sm_id: u64, value: u64, ) -> u64 { let (fn_id, _, data) = add::new(&value).encode(); let entry = LogEntry { id: 0, term: 0, sm_id, fn_id, data, }; match RaftRpcService::c_command(service.as_ref(), plane_id, entry).await { ClientCmdResponse::Success { data: Ok(data), .. } => add::decode_return(&data), other => panic!("unexpected local command response: {:?}", other), } } #[tokio::test(flavor = "multi_thread")] async fn startup() { let (success, _, _) = RaftService::new_server(Options { storage: Storage::default(), address: String::from("127.0.0.1:2000"), service_id: DEFAULT_SERVICE_ID, }) .await; assert!(success); } #[tokio::test(flavor = "multi_thread")] async fn type2_plane_client_roundtrip() { let port = 4210 + (rand::random::() % 200); let addr = format!("127.0.0.1:{}", port); let service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; let plane_id = PlaneId::type2(7).unwrap(); let plane = service .ensure_plane(PlaneSpec { plane_id }) .await .expect("plane should be created"); plane .register_state_machine(Box::new(CounterStateMachine { value: 0 })) .await .expect("state machine should register on type-2 plane"); plane .recover_after_register() .await .expect("type-2 plane should replay committed logs"); let client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .expect("raft client should connect"); let plane_client = client.plane(plane_id); let sm_client = SMClient::new(77, &plane_client); assert_eq!(sm_client.add(&5).await.unwrap(), 5); assert_eq!(sm_client.get().await.unwrap(), 5); assert!(plane.have_state_machine(77).await.unwrap()); assert!(plane_client.have_state_machine(77).await.unwrap()); assert_eq!( plane_client.cluster_info().await.unwrap().leader_id, service.id ); } #[tokio::test(flavor = "multi_thread")] async fn loaded_type2_planes_only_reports_materialized_type2_runtimes() { let port = 4410 + (rand::random::() % 200); let addr = format!("127.0.0.1:{}", port); let service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; assert!(service.loaded_type2_planes().await.is_empty()); let plane_id = PlaneId::type2(9).unwrap(); service .ensure_plane(PlaneSpec { plane_id }) .await .expect("type-2 plane should materialize"); assert_eq!(service.loaded_type2_planes().await, vec![plane_id]); } #[tokio::test(flavor = "multi_thread")] async fn type2_plane_shutdown_rejects_commands() { let port = 4610 + (rand::random::() % 200); let addr = format!("127.0.0.1:{}", port); let service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; let plane_id = PlaneId::type2(8).unwrap(); let plane = service .ensure_plane(PlaneSpec { plane_id }) .await .expect("plane should be created"); plane .register_state_machine(Box::new(CounterStateMachine { value: 0 })) .await .expect("state machine should register on type-2 plane"); plane .recover_after_register() .await .expect("type-2 plane should replay committed logs"); let client = RaftClient::new(&vec![addr], DEFAULT_SERVICE_ID) .await .expect("raft client should connect"); let plane_client = client.plane(plane_id); let sm_client = SMClient::new(77, &plane_client); assert_eq!(sm_client.add(&5).await.unwrap(), 5); plane .shutdown() .await .expect("type-2 plane should shut down"); assert!(!plane.is_leader().await.unwrap()); assert!(matches!( sm_client.add(&1).await, Err(ExecError::ShuttingDown) )); } #[tokio::test(flavor = "multi_thread")] async fn unknown_type2_plane_does_not_fall_back_to_type1() { let addr = String::from("127.0.0.1:22120"); let service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; service .register_state_machine(Box::new(CounterStateMachine { value: 0 })) .await; service.recover_after_register().await; let client = RaftClient::new(&vec![addr], DEFAULT_SERVICE_ID) .await .expect("raft client should connect"); let unknown_plane = client.plane(PlaneId::type2(999).unwrap()); assert!(!unknown_plane.have_state_machine(77).await.unwrap()); } #[tokio::test(flavor = "multi_thread")] async fn unknown_type2_plane_rejects_commands_without_leader_discovery_loop() { let addr = String::from("127.0.0.1:22122"); let service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; let client = RaftClient::new(&vec![addr], DEFAULT_SERVICE_ID) .await .expect("raft client should connect"); let unknown_plane = client.plane(PlaneId::type2(1001).unwrap()); let sm_client = SMClient::new(77, &unknown_plane); assert!(matches!( sm_client.add(&1).await, Err(ExecError::ShuttingDown) )); } #[tokio::test(flavor = "multi_thread")] async fn root_membership_helpers_hide_config_sm_commands() { let addr = String::from("127.0.0.1:22123"); let extra_addr = String::from("127.0.0.1:22124"); let service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; let extra_service = RaftService::new(Options { storage: Storage::default(), address: extra_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let extra_server = Server::new(&extra_addr); extra_server.register_service(&extra_service).await; Server::listen_and_resume(&extra_server).await; assert!(RaftService::start(&extra_service, false).await); let client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .expect("raft client should connect"); let members = client .root_member_addresses() .await .expect("root member list should be readable"); assert_eq!(members.len(), 1); assert_eq!(members[0], addr); assert!(client .add_root_member(&extra_addr) .await .expect("adding root member should succeed")); let members = client .root_member_addresses() .await .expect("root member list should include new member"); assert!(members.iter().any(|member| member == &extra_addr)); client .remove_root_member(&extra_addr) .await .expect("removing root member should succeed"); let members = client .root_member_addresses() .await .expect("root member list should be readable after removal"); assert!(!members.iter().any(|member| member == &extra_addr)); } #[tokio::test(flavor = "multi_thread")] async fn type2_plane_bootstrap_is_idempotent_for_same_members() { let addr = String::from("127.0.0.1:22125"); let service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; let plane_id = PlaneId::type2(43).unwrap(); service .ensure_plane_from_seeds(PlaneBootstrap { plane_id, seed_nodes: vec![addr.clone()], }) .await .expect("initial type-2 bootstrap should succeed"); let plane = service .ensure_plane_from_seeds(PlaneBootstrap { plane_id, seed_nodes: vec![addr.clone(), addr.clone()], }) .await .expect("repeating type-2 bootstrap with same members should be idempotent"); let info = plane .cluster_info() .await .expect("type-2 cluster info should be available"); assert_eq!(info.members.len(), 1); assert_eq!(info.members[0].1, addr); } #[tokio::test(flavor = "multi_thread")] async fn type2_plane_bootstrap_rejects_conflicting_member_set() { let addr = String::from("127.0.0.1:22126"); let extra_addr = String::from("127.0.0.1:22127"); let service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; let extra_service = RaftService::new(Options { storage: Storage::default(), address: extra_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let extra_server = Server::new(&extra_addr); extra_server.register_service(&extra_service).await; Server::listen_and_resume(&extra_server).await; assert!(RaftService::start(&extra_service, false).await); extra_service.join(&vec![addr.clone()]).await.unwrap(); async_wait_secs().await; let plane_id = PlaneId::type2(44).unwrap(); service .ensure_plane_from_seeds(PlaneBootstrap { plane_id, seed_nodes: vec![addr.clone()], }) .await .expect("initial type-2 bootstrap should succeed"); assert!(extra_service.leave().await); async_wait_secs().await; async_wait_secs().await; let err = match service .ensure_plane_from_seeds(PlaneBootstrap { plane_id, seed_nodes: vec![addr.clone()], }) .await { Ok(_) => panic!("conflicting type-2 bootstrap should be rejected"), Err(err) => err, }; match err { PlaneBootstrapError::MembershipConflict { plane_id: conflict_plane_id, current_members, requested_members, } => { assert_eq!(conflict_plane_id, plane_id); assert_eq!(current_members, vec![addr, extra_addr]); assert_eq!(requested_members, vec![String::from("127.0.0.1:22126")]); } other => panic!("unexpected bootstrap error: {:?}", other), } } #[tokio::test(flavor = "multi_thread")] async fn type2_plane_multinode_replication_and_reelection() { let _ = env_logger::try_init(); let addr1 = String::from("127.0.0.1:22130"); let addr2 = String::from("127.0.0.1:22131"); let addr3 = String::from("127.0.0.1:22132"); let service1 = RaftService::new(Options { storage: Storage::default(), address: addr1.clone(), service_id: DEFAULT_SERVICE_ID, }); let server1 = Server::new(&addr1); server1.register_service(&service1).await; Server::listen_and_resume(&server1).await; assert!(RaftService::start(&service1, false).await); let service2 = RaftService::new(Options { storage: Storage::default(), address: addr2.clone(), service_id: DEFAULT_SERVICE_ID, }); let server2 = Server::new(&addr2); server2.register_service(&service2).await; Server::listen_and_resume(&server2).await; assert!(RaftService::start(&service2, false).await); let service3 = RaftService::new(Options { storage: Storage::default(), address: addr3.clone(), service_id: DEFAULT_SERVICE_ID, }); let server3 = Server::new(&addr3); server3.register_service(&service3).await; Server::listen_and_resume(&server3).await; assert!(RaftService::start(&service3, false).await); service1.bootstrap().await; service2.join(&vec![addr1.clone()]).await.unwrap(); service3 .join(&vec![addr1.clone(), addr2.clone()]) .await .unwrap(); async_wait_secs().await; let plane_id = PlaneId::type2(41).unwrap(); let leader_plane = service1 .ensure_plane_from_seeds(PlaneBootstrap { plane_id, seed_nodes: vec![addr1.clone()], }) .await .expect("type-2 leader plane should be created"); leader_plane .register_state_machine(Box::new(CounterStateMachine { value: 0 })) .await .expect("leader state machine should register"); leader_plane .recover_after_register() .await .expect("leader plane should recover after register"); let client = RaftClient::new( &vec![addr1.clone(), addr2.clone(), addr3.clone()], DEFAULT_SERVICE_ID, ) .await .expect("raft client should connect"); let plane_client = client.plane(plane_id); async_wait_secs().await; async_wait_secs().await; let follower_plane2 = wait_for_plane(&service2, plane_id).await; follower_plane2 .register_state_machine(Box::new(CounterStateMachine { value: 0 })) .await .expect("follower 2 state machine should register"); follower_plane2 .recover_after_register() .await .expect("follower 2 plane should recover after register"); let follower_plane3 = wait_for_plane(&service3, plane_id).await; follower_plane3 .register_state_machine(Box::new(CounterStateMachine { value: 0 })) .await .expect("follower 3 state machine should register"); follower_plane3 .recover_after_register() .await .expect("follower 3 plane should recover after register"); let sm_client = SMClient::new(77, &plane_client); assert_eq!(sm_client.add(&5).await.unwrap(), 5); async_wait_secs().await; assert_eq!(query_counter_locally(&service2, plane_id, 77).await, 5); assert_eq!(query_counter_locally(&service3, plane_id, 77).await, 5); leader_plane .shutdown() .await .expect("leader plane should shut down cleanly"); let candidate2 = RaftRpcService::reelect(service2.as_ref(), plane_id).await; let candidate3 = if candidate2 { false } else { RaftRpcService::reelect(service3.as_ref(), plane_id).await }; assert!( candidate2 || candidate3, "one follower should win the re-election" ); async_wait_secs().await; let plane2_after = wait_for_plane(&service2, plane_id).await; let plane3_after = wait_for_plane(&service3, plane_id).await; let plane2_is_leader = plane2_after.is_leader().await.unwrap(); let plane3_is_leader = plane3_after.is_leader().await.unwrap(); assert_ne!(plane2_is_leader, plane3_is_leader); let new_leader = if plane2_is_leader { &service2 } else { &service3 }; let new_follower = if plane2_is_leader { &service3 } else { &service2 }; assert_eq!(add_counter_locally(new_leader, plane_id, 77, 2).await, 7); async_wait_secs().await; assert_eq!(query_counter_locally(new_follower, plane_id, 77).await, 7); } #[tokio::test(flavor = "multi_thread")] async fn type2_plane_multinode_follower_restart_recovers_without_eager_load() { let _ = env_logger::try_init(); let dir1 = std::env::temp_dir().join(format!( "raft_type2_multi_recover_1_{}", rand::random::() )); let dir2 = std::env::temp_dir().join(format!( "raft_type2_multi_recover_2_{}", rand::random::() )); std::fs::create_dir_all(&dir1).unwrap(); std::fs::create_dir_all(&dir2).unwrap(); let addr1 = String::from("127.0.0.1:22140"); let addr2 = String::from("127.0.0.1:22141"); let service1 = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: dir1.to_string_lossy().to_string(), take_snapshots: false, append_logs: true, trim_logs: true, snapshot_log_threshold: 1000, log_compaction_threshold: 2000, }), address: addr1.clone(), service_id: DEFAULT_SERVICE_ID, }); let server1 = Server::new(&addr1); server1.register_service(&service1).await; Server::listen_and_resume(&server1).await; assert!(RaftService::start(&service1, false).await); let service2 = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: dir2.to_string_lossy().to_string(), take_snapshots: false, append_logs: true, trim_logs: true, snapshot_log_threshold: 1000, log_compaction_threshold: 2000, }), address: addr2.clone(), service_id: DEFAULT_SERVICE_ID, }); let server2 = Server::new(&addr2); server2.register_service(&service2).await; Server::listen_and_resume(&server2).await; assert!(RaftService::start(&service2, false).await); service1.bootstrap().await; service2.join(&vec![addr1.clone()]).await.unwrap(); async_wait_secs().await; let plane_id = PlaneId::type2(42).unwrap(); let leader_plane = service1 .ensure_plane_from_seeds(PlaneBootstrap { plane_id, seed_nodes: vec![addr1.clone()], }) .await .expect("type-2 leader plane should be created"); leader_plane .register_state_machine(Box::new(PersistentCounterStateMachine { value: 0 })) .await .expect("leader persistent state machine should register"); leader_plane .recover_after_register() .await .expect("leader plane should recover after register"); let client = RaftClient::new(&vec![addr1.clone(), addr2.clone()], DEFAULT_SERVICE_ID) .await .expect("raft client should connect"); let plane_client = client.plane(plane_id); async_wait_secs().await; async_wait_secs().await; let follower_plane = wait_for_plane(&service2, plane_id).await; follower_plane .register_state_machine(Box::new(PersistentCounterStateMachine { value: 0 })) .await .expect("follower persistent state machine should register"); follower_plane .recover_after_register() .await .expect("follower plane should recover after register"); let sm_client = SMClient::new(88, &plane_client); assert_eq!(sm_client.add(&7).await.unwrap(), 7); async_wait_secs().await; assert_eq!(query_counter_locally(&service2, plane_id, 88).await, 7); follower_plane .flush_persistence() .await .expect("follower persistence should flush"); follower_plane .shutdown() .await .expect("follower plane should shut down cleanly"); { let mut planes = service2.planes.write().await; assert!(planes.remove(&plane_id).is_some()); } let reloaded_follower = service2 .plane(plane_id) .await .expect("persisted follower plane should lazy load from disk"); reloaded_follower .register_state_machine(Box::new(PersistentCounterStateMachine { value: 0 })) .await .expect("reloaded follower state machine should register"); reloaded_follower .recover_after_register() .await .expect("reloaded follower plane should replay persisted logs"); assert!( !reloaded_follower.is_leader().await.unwrap(), "reloaded follower should not self-promote during lazy recovery" ); async_wait_secs().await; assert_eq!(query_counter_locally(&service2, plane_id, 88).await, 7); assert_eq!(sm_client.add(&1).await.unwrap(), 8); async_wait_secs().await; assert_eq!(query_counter_locally(&service2, plane_id, 88).await, 8); let _ = std::fs::remove_dir_all(&dir1); let _ = std::fs::remove_dir_all(&dir2); } #[tokio::test(flavor = "multi_thread")] async fn type2_plane_lazy_loads_after_unload() { let temp_dir = std::env::temp_dir().join(format!("raft_type2_lazy_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let addr = String::from("127.0.0.1:22121"); let service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: temp_dir.to_string_lossy().to_string(), take_snapshots: false, append_logs: true, trim_logs: true, snapshot_log_threshold: 1000, log_compaction_threshold: 2000, }), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&service, false).await); service.bootstrap().await; let hot_plane_id = PlaneId::type2(31).unwrap(); let cold_plane_id = PlaneId::type2(32).unwrap(); let hot_plane = service .ensure_plane(PlaneSpec { plane_id: hot_plane_id, }) .await .expect("hot plane should be created"); hot_plane .register_state_machine(Box::new(CounterStateMachine { value: 0 })) .await .expect("state machine should register on hot plane"); hot_plane .recover_after_register() .await .expect("hot plane should recover after register"); let cold_plane = service .ensure_plane(PlaneSpec { plane_id: cold_plane_id, }) .await .expect("cold plane should be created"); cold_plane .register_state_machine(Box::new(PersistentCounterStateMachine { value: 0 })) .await .expect("state machine should register on cold plane"); cold_plane .recover_after_register() .await .expect("cold plane should recover after register"); let client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .expect("shared raft client should connect"); let hot_client = client.plane(hot_plane_id); let cold_client = client.plane(cold_plane_id); let hot_sm = SMClient::new(77, &hot_client); let cold_sm = SMClient::new(88, &cold_client); assert_eq!(hot_sm.add(&3).await.unwrap(), 3); assert_eq!(cold_sm.add(&7).await.unwrap(), 7); cold_plane .flush_persistence() .await .expect("cold plane persistence should flush"); cold_plane .shutdown() .await .expect("cold plane should shut down cleanly"); { let mut planes = service.planes.write().await; assert!(planes.remove(&cold_plane_id).is_some()); } assert_eq!(hot_sm.add(&2).await.unwrap(), 5); let reloaded_cold_plane = service .plane(cold_plane_id) .await .expect("persisted cold plane should lazy load from disk"); reloaded_cold_plane .register_state_machine(Box::new(PersistentCounterStateMachine { value: 0 })) .await .expect("persistent state machine should register after lazy load"); reloaded_cold_plane .recover_after_register() .await .expect("lazy-loaded plane should replay persisted logs"); let reloaded_cold_client = client.plane(cold_plane_id); let reloaded_cold_sm = SMClient::new(88, &reloaded_cold_client); assert!(reloaded_cold_plane.is_leader().await.unwrap()); assert_eq!(reloaded_cold_sm.get().await.unwrap(), 7); assert_eq!(reloaded_cold_sm.add(&1).await.unwrap(), 8); let _ = std::fs::remove_dir_all(&temp_dir); } #[tokio::test(flavor = "multi_thread")] async fn server_membership() { let _ = env_logger::try_init(); let s1_addr = String::from("127.0.0.1:2001"); let s2_addr = String::from("127.0.0.1:2002"); let s3_addr = String::from("127.0.0.1:2003"); let service1 = RaftService::new(Options { storage: Storage::default(), address: s1_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); info!("Starting server 1"); let server1 = Server::new(&s1_addr); info!("Register raft service for server 1"); server1.register_service(&service1).await; info!("Listening server 1"); Server::listen_and_resume(&server1).await; info!("Start raft service server 1"); assert!(RaftService::start(&service1, false).await); info!("Bootstrap raft service server 1"); service1.bootstrap().await; let num_members = service1.num_members().await; assert_eq!(num_members, 1); info!("Starting server 2"); let server2 = Server::new(&s2_addr); info!("Register raft service for server 2"); let service2 = RaftService::new(Options { storage: Storage::default(), address: s2_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); server2.register_service(&service2).await; info!("Listening server 2"); Server::listen_and_resume(&server2).await; info!("Start raft service for server 2"); assert!(RaftService::start(&service2, false).await); info!("Server 2 join with server 1"); let join_result = service2.join(&vec![s1_addr.clone()]).await; match join_result { Err(ExecError::ServersUnreachable) => panic!("Server unreachable"), Err(ExecError::CannotConstructClient) => panic!("Cannot Construct Client"), Err(e) => panic!(e), Ok(join_success) => assert!(join_success), } assert!(join_result.is_ok()); info!("Checking number of members in both side"); assert_eq!(service1.num_members().await, 2); assert_eq!(service2.num_members().await, 2); info!("Starting server 3"); let service3 = RaftService::new(Options { storage: Storage::default(), address: s3_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server3 = Server::new(&s3_addr); Server::listen_and_resume(&server3).await; info!("Register raft service for server 3"); server3.register_service(&service3).await; info!("Start raft service for server 3"); assert!(RaftService::start(&service3, false).await); info!("Server 3 join server 1 and server 2"); let join_result = service3.join(&vec![s1_addr.clone(), s2_addr.clone()]).await; assert!(join_result.unwrap()); info!("Checking numbers of users on 3 servers"); assert_eq!(service1.num_members().await, 3); assert_eq!(service2.num_members().await, 3); assert_eq!(service3.num_members().await, 3); async_wait_secs().await; // test remove member info!( "Server 1 ({}) is leaving, leader {}", service1.id, service1.leader_id().await ); assert!(service1.leave().await); async_wait_secs().await; info!("Check number of servers, should be 2"); assert_eq!(service2.num_members().await, 2); assert_eq!(service3.num_members().await, 2); async_wait_secs().await; info!( "Server 2 ({}) is leaving, leader {}", server2.server_id, service2.leader_id().await ); assert!(service2.leave().await); // there will be some unavailability in leader transaction async_wait_secs().await; async_wait_secs().await; async_wait_secs().await; assert_eq!(service3.num_members().await, 1); } #[tokio::test(flavor = "multi_thread")] async fn log_replication() { let _ = env_logger::try_init(); info!("Testing log replications"); let s1_addr = String::from("127.0.0.1:2004"); let s2_addr = String::from("127.0.0.1:2005"); let s3_addr = String::from("127.0.0.1:2006"); let s4_addr = String::from("127.0.0.1:2007"); let s5_addr = String::from("127.0.0.1:2008"); let service1 = RaftService::new(Options { storage: Storage::default(), address: s1_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let service2 = RaftService::new(Options { storage: Storage::default(), address: s2_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let service3 = RaftService::new(Options { storage: Storage::default(), address: s3_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let service4 = RaftService::new(Options { storage: Storage::default(), address: s4_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let service5 = RaftService::new(Options { storage: Storage::default(), address: s5_addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server_list = vec![ s1_addr.clone(), s2_addr.clone(), s3_addr.clone(), s4_addr.clone(), ]; info!("Start server 1"); let server1 = Server::new(&s1_addr); info!("Register raft service for server 1"); server1.register_service(&service1).await; info!("Listen server 1"); Server::listen_and_resume(&server1).await; info!("Starting raft service for server 1"); assert!(RaftService::start(&service1, false).await); info!("Bootstrap raft for server 1"); assert_eq!(service1.probe_and_join(&server_list).await.unwrap(), false); info!("Starting server 2"); let server2 = Server::new(&s2_addr); info!("Listening server 2"); Server::listen_and_resume(&server2).await; info!("Register raft service for server 2"); server2.register_service(&service2).await; info!("Start raft service for server 2"); assert!(RaftService::start(&service2, false).await); info!("Server 2 join cluster"); let join_result = service2.probe_and_join(&server_list).await; join_result.unwrap(); info!("Starting server 3"); let server3 = Server::new(&s3_addr); info!("Register raft service for server 3"); server3.register_service(&service3).await; info!("Listening for server 3"); Server::listen_and_resume(&server3).await; info!("Starting raft service for server 3"); assert!(RaftService::start(&service3, false).await); info!("Server 3 join the cluster"); let join_result = service3.probe_and_join(&server_list).await; join_result.unwrap(); info!("Starting server 4"); let server4 = Server::new(&s4_addr); info!("Register raft service for server 4"); server4.register_service(&service4).await; info!("Listening for server 4"); Server::listen_and_resume(&server4).await; info!("Starting raft service for server 4"); assert!(RaftService::start(&service4, false).await); info!("Server 4 join cluster"); let join_result = service4.probe_and_join(&server_list).await; join_result.unwrap(); info!("Starting server 5"); let server5 = Server::new(&s5_addr); info!("Register raft service for server 5"); server5.register_service(&service5).await; info!("Listening for server 5"); Server::listen_and_resume(&server5).await; info!("Starting raft service for server 5"); assert!(RaftService::start(&service5, false).await); info!("Server 5 join cluster"); let join_result = service5.probe_and_join(&server_list).await; join_result.unwrap(); info!("Waiting for seconds for consistency check"); async_wait_secs().await; // wait for membership replication to take effect async_wait_secs().await; async_wait_secs().await; info!("Number of logs should be the same"); assert_eq!(service1.num_logs().await, service2.num_logs().await); assert_eq!(service2.num_logs().await, service3.num_logs().await); assert_eq!(service3.num_logs().await, service4.num_logs().await); assert_eq!(service4.num_logs().await, service5.num_logs().await); assert_eq!(service5.num_logs().await, 4); // check all logs replicated info!("All servers should have the same leader id on record"); assert_eq!(service1.leader_id().await, service1.id); assert_eq!(service2.leader_id().await, service1.id); assert_eq!(service3.leader_id().await, service1.id); assert_eq!(service4.leader_id().await, service1.id); assert_eq!(service5.leader_id().await, service1.id); } mod state_machine { use super::*; use crate::raft::client::RaftClient; use crate::raft::disk; use crate::raft::state_machine::configs::CONFIG_SM_ID; use crate::raft::state_machine::master::MasterStateMachine; use crate::raft::{ LifecycleState, LogEntry, Membership, RaftMeta, Service, SnapshotEntity, }; use crate::utils::time::async_wait; use futures::stream::FuturesUnordered; use std::collections::BTreeMap; use std::sync::Arc; use std::time::Duration; raft_state_machine! { def qry answer_to_the_universe(name: String) -> String; def qry get_shot() -> i32; def cmd take_a_shot(num: i32) -> i32; } struct SM { shots: i32, } impl StateMachineCmds for SM { fn answer_to_the_universe<'a>(&'a self, name: String) -> BoxFuture<'_, String> { future::ready(format!("{}, the answer is 42", name)).boxed() } fn take_a_shot(&mut self, num: i32) -> BoxFuture { self.shots -= num; info!("Shot...{}...now...{}", num, self.shots); future::ready(self.shots).boxed() } fn get_shot(&self) -> BoxFuture { future::ready(self.shots).boxed() } } impl StateMachineCtl for SM { raft_sm_complete!(); fn id(&self) -> u64 { 15 } fn snapshot(&self) -> Vec { // Serialize the shots value crate::utils::serde::serialize(&self.shots) } fn recover(&mut self, data: Vec) -> BoxFuture<()> { // Deserialize and restore the shots value if !data.is_empty() { self.shots = crate::utils::serde::deserialize(&data).unwrap(); info!("SM recovered state: shots={}", self.shots); } future::ready(()).boxed() } fn recoverable(&self) -> bool { true } } #[tokio::test(flavor = "multi_thread")] async fn query_and_command() { let _ = env_logger::try_init(); info!("TESTING CALLBACK"); let addr = String::from("127.0.0.1:2009"); let raft_service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 10 }; let server = Server::new(&addr); let sm_id = sm.id(); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service.bootstrap().await; async_wait_secs().await; let raft_client = RaftClient::new(&vec![addr], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client = client::SMClient::new(sm_id, &raft_client); assert_eq!( sm_client .answer_to_the_universe(&"Alice".to_string()) .await .unwrap(), "Alice, the answer is 42" ); assert_eq!(sm_client.take_a_shot(&2).await.unwrap(), 8); } #[tokio::test(flavor = "multi_thread")] async fn multi_server_command() { let _ = env_logger::try_init(); // 5 servers let base_port = 4810 + (rand::random::() % 200); let addresses: Vec<_> = (0..5) .map(|offset| format!("127.0.0.1:{}", base_port + offset)) .collect(); let raft_services = addresses .iter() .map(|addr| { let addr = addr.clone(); async move { let raft_service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 10 }; let server = Server::new(&addr); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service } }) .collect::>() .collect::>() .await; raft_services[0].bootstrap().await; for i in 1..raft_services.len() { raft_services[i].join(&addresses).await.unwrap(); } info!("Waiting cluster to be stable"); async_wait(Duration::from_secs(2)).await; let raft_client = RaftClient::new(&addresses, DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client = Arc::new(client::SMClient::new(15, &raft_client)); info!("Mass command"); for _ in 0..100 { sm_client.take_a_shot(&-1).await.unwrap(); } async_wait(Duration::from_secs(5)).await; info!("Mass query"); for i in 0..100 { assert_eq!( sm_client.get_shot().await.unwrap(), 110, "fail at test {}", i ); } info!("Tests after leader transfer"); info!("Leader should be consistent"); for svr in &raft_services { assert_eq!(svr.leader_id().await, raft_services[0].id); } // Leader leave debug!("Leader leave cluster"); raft_services[0].leave().await; async_wait(Duration::from_secs(5)).await; debug!("Leader should be changed"); let new_leader = raft_services[1].leader_id().await; for (i, svr) in raft_services.iter().enumerate() { if svr.id == raft_services[0].id { continue; } assert_ne!( svr.leader_id().await, raft_services[0].id, "id {} at node {}", i, svr.id ); assert_eq!(svr.leader_id().await, new_leader); } info!("Now we have leader {}", new_leader); info!("Mass command"); for i in 0..10 { let res = sm_client.take_a_shot(&1).await; assert!(res.is_ok(), "{:?} at {}", res, i); } async_wait(Duration::from_secs(5)).await; info!("Mass query"); for i in 0..100 { assert_eq!( sm_client.get_shot().await.unwrap(), 100, "fail at test {}", i ); } } #[tokio::test(flavor = "multi_thread")] async fn snapshot_disk_persistence() { let _ = env_logger::try_init(); info!("TESTING SNAPSHOT DISK PERSISTENCE"); let temp_dir = std::env::temp_dir().join(format!("raft_test_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let data_path = temp_dir.to_str().unwrap().to_string(); let addr = String::from("127.0.0.1:3000"); // Create service with disk storage let raft_service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: true, append_logs: true, trim_logs: true, snapshot_log_threshold: 10, log_compaction_threshold: 20, }), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 100 }; let server = Server::new(&addr); let sm_id = sm.id(); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service.bootstrap().await; async_wait_secs().await; // Manually trigger a snapshot to test persistence { let mut meta = raft_service.write_meta().await; // Execute some state changes first for _ in 0..5 { meta.last_applied += 1; } raft_service.take_snapshot(&mut meta).await; assert!( meta.last_snapshot_index > 0, "Snapshot should have been created" ); info!( "Manual snapshot created at index {}", meta.last_snapshot_index ); } // Verify snapshot file exists on disk let snapshot_path = std::path::PathBuf::from(&data_path).join("snapshot.dat"); assert!(snapshot_path.exists(), "Snapshot file should exist on disk"); info!("Snapshot file verified on disk"); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); info!("Snapshot persistence test passed"); } #[tokio::test(flavor = "multi_thread")] async fn snapshot_persistence_and_recovery() { let _ = env_logger::try_init(); info!("TESTING SNAPSHOT FILE PERSISTENCE AND RELOAD"); let temp_dir = std::env::temp_dir().join(format!("raft_persist_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); // Create a state machine and snapshot it let mut sm1 = SM { shots: 42 }; let snapshot_data = sm1.snapshot(); info!("Created snapshot with shots=42"); // Create snapshot entity let snapshot_entity = SnapshotEntity { last_included_index: 100, last_included_term: 5, snapshot: snapshot_data, }; // Write to disk let mut storage = disk::StorageEntity { logs: None, snapshot: None, last_term: 0, base_path: temp_dir.clone(), plane_id: PlaneId::type1(), }; storage.write_snapshot(&snapshot_entity).await.unwrap(); info!("Snapshot persisted to disk"); // Verify file exists let snapshot_file = temp_dir.join("snapshot.dat"); assert!(snapshot_file.exists(), "Snapshot file should exist"); // Load snapshot from disk let loaded_snapshot = storage.read_snapshot().await.unwrap(); assert!(loaded_snapshot.is_some(), "Should load snapshot"); let loaded = loaded_snapshot.unwrap(); assert_eq!(loaded.last_included_index, 100); assert_eq!(loaded.last_included_term, 5); // Create a new state machine and recover from loaded snapshot let mut sm2 = SM { shots: 999 }; // Different initial state sm2.recover(loaded.snapshot).await; // Verify recovery assert_eq!(sm2.shots, 42, "Should recover to snapshot value"); info!("Successfully recovered state from persisted snapshot!"); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_snapshot_recovery_on_startup() { let _ = env_logger::try_init(); info!("TESTING SNAPSHOT RECOVERY ON STARTUP"); let temp_dir = std::env::temp_dir().join(format!("raft_recovery_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); // Create a snapshot file on disk let snapshot = SnapshotEntity { last_included_index: 100, last_included_term: 5, snapshot: vec![42u8; 100], // Some test data }; let mut storage = disk::StorageEntity { logs: None, snapshot: None, last_term: 0, base_path: temp_dir.clone(), plane_id: PlaneId::type1(), }; // Write snapshot to disk storage.write_snapshot(&snapshot).await.unwrap(); info!("Snapshot written to disk for recovery test"); // Verify load_snapshot_on_startup would work let loaded = storage.read_snapshot().await.unwrap(); assert!(loaded.is_some(), "Should load snapshot from disk"); let recovered = loaded.unwrap(); assert_eq!(recovered.last_included_index, 100); assert_eq!(recovered.last_included_term, 5); assert_eq!(recovered.snapshot.len(), 100); info!("Snapshot recovery test passed - would recover on startup"); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_snapshot_write_and_read() { let _ = env_logger::try_init(); info!("TESTING SNAPSHOT WRITE AND READ FROM DISK"); let temp_dir = std::env::temp_dir().join(format!("raft_snapshot_io_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); // Create a snapshot entity let test_data = vec![1u8, 2, 3, 4, 5, 42, 100]; let snapshot = SnapshotEntity { last_included_index: 42, last_included_term: 5, snapshot: test_data.clone(), }; // Create storage entity let mut storage = disk::StorageEntity { logs: None, snapshot: None, last_term: 0, base_path: temp_dir.clone(), plane_id: PlaneId::type1(), }; // Write snapshot storage.write_snapshot(&snapshot).await.unwrap(); info!("Snapshot written to disk"); // Verify file exists let snapshot_file = temp_dir.join("snapshot.dat"); assert!(snapshot_file.exists(), "Snapshot file should exist"); // Read snapshot back let loaded = storage.read_snapshot().await.unwrap(); assert!(loaded.is_some(), "Should load snapshot"); let loaded_snapshot = loaded.unwrap(); assert_eq!(loaded_snapshot.last_included_index, 42); assert_eq!(loaded_snapshot.last_included_term, 5); assert_eq!(loaded_snapshot.snapshot, test_data); info!("Snapshot read successfully and data matches"); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_snapshot_corruption_detection() { let _ = env_logger::try_init(); info!("TESTING SNAPSHOT CORRUPTION DETECTION"); let temp_dir = std::env::temp_dir() .join(format!("raft_snapshot_corrupt_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let snapshot = SnapshotEntity { last_included_index: 10, last_included_term: 2, snapshot: vec![1, 2, 3, 4, 5], }; let mut storage = disk::StorageEntity { logs: None, snapshot: None, last_term: 0, base_path: temp_dir.clone(), plane_id: PlaneId::type1(), }; // Write valid snapshot storage.write_snapshot(&snapshot).await.unwrap(); // Corrupt the file by modifying some bytes let snapshot_file = temp_dir.join("snapshot.dat"); let mut file_data = std::fs::read(&snapshot_file).unwrap(); if file_data.len() > 20 { file_data[20] ^= 0xFF; // Flip some bits std::fs::write(&snapshot_file, file_data).unwrap(); } // Try to read corrupted snapshot let result = storage.read_snapshot().await; assert!(result.is_ok(), "Should not error on corruption"); assert!( result.unwrap().is_none(), "Should return None for corrupted snapshot" ); info!("Corruption detection working correctly"); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_log_compaction_removes_old_logs() { let _ = env_logger::try_init(); info!("TESTING LOG COMPACTION"); let temp_dir = std::env::temp_dir().join(format!("raft_compact_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let addr = String::from("127.0.0.1:3100"); let raft_service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: temp_dir.to_str().unwrap().to_string(), take_snapshots: true, append_logs: true, trim_logs: true, snapshot_log_threshold: 5, log_compaction_threshold: 10, }), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 100 }; let server = Server::new(&addr); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service.bootstrap().await; async_wait_secs().await; // Get baseline log count and highest log ID let (baseline_count, max_log_id) = { let meta = raft_service.read_meta().await; let logs = meta.logs.read().await; let max_id = logs.keys().max().copied().unwrap_or(0); (logs.len(), max_id) }; info!("Baseline: {} logs, max_id: {}", baseline_count, max_log_id); // Add 20 more logs with sequential IDs let new_log_start = max_log_id + 1; let new_log_end = new_log_start + 19; { let meta = raft_service.read_meta().await; let mut logs = meta.logs.write().await; for i in new_log_start..=new_log_end { logs.insert( i, LogEntry { id: i, term: 1, sm_id: 15, fn_id: 1, data: vec![], }, ); } } let after_add = raft_service.num_logs().await; info!("After adding 20 logs: {}", after_add); assert_eq!(after_add, baseline_count + 20, "Should have added 20 logs"); // Create snapshot that covers first half of new logs let snapshot_index = new_log_start + 9; // Cover first 10 of our new logs { let mut meta = raft_service.write_meta().await; meta.last_applied = snapshot_index; raft_service.take_snapshot(&mut meta).await; assert_eq!(meta.last_snapshot_index, snapshot_index); } let final_count = raft_service.num_logs().await; info!("Final log count after compaction: {}", final_count); // Should have compacted logs up to snapshot_index // Remaining: baseline logs after snapshot_index + remaining new logs assert!( final_count < after_add, "Should have compacted some logs: before={}, after={}", after_add, final_count ); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); info!("Log compaction test passed"); } #[tokio::test(flavor = "multi_thread")] async fn test_snapshot_threshold_configuration() { let _ = env_logger::try_init(); info!("TESTING SNAPSHOT THRESHOLD CONFIGURATION"); let temp_dir = std::env::temp_dir().join(format!("raft_threshold_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let addr = String::from("127.0.0.1:3101"); // Create with custom thresholds let raft_service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: temp_dir.to_str().unwrap().to_string(), take_snapshots: true, append_logs: true, trim_logs: true, snapshot_log_threshold: 3, // Very low for testing log_compaction_threshold: 6, }), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 100 }; let server = Server::new(&addr); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service.bootstrap().await; async_wait_secs().await; // Simulate some activity { let mut meta = raft_service.write_meta().await; meta.last_applied = 5; // Above threshold of 3 // Test should_take_snapshot let should_snapshot = raft_service.should_take_snapshot(&meta, 10); assert!( should_snapshot, "Should trigger snapshot when last_applied (5) > threshold (3)" ); info!("Threshold check passed"); } // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_state_machine_snapshot_and_recovery() { let _ = env_logger::try_init(); info!("TESTING STATE MACHINE SNAPSHOT AND RECOVERY"); // Create SM with initial state let mut sm = SM { shots: 42 }; // Take snapshot let snapshot_data = sm.snapshot(); info!("Snapshot taken, size: {} bytes", snapshot_data.len()); // Modify state sm.shots = 999; assert_eq!(sm.shots, 999); // Recover from snapshot sm.recover(snapshot_data).await; // Verify state was restored assert_eq!(sm.shots, 42, "State should be recovered to snapshot value"); info!("State machine recovery test passed"); } #[tokio::test(flavor = "multi_thread")] async fn test_install_snapshot_compacts_logs() { let _ = env_logger::try_init(); info!("TESTING install_snapshot COMPACTS LOGS"); let temp_dir = std::env::temp_dir().join(format!("raft_install_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let addr = String::from("127.0.0.1:3102"); let raft_service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: temp_dir.to_str().unwrap().to_string(), take_snapshots: true, append_logs: true, trim_logs: true, snapshot_log_threshold: 10, log_compaction_threshold: 20, }), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 100 }; let server = Server::new(&addr); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service.bootstrap().await; async_wait_secs().await; // Add logs { let meta = raft_service.read_meta().await; let mut logs = meta.logs.write().await; for i in 1..=15u64 { logs.insert( i, LogEntry { id: i, term: 1, sm_id: 15, fn_id: 1, data: vec![], }, ); } } let before_count = raft_service.num_logs().await; info!("Logs before install_snapshot: {}", before_count); // Create valid snapshot data (SnapshotDataItems format) use crate::raft::state_machine::master::SnapshotDataItems; let snapshot_items: SnapshotDataItems = vec![ (CONFIG_SM_ID, vec![1u8, 2, 3]), // Config SM snapshot (15u64, vec![42u8; 10]), // Test SM snapshot ]; let snapshot_data = crate::utils::serde::serialize(&snapshot_items); // Simulate receiving a snapshot via install_snapshot let _result = (&*raft_service as &dyn Service) .install_snapshot( PlaneId::type1(), 1, // term 12345, // leader_id 10, // last_included_index 1, // last_included_term snapshot_data, ) .await; let after_count = raft_service.num_logs().await; info!("Logs after install_snapshot: {}", after_count); // Should have removed logs 1-10, keeping logs with id > 10 assert!( after_count < before_count, "Should have compacted logs: before={}, after={}", before_count, after_count ); // Verify logs 1-10 are gone { let meta = raft_service.read_meta().await; let logs = meta.logs.read().await; for i in 1..=10 { assert!( !logs.contains_key(&i), "Log {} should have been compacted", i ); } // Logs 11-15 should still exist for i in 11..=15 { assert!(logs.contains_key(&i), "Log {} should still exist", i); } } // Verify snapshot metadata was updated let meta = raft_service.read_meta().await; assert_eq!(meta.last_snapshot_index, 10); assert_eq!(meta.last_snapshot_term, 1); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); info!("install_snapshot compaction test passed"); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_logs_written_to_disk() { let _ = env_logger::try_init(); info!("TESTING WAL - LOGS WRITTEN TO DISK"); let temp_dir = std::env::temp_dir().join(format!("raft_wal_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let data_path = temp_dir.to_str().unwrap().to_string(); let addr = String::from("127.0.0.1:3200"); let raft_service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: false, // Disable snapshots to focus on logs append_logs: true, // Enable WAL trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 100 }; let server = Server::new(&addr); let sm_id = sm.id(); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service.bootstrap().await; async_wait_secs().await; // Verify log file doesn't exist yet or is small let log_file_path = std::path::PathBuf::from(&data_path).join("log.dat"); let initial_size = if log_file_path.exists() { std::fs::metadata(&log_file_path).unwrap().len() } else { 0 }; info!("Initial log file size: {} bytes", initial_size); // Execute commands - should write to WAL let raft_client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client = client::SMClient::new(sm_id, &raft_client); info!("Executing 10 commands that should be written to WAL"); for i in 0..10 { sm_client.take_a_shot(&1).await.unwrap(); } async_wait(Duration::from_secs(2)).await; // Verify log file exists and grew assert!(log_file_path.exists(), "WAL log file should exist"); let final_size = std::fs::metadata(&log_file_path).unwrap().len(); info!("Final log file size: {} bytes", final_size); assert!( final_size > initial_size, "Log file should have grown: initial={}, final={}", initial_size, final_size ); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); info!("WAL persistence test passed"); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_recovery_after_crash() { let _ = env_logger::try_init(); info!("TESTING WAL - RECOVERY AFTER SIMULATED CRASH"); let temp_dir = std::env::temp_dir().join(format!("raft_wal_crash_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let data_path = temp_dir.to_str().unwrap().to_string(); let port = 3300 + (rand::random::() % 100); let addr = format!("127.0.0.1:{}", port); // Phase 1: Write some logs let num_logs_before_crash; { info!("Phase 1: Starting first instance and writing logs"); let raft_service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: false, // No snapshots, only WAL append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 100 }; let server = Server::new(&addr); let sm_id = sm.id(); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; raft_service.register_state_machine(Box::new(sm)).await; RaftService::start(&raft_service, false).await; raft_service.bootstrap().await; async_wait_secs().await; let raft_client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client = client::SMClient::new(sm_id, &raft_client); // Execute commands info!("Executing 5 commands"); for _ in 0..5 { sm_client.take_a_shot(&1).await.unwrap(); } async_wait(Duration::from_secs(2)).await; // Verify state before "crash" let state_before = sm_client.get_shot().await.unwrap(); assert_eq!(state_before, 95); info!("State before crash: {}", state_before); num_logs_before_crash = raft_service.num_logs().await; info!("Logs before crash: {}", num_logs_before_crash); // Simulate crash - just drop everything drop(sm_client); drop(raft_client); drop(raft_service); drop(server); info!("Simulated crash - dropped all services"); } async_wait(Duration::from_secs(2)).await; // Phase 2: Recover from WAL { info!("Phase 2: Starting second instance and recovering from WAL"); let port2 = port + 1; let addr2 = format!("127.0.0.1:{}", port2); let raft_service2 = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), // Same data directory! take_snapshots: false, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr2.clone(), service_id: DEFAULT_SERVICE_ID, }); // New SM with different initial state let sm2 = SM { shots: 999 }; // Different from crashed instance let server2 = Server::new(&addr2); let sm_id = sm2.id(); server2.register_service(&raft_service2).await; Server::listen_and_resume(&server2).await; raft_service2.register_state_machine(Box::new(sm2)).await; // This should load logs from disk! RaftService::start(&raft_service2, false).await; raft_service2.bootstrap().await; async_wait(Duration::from_secs(2)).await; // Check that logs were recovered let num_logs_after = raft_service2.num_logs().await; info!("Logs after recovery: {}", num_logs_after); assert!(num_logs_after > 0, "Should have recovered logs from disk"); // The logs should be similar to before crash // (might have some membership logs added/removed) let expected_min = num_logs_before_crash.saturating_sub(10); assert!( num_logs_after >= expected_min, "Should recover most logs: before={}, after={}, expected >= {}", num_logs_before_crash, num_logs_after, expected_min ); info!("WAL recovery test passed - logs recovered from disk!"); } // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_log_file_format() { let _ = env_logger::try_init(); info!("TESTING WAL - LOG FILE FORMAT AND CONTENTS"); let temp_dir = std::env::temp_dir().join(format!("raft_wal_format_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); // Create storage and write some logs let mut storage = disk::StorageEntity { logs: None, snapshot: None, last_term: 0, base_path: temp_dir.clone(), plane_id: PlaneId::type1(), }; // Create test logs in memory let mut logs = BTreeMap::new(); for i in 1..=5u64 { logs.insert( i, LogEntry { id: i, term: 1, sm_id: 15, fn_id: 1, data: vec![i as u8, (i * 2) as u8], }, ); } // Create minimal RaftMeta for testing use async_std::sync::RwLock; let meta = RaftMeta { term: 1, vote_for: None, timeout: 10000, last_checked: 0, membership: Membership::Undefined, logs: Arc::new(RwLock::new(BTreeMap::new())), state_machine: Arc::new(RwLock::new(MasterStateMachine::new(DEFAULT_SERVICE_ID))), commit_index: 5, last_applied: 5, leader_id: 0, storage: None, last_snapshot_index: 0, last_snapshot_term: 0, lifecycle: LifecycleState::Running, }; let meta_lock = async_std::sync::RwLock::new(meta); let meta_guard = meta_lock.write().await; let logs_lock = async_std::sync::RwLock::new(logs); let logs_guard = logs_lock.write().await; // Open log file manually let log_path = temp_dir.join("log.dat"); storage.logs = Some(tokio::fs::File::create(&log_path).await.unwrap()); // Write logs to disk storage.append_logs(&meta_guard, &logs_guard).await.unwrap(); drop(storage); info!("Wrote 5 logs to WAL"); // Verify file exists and has content assert!(log_path.exists(), "Log file should exist"); let file_size = std::fs::metadata(&log_path).unwrap().len(); info!("Log file size: {} bytes", file_size); assert!(file_size > 0, "Log file should have content"); // Read back and verify we can parse it let file_contents = std::fs::read(&log_path).unwrap(); assert!( file_contents.len() > 50, "Log file should contain serialized entries" ); info!("WAL file format test passed"); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_fsync_durability() { let _ = env_logger::try_init(); info!("TESTING WAL - FSYNC DURABILITY GUARANTEE"); let temp_dir = std::env::temp_dir().join(format!("raft_wal_fsync_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let data_path = temp_dir.to_str().unwrap().to_string(); let addr = String::from("127.0.0.1:3400"); let raft_service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: false, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 100 }; let server = Server::new(&addr); let sm_id = sm.id(); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service.bootstrap().await; async_wait_secs().await; let raft_client = RaftClient::new(&vec![addr.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client = client::SMClient::new(sm_id, &raft_client); // Execute ONE command info!("Executing single command"); sm_client.take_a_shot(&1).await.unwrap(); // Small wait to ensure write completes async_wait(Duration::from_millis(500)).await; // Check log file was written immediately let log_file_path = std::path::PathBuf::from(&data_path).join("log.dat"); assert!( log_file_path.exists(), "Log file should exist after one command" ); // Get file modification time let metadata1 = std::fs::metadata(&log_file_path).unwrap(); let modified1 = metadata1.modified().unwrap(); info!("Log file modified at: {:?}", modified1); // Execute another command async_wait(Duration::from_millis(100)).await; sm_client.take_a_shot(&1).await.unwrap(); async_wait(Duration::from_millis(500)).await; // Verify file was modified again (new write) let metadata2 = std::fs::metadata(&log_file_path).unwrap(); let modified2 = metadata2.modified().unwrap(); let size2 = metadata2.len(); assert!( modified2 >= modified1, "Log file should be updated after second command" ); assert!( size2 > metadata1.len(), "Log file should grow with new entries" ); info!("WAL fsync durability test passed - each command persisted"); // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_log_recovery_integration() { let _ = env_logger::try_init(); info!("TESTING WAL - LOG RECOVERY (logs are deltas, need same initial state)"); let temp_dir = std::env::temp_dir().join(format!("raft_wal_state_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let data_path = temp_dir.to_str().unwrap().to_string(); let port1 = 3500 + (rand::random::() % 50); let port2 = port1 + 100; let addr1 = format!("127.0.0.1:{}", port1); let addr2 = format!("127.0.0.1:{}", port2); let sm_id = 15u64; let expected_final_state; // Phase 1: Create initial state { info!("Phase 1: Creating initial state with WAL enabled"); let raft_service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: false, // Only WAL, no snapshots append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr1.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 100 }; let server = Server::new(&addr1); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service.register_state_machine(Box::new(sm)).await; raft_service.bootstrap().await; async_wait_secs().await; let raft_client = RaftClient::new(&vec![addr1.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client = client::SMClient::new(sm_id, &raft_client); // Execute commands that modify state info!("Executing 7 commands: take_a_shot(3) each time"); for _ in 0..7 { sm_client.take_a_shot(&3).await.unwrap(); } async_wait(Duration::from_secs(2)).await; // Record expected state (100 - 7*3 = 79) expected_final_state = sm_client.get_shot().await.unwrap(); info!("State before crash: {}", expected_final_state); assert_eq!(expected_final_state, 79); // Verify logs were written let log_file = std::path::PathBuf::from(&data_path).join("log.dat"); assert!(log_file.exists(), "WAL should exist"); info!("Simulating crash..."); drop(sm_client); drop(raft_client); drop(raft_service); drop(server); } async_wait(Duration::from_secs(2)).await; // Phase 2: Recover from WAL { info!("Phase 2: Recovering from WAL after crash"); let raft_service2 = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), // Same directory! take_snapshots: false, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr2.clone(), // Different port service_id: DEFAULT_SERVICE_ID, }); // Start with SAME initial state (WAL only replays commands, not full state) let sm2 = SM { shots: 100 }; // Same as first instance let server2 = Server::new(&addr2); server2.register_service(&raft_service2).await; Server::listen_and_resume(&server2).await; // Register state machine BEFORE start (important for recovery) raft_service2.register_state_machine(Box::new(sm2)).await; // This should load logs from disk! RaftService::start(&raft_service2, false).await; raft_service2.bootstrap().await; async_wait(Duration::from_secs(3)).await; // Verify logs were recovered let recovered_logs = raft_service2.num_logs().await; info!("Recovered {} logs from WAL", recovered_logs); assert!(recovered_logs > 0, "Should have recovered logs"); // Manually apply the recovered logs { let mut meta = raft_service2.write_meta().await; info!( "Before applying: last_applied={}, commit_index={}", meta.last_applied, meta.commit_index ); super::super::check_commit(&mut meta).await; info!("After applying: last_applied={}", meta.last_applied); } // Verify state was recovered let raft_client2 = RaftClient::new(&vec![addr2.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client2 = client::SMClient::new(sm_id, &raft_client2); let recovered_state = sm_client2.get_shot().await.unwrap(); info!("State after recovery: {}", recovered_state); // WAL recovers committed logs only, so might be within 1-2 commands // of expected state (uncommitted commands are lost, which is correct) let diff = (recovered_state as i32 - expected_final_state as i32).abs(); assert!( diff <= 6, // Allow for 2 uncommitted commands (2 * 3 = 6) "State should be close to expected: expected={}, got={}, diff={}", expected_final_state, recovered_state, diff ); // Most importantly, verify logs were actually recovered assert!( recovered_logs > 5, "Should have recovered multiple logs from WAL" ); info!("✅ WAL recovery test PASSED - logs recovered and replayed!"); } // Clean up std::fs::remove_dir_all(&temp_dir).unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_deterministic_encoding() { let _ = env_logger::try_init(); info!("TESTING WAL - DETERMINISTIC ENCODING"); // Create two identical log entries let log1 = LogEntry { id: 42, term: 5, sm_id: 15, fn_id: 99, data: vec![1, 2, 3, 4, 5], }; let log2 = LogEntry { id: 42, term: 5, sm_id: 15, fn_id: 99, data: vec![1, 2, 3, 4, 5], }; // Create two DiskLogEntries from them let disk_entry1 = disk::DiskLogEntry { term: 5, commit_index: 100, last_applied: 100, log: log1, }; let disk_entry2 = disk::DiskLogEntry { term: 5, commit_index: 100, last_applied: 100, log: log2, }; // Encode both let encoded1 = disk_entry1.encode(); let encoded2 = disk_entry2.encode(); // Verify determinism: same input → same output assert_eq!(encoded1, encoded2, "Encoding should be deterministic"); info!("✓ Deterministic encoding verified"); // Verify encoding is byte-for-byte identical assert_eq!(encoded1.len(), encoded2.len()); for i in 0..encoded1.len() { assert_eq!( encoded1[i], encoded2[i], "Byte {} differs: {} vs {}", i, encoded1[i], encoded2[i] ); } info!("✓ Byte-for-byte identical"); // Decode and verify correctness let decoded = disk::DiskLogEntry::decode(&encoded1).unwrap(); assert_eq!(decoded.term, 5); assert_eq!(decoded.commit_index, 100); assert_eq!(decoded.last_applied, 100); assert_eq!(decoded.log.id, 42); assert_eq!(decoded.log.term, 5); assert_eq!(decoded.log.sm_id, 15); assert_eq!(decoded.log.fn_id, 99); assert_eq!(decoded.log.data, vec![1, 2, 3, 4, 5]); info!("✓ Decoding produces correct values"); // Verify encoding size is predictable let expected_size = 8 * 8 + 5; // 8 u64 fields + 5 data bytes = 69 bytes assert_eq!(encoded1.len(), expected_size); info!("✓ Encoding size is predictable: {} bytes", expected_size); info!("Deterministic encoding test passed!"); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_encoding_with_empty_data() { let _ = env_logger::try_init(); info!("TESTING WAL - ENCODING WITH EMPTY DATA"); let entry = disk::DiskLogEntry { term: 1, commit_index: 10, last_applied: 10, log: LogEntry { id: 10, term: 1, sm_id: 1, fn_id: 1, data: vec![], // Empty data }, }; let encoded = entry.encode(); assert_eq!(encoded.len(), 64, "Empty data should encode to 64 bytes"); let decoded = disk::DiskLogEntry::decode(&encoded).unwrap(); assert_eq!(decoded.log.data.len(), 0); assert_eq!(decoded.log.id, 10); info!("Empty data encoding test passed"); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_encoding_with_large_data() { let _ = env_logger::try_init(); info!("TESTING WAL - ENCODING WITH LARGE DATA"); let large_data = vec![42u8; 10000]; let entry = disk::DiskLogEntry { term: 99, commit_index: 500, last_applied: 500, log: LogEntry { id: 500, term: 99, sm_id: 7, fn_id: 3, data: large_data.clone(), }, }; let encoded = entry.encode(); assert_eq!( encoded.len(), 64 + 10000, "Should be 64 header + 10000 data" ); let decoded = disk::DiskLogEntry::decode(&encoded).unwrap(); assert_eq!(decoded.log.data.len(), 10000); assert_eq!(decoded.log.data, large_data); assert_eq!(decoded.term, 99); info!("Large data encoding test passed"); } #[tokio::test(flavor = "multi_thread")] async fn test_wal_only_minimal_rsm_recovery() { let _ = env_logger::try_init(); info!("=============================================="); info!("WAL-ONLY E2E (macro SM): minimal commands, no snapshot"); info!("=============================================="); let temp_dir = std::env::temp_dir().join(format!("raft_wal_only_min_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let data_path = temp_dir.to_str().unwrap().to_string(); let port1 = 3700 + (rand::random::() % 50); let port2 = port1 + 100; let addr1 = format!("127.0.0.1:{}", port1); let addr2 = format!("127.0.0.1:{}", port2); let sm_id = 15u64; let expected_state: i32; // ===== PHASE 1: Start single-node with WAL only and execute small commands ===== { let service = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: false, // WAL-only append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr1.clone(), service_id: DEFAULT_SERVICE_ID, }); let sm = SM { shots: 50 }; let server = Server::new(&addr1); server.register_service(&service).await; Server::listen_and_resume(&server).await; service.register_state_machine(Box::new(sm)).await; RaftService::start(&service, false).await; service.bootstrap().await; async_wait(Duration::from_secs(2)).await; let client = RaftClient::new(&vec![addr1.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client = client::SMClient::new(sm_id, &client); // Execute 3 small commands: total delta = 6 let deltas = [1, 2, 3]; for d in deltas.iter() { sm_client.take_a_shot(d).await.unwrap(); } // Ensure committed logs are applied and WAL is flushed before crash (deterministic) { let mut meta = service.write_meta().await; super::super::check_commit(&mut meta).await; if let Some(storage_mutex) = &meta.storage { let mut storage = storage_mutex.lock().await; let _ = storage.flush_wal().await; } } // Persist current commit progress as an extra safety barrier service.flush_persistence().await; async_wait(Duration::from_secs(1)).await; // Record actual state before crash (source of truth for recovery) expected_state = sm_client.get_shot().await.unwrap(); info!("State before crash: {}", expected_state); // Verify WAL exists let wal_file = std::path::PathBuf::from(&data_path).join("log.dat"); assert!(wal_file.exists(), "WAL file should exist"); // Graceful shutdown instead of drop drop(sm_client); drop(client); service.shutdown().await; server.shutdown().await; } async_wait(Duration::from_secs(2)).await; // ===== PHASE 2: Recover from WAL only ===== { let service2 = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), // Same directory take_snapshots: false, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr2.clone(), service_id: DEFAULT_SERVICE_ID, }); // IMPORTANT: start with same initial state; WAL replays deltas let sm2 = SM { shots: 50 }; let server2 = Server::new(&addr2); server2.register_service(&service2).await; Server::listen_and_resume(&server2).await; service2.register_state_machine(Box::new(sm2)).await; RaftService::start(&service2, false).await; service2.bootstrap().await; async_wait(Duration::from_secs(2)).await; // Ensure logs were recovered let recovered = service2.num_logs().await; info!("Recovered {} logs from WAL", recovered); assert!(recovered > 0); // Apply recovered logs { let mut meta = service2.write_meta().await; super::super::check_commit(&mut meta).await; } // Ensure all committed logs are applied after restart { let mut meta = service2.write_meta().await; super::super::check_commit(&mut meta).await; } // Verify state equals pre-crash state let client2 = RaftClient::new(&vec![addr2.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client2 = client::SMClient::new(sm_id, &client2); let recovered_state = sm_client2.get_shot().await.unwrap(); info!("Recovered state: {}", recovered_state); // Assert equality pre vs post crash assert_eq!( recovered_state, expected_state, "recovered state should equal pre-crash state" ); // Clean up drop(sm_client2); drop(client2); drop(service2); drop(server2); } std::fs::remove_dir_all(&temp_dir).unwrap(); } // ── NEW RECOVERY TESTS ─────────────────────────────────────────────── /// Simulate abrupt crash (drop without shutdown/flush) and verify the SM /// recovers to the exact pre-crash state via WAL + commit.idx. /// Asserts "not start over": recovered state ≠ fresh initial state. #[tokio::test(flavor = "multi_thread")] async fn test_abrupt_crash_full_recovery() { let _ = env_logger::try_init(); info!("=== TEST: abrupt crash → full WAL recovery ==="); let temp_dir = std::env::temp_dir().join(format!("raft_abrupt_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let data_path = temp_dir.to_str().unwrap().to_string(); let sm_id = 15u64; let initial_shots = 100i32; let num_cmds = 8i32; let expected = initial_shots - num_cmds; // 92 let port1 = 4001u16 + (rand::random::() % 20); let addr1 = format!("127.0.0.1:{}", port1); // ─── Phase 1: run commands then drop abruptly (no shutdown/flush) ─── { let svc = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: false, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr1.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr1); server.register_service(&svc).await; Server::listen_and_resume(&server).await; svc.register_state_machine(Box::new(SM { shots: initial_shots })).await; RaftService::start(&svc, false).await; svc.bootstrap().await; async_wait_secs().await; let client = RaftClient::new(&vec![addr1.clone()], DEFAULT_SERVICE_ID).await.unwrap(); let sm_client = client::SMClient::new(sm_id, &client); for _ in 0..num_cmds { sm_client.take_a_shot(&1).await.unwrap(); } async_wait(Duration::from_secs(2)).await; let before = sm_client.get_shot().await.unwrap(); assert_eq!(before, expected, "pre-crash state wrong"); info!("State before crash: {}", before); // Verify WAL and commit.idx exist assert!(temp_dir.join("log.dat").exists(), "WAL must exist"); assert!(temp_dir.join("commit.idx").exists(), "commit.idx must exist (written per-command)"); // ABRUPT CRASH — no shutdown(), no flush_persistence() drop(sm_client); drop(client); drop(svc); drop(server); info!("Abrupt crash simulated (all handles dropped)"); } async_wait(Duration::from_secs(2)).await; // ─── Phase 2: restart with same initial state, recover ─── let port2 = port1 + 50; let addr2 = format!("127.0.0.1:{}", port2); { let svc2 = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: false, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr2.clone(), service_id: DEFAULT_SERVICE_ID, }); let server2 = Server::new(&addr2); server2.register_service(&svc2).await; Server::listen_and_resume(&server2).await; // Register SM with SAME initial shots so replay produces the right result svc2.register_state_machine(Box::new(SM { shots: initial_shots })).await; RaftService::start(&svc2, false).await; svc2.bootstrap().await; // Replay committed WAL logs into the SM svc2.recover_after_register().await; async_wait(Duration::from_secs(2)).await; let client2 = RaftClient::new(&vec![addr2.clone()], DEFAULT_SERVICE_ID).await.unwrap(); let sm_client2 = client::SMClient::new(sm_id, &client2); let recovered = sm_client2.get_shot().await.unwrap(); info!("Recovered state: {} (expected {})", recovered, expected); assert_ne!(recovered, initial_shots, "Must not equal untouched initial state"); assert_eq!(recovered, expected, "Must recover exact pre-crash state via WAL"); drop(sm_client2); drop(client2); drop(svc2); drop(server2); } std::fs::remove_dir_all(&temp_dir).unwrap(); info!("=== PASS: abrupt crash recovery ==="); } /// Write N valid WAL entries, then append partial bytes to simulate a power cut /// mid-entry. Recovery must load all N good entries and truncate the corrupt tail. #[tokio::test(flavor = "multi_thread")] async fn test_wal_partial_write_truncation() { let _ = env_logger::try_init(); info!("=== TEST: partial WAL write → truncation on recovery ==="); let temp_dir = std::env::temp_dir().join(format!("raft_partial_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let log_path = temp_dir.join("log.dat"); const N: usize = 5; const DATA_LEN: usize = 8; // bytes per entry's data field // ─── Phase 1: write N complete entries via StorageEntity ─── { let mut logs = BTreeMap::new(); for i in 1..=N as u64 { logs.insert(i, LogEntry { id: i, term: 1, sm_id: 15, fn_id: 1, data: vec![i as u8; DATA_LEN] }); } let meta = RaftMeta { term: 1, vote_for: None, timeout: 10000, last_checked: 0, membership: Membership::Undefined, logs: Arc::new(async_std::sync::RwLock::new(BTreeMap::new())), state_machine: Arc::new(async_std::sync::RwLock::new( MasterStateMachine::new(DEFAULT_SERVICE_ID))), commit_index: N as u64, last_applied: N as u64, leader_id: 0, storage: None, last_snapshot_index: 0, last_snapshot_term: 0, lifecycle: LifecycleState::Running, }; let meta_lock = async_std::sync::RwLock::new(meta); let meta_guard = meta_lock.write().await; let logs_lock = async_std::sync::RwLock::new(logs); let logs_guard = logs_lock.write().await; let mut storage = disk::StorageEntity { logs: Some(tokio::fs::File::create(&log_path).await.unwrap()), snapshot: None, last_term: 0, base_path: temp_dir.clone(), plane_id: PlaneId::type1(), }; storage.append_logs(&meta_guard, &logs_guard).await.unwrap(); // drop storage to flush/close } let size_good = std::fs::metadata(&log_path).unwrap().len(); assert!(size_good > 0, "WAL must have content after {} entries", N); info!("WAL size after {} complete entries: {} bytes", N, size_good); // ─── Phase 2: append 5 garbage bytes (partial length prefix) ─── { use std::io::Write as _; let mut f = std::fs::OpenOptions::new().append(true).open(&log_path).unwrap(); f.write_all(&[0xDE, 0xAD, 0xBE, 0xEF, 0xCA]).unwrap(); f.sync_all().unwrap(); } let size_with_garbage = std::fs::metadata(&log_path).unwrap().len(); assert_eq!(size_with_garbage, size_good + 5, "File should be exactly 5 bytes larger after injecting garbage"); // ─── Phase 3: recover using new_with_options ─── let mut term = 0u64; let mut commit_index = 0u64; let mut last_applied = 0u64; let mut recovered_logs = BTreeMap::new(); let opts = Options { storage: Storage::DISK(disk::DiskOptions { path: temp_dir.to_str().unwrap().to_string(), take_snapshots: false, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: "127.0.0.1:0".to_string(), service_id: DEFAULT_SERVICE_ID, }; let _storage = disk::StorageEntity::new_with_options( &opts, &mut term, &mut commit_index, &mut last_applied, &mut recovered_logs ).unwrap(); // All N good entries must be present assert_eq!(recovered_logs.len(), N, "Should recover exactly {} entries; got {}", N, recovered_logs.len()); // The corrupt 5-byte tail must have been truncated let size_after = std::fs::metadata(&log_path).unwrap().len(); assert_eq!(size_after, size_good, "WAL file must be truncated back to {} bytes; got {}", size_good, size_after); info!("Recovered {} entries; corrupt tail truncated ({} → {} bytes)", N, size_with_garbage, size_after); std::fs::remove_dir_all(&temp_dir).unwrap(); info!("=== PASS: partial write truncation ==="); } /// Write N WAL entries, then flip bytes in the CRC field of entry K. /// Recovery must stop at entry K (recovering K entries, not K+1..N) /// and the file must be truncated at the corruption boundary. #[tokio::test(flavor = "multi_thread")] async fn test_wal_crc_corruption_stops_at_bad_entry() { let _ = env_logger::try_init(); info!("=== TEST: CRC corruption → partial recovery stops at bad entry ==="); let temp_dir = std::env::temp_dir().join(format!("raft_crc_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let log_path = temp_dir.join("log.dat"); const N: usize = 6; const DATA_LEN: usize = 8; // Corrupt entry at 0-based index CORRUPT_IDX; first CORRUPT_IDX entries should survive. const CORRUPT_IDX: usize = 3; // On-disk layout per entry: [8 len][4 CRC][64 fixed][DATA_LEN data] const RECORD_SIZE: usize = 8 + 4 + 64 + DATA_LEN; // = 84 bytes // ─── Phase 1: write N complete entries ─── { let mut logs = BTreeMap::new(); for i in 1..=N as u64 { logs.insert(i, LogEntry { id: i, term: 1, sm_id: 15, fn_id: 1, data: vec![i as u8; DATA_LEN] }); } let meta = RaftMeta { term: 1, vote_for: None, timeout: 10000, last_checked: 0, membership: Membership::Undefined, logs: Arc::new(async_std::sync::RwLock::new(BTreeMap::new())), state_machine: Arc::new(async_std::sync::RwLock::new( MasterStateMachine::new(DEFAULT_SERVICE_ID))), commit_index: N as u64, last_applied: N as u64, leader_id: 0, storage: None, last_snapshot_index: 0, last_snapshot_term: 0, lifecycle: LifecycleState::Running, }; let meta_lock = async_std::sync::RwLock::new(meta); let meta_guard = meta_lock.write().await; let logs_lock = async_std::sync::RwLock::new(logs); let logs_guard = logs_lock.write().await; let mut storage = disk::StorageEntity { logs: Some(tokio::fs::File::create(&log_path).await.unwrap()), snapshot: None, last_term: 0, base_path: temp_dir.clone(), plane_id: PlaneId::type1(), }; storage.append_logs(&meta_guard, &logs_guard).await.unwrap(); } let size_before = std::fs::metadata(&log_path).unwrap().len(); assert_eq!(size_before, (N * RECORD_SIZE) as u64, "WAL size mismatch: expected {} bytes for {} entries", N * RECORD_SIZE, N); // ─── Phase 2: flip all 4 CRC bytes of entry CORRUPT_IDX ─── { let mut file_data = std::fs::read(&log_path).unwrap(); // CRC starts at byte 8 (after 8-byte length prefix) within each record let crc_offset = CORRUPT_IDX * RECORD_SIZE + 8; file_data[crc_offset] ^= 0xFF; file_data[crc_offset + 1] ^= 0xFF; file_data[crc_offset + 2] ^= 0xFF; file_data[crc_offset + 3] ^= 0xFF; std::fs::write(&log_path, &file_data).unwrap(); info!("Corrupted CRC of entry {} at byte offset {}", CORRUPT_IDX, crc_offset); } // ─── Phase 3: recover ─── let mut term = 0u64; let mut commit_index = 0u64; let mut last_applied = 0u64; let mut recovered_logs = BTreeMap::new(); let opts = Options { storage: Storage::DISK(disk::DiskOptions { path: temp_dir.to_str().unwrap().to_string(), take_snapshots: false, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: "127.0.0.1:0".to_string(), service_id: DEFAULT_SERVICE_ID, }; let _storage = disk::StorageEntity::new_with_options( &opts, &mut term, &mut commit_index, &mut last_applied, &mut recovered_logs ).unwrap(); // Only the CORRUPT_IDX entries before the corruption survive assert_eq!(recovered_logs.len(), CORRUPT_IDX, "Should recover exactly {} entries before corruption; got {}", CORRUPT_IDX, recovered_logs.len()); // Verify the recovered entries are the correct ones (ids 1..CORRUPT_IDX) for id in 1..=CORRUPT_IDX as u64 { assert!(recovered_logs.contains_key(&id), "Entry id={} should be present", id); } for id in (CORRUPT_IDX + 1) as u64..=N as u64 { assert!(!recovered_logs.contains_key(&id), "Entry id={} should have been dropped (after corruption)", id); } // File truncated at the corruption boundary let expected_truncated_size = (CORRUPT_IDX * RECORD_SIZE) as u64; let actual_size = std::fs::metadata(&log_path).unwrap().len(); assert_eq!(actual_size, expected_truncated_size, "WAL must be truncated to {} bytes at corruption; got {}", expected_truncated_size, actual_size); info!("CRC corruption test: {} good entries recovered, {} corrupt entries dropped, \ file truncated from {} to {} bytes", CORRUPT_IDX, N - CORRUPT_IDX, size_before, actual_size); std::fs::remove_dir_all(&temp_dir).unwrap(); info!("=== PASS: CRC corruption stops recovery at bad entry ==="); } /// Crash after snapshot + additional WAL entries. /// On restart the SM must recover from the snapshot and then replay the /// post-snapshot WAL log entries — proving "partial recovery, not start over". #[tokio::test(flavor = "multi_thread")] async fn test_snapshot_plus_wal_not_start_over() { let _ = env_logger::try_init(); info!("=== TEST: snapshot + post-snapshot WAL crash recovery ==="); let temp_dir = std::env::temp_dir().join(format!("raft_snap_wal_{}", rand::random::())); std::fs::create_dir_all(&temp_dir).unwrap(); let data_path = temp_dir.to_str().unwrap().to_string(); let sm_id = 15u64; const INITIAL: i32 = 100; const CMDS_BEFORE_SNAP: i32 = 5; // shots: 100 → 95 const CMDS_AFTER_SNAP: i32 = 3; // shots: 95 → 92 let expected = INITIAL - CMDS_BEFORE_SNAP - CMDS_AFTER_SNAP; // 92 let port1 = 4050u16 + (rand::random::() % 20); let addr1 = format!("127.0.0.1:{}", port1); // ─── Phase 1: commands → snapshot → more commands → abrupt crash ─── { let svc = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: true, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, // won't auto-trigger; we do it manually log_compaction_threshold: 20000, }), address: addr1.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr1); server.register_service(&svc).await; Server::listen_and_resume(&server).await; svc.register_state_machine(Box::new(SM { shots: INITIAL })).await; RaftService::start(&svc, false).await; svc.bootstrap().await; async_wait_secs().await; let client = RaftClient::new(&vec![addr1.clone()], DEFAULT_SERVICE_ID).await.unwrap(); let sm_client = client::SMClient::new(sm_id, &client); // Execute CMDS_BEFORE_SNAP commands for _ in 0..CMDS_BEFORE_SNAP { sm_client.take_a_shot(&1).await.unwrap(); } async_wait(Duration::from_secs(1)).await; let state_before_snap = sm_client.get_shot().await.unwrap(); assert_eq!(state_before_snap, INITIAL - CMDS_BEFORE_SNAP); info!("State before snapshot: {}", state_before_snap); // Explicitly take a snapshot at this point { let mut meta = svc.write_meta().await; svc.take_snapshot(&mut meta).await; info!("Snapshot taken at state {}", state_before_snap); } assert!(temp_dir.join("snapshot.dat").exists(), "snapshot.dat must exist"); // Execute CMDS_AFTER_SNAP more commands (post-snapshot WAL entries) for _ in 0..CMDS_AFTER_SNAP { sm_client.take_a_shot(&1).await.unwrap(); } async_wait(Duration::from_secs(1)).await; let state_before_crash = sm_client.get_shot().await.unwrap(); assert_eq!(state_before_crash, expected); info!("State before crash: {}", state_before_crash); assert!(temp_dir.join("log.dat").exists(), "WAL must exist"); assert!(temp_dir.join("commit.idx").exists(), "commit.idx must exist"); // Abrupt crash drop(sm_client); drop(client); drop(svc); drop(server); info!("Abrupt crash (post snapshot + {} WAL entries)", CMDS_AFTER_SNAP); } async_wait(Duration::from_secs(2)).await; // ─── Phase 2: restart with DIFFERENT initial state (999) ─── // If the SM "starts over", it would show 999 (no recovery) or 996 (999 - CMDS_AFTER_SNAP, // only post-snapshot WAL replay from wrong base). Correct recovery gives exactly 92. let port2 = port1 + 50; let addr2 = format!("127.0.0.1:{}", port2); { let svc2 = RaftService::new(Options { storage: Storage::DISK(disk::DiskOptions { path: data_path.clone(), take_snapshots: true, append_logs: true, trim_logs: false, snapshot_log_threshold: 10000, log_compaction_threshold: 20000, }), address: addr2.clone(), service_id: DEFAULT_SERVICE_ID, }); let server2 = Server::new(&addr2); server2.register_service(&svc2).await; Server::listen_and_resume(&server2).await; // IMPORTANT: register SM AFTER start() so that load_snapshot_on_startup() // stores the snapshot bytes before register() applies them. // start() loads snapshot → stores snapshot[15] → register() applies snapshot → SM.shots=95 RaftService::start(&svc2, false).await; svc2.register_state_machine(Box::new(SM { shots: 999 })).await; svc2.bootstrap().await; // Replay post-snapshot WAL entries (entries after snapshot index → shots 95→92) svc2.recover_after_register().await; async_wait(Duration::from_secs(2)).await; let client2 = RaftClient::new(&vec![addr2.clone()], DEFAULT_SERVICE_ID).await.unwrap(); let sm_client2 = client::SMClient::new(sm_id, &client2); let recovered = sm_client2.get_shot().await.unwrap(); info!("Recovered state: {} (expected {})", recovered, expected); // Core assertions assert_ne!(recovered, 999, "SM must NOT show initial 999 — that would be 'starting over'"); assert_ne!(recovered, INITIAL, "SM must NOT show {} — that would mean snapshot was ignored", INITIAL); assert_eq!(recovered, expected, "SM must recover to exact pre-crash state via snapshot + WAL replay"); drop(sm_client2); drop(client2); drop(svc2); drop(server2); } std::fs::remove_dir_all(&temp_dir).unwrap(); info!("=== PASS: snapshot + WAL crash recovery (not start over) ==="); } } } ================================================ FILE: src/raft/state_machine/callback/client.rs ================================================ use super::*; use crate::utils::time::get_time; use async_std::sync::*; use futures::future::BoxFuture; use futures::stream::FuturesUnordered; use std::collections::HashMap; use std::sync::Arc; trait SubFunc = Fn(Vec) -> BoxFuture<'static, ()>; trait BoxedSubFunc = SubFunc + Send + Sync; pub struct SubscriptionService { pub subs: RwLock, u64)>>>, pub server_address: String, pub session_id: u64, } impl Service for SubscriptionService { fn notify<'a>(&'a self, key: SubKey, data: &'a Vec) -> BoxFuture<'a, ()> { debug!("Received notification for key {:?}", key); async move { let subs = self.subs.read().await; if let Some(subs) = subs.get(&key) { let subs = Pin::new(subs); let futs: FuturesUnordered<_> = subs .iter() .map(|(fun, _)| { let fun_pinned = Pin::new(fun); fun_pinned(data.clone()) }) .collect(); // Spawn async task DETACHED with the function to avoid deadlocks inside raft state machine tokio::spawn(async move { let _: Vec<_> = futs.collect().await; }); } } .boxed() } } dispatch_rpc_service_functions!(SubscriptionService); service_with_id!(SubscriptionService, DEFAULT_SERVICE_ID); impl SubscriptionService { pub async fn initialize(server: &Arc) -> Arc { let service = Arc::new(SubscriptionService { subs: RwLock::new(HashMap::new()), server_address: server.address().clone(), session_id: get_time() as u64, }); server.register_service(&service).await; service } } ================================================ FILE: src/raft/state_machine/callback/mod.rs ================================================ use bifrost_plugins::hash_ident; use crate::raft::PlaneId; pub mod client; pub mod server; pub use server::SMCallback; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, serde::Deserialize, serde::Serialize)] pub struct SubKey { pub service_id: u64, pub plane_id: PlaneId, pub sm_id: u64, pub fn_id: u64, pub pattern_id: u64, } impl SubKey { pub const fn new( service_id: u64, plane_id: PlaneId, sm_id: u64, fn_id: u64, pattern_id: u64, ) -> Self { Self { service_id, plane_id, sm_id, fn_id, pattern_id, } } } pub static DEFAULT_SERVICE_ID: u64 = hash_ident!(BIFROST_RAFT_SM_CALLBACK_DEFAULT_SERVICE) as u64; service! { rpc notify(key: SubKey, data: &Vec); } #[cfg(test)] mod test { use crate::raft::client::RaftClient; use crate::raft::state_machine::callback::server::SMCallback; use crate::raft::state_machine::StateMachineCtl; use crate::raft::{Options, PlaneId, PlaneSpec, RaftService, Storage, DEFAULT_SERVICE_ID}; use crate::rpc::Server; use crate::utils::time::async_wait_secs; use future::FutureExt; use std::sync::atomic::*; use std::sync::Arc; pub struct Trigger { count: u64, callback: SMCallback, } raft_state_machine! { def cmd trigger(); def sub on_trigged() -> u64; } impl StateMachineCmds for Trigger { fn trigger(&mut self) -> BoxFuture<()> { self.count += 1; async move { self.callback .notify(commands::on_trigged::new(), self.count) .await .unwrap(); } .boxed() } } impl StateMachineCtl for Trigger { raft_sm_complete!(); fn id(&self) -> u64 { 10 } fn snapshot(&self) -> Vec { unreachable!() } fn recover(&mut self, _: Vec) -> BoxFuture<()> { future::ready(()).boxed() } fn recoverable(&self) -> bool { false } } #[tokio::test(flavor = "multi_thread")] async fn dummy() { let _ = env_logger::try_init(); info!("TESTING CALLBACK"); let addr = String::from("127.0.0.1:2110"); let raft_service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); let dummy_sm = Trigger { count: 0, callback: SMCallback::new(10, raft_service.clone()).await, }; let sm_id = dummy_sm.id(); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; RaftService::start(&raft_service, false).await; raft_service .register_state_machine(Box::new(dummy_sm)) .await; raft_service.bootstrap().await; async_wait_secs().await; let raft_client = RaftClient::new(&vec![addr], DEFAULT_SERVICE_ID) .await .unwrap(); let sm_client = Arc::new(client::SMClient::new(sm_id, &raft_client)); let loops = 10; let counter = Arc::new(AtomicUsize::new(0)); let counter_clone = counter.clone(); let sumer = Arc::new(AtomicUsize::new(0)); let sumer_clone = sumer.clone(); let mut expected_sum = 0; RaftClient::prepare_subscription(&server).await; sm_client .on_trigged(move |res: u64| { counter_clone.fetch_add(1, Ordering::Relaxed); sumer_clone.fetch_add(res as usize, Ordering::Relaxed); info!("CALLBACK TRIGGERED {}", res); future::ready(()).boxed() }) .await .unwrap() .unwrap(); for i in 0..loops { let sm_client = sm_client.clone(); expected_sum += i + 1; tokio::spawn(async move { sm_client.trigger().await.unwrap(); }); } async_wait_secs().await; assert_eq!(counter.load(Ordering::Relaxed), loops); assert_eq!(sumer.load(Ordering::Relaxed), expected_sum); } #[tokio::test(flavor = "multi_thread")] async fn dummy_type2_plane() { let _ = env_logger::try_init(); let addr = String::from("127.0.0.1:2111"); let raft_service = RaftService::new(Options { storage: Storage::default(), address: addr.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&addr); server.register_service(&raft_service).await; Server::listen_and_resume(&server).await; assert!(RaftService::start(&raft_service, false).await); raft_service.bootstrap().await; let plane_id = PlaneId::type2(11).unwrap(); let plane = raft_service .ensure_plane(PlaneSpec { plane_id }) .await .expect("plane should be created"); let dummy_sm = Trigger { count: 0, callback: plane .callback(10) .await .expect("type-2 callback should bind to plane-local subscriptions"), }; let sm_id = dummy_sm.id(); plane .register_state_machine(Box::new(dummy_sm)) .await .expect("type-2 state machine should register"); plane .recover_after_register() .await .expect("type-2 plane should finish registration recovery"); let raft_client = RaftClient::new(&vec![addr], DEFAULT_SERVICE_ID) .await .unwrap(); let plane_client = raft_client.plane(plane_id); let sm_client = Arc::new(client::SMClient::new(sm_id, &plane_client)); let loops = 5; let counter = Arc::new(AtomicUsize::new(0)); let counter_clone = counter.clone(); let sumer = Arc::new(AtomicUsize::new(0)); let sumer_clone = sumer.clone(); let mut expected_sum = 0; RaftClient::prepare_subscription(&server).await; sm_client .on_trigged(move |res: u64| { counter_clone.fetch_add(1, Ordering::Relaxed); sumer_clone.fetch_add(res as usize, Ordering::Relaxed); future::ready(()).boxed() }) .await .unwrap() .unwrap(); for i in 0..loops { let sm_client = sm_client.clone(); expected_sum += i + 1; tokio::spawn(async move { sm_client.trigger().await.unwrap(); }); } async_wait_secs().await; assert_eq!(counter.load(Ordering::Relaxed), loops); assert_eq!(sumer.load(Ordering::Relaxed), expected_sum); } } ================================================ FILE: src/raft/state_machine/callback/server.rs ================================================ use super::super::OpType; use super::*; use crate::raft::{PlaneError, PlaneId, RaftMsg, RaftService}; use crate::rpc; use async_std::sync::*; use bifrost_hasher::{hash_bytes, hash_str}; use futures::stream::FuturesUnordered; use serde; use serde::{Deserialize, Serialize}; use std::any::Any; use std::collections::{HashMap, HashSet}; use std::sync::Arc; pub struct Subscriber { pub session_id: u64, pub client: Arc, } pub struct Subscriptions { next_id: u64, subscribers: HashMap, suber_subs: HashMap>, //suber_id -> sub_id subscriptions: HashMap>, // key -> sub_id sub_suber: HashMap, sub_to_key: HashMap, //sub_id -> sub_key } impl Subscriptions { pub fn new() -> Subscriptions { Subscriptions { next_id: 0, subscribers: HashMap::new(), suber_subs: HashMap::new(), subscriptions: HashMap::new(), sub_suber: HashMap::new(), sub_to_key: HashMap::new(), } } pub async fn subscribe( &mut self, key: SubKey, address: &String, session_id: u64, ) -> Result { let suber_id = hash_str(address); let suber_exists = self.subscribers.contains_key(&suber_id); let sub_id = self.next_id; debug!( "Subscription {:?} from {}, address {}, plane {}, fn {}, pattern {}", key, suber_id, address, key.plane_id.raw(), key.fn_id, key.pattern_id ); let require_reload_suber = if suber_exists { match self.subscribers.get(&suber_id) { Some(subscriber) => { let session_match = subscriber.session_id == session_id; if !session_match { self.remove_subscriber(suber_id); true } else { false } } None => { error!("Subscriber {} exists flag is true but not found in map - data inconsistency", suber_id); // Treat as if subscriber doesn't exist - require reload true } } } else { true }; if !suber_exists || require_reload_suber { self.subscribers.insert( suber_id, Subscriber { session_id, client: { if let Ok(client) = RPCClient::new_async(address).await { AsyncServiceClient::new(&client) } else { return Err(()); } }, }, ); } self.suber_subs .entry(suber_id) .or_insert_with(|| HashSet::new()) .insert(sub_id); self.subscriptions .entry(key) .or_insert_with(|| HashSet::new()) .insert(sub_id); self.sub_to_key.insert(sub_id, key); self.sub_suber.insert(sub_id, suber_id); self.next_id += 1; Ok(sub_id) } pub fn remove_subscriber(&mut self, suber_id: u64) { debug!("Removing subscriber {}", suber_id); let suber_subs = if let Some(sub_ids) = self.suber_subs.get(&suber_id) { sub_ids.iter().cloned().collect() } else { Vec::::new() }; for subs_id in suber_subs { self.remove_subscription(subs_id) } self.subscribers.remove(&suber_id); self.suber_subs.remove(&suber_id); } pub fn remove_subscription(&mut self, id: u64) { debug!("Removing subscription {}", id); let sub_key = self.sub_to_key.remove(&id); if let Some(sub_key) = sub_key { if let Some(ref mut sub_subers) = self.subscriptions.get_mut(&sub_key) { sub_subers.remove(&id); self.sub_suber.remove(&id); } } } } // used for raft services to subscribe directly from state machine instances pub struct InternalSubscription { action: Box, } pub struct SMCallback { pub subscriptions: Arc>, pub raft_service: Arc, pub internal_subs: RwLock>>, pub plane_id: PlaneId, pub sm_id: u64, } #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum NotifyError { IsNotLeader, OpTypeNotSubscribe, CannotFindSubscription, CannotFindSubscribers, CannotFindSubscriber, CannotCastInternalSub, } impl SMCallback { pub async fn new(state_machine_id: u64, raft_service: Arc) -> SMCallback { Self::new_on_plane(state_machine_id, PlaneId::type1(), raft_service) .await .expect("type-1 callback construction should not fail") } pub async fn new_on_plane( state_machine_id: u64, plane_id: PlaneId, raft_service: Arc, ) -> Result { let subscriptions = raft_service.subscriptions_on_plane(plane_id).await?; Ok(SMCallback { subscriptions, raft_service: raft_service.clone(), plane_id, sm_id: state_machine_id, internal_subs: RwLock::new(HashMap::new()), }) } pub async fn notify( &self, msg: M, message: R, ) -> Result<(usize, Vec, Vec>), NotifyError> where R: serde::Serialize + Send + Sync + Clone + Any + Unpin + 'static, M: RaftMsg + 'static, { let is_leader = self .raft_service .is_leader_on_plane(self.plane_id) .await .unwrap_or(false); if !is_leader { debug!( "Will not send notification from {} on plane {} because this node is not a leader", self.raft_service.get_server_id(), self.plane_id.raw() ); return Err(NotifyError::IsNotLeader); } let (fn_id, op_type, pattern_data) = msg.encode(); return match op_type { OpType::SUBSCRIBE => { let pattern_id = hash_bytes(&pattern_data.as_slice()); let raft_sid = self.raft_service.options.service_id; let sm_id = self.sm_id; let key = SubKey::new(raft_sid, self.plane_id, sm_id, fn_id, pattern_id); let internal_subs = self.internal_subs.read().await; let svr_subs = self.subscriptions.read().await; debug!( "Sending notification, func {}, op: {:?}, pattern_id {}", fn_id, op_type, pattern_id ); if let Some(internal_subs) = internal_subs.get(&pattern_id) { for is in internal_subs { (is.action)(&message) } } else { trace!("Cannot found internal subs {}", pattern_id); } if let Some(sub_ids) = svr_subs.subscriptions.get(&key) { let sub_result_futs: FuturesUnordered<_> = sub_ids .iter() .map(|sub_id| { let message = Pin::new(&message); async move { let svr_subs = self.subscriptions.read().await; if let Some(subscriber_id) = svr_subs.sub_suber.get(&sub_id) { if let Some(subscriber) = svr_subs.subscribers.get(&subscriber_id) { let data = crate::utils::serde::serialize(&*message); let client = &subscriber.client; debug!( "Sending out callback notification to sub id {}", sub_id ); let client_result = client.notify(key, &data).await; Ok(client_result) } else { Err(NotifyError::CannotFindSubscriber) } } else { Err(NotifyError::CannotFindSubscribers) } } }) .collect(); let sub_result: Vec<_> = sub_result_futs.collect().await; let errors: Vec = sub_result .iter() .filter_map(|r| { if let Err(e) = r { Some(e.clone()) } else { None } }) .collect(); let response: Vec<_> = sub_result .into_iter() .filter_map(|r| if let Ok(value) = r { Some(value) } else { None }) .collect(); Ok((sub_ids.len(), errors, response)) } else { Err(NotifyError::CannotFindSubscription) } } _ => Err(NotifyError::OpTypeNotSubscribe), }; } pub async fn internal_subscribe(&self, msg: M, trigger: F) -> Result<(), NotifyError> where M: RaftMsg, F: Fn(&R) + Sync + Send + 'static, R: 'static, { let (_, op_type, pattern_data) = msg.encode(); match op_type { OpType::SUBSCRIBE => { let pattern_id = hash_bytes(&pattern_data.as_slice()); let mut internal_subs = self.internal_subs.write().await; internal_subs .entry(pattern_id) .or_insert_with(|| Vec::new()) .push(InternalSubscription { action: Box::new(move |any: &dyn Any| match any.downcast_ref::() { Some(r) => trigger(r), None => warn!("type mismatch in internal subscription"), }), }); Ok(()) } _ => Err(NotifyError::OpTypeNotSubscribe), } } } pub async fn notify(callback: &Option, msg: M, data: F) where F: FnOnce() -> R, M: RaftMsg + Send + 'static, R: serde::Serialize + Send + Sync + Clone + Unpin + Any + 'static, { if let Some(ref callback) = *callback { match callback.notify(msg, data()).await { Ok(_) | Err(NotifyError::IsNotLeader) => {} Err(e) => warn!( "Cannot send nofication, failed after called due to: {:?}", e ), } } else { warn!("Cannot send notification, callback handler is empty"); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_subscriptions_new() { let subs = Subscriptions::new(); assert_eq!(subs.next_id, 0); assert!(subs.subscribers.is_empty()); assert!(subs.suber_subs.is_empty()); assert!(subs.subscriptions.is_empty()); assert!(subs.sub_suber.is_empty()); assert!(subs.sub_to_key.is_empty()); } #[test] fn test_remove_subscription_nonexistent() { let mut subs = Subscriptions::new(); // Remove non-existent subscription should not crash subs.remove_subscription(999); assert!(subs.sub_to_key.is_empty()); assert!(subs.subscriptions.is_empty()); } #[test] fn test_remove_subscription() { let mut subs = Subscriptions::new(); // Manually add a subscription let sub_id = 1u64; let sub_key = SubKey::new(0, PlaneId::type1(), 0, 100, 200); subs.sub_to_key.insert(sub_id, sub_key); subs.subscriptions .entry(sub_key) .or_insert_with(HashSet::new) .insert(sub_id); subs.sub_suber.insert(sub_id, 42u64); // Now remove it subs.remove_subscription(sub_id); assert!(!subs.sub_to_key.contains_key(&sub_id)); assert!(!subs.sub_suber.contains_key(&sub_id)); if let Some(subs_set) = subs.subscriptions.get(&sub_key) { assert!(!subs_set.contains(&sub_id)); } } #[test] fn test_remove_subscriber() { let mut subs = Subscriptions::new(); let suber_id = 42u64; let sub_id = 1u64; let sub_key = SubKey::new(0, PlaneId::type1(), 0, 100, 200); // Manually set up subscriber with subscription subs.suber_subs .entry(suber_id) .or_insert_with(HashSet::new) .insert(sub_id); subs.sub_to_key.insert(sub_id, sub_key); subs.subscriptions .entry(sub_key) .or_insert_with(HashSet::new) .insert(sub_id); subs.sub_suber.insert(sub_id, suber_id); // Remove the subscriber subs.remove_subscriber(suber_id); assert!(!subs.suber_subs.contains_key(&suber_id)); assert!(!subs.subscribers.contains_key(&suber_id)); assert!(!subs.sub_to_key.contains_key(&sub_id)); assert!(!subs.sub_suber.contains_key(&sub_id)); } #[test] fn test_remove_subscriber_nonexistent() { let mut subs = Subscriptions::new(); // Remove non-existent subscriber should not crash subs.remove_subscriber(999); assert!(subs.subscribers.is_empty()); } #[test] fn test_notify_error_debug() { // Test that NotifyError can be debugged and cloned let error = NotifyError::IsNotLeader; let cloned = error.clone(); assert!(matches!(cloned, NotifyError::IsNotLeader)); // Test all variants let _ = NotifyError::OpTypeNotSubscribe; let _ = NotifyError::CannotFindSubscription; let _ = NotifyError::CannotFindSubscribers; let _ = NotifyError::CannotFindSubscriber; let _ = NotifyError::CannotCastInternalSub; } #[test] fn test_subscriptions_next_id_increment() { let mut subs = Subscriptions::new(); assert_eq!(subs.next_id, 0); // Simulate what subscribe does with next_id let first_id = subs.next_id; subs.next_id += 1; let second_id = subs.next_id; subs.next_id += 1; assert_eq!(first_id, 0); assert_eq!(second_id, 1); assert_eq!(subs.next_id, 2); } #[test] fn test_subscriptions_multiple_subs_per_subscriber() { let mut subs = Subscriptions::new(); let suber_id = 42u64; let sub_id1 = 1u64; let sub_id2 = 2u64; let sub_key1 = SubKey::new(0, PlaneId::type1(), 0, 100, 200); let sub_key2 = SubKey::new(0, PlaneId::type1(), 0, 101, 201); // Add two subscriptions for same subscriber subs.suber_subs .entry(suber_id) .or_insert_with(HashSet::new) .insert(sub_id1); subs.suber_subs .entry(suber_id) .or_insert_with(HashSet::new) .insert(sub_id2); subs.sub_to_key.insert(sub_id1, sub_key1); subs.sub_to_key.insert(sub_id2, sub_key2); subs.sub_suber.insert(sub_id1, suber_id); subs.sub_suber.insert(sub_id2, suber_id); // Verify both subscriptions are tracked let subscriber_subs = subs.suber_subs.get(&suber_id).unwrap(); assert_eq!(subscriber_subs.len(), 2); assert!(subscriber_subs.contains(&sub_id1)); assert!(subscriber_subs.contains(&sub_id2)); // Remove the subscriber - should remove both subscriptions subs.remove_subscriber(suber_id); assert!(!subs.sub_to_key.contains_key(&sub_id1)); assert!(!subs.sub_to_key.contains_key(&sub_id2)); assert!(!subs.suber_subs.contains_key(&suber_id)); } } ================================================ FILE: src/raft/state_machine/configs.rs ================================================ use crate::raft::state_machine::callback::server::Subscriptions; use crate::raft::state_machine::callback::SubKey; use crate::raft::state_machine::StateMachineCtl; use crate::raft::AsyncServiceClient; use crate::rpc::{self, ServiceClient}; use async_std::sync::*; use bifrost_hasher::hash_str; use futures::FutureExt; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; pub const CONFIG_SM_ID: u64 = 1; #[derive(Clone)] pub struct RaftMember { pub rpc: Arc, pub address: String, pub id: u64, } pub struct Configures { pub members: HashMap, // keep it in arc lock for reference in callback server.rs pub subscriptions: Arc>, service_id: u64, } pub type MemberConfigSnapshot = HashSet; #[derive(Serialize, Deserialize, Debug)] pub struct ConfigSnapshot { members: MemberConfigSnapshot, //TODO: snapshot for subscriptions } raft_state_machine! { def cmd new_member_(address: String) -> bool; def cmd del_member_(address: String); def qry member_address() -> Vec; def cmd subscribe(key: SubKey, address: String, session_id: u64) -> Result; def cmd unsubscribe(sub_id: u64); } impl StateMachineCmds for Configures { fn new_member_(&mut self, address: String) -> BoxFuture { async move { let addr = address.clone(); let id = hash_str(&addr); if !self.members.contains_key(&id) { match rpc::DEFAULT_CLIENT_POOL.get(&address).await { Ok(client) => { self.members.insert( id, RaftMember { rpc: AsyncServiceClient::new_with_service_id( self.service_id, &client, ), address, id, }, ); return true; } Err(_) => {} } } false } .boxed() } fn del_member_(&mut self, address: String) -> BoxFuture<()> { async move { let hash = hash_str(&address); self.members.remove(&hash); } .boxed() } fn member_address(&self) -> BoxFuture> { future::ready(self.members.values().map(|m| m.address.clone()).collect()).boxed() } fn subscribe( &mut self, key: SubKey, address: String, session_id: u64, ) -> BoxFuture> { async move { let mut subs = self.subscriptions.write().await; subs.subscribe(key, &address, session_id).await } .boxed() } fn unsubscribe(&mut self, sub_id: u64) -> BoxFuture<()> { async move { let mut subs = self.subscriptions.write().await; subs.remove_subscription(sub_id); } .boxed() } } impl StateMachineCtl for Configures { raft_sm_complete!(); fn id(&self) -> u64 { CONFIG_SM_ID } fn snapshot(&self) -> Vec { let mut snapshot = ConfigSnapshot { members: HashSet::with_capacity(self.members.len()), }; for (_, member) in self.members.iter() { snapshot.members.insert(member.address.clone()); } crate::utils::serde::serialize(&snapshot) } fn recover(&mut self, data: Vec) -> BoxFuture<()> { match crate::utils::serde::deserialize::(&data) { Some(snapshot) => self.recover_members(snapshot.members).boxed(), None => { error!( "Failed to deserialize config state machine snapshot. Config recovery failed." ); // Return empty future - state machine will start with empty config future::ready(()).boxed() } } } fn recoverable(&self) -> bool { true } } impl Configures { pub fn new(service_id: u64) -> Configures { Configures { members: HashMap::new(), service_id, subscriptions: Arc::new(RwLock::new(Subscriptions::new())), } } async fn recover_members(&mut self, snapshot: MemberConfigSnapshot) { let mut curr_members: MemberConfigSnapshot = HashSet::with_capacity(self.members.len()); for (_, member) in self.members.iter() { curr_members.insert(member.address.clone()); } let to_del = curr_members.difference(&snapshot); let to_add = snapshot.difference(&curr_members); for addr in to_del { self.del_member(addr.clone()).await; } for addr in to_add { self.new_member(addr.clone()).await; } } pub async fn new_member(&mut self, address: String) -> bool { self.new_member_(address).await } pub async fn del_member(&mut self, address: String) { self.del_member_(address).await } pub fn member_existed(&self, id: u64) -> bool { self.members.contains_key(&id) } } #[cfg(test)] mod tests { use super::*; use crate::raft::state_machine::StateMachineCtl; #[test] fn test_configures_new() { let service_id = 12345u64; let config = Configures::new(service_id); assert_eq!(config.service_id, service_id); assert!(config.members.is_empty()); assert_eq!(config.id(), CONFIG_SM_ID); } #[test] fn test_configures_id() { let config = Configures::new(1); assert_eq!(config.id(), CONFIG_SM_ID); assert_eq!(CONFIG_SM_ID, 1); } #[test] fn test_configures_recoverable() { let config = Configures::new(1); assert!(config.recoverable()); } #[test] fn test_member_existed() { let config = Configures::new(1); let member_id = hash_str(&String::from("test_member")); assert!(!config.member_existed(member_id)); assert!(!config.member_existed(123456)); assert!(!config.member_existed(0)); } #[tokio::test(flavor = "multi_thread")] async fn test_member_address() { let config = Configures::new(1); let addresses = config.member_address().await; assert!(addresses.is_empty()); } #[test] fn test_snapshot_empty() { let config = Configures::new(1); let snapshot = config.snapshot(); assert!(!snapshot.is_empty()); let deserialized: Option = crate::utils::serde::deserialize(&snapshot); assert!(deserialized.is_some()); let snapshot_data = deserialized.unwrap(); assert!(snapshot_data.members.is_empty()); } #[tokio::test(flavor = "multi_thread")] async fn test_recover_empty_snapshot() { let mut config = Configures::new(1); let snapshot = config.snapshot(); config.recover(snapshot).await; assert!(config.members.is_empty()); } #[tokio::test(flavor = "multi_thread")] async fn test_recover_invalid_data() { let mut config = Configures::new(1); // Try to recover from invalid data let invalid_data = vec![0xFF, 0xFF, 0xFF]; config.recover(invalid_data).await; // Should not crash, just log error and continue with empty config assert!(config.members.is_empty()); } #[tokio::test(flavor = "multi_thread")] async fn test_del_member() { let mut config = Configures::new(1); // Delete non-existent member should not crash config.del_member(String::from("non_existent")).await; assert!(config.members.is_empty()); } #[test] fn test_config_snapshot_serialization() { let mut snapshot = ConfigSnapshot { members: HashSet::new(), }; snapshot.members.insert(String::from("member1")); snapshot.members.insert(String::from("member2")); let serialized = crate::utils::serde::serialize(&snapshot); let deserialized: Option = crate::utils::serde::deserialize(&serialized); assert!(deserialized.is_some()); let recovered = deserialized.unwrap(); assert_eq!(recovered.members.len(), 2); assert!(recovered.members.contains("member1")); assert!(recovered.members.contains("member2")); } #[test] fn test_configures_has_subscriptions() { let config = Configures::new(1); // Verify subscriptions Arc is initialized assert_eq!(Arc::strong_count(&config.subscriptions), 1); } // ============================================================================ // REAL-WORLD MEMBERSHIP CHANGE SCENARIOS // ============================================================================ #[tokio::test(flavor = "multi_thread")] async fn test_membership_server_join() { use crate::rpc::Server; let mut config = Configures::new(1); // Initially no members assert!(config.members.is_empty()); assert_eq!(config.member_address().await.len(), 0); // Start a test server for the member to join let member_addr = String::from("127.0.0.1:4100"); let server = Server::new(&member_addr); Server::listen_and_resume(&server).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // Attempt to add new member let result = config.new_member(member_addr.clone()).await; assert!(result, "Member should join successfully"); // Verify member was added let member_id = hash_str(&member_addr); assert!(config.member_existed(member_id)); assert_eq!(config.members.len(), 1); let addresses = config.member_address().await; assert_eq!(addresses.len(), 1); assert!(addresses.contains(&member_addr)); // Cleanup server.shutdown().await; } #[tokio::test(flavor = "multi_thread")] async fn test_membership_server_leave() { use crate::rpc::Server; let mut config = Configures::new(1); // Add a member first let member_addr = String::from("127.0.0.1:4200"); let server = Server::new(&member_addr); Server::listen_and_resume(&server).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; config.new_member(member_addr.clone()).await; let member_id = hash_str(&member_addr); assert!(config.member_existed(member_id)); // Now remove the member config.del_member(member_addr.clone()).await; // Verify member was removed assert!(!config.member_existed(member_id)); assert_eq!(config.members.len(), 0); let addresses = config.member_address().await; assert_eq!(addresses.len(), 0); // Cleanup server.shutdown().await; } #[tokio::test(flavor = "multi_thread")] async fn test_membership_server_rejoin() { use crate::rpc::Server; let mut config = Configures::new(1); let member_addr = String::from("127.0.0.1:4300"); let server = Server::new(&member_addr); Server::listen_and_resume(&server).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // Initial join let result1 = config.new_member(member_addr.clone()).await; assert!(result1, "First join should succeed"); let member_id = hash_str(&member_addr); assert!(config.member_existed(member_id)); // Leave config.del_member(member_addr.clone()).await; assert!(!config.member_existed(member_id)); // Rejoin - should succeed let result2 = config.new_member(member_addr.clone()).await; assert!(result2, "Rejoin should succeed"); assert!(config.member_existed(member_id)); // Cleanup server.shutdown().await; } #[tokio::test(flavor = "multi_thread")] async fn test_membership_multiple_servers_join() { use crate::rpc::Server; let mut config = Configures::new(1); // Add multiple servers let addrs = vec![ String::from("127.0.0.1:4400"), String::from("127.0.0.1:4401"), String::from("127.0.0.1:4402"), ]; let mut servers = vec![]; for addr in &addrs { let server = Server::new(addr); Server::listen_and_resume(&server).await; servers.push(server); } tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; // Join all servers for addr in &addrs { let result = config.new_member(addr.clone()).await; assert!(result, "Server {} should join", addr); } // Verify all are members assert_eq!(config.members.len(), 3); let member_addrs = config.member_address().await; assert_eq!(member_addrs.len(), 3); for addr in &addrs { assert!(member_addrs.contains(addr), "Should contain {}", addr); let member_id = hash_str(addr); assert!(config.member_existed(member_id)); } // Cleanup for server in servers { server.shutdown().await; } } #[tokio::test(flavor = "multi_thread")] async fn test_membership_partial_leave() { use crate::rpc::Server; let mut config = Configures::new(1); let addrs = vec![ String::from("127.0.0.1:4500"), String::from("127.0.0.1:4501"), String::from("127.0.0.1:4502"), ]; let mut servers = vec![]; for addr in &addrs { let server = Server::new(addr); Server::listen_and_resume(&server).await; servers.push(server); } tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; // All join for addr in &addrs { config.new_member(addr.clone()).await; } assert_eq!(config.members.len(), 3); // Remove middle server config.del_member(addrs[1].clone()).await; // Verify only 2 remain assert_eq!(config.members.len(), 2); let member_addrs = config.member_address().await; assert!(member_addrs.contains(&addrs[0])); assert!(!member_addrs.contains(&addrs[1])); // Removed assert!(member_addrs.contains(&addrs[2])); // Cleanup for server in servers { server.shutdown().await; } } #[tokio::test(flavor = "multi_thread")] async fn test_membership_duplicate_join_attempt() { use crate::rpc::Server; let mut config = Configures::new(1); let member_addr = String::from("127.0.0.1:4600"); let server = Server::new(&member_addr); Server::listen_and_resume(&server).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // First join let result1 = config.new_member(member_addr.clone()).await; assert!(result1); // Attempt duplicate join - should fail gracefully let result2 = config.new_member(member_addr.clone()).await; assert!(!result2, "Duplicate join should fail"); // Still only one member assert_eq!(config.members.len(), 1); // Cleanup server.shutdown().await; } #[tokio::test(flavor = "multi_thread")] async fn test_membership_snapshot_with_members() { use crate::rpc::Server; let mut config = Configures::new(1); let addrs = vec![ String::from("127.0.0.1:4700"), String::from("127.0.0.1:4701"), ]; let mut servers = vec![]; for addr in &addrs { let server = Server::new(addr); Server::listen_and_resume(&server).await; servers.push(server); } tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; // Add members for addr in &addrs { config.new_member(addr.clone()).await; } // Take snapshot let snapshot = config.snapshot(); assert!(!snapshot.is_empty()); // Verify snapshot contains members let snapshot_data: Option = crate::utils::serde::deserialize(&snapshot); assert!(snapshot_data.is_some()); let snapshot_data = snapshot_data.unwrap(); assert_eq!(snapshot_data.members.len(), 2); for addr in &addrs { assert!(snapshot_data.members.contains(addr)); } // Cleanup for server in servers { server.shutdown().await; } } #[tokio::test(flavor = "multi_thread")] async fn test_membership_recovery_with_changes() { use crate::rpc::Server; let mut config1 = Configures::new(1); let addrs = vec![ String::from("127.0.0.1:4800"), String::from("127.0.0.1:4801"), ]; let mut servers = vec![]; for addr in &addrs { let server = Server::new(addr); Server::listen_and_resume(&server).await; servers.push(server); } tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; // Add members to first config for addr in &addrs { config1.new_member(addr.clone()).await; } // Take snapshot let snapshot = config1.snapshot(); // Create new config and recover let mut config2 = Configures::new(1); config2.recover(snapshot).await; // Add the third server let new_addr = String::from("127.0.0.1:4802"); let new_server = Server::new(&new_addr); Server::listen_and_resume(&new_server).await; servers.push(new_server); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // The recovered config should be able to add new members let result = config2.new_member(new_addr.clone()).await; assert!(result); // Cleanup for server in servers { server.shutdown().await; } } #[tokio::test(flavor = "multi_thread")] async fn test_membership_leave_nonexistent() { let mut config = Configures::new(1); // Try to remove a member that doesn't exist - should not crash config.del_member(String::from("127.0.0.1:9999")).await; assert_eq!(config.members.len(), 0); } #[tokio::test(flavor = "multi_thread")] async fn test_membership_join_unreachable_server() { let mut config = Configures::new(1); // Try to add a member that's not actually running let unreachable_addr = String::from("127.0.0.1:9998"); let result = config.new_member(unreachable_addr.clone()).await; // Should fail because server is not reachable assert!(!result, "Join should fail for unreachable server"); assert_eq!(config.members.len(), 0); } } ================================================ FILE: src/raft/state_machine/macros.rs ================================================ //TODO: Use higher order macro to merge with rpc service! macro when possible to do this in Rust. //Current major problem is inner repeated macro will be recognized as outer macro which breaks expand #[macro_export] macro_rules! raft_trait_fn { (qry $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty) => { fn $fn_name<'a>(&'a self, $($arg:$in_),*) -> ::futures::future::BoxFuture<$out>; }; (cmd $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty) => { fn $fn_name<'a>(&'a mut self, $($arg:$in_),*) -> ::futures::future::BoxFuture<$out>; }; (sub $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty) => {} } #[macro_export] macro_rules! raft_client_fn { (sub $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty) => { pub fn $fn_name(&self, f: F, $($arg:$in_),* ) -> BoxFuture, $crate::raft::state_machine::master::ExecError>> where F: Fn($out) -> BoxFuture<'static, ()> + 'static + Send + Sync { self.client.subscribe( self.sm_id, $fn_name::new($($arg,)*), f ).boxed() } }; ($others:ident $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty) => { pub async fn $fn_name(&self, $($arg:$in_),*) -> Result<$out, $crate::raft::state_machine::master::ExecError> { self.client.execute( self.sm_id, $fn_name::new($($arg,)*) ).await } }; } #[macro_export] macro_rules! raft_fn_op_type { (qry) => { $crate::raft::state_machine::OpType::QUERY }; (cmd) => { $crate::raft::state_machine::OpType::COMMAND }; (sub) => { $crate::raft::state_machine::OpType::SUBSCRIBE }; } #[macro_export] macro_rules! raft_dispatch_fn { ($fn_name:ident $s: ident $d: ident ( $( $arg:ident : $in_:ty ),* )) => {{ let decoded: ($($in_,)*) = match $crate::utils::serde::deserialize($d) { Some(decoded) => decoded, None => panic!("Failed to deserialize function call data for function: {}, s: {}, d: {}", stringify!($fn_name), stringify!($s), stringify!($d)), }; let ($($arg,)*) = decoded; let f_result = $s.$fn_name($($arg),*).await; Some($crate::utils::serde::serialize(&f_result)) }}; } #[macro_export] macro_rules! raft_dispatch_cmd { (cmd $fn_name:ident $s: ident $d: ident ( $( $arg:ident : $in_:ty ),* )) => { raft_dispatch_fn!($fn_name $s $d( $( $arg : $in_ ),* )) }; ($others:ident $fn_name:ident $s: ident $d: ident ( $( $arg:ident : $in_:ty ),* )) => {None}; } #[macro_export] macro_rules! raft_dispatch_qry { (qry $fn_name:ident $s: ident $d: ident ( $( $arg:ident : $in_:ty ),* )) => { raft_dispatch_fn!($fn_name $s $d( $( $arg : $in_ ),* )) }; ($others:ident $fn_name:ident $s: ident $d: ident ( $( $arg:ident : $in_:ty ),* )) => {None}; } #[macro_export] macro_rules! raft_sm_complete { () => { fn fn_dispatch_cmd<'a>( &'a mut self, fn_id: u64, data: &'a Vec, ) -> ::futures::future::BoxFuture<'a, Option>> { self.dispatch_cmd_(fn_id, data) } fn fn_dispatch_qry<'a>( &'a self, fn_id: u64, data: &'a Vec, ) -> ::futures::future::BoxFuture<'a, Option>> { self.dispatch_qry_(fn_id, data) } fn op_type(&mut self, fn_id: u64) -> Option<$crate::raft::state_machine::OpType> { self.op_type_(fn_id) } }; } #[macro_export] macro_rules! raft_state_machine { ( $( $(#[$attr:meta])* def $smt:ident $fn_name:ident( $( $arg:ident : $in_:ty ),* ) $(-> $out:ty)* ; )* ) => { raft_state_machine! {{ $( $(#[$attr])* def $smt $fn_name( $( $arg : $in_ ),* ) $(-> $out)*; )* }} }; ( { $(#[$attr:meta])* def $smt:ident $fn_name:ident( $( $arg:ident : $in_:ty ),* ); // No return $( $unexpanded:tt )* } $( $expanded:tt )* ) => { raft_state_machine! { { $( $unexpanded )* } $( $expanded )* $(#[$attr])* def $smt $fn_name( $( $arg : $in_ ),* ) -> (); } }; ( { $(#[$attr:meta])* def $smt:ident $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; $( $unexpanded:tt )* } $( $expanded:tt )* ) => { raft_state_machine! { { $( $unexpanded )* } $( $expanded )* $(#[$attr])* def $smt $fn_name( $( $arg : $in_ ),* ) -> $out; } }; ( {} // all expanded $( $(#[$attr:meta])* def $smt:ident $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty; )* ) => { #[allow(unused_imports)] use futures::prelude::*; use futures::future::BoxFuture; #[allow(dead_code)] #[allow(unused_imports)] pub mod commands { use super::*; use futures::prelude::*; use serde::{Serialize, Deserialize}; $( #[derive(Serialize, Deserialize, Debug)] #[allow(non_camel_case_types)] pub struct $fn_name { pub data: Vec } impl $crate::raft::RaftMsg<$out> for $fn_name { fn encode(self) -> (u64, $crate::raft::state_machine::OpType, Vec) { ( ::bifrost_plugins::hash_ident!($fn_name) as u64, raft_fn_op_type!($smt), self.data ) } fn decode_return(data: &Vec) -> $out { $crate::utils::serde::deserialize(data).unwrap() } } impl $fn_name { pub fn new($($arg:&$in_),*) -> $fn_name { let req_data = ($($arg,)*); $fn_name { data: $crate::utils::serde::serialize(&req_data) } } } )* } #[allow(dead_code)] #[allow(unused_variables)] pub trait StateMachineCmds: $crate::raft::state_machine::StateMachineCtl { $( $(#[$attr])* raft_trait_fn!($smt $fn_name( $( $arg : $in_ ),* ) -> $out); )* fn op_type_(&self, fn_id: u64) -> Option<$crate::raft::state_machine::OpType> { match fn_id as usize { $(::bifrost_plugins::hash_ident!($fn_name) => { Some(raft_fn_op_type!($smt)) }),* _ => { debug!("Undefined function id: {}", fn_id); None } } } fn dispatch_cmd_<'a>(&'a mut self, fn_id: u64, data: &'a Vec) -> BoxFuture>> { async move { match fn_id as usize { $(::bifrost_plugins::hash_ident!($fn_name) => { raft_dispatch_cmd!($smt $fn_name self data( $( $arg : $in_ ),* )) }),* _ => { debug!("Undefined function id: {}. We have {}", fn_id, concat!(stringify!($($fn_name),*))); None } } }.boxed() } fn dispatch_qry_<'a>(&'a self, fn_id: u64, data: &'a Vec) -> BoxFuture>> { async move { match fn_id as usize { $(::bifrost_plugins::hash_ident!($fn_name) => { raft_dispatch_qry!($smt $fn_name self data( $( $arg : $in_ ),* )) }),* _ => { debug!("Undefined function id: {}", fn_id); None } } }.boxed() } } #[allow(dead_code)] #[allow(unused_imports)] pub mod client { use super::*; use std::sync::Arc; use super::commands::*; use $crate::raft::client::*; use $crate::raft::state_machine::master::ExecError; use $crate::raft::state_machine::StateMachineClient; use $crate::raft::client::{AsRaftPlaneClient, RaftClient, RaftPlaneClient, SubscriptionError, SubscriptionReceipt}; pub struct SMClient { client: Arc, sm_id: u64 } impl SMClient { $( $(#[$attr])* raft_client_fn!($smt $fn_name( $( $arg : &$in_ ),* ) -> $out); )* pub fn new(sm_id: u64, client: &Arc) -> Self where C: AsRaftPlaneClient + 'static, { Self { client: client.as_raft_plane_client(), sm_id: sm_id } } } impl StateMachineClient for SMClient { fn new_instance (sm_id: u64, client: &Arc) -> Self { Self::new(sm_id, client) } } } }; } ================================================ FILE: src/raft/state_machine/master.rs ================================================ use self::configs::{Configures, RaftMember, CONFIG_SM_ID}; use super::super::*; use super::*; use std::collections::HashMap; use std::error::Error; use std::fmt; use std::fmt::Display; use std::fmt::Formatter; #[derive(Serialize, Deserialize, Debug, Clone)] pub enum ExecError { SmNotFound(u64), FnNotFound(u64, u64), // (sm_id, fn_id) ServersUnreachable, CannotConstructClient, NotCommitted, ShuttingDown, Unknown, TooManyRetry, } pub enum RegisterResult { OK, EXISTED, RESERVED, } pub type ExecOk = Vec; pub type ExecResult = Result; pub type SubStateMachine = Box; pub type SnapshotDataItem = (u64, Vec); pub type SnapshotDataItems = Vec; raft_state_machine! {} pub struct MasterStateMachine { subs: HashMap, snapshots: HashMap>, pub configs: Configures, plane_id: PlaneId, } impl StateMachineCmds for MasterStateMachine {} impl StateMachineCtl for MasterStateMachine { raft_sm_complete!(); fn id(&self) -> u64 { 0 } fn snapshot(&self) -> Vec { let mut sms: SnapshotDataItems = Vec::with_capacity(self.subs.len()); for (sm_id, smc) in self.subs.iter() { if !smc.recoverable() { continue; } let sub_snapshot = smc.snapshot(); sms.push((*sm_id, sub_snapshot)); } sms.push((self.configs.id(), self.configs.snapshot())); let data = crate::utils::serde::serialize(&sms); data } fn recover(&mut self, data: Vec) -> BoxFuture<()> { match crate::utils::serde::deserialize::(data.as_slice()) { Some(sms) => { for (sm_id, snapshot) in sms { self.snapshots.insert(sm_id, snapshot); } } None => { error!( "Failed to deserialize master state machine snapshot for plane {}. State machine recovery failed.", self.plane_id.raw() ); // Clear snapshots to start fresh - this is safer than leaving corrupted state self.snapshots.clear(); } } future::ready(()).boxed() } fn recoverable(&self) -> bool { true } } pub fn parse_output<'a>(r: Option>) -> ExecResult { if let Some(d) = r { Ok(d) } else { // Caller will wrap with correct (sm_id, fn_id); default to (0,0) if unknown Err(ExecError::FnNotFound(0, 0)) } } impl MasterStateMachine { pub fn new(service_id: u64) -> MasterStateMachine { Self::new_on_plane(service_id, PlaneId::type1()) } pub fn new_on_plane(service_id: u64, plane_id: PlaneId) -> MasterStateMachine { let msm = MasterStateMachine { subs: HashMap::new(), snapshots: HashMap::new(), configs: Configures::new(service_id), plane_id, }; msm } /// Whether a given state machine id should be persisted/recovered. pub fn is_recoverable(&self, sm_id: u64) -> bool { if sm_id == CONFIG_SM_ID { return self.configs.recoverable(); } if let Some(sm) = self.subs.get(&sm_id) { return sm.recoverable(); } // Default to true if SM is not yet registered so we don't skip WAL true } pub fn register(&mut self, smc: SubStateMachine) -> RegisterResult { let id = smc.id(); if is_reserved_internal_sm_id(id) { return RegisterResult::RESERVED; } if self.subs.contains_key(&id) { return RegisterResult::EXISTED; }; self.subs.insert(id, smc); RegisterResult::OK } pub async fn recover_registered_snapshots(&mut self) { if let Some(snapshot) = self.snapshots.remove(&CONFIG_SM_ID) { self.configs.recover(snapshot).await; } let recoverable_ids: Vec = self .subs .keys() .filter(|id| self.snapshots.contains_key(id)) .copied() .collect(); for sm_id in recoverable_ids { if let Some(snapshot) = self.snapshots.remove(&sm_id) { if let Some(smc) = self.subs.get_mut(&sm_id) { smc.recover(snapshot).await; } } } } pub fn members(&self) -> &HashMap { &self.configs.members } pub async fn commit_cmd(&mut self, entry: &LogEntry) -> ExecResult { match entry.sm_id { CONFIG_SM_ID => { let out = self.configs.fn_dispatch_cmd(entry.fn_id, &entry.data).await; match out { Some(d) => Ok(d), None => { warn!( "FN not found for cmd on plane {} sm_id={}, fn_id={} at log_id={}", self.plane_id.raw(), entry.sm_id, entry.fn_id, entry.id ); Err(ExecError::FnNotFound(entry.sm_id, entry.fn_id)) } } } _ => { match self.subs.get_mut(&entry.sm_id) { Some(sm) => { let out = sm.as_mut().fn_dispatch_cmd(entry.fn_id, &entry.data).await; match out { Some(data) => Ok(data), None => { warn!( "FN not found for cmd on plane {} sm_id={}, fn_id={} at log_id={}", self.plane_id.raw(), entry.sm_id, entry.fn_id, entry.id ); Err(ExecError::FnNotFound(entry.sm_id, entry.fn_id)) } } } None => { warn!( "SM not found for cmd on plane {} sm_id={} at log_id={}, have SMs: {:?}", self.plane_id.raw(), entry.sm_id, entry.id, self.subs.keys().collect::>() ); Err(ExecError::SmNotFound(entry.sm_id)) } } } } } pub async fn exec_qry(&self, entry: &LogEntry) -> ExecResult { match entry.sm_id { CONFIG_SM_ID => { let out = self.configs.fn_dispatch_qry(entry.fn_id, &entry.data).await; match out { Some(d) => Ok(d), None => { warn!( "FN not found for qry on plane {} sm_id={}, fn_id={} at log_id={}", self.plane_id.raw(), entry.sm_id, entry.fn_id, entry.id ); Err(ExecError::FnNotFound(entry.sm_id, entry.fn_id)) } } } _ => { match self.subs.get(&entry.sm_id) { Some(sm) => { let out = sm.fn_dispatch_qry(entry.fn_id, &entry.data).await; match out { Some(data) => Ok(data), None => { warn!( "FN not found for qry on plane {} sm_id={}, fn_id={} at log_id={}", self.plane_id.raw(), entry.sm_id, entry.fn_id, entry.id ); Err(ExecError::FnNotFound(entry.sm_id, entry.fn_id)) } } } None => { warn!( "SM not found for qry on plane {} sm_id={} at log_id={}, have SMs: {:?}", self.plane_id.raw(), entry.sm_id, entry.id, self.subs.keys().collect::>() ); Err(ExecError::SmNotFound(entry.sm_id)) } } } } } pub fn clear_subs(&mut self) { self.subs.clear() } pub fn has_sub(&self, id: &u64) -> bool { self.subs.contains_key(&id) } } impl Error for ExecError {} impl Display for ExecError { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "{:?}", self) } } #[cfg(test)] mod tests { use super::*; use crate::raft::state_machine::StateMachineCtl; // Mock state machine for testing struct MockStateMachine { id: u64, recoverable: bool, snapshot_data: Vec, recovered_data: Option>, } impl StateMachineCtl for MockStateMachine { raft_sm_complete!(); fn id(&self) -> u64 { self.id } fn snapshot(&self) -> Vec { self.snapshot_data.clone() } fn recover(&mut self, data: Vec) -> BoxFuture<()> { self.recovered_data = Some(data); future::ready(()).boxed() } fn recoverable(&self) -> bool { self.recoverable } } impl StateMachineCmds for MockStateMachine {} #[test] fn test_master_state_machine_new() { let service_id = 12345u64; let msm = MasterStateMachine::new(service_id); assert_eq!(msm.id(), 0); assert!(msm.subs.is_empty()); assert!(msm.snapshots.is_empty()); // configs should be initialized assert_eq!(msm.configs.id(), CONFIG_SM_ID); } #[test] fn test_master_state_machine_id() { let msm = MasterStateMachine::new(1); assert_eq!(msm.id(), 0); } #[test] fn test_master_state_machine_recoverable() { let msm = MasterStateMachine::new(1); assert!(msm.recoverable()); } #[test] fn test_is_recoverable_config_sm() { let msm = MasterStateMachine::new(1); // CONFIG_SM_ID (1) should be recoverable assert!(msm.is_recoverable(CONFIG_SM_ID)); } #[test] fn test_is_recoverable_registered_sm() { let mut msm = MasterStateMachine::new(1); let mock_sm = Box::new(MockStateMachine { id: 100, recoverable: true, snapshot_data: vec![1, 2, 3], recovered_data: None, }); msm.register(mock_sm); assert!(msm.is_recoverable(100)); } #[test] fn test_is_recoverable_non_recoverable_sm() { let mut msm = MasterStateMachine::new(1); let mock_sm = Box::new(MockStateMachine { id: 200, recoverable: false, snapshot_data: vec![], recovered_data: None, }); msm.register(mock_sm); assert!(!msm.is_recoverable(200)); } #[test] fn test_is_recoverable_unregistered_sm() { let msm = MasterStateMachine::new(1); // Unregistered SMs default to true so we don't skip WAL assert!(msm.is_recoverable(999)); } #[test] fn test_register_ok() { let mut msm = MasterStateMachine::new(1); let mock_sm = Box::new(MockStateMachine { id: 10, recoverable: true, snapshot_data: vec![1, 2, 3], recovered_data: None, }); let result = msm.register(mock_sm); assert!(matches!(result, RegisterResult::OK)); assert!(msm.has_sub(&10)); } #[test] fn test_register_reserved() { let mut msm = MasterStateMachine::new(1); // ID 0 is reserved for master let mock_sm0 = Box::new(MockStateMachine { id: 0, recoverable: true, snapshot_data: vec![], recovered_data: None, }); let result = msm.register(mock_sm0); assert!(matches!(result, RegisterResult::RESERVED)); // ID 1 is reserved for config let mock_sm1 = Box::new(MockStateMachine { id: 1, recoverable: true, snapshot_data: vec![], recovered_data: None, }); let result = msm.register(mock_sm1); assert!(matches!(result, RegisterResult::RESERVED)); } #[test] fn test_register_existed() { let mut msm = MasterStateMachine::new(1); let mock_sm1 = Box::new(MockStateMachine { id: 10, recoverable: true, snapshot_data: vec![1, 2, 3], recovered_data: None, }); msm.register(mock_sm1); // Try to register again with same ID let mock_sm2 = Box::new(MockStateMachine { id: 10, recoverable: true, snapshot_data: vec![4, 5, 6], recovered_data: None, }); let result = msm.register(mock_sm2); assert!(matches!(result, RegisterResult::EXISTED)); } #[tokio::test(flavor = "multi_thread")] async fn test_register_with_snapshot_recovery() { let mut msm = MasterStateMachine::new(1); // Add a snapshot for SM id 10 let snapshot_data = vec![1, 2, 3, 4, 5]; msm.snapshots.insert(10, snapshot_data.clone()); let mock_sm = Box::new(MockStateMachine { id: 10, recoverable: true, snapshot_data: vec![], recovered_data: None, }); msm.register(mock_sm); // Snapshot is kept pending until registered snapshots are explicitly recovered. assert_eq!(msm.snapshots.get(&10), Some(&snapshot_data)); msm.recover_registered_snapshots().await; // Snapshot should be removed after replay. assert!(!msm.snapshots.contains_key(&10)); let recovered = msm .subs .get(&10) .and_then(|sm| { let any = sm.as_ref() as &dyn std::any::Any; any.downcast_ref::() }) .and_then(|sm| sm.recovered_data.clone()); assert_eq!(recovered, Some(snapshot_data)); } #[tokio::test(flavor = "multi_thread")] async fn test_recover_registered_snapshots_applies_config_snapshot() { use crate::rpc::Server; let mut msm = MasterStateMachine::new(1); msm.configs.members.clear(); // new_member opens an RPC client to the address, so a real listener is required. let member_addr = String::from("127.0.0.1:9100"); let server = Server::new(&member_addr); Server::listen_and_resume(&server).await; let mut restored = Configures::new(99); let added = restored.new_member(member_addr.clone()).await; assert!(added, "Member should join successfully"); msm.snapshots.insert(CONFIG_SM_ID, restored.snapshot()); msm.recover_registered_snapshots().await; assert_eq!(msm.configs.members.len(), 1); assert!(!msm.snapshots.contains_key(&CONFIG_SM_ID)); server.shutdown().await; } #[test] fn test_members() { let msm = MasterStateMachine::new(1); let members = msm.members(); assert!(members.is_empty()); } #[test] fn test_clear_subs() { let mut msm = MasterStateMachine::new(1); let mock_sm = Box::new(MockStateMachine { id: 10, recoverable: true, snapshot_data: vec![1, 2, 3], recovered_data: None, }); msm.register(mock_sm); assert!(msm.has_sub(&10)); msm.clear_subs(); assert!(!msm.has_sub(&10)); assert!(msm.subs.is_empty()); } #[test] fn test_has_sub() { let mut msm = MasterStateMachine::new(1); assert!(!msm.has_sub(&10)); let mock_sm = Box::new(MockStateMachine { id: 10, recoverable: true, snapshot_data: vec![1, 2, 3], recovered_data: None, }); msm.register(mock_sm); assert!(msm.has_sub(&10)); assert!(!msm.has_sub(&20)); } #[test] fn test_snapshot_empty() { let msm = MasterStateMachine::new(1); let snapshot = msm.snapshot(); assert!(!snapshot.is_empty()); // Should be able to deserialize let items: Option = crate::utils::serde::deserialize(&snapshot); assert!(items.is_some()); } #[test] fn test_snapshot_with_recoverable_sm() { let mut msm = MasterStateMachine::new(1); let mock_sm = Box::new(MockStateMachine { id: 10, recoverable: true, snapshot_data: vec![1, 2, 3, 4, 5], recovered_data: None, }); msm.register(mock_sm); let snapshot = msm.snapshot(); let items: Option = crate::utils::serde::deserialize(&snapshot); assert!(items.is_some()); let items = items.unwrap(); // Should have config SM + our mock SM assert!(items.len() >= 2); // Check that our mock SM's snapshot is included let has_mock_sm = items.iter().any(|(id, _)| *id == 10); assert!(has_mock_sm); } #[test] fn test_snapshot_non_recoverable_sm_excluded() { let mut msm = MasterStateMachine::new(1); let mock_sm = Box::new(MockStateMachine { id: 20, recoverable: false, snapshot_data: vec![1, 2, 3], recovered_data: None, }); msm.register(mock_sm); let snapshot = msm.snapshot(); let items: Option = crate::utils::serde::deserialize(&snapshot); assert!(items.is_some()); let items = items.unwrap(); // Non-recoverable SM should not be in snapshot let has_mock_sm = items.iter().any(|(id, _)| *id == 20); assert!(!has_mock_sm); } #[tokio::test(flavor = "multi_thread")] async fn test_recover_valid_snapshot() { let mut msm = MasterStateMachine::new(1); let mut items = SnapshotDataItems::new(); items.push((10, vec![1, 2, 3])); items.push((20, vec![4, 5, 6])); let snapshot = crate::utils::serde::serialize(&items); msm.recover(snapshot).await; assert_eq!(msm.snapshots.get(&10), Some(&vec![1, 2, 3])); assert_eq!(msm.snapshots.get(&20), Some(&vec![4, 5, 6])); } #[tokio::test(flavor = "multi_thread")] async fn test_recover_invalid_snapshot() { let mut msm = MasterStateMachine::new(1); // Add some existing snapshots msm.snapshots.insert(100, vec![1, 2, 3]); // Recover with invalid data let invalid_data = vec![0xFF, 0xFF, 0xFF]; msm.recover(invalid_data).await; // Snapshots should be cleared on error assert!(msm.snapshots.is_empty()); } #[test] fn test_parse_output_some() { let data = vec![1, 2, 3, 4, 5]; let result = parse_output(Some(data.clone())); assert!(result.is_ok()); assert_eq!(result.unwrap(), data); } #[test] fn test_parse_output_none() { let result = parse_output(None); assert!(result.is_err()); assert!(matches!(result, Err(ExecError::FnNotFound(0, 0)))); } #[test] fn test_exec_error_display() { let error = ExecError::SmNotFound(123); let display = format!("{}", error); assert!(display.contains("SmNotFound")); let error2 = ExecError::FnNotFound(1, 2); let display2 = format!("{}", error2); assert!(display2.contains("FnNotFound")); } #[test] fn test_exec_error_debug() { let error = ExecError::ServersUnreachable; let debug = format!("{:?}", error); assert!(debug.contains("ServersUnreachable")); } #[test] fn test_exec_error_clone() { let error = ExecError::CannotConstructClient; let cloned = error.clone(); assert!(matches!(cloned, ExecError::CannotConstructClient)); } #[test] fn test_all_exec_error_variants() { // Test that all variants can be created let _ = ExecError::SmNotFound(1); let _ = ExecError::FnNotFound(1, 2); let _ = ExecError::ServersUnreachable; let _ = ExecError::CannotConstructClient; let _ = ExecError::NotCommitted; let _ = ExecError::ShuttingDown; let _ = ExecError::Unknown; let _ = ExecError::TooManyRetry; } #[test] fn test_register_result_variants() { // Test that all variants exist let _ = RegisterResult::OK; let _ = RegisterResult::EXISTED; let _ = RegisterResult::RESERVED; } // ============================================================================ // REAL-WORLD INTEGRATION SCENARIOS // ============================================================================ // Real-world state machine that simulates a key-value store struct KeyValueStateMachine { id: u64, data: HashMap, } impl StateMachineCtl for KeyValueStateMachine { raft_sm_complete!(); fn id(&self) -> u64 { self.id } fn snapshot(&self) -> Vec { crate::utils::serde::serialize(&self.data) } fn recover(&mut self, data: Vec) -> BoxFuture<()> { if let Some(recovered_data) = crate::utils::serde::deserialize(&data) { self.data = recovered_data; } future::ready(()).boxed() } fn recoverable(&self) -> bool { true } } impl StateMachineCmds for KeyValueStateMachine { fn dispatch_cmd_<'a>( &'a mut self, fn_id: u64, data: &'a Vec, ) -> BoxFuture<'a, Option>> { async move { match fn_id { 1 => { // SET command if let Some((key, value)) = crate::utils::serde::deserialize::<(String, String)>(data) { self.data.insert(key, value); Some(crate::utils::serde::serialize(&true)) } else { None } } 2 => { // DELETE command if let Some(key) = crate::utils::serde::deserialize::(data) { let existed = self.data.remove(&key).is_some(); Some(crate::utils::serde::serialize(&existed)) } else { None } } _ => None, } } .boxed() } fn dispatch_qry_<'a>( &'a self, fn_id: u64, data: &'a Vec, ) -> BoxFuture<'a, Option>> { async move { match fn_id { 10 => { // GET query if let Some(key) = crate::utils::serde::deserialize::(data) { let value = self.data.get(&key).cloned(); Some(crate::utils::serde::serialize(&value)) } else { None } } 11 => { // COUNT query Some(crate::utils::serde::serialize(&self.data.len())) } _ => None, } } .boxed() } fn op_type_(&self, _fn_id: u64) -> Option { Some(OpType::COMMAND) } } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_kv_store_workflow() { let mut msm = MasterStateMachine::new(1); // Register a KV store state machine let kv_sm = Box::new(KeyValueStateMachine { id: 100, data: HashMap::new(), }); let result = msm.register(kv_sm); assert!(matches!(result, RegisterResult::OK)); // Simulate SET command: key="user:123", value="Alice" let set_data = crate::utils::serde::serialize(&(String::from("user:123"), String::from("Alice"))); let set_entry = LogEntry { id: 1, term: 1, sm_id: 100, fn_id: 1, // SET data: set_data, }; let result = msm.commit_cmd(&set_entry).await; assert!(result.is_ok()); // Simulate GET query: key="user:123" let get_data = crate::utils::serde::serialize(&String::from("user:123")); let get_entry = LogEntry { id: 2, term: 1, sm_id: 100, fn_id: 10, // GET data: get_data, }; let result = msm.exec_qry(&get_entry).await; assert!(result.is_ok()); let value: Option> = crate::utils::serde::deserialize(&result.unwrap()); assert_eq!(value, Some(Some(String::from("Alice")))); } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_multiple_operations() { let mut msm = MasterStateMachine::new(1); let kv_sm = Box::new(KeyValueStateMachine { id: 100, data: HashMap::new(), }); msm.register(kv_sm); // Insert multiple key-value pairs for i in 0..5 { let key = format!("key:{}", i); let value = format!("value:{}", i); let data = crate::utils::serde::serialize(&(key, value)); let entry = LogEntry { id: i + 1, term: 1, sm_id: 100, fn_id: 1, // SET data, }; let result = msm.commit_cmd(&entry).await; assert!(result.is_ok()); } // Query the count let count_entry = LogEntry { id: 10, term: 1, sm_id: 100, fn_id: 11, // COUNT data: vec![], }; let result = msm.exec_qry(&count_entry).await; assert!(result.is_ok()); let count: usize = crate::utils::serde::deserialize(&result.unwrap()).unwrap(); assert_eq!(count, 5); // Delete one key let delete_data = crate::utils::serde::serialize(&String::from("key:2")); let delete_entry = LogEntry { id: 11, term: 1, sm_id: 100, fn_id: 2, // DELETE data: delete_data, }; let result = msm.commit_cmd(&delete_entry).await; assert!(result.is_ok()); // Verify count decreased let result = msm.exec_qry(&count_entry).await; let count: usize = crate::utils::serde::deserialize(&result.unwrap()).unwrap(); assert_eq!(count, 4); } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_snapshot_and_recovery() { let mut msm = MasterStateMachine::new(1); let kv_sm = Box::new(KeyValueStateMachine { id: 100, data: HashMap::new(), }); msm.register(kv_sm); // Populate with data for i in 0..10 { let key = format!("session:{}", i); let value = format!("token:{}", i * 100); let data = crate::utils::serde::serialize(&(key, value)); let entry = LogEntry { id: i + 1, term: 1, sm_id: 100, fn_id: 1, // SET data, }; msm.commit_cmd(&entry).await.unwrap(); } // Take a snapshot let snapshot = msm.snapshot(); assert!(!snapshot.is_empty()); // Create a new master state machine and recover let mut new_msm = MasterStateMachine::new(1); new_msm.recover(snapshot).await; // Register the state machine again - it should recover data let new_kv_sm = Box::new(KeyValueStateMachine { id: 100, data: HashMap::new(), }); new_msm.register(new_kv_sm); new_msm.recover_registered_snapshots().await; // Verify data was recovered by querying let get_data = crate::utils::serde::serialize(&String::from("session:5")); let get_entry = LogEntry { id: 100, term: 2, sm_id: 100, fn_id: 10, // GET data: get_data, }; let result = new_msm.exec_qry(&get_entry).await; assert!(result.is_ok()); let value: Option> = crate::utils::serde::deserialize(&result.unwrap()); assert_eq!(value, Some(Some(String::from("token:500")))); } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_multiple_state_machines() { let mut msm = MasterStateMachine::new(1); // Register two different KV stores for different purposes let users_sm = Box::new(KeyValueStateMachine { id: 100, // Users store data: HashMap::new(), }); msm.register(users_sm); let sessions_sm = Box::new(KeyValueStateMachine { id: 200, // Sessions store data: HashMap::new(), }); msm.register(sessions_sm); // Add user let user_data = crate::utils::serde::serialize(&( String::from("user:1"), String::from("alice@example.com"), )); let user_entry = LogEntry { id: 1, term: 1, sm_id: 100, fn_id: 1, data: user_data, }; msm.commit_cmd(&user_entry).await.unwrap(); // Add session for that user let session_data = crate::utils::serde::serialize(&(String::from("session:abc"), String::from("user:1"))); let session_entry = LogEntry { id: 2, term: 1, sm_id: 200, fn_id: 1, data: session_data, }; msm.commit_cmd(&session_entry).await.unwrap(); // Query both state machines let user_query = LogEntry { id: 3, term: 1, sm_id: 100, fn_id: 11, // COUNT data: vec![], }; let user_count: usize = crate::utils::serde::deserialize(&msm.exec_qry(&user_query).await.unwrap()).unwrap(); let session_query = LogEntry { id: 4, term: 1, sm_id: 200, fn_id: 11, // COUNT data: vec![], }; let session_count: usize = crate::utils::serde::deserialize(&msm.exec_qry(&session_query).await.unwrap()).unwrap(); assert_eq!(user_count, 1); assert_eq!(session_count, 1); } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_error_handling() { let mut msm = MasterStateMachine::new(1); let kv_sm = Box::new(KeyValueStateMachine { id: 100, data: HashMap::new(), }); msm.register(kv_sm); // Try to execute command on non-existent state machine let entry = LogEntry { id: 1, term: 1, sm_id: 999, // Doesn't exist fn_id: 1, data: vec![], }; let result = msm.commit_cmd(&entry).await; assert!(result.is_err()); assert!(matches!(result, Err(ExecError::SmNotFound(999)))); // Try to execute non-existent function let bad_fn_entry = LogEntry { id: 2, term: 1, sm_id: 100, fn_id: 999, // Doesn't exist data: vec![], }; let result = msm.commit_cmd(&bad_fn_entry).await; assert!(result.is_err()); assert!(matches!(result, Err(ExecError::FnNotFound(100, 999)))); } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_query_on_non_existent_sm() { let msm = MasterStateMachine::new(1); let entry = LogEntry { id: 1, term: 1, sm_id: 999, fn_id: 10, data: vec![], }; let result = msm.exec_qry(&entry).await; assert!(result.is_err()); assert!(matches!(result, Err(ExecError::SmNotFound(999)))); } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_config_sm_operations() { let mut msm = MasterStateMachine::new(1); // CONFIG_SM_ID is always registered let entry = LogEntry { id: 1, term: 1, sm_id: CONFIG_SM_ID, fn_id: 1, // Some config function data: vec![], }; // This should route to config SM, not error with SmNotFound let result = msm.commit_cmd(&entry).await; // It may error with FnNotFound but not SmNotFound if result.is_err() { assert!(matches!(result, Err(ExecError::FnNotFound(_, _)))); } } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_state_machine_lifecycle() { let mut msm = MasterStateMachine::new(1); let sm_id = 100u64; // Initially, state machine doesn't exist assert!(!msm.has_sub(&sm_id)); // Register it let kv_sm = Box::new(KeyValueStateMachine { id: sm_id, data: HashMap::new(), }); msm.register(kv_sm); assert!(msm.has_sub(&sm_id)); // Use it let data = crate::utils::serde::serialize(&(String::from("test"), String::from("value"))); let entry = LogEntry { id: 1, term: 1, sm_id, fn_id: 1, data, }; let result = msm.commit_cmd(&entry).await; assert!(result.is_ok()); // Clear all state machines msm.clear_subs(); assert!(!msm.has_sub(&sm_id)); // Try to use it after clearing - should error let entry2 = LogEntry { id: 2, term: 1, sm_id, fn_id: 1, data: vec![], }; let result = msm.commit_cmd(&entry2).await; assert!(result.is_err()); assert!(matches!(result, Err(ExecError::SmNotFound(_)))); } #[tokio::test(flavor = "multi_thread")] async fn test_real_world_concurrent_queries() { let msm = Arc::new(MasterStateMachine::new(1)); // This is a read-only test simulating concurrent queries // In real world, multiple threads would query simultaneously let query_entry = LogEntry { id: 1, term: 1, sm_id: CONFIG_SM_ID, fn_id: 1, data: vec![], }; // Simulate concurrent reads let msm1 = msm.clone(); let msm2 = msm.clone(); let entry1 = query_entry.clone(); let entry2 = query_entry.clone(); let handle1 = tokio::spawn(async move { msm1.exec_qry(&entry1).await }); let handle2 = tokio::spawn(async move { msm2.exec_qry(&entry2).await }); // Both should complete (may error with FnNotFound but shouldn't panic) let _ = tokio::try_join!(handle1, handle2); } } ================================================ FILE: src/raft/state_machine/mod.rs ================================================ use crate::raft::client::RaftPlaneClient; use std::any::Any; use std::sync::Arc; pub enum Storage { MEMORY, DISK(String), } #[derive(Debug)] pub enum OpType { COMMAND, QUERY, SUBSCRIBE, } pub trait StateMachineCtl: Sync + Send + Any { fn id(&self) -> u64; fn snapshot(&self) -> Vec; fn recover(&mut self, data: Vec) -> ::futures::future::BoxFuture<()>; fn recoverable(&self) -> bool; fn fn_dispatch_qry<'a>( &'a self, fn_id: u64, data: &'a Vec, ) -> ::futures::future::BoxFuture<'a, Option>>; fn fn_dispatch_cmd<'a>( &'a mut self, fn_id: u64, data: &'a Vec, ) -> ::futures::future::BoxFuture<'a, Option>>; fn op_type(&mut self, fn_id: u64) -> Option; } pub trait OpTypes { fn op_type(&self, fn_id: u64) -> Option; } pub trait StateMachineClient { fn new_instance(sm_id: u64, client: &Arc) -> Self; } pub const MASTER_SM_ID: u64 = 0; pub const CONFIG_SM_ID: u64 = 1; pub const RESERVED_INTERNAL_SM_ID_END: u64 = 2; pub const fn is_reserved_internal_sm_id(sm_id: u64) -> bool { sm_id <= RESERVED_INTERNAL_SM_ID_END } #[macro_use] pub mod macros; pub mod callback; pub mod configs; pub mod master; ================================================ FILE: src/rpc/cluster.rs ================================================ use std::{future::Future, sync::Arc}; use crate::{ conshash::ConsistentHashing, raft::state_machine::master::ExecError, rpc::{RPCError, DEFAULT_CLIENT_POOL}, }; use futures::stream::FuturesUnordered; use tokio_stream::StreamExt; use super::{RPCClient, ServiceClientWithId}; pub async fn broadcast_to_members( conshash: &Arc, func: F, ) -> Result)>, ExecError> where C: ServiceClientWithId, F: Fn(Arc) -> Fut + Clone + Send + 'static, Fut: Future> + Send, { let server_ids = all_server_ids(&conshash).await?; broadcast_with_server_ids(server_ids, &conshash, func).await } pub async fn all_server_ids( conshash: &Arc, ) -> Result, ExecError> { let (members, _) = conshash.membership().all_members(true).await?; Ok(members.into_iter().map(|m| m.id)) } pub async fn broadcast_with_server_ids( server_ids: I, conshash: &Arc, func: F, ) -> Result)>, ExecError> where I: Iterator, C: ServiceClientWithId, F: Fn(Arc) -> Fut + Clone + Send + 'static, Fut: Future> + Send, { let member_futs: FuturesUnordered<_> = server_ids .map(|sid| { let func = func.clone(); async move { let client = match client_by_server_id(&conshash, sid).await { Ok(client) => client, Err(e) => { error!("Failed to get client by server id {}: {:?}", sid, e); return (sid, Err(e)); } }; return (sid, func(client).await); } }) .collect(); let results = member_futs.collect::>().await; Ok(results) } pub async fn client_by_server_id( conshash: &Arc, server_id: u64, ) -> Result, RPCError> where C: ServiceClientWithId, { DEFAULT_CLIENT_POOL .get_by_id(server_id, move |sid| conshash.to_server_name(sid)) .await .map_err(|e| RPCError::IOError(e)) .map(|c| client_by_rpc_client(&c)) } pub fn client_by_rpc_client(client: &Arc) -> Arc where C: ServiceClientWithId, { C::new_with_service_id(C::SERVICE_ID, client) } #[cfg(test)] mod tests { use super::*; use crate::rpc::RPCClient; use std::sync::Arc; // Define a test service for cluster operations mod test_cluster_service { use super::*; use futures::future::BoxFuture; service! { rpc get_id() -> u64; rpc echo(msg: String) -> String; } pub struct TestService { pub id: u64, } impl Service for TestService { fn get_id(&self) -> BoxFuture { futures::future::ready(self.id).boxed() } fn echo(&self, msg: String) -> BoxFuture { futures::future::ready(format!("Echo: {}", msg)).boxed() } } dispatch_rpc_service_functions!(TestService); impl ServiceClientWithId for AsyncServiceClient { const SERVICE_ID: u64 = 999; } } #[test] fn test_client_by_rpc_client_creation() { // Test that we can create a service client from an RPC client // This is a pure unit test that doesn't require network setup use bifrost_hasher::hash_str; let addr = String::from("127.0.0.1:3400"); let server_id = hash_str(&addr); // Create a mock RPC client structure - we're only testing the wrapper function // Note: We can't actually test the full functionality without network setup, // but we can verify the function signature and type conversion works // This test validates that the ServiceClientWithId trait is properly implemented // Full integration tests are in the membership and conshash test modules } // Note: Full integration tests for broadcast_to_members, all_server_ids, and // client_by_server_id are covered in the membership and conshash integration tests // since they require a complete raft cluster setup. } ================================================ FILE: src/rpc/mod.rs ================================================ #[macro_use] pub mod proto; pub mod cluster; use crate::{tcp, DISABLE_SHORTCUT}; use bifrost_hasher::hash_str; use bytes::{Buf, BufMut, BytesMut}; use futures::future::BoxFuture; use futures::prelude::*; use futures::Future; use lightning::map::*; use serde::{Deserialize, Serialize}; use std::backtrace; use std::error::Error; use std::io; use std::pin::Pin; use std::sync::Arc; use std::sync::Mutex as StdMutex; use std::time::Duration; use tokio::time::sleep; use tokio::time::*; lazy_static! { pub static ref DEFAULT_CLIENT_POOL: ClientPool = ClientPool::new(); } #[derive(Serialize, Deserialize, Debug)] pub enum RPCRequestError { FunctionIdNotFound, ServiceIdNotFound, BadRequest, Other, } #[derive(Debug)] pub enum RPCError { IOError(io::Error), RequestError(RPCRequestError), ClientCannotDecodeResponse, } pub trait RPCService: Sync + Send { fn dispatch(&self, data: BytesMut) -> BoxFuture>; fn register_shortcut_service( &self, service_ptr: usize, server_id: u64, service_id: u64, ) -> ::std::pin::Pin + Send>>; fn service_symbol(&self) -> &'static str; } pub struct Server { services: PtrHashMap>, pub address: String, pub server_id: u64, tcp_server: StdMutex>>, shutdown_handle: StdMutex>>, } unsafe impl Sync for Server {} pub struct ClientPool { clients: PtrHashMap>, } fn encode_res(res: Result) -> BytesMut { match res { Ok(buffer) => [0u8; 1].iter().cloned().chain(buffer.into_iter()).collect(), Err(e) => { let err_id = match e { RPCRequestError::FunctionIdNotFound => 1u8, RPCRequestError::ServiceIdNotFound => 2u8, _ => 255u8, }; BytesMut::from(&[err_id][..]) } } } fn decode_res(res: io::Result) -> Result { match res { Ok(mut res) => { if res[0] == 0u8 { res.advance(1); Ok(res.split()) } else { match res[0] { 1u8 => Err(RPCError::RequestError(RPCRequestError::FunctionIdNotFound)), 2u8 => Err(RPCError::RequestError(RPCRequestError::ServiceIdNotFound)), _ => Err(RPCError::RequestError(RPCRequestError::Other)), } } } Err(e) => Err(RPCError::IOError(e)), } } pub fn read_u64_head(mut data: BytesMut) -> (u64, BytesMut) { let num = data.get_u64_le(); (num, data) } impl Server { pub fn new(address: &String) -> Arc { Arc::new(Server { services: PtrHashMap::with_capacity(16), address: address.clone(), server_id: hash_str(address), tcp_server: StdMutex::new(None), shutdown_handle: StdMutex::new(None), }) } pub async fn listen(server: &Arc) -> Result<(), Box> { let address = &server.address; let tcp_server = Arc::new(tcp::server::Server::new()); // Store tcp_server reference match server.tcp_server.lock() { Ok(mut guard) => *guard = Some(tcp_server.clone()), Err(e) => error!("Failed to store tcp_server reference: {}", e), } let server_clone = server.clone(); tcp_server .listen( address, Arc::new(move |data| { let server = server_clone.clone(); async move { let (svr_id, data) = read_u64_head(data); let service = server.services.get(&svr_id); trace!("Processing request for service {}", svr_id); match service { Some(service) => { let svr_res = service.dispatch(data).await; encode_res(svr_res) } None => { let service_list = server .services .entries() .into_iter() .map(|(sid, service)| { format!("{}:{}", sid, service.service_symbol()) }) .collect::>(); error!( "Service {} not found, have {:?}, backtrace: {:?}", svr_id, service_list.join(", "), backtrace::Backtrace::capture() ); encode_res(Err(RPCRequestError::ServiceIdNotFound)) } } } .boxed() }), ) .await } pub async fn listen_and_resume(server: &Arc) { let address = server.address.clone(); let tcp_server = Arc::new(tcp::server::Server::new()); // Store tcp_server in the server struct match server.tcp_server.lock() { Ok(mut guard) => *guard = Some(tcp_server.clone()), Err(e) => error!("Failed to store tcp_server reference: {}", e), } let server_clone = server.clone(); let handle = tokio::spawn(async move { let result = tcp_server .listen( &address, Arc::new(move |data| { let server = server_clone.clone(); async move { let (svr_id, data) = read_u64_head(data); let service = server.services.get(&svr_id); trace!("Processing request for service {}", svr_id); match service { Some(service) => { let svr_res = service.dispatch(data).await; encode_res(svr_res) } None => { let service_list = server .services .entries() .into_iter() .map(|(sid, service)| { format!("{}:{}", sid, service.service_symbol()) }) .collect::>(); error!( "Service {} not found, have {:?}, backtrace: {:?}", svr_id, service_list.join(", "), backtrace::Backtrace::capture() ); encode_res(Err(RPCRequestError::ServiceIdNotFound)) } } } .boxed() }), ) .await; if let Err(e) = result { error!("RPC server error: {:?}", e); } }); // Store handle match server.shutdown_handle.lock() { Ok(mut guard) => *guard = Some(handle), Err(e) => error!("Failed to store shutdown handle: {}", e), } sleep(Duration::from_secs(1)).await } pub async fn shutdown(&self) { info!("Shutting down RPC server on {}", self.address); match self.tcp_server.lock() { Ok(guard) => { if let Some(ref tcp_server) = *guard { tcp_server.shutdown(); } } Err(e) => error!("Failed to acquire tcp_server lock during shutdown: {}", e), } // Give it a moment to shut down gracefully sleep(Duration::from_millis(100)).await; } pub async fn register_service_with_id(&self, service_id: u64, service: &Arc) where T: RPCService + Sized + 'static, { let service = service.clone(); if !DISABLE_SHORTCUT { let service_ptr = Arc::into_raw(service.clone()) as usize; service .register_shortcut_service(service_ptr, self.server_id, service_id) .await; } else { debug!("SERVICE SHORTCUT DISABLED"); } info!( "Registering service {} with id {}", service.service_symbol(), service_id ); self.services.insert(service_id, service); } pub async fn register_service(&self, service: &Arc) where T: RPCServiceWithId + Sized + 'static, { self.register_service_with_id(T::SERVICE_ID, service).await } pub async fn remove_service(&self, service_id: u64) { self.services.remove(&service_id); } pub fn address(&self) -> &String { &self.address } } pub struct RPCClient { client: tcp::client::Client, pub server_id: u64, pub address: String, } pub fn prepend_u64(num: u64, data: BytesMut) -> BytesMut { let mut bytes = BytesMut::with_capacity(8 + data.len()); bytes.put_u64_le(num); bytes.extend_from_slice(data.as_ref()); bytes } impl RPCClient { pub async fn send_async( self: Pin<&Self>, svr_id: u64, data: BytesMut, ) -> Result { let client = &self.client; let payload = prepend_u64(svr_id, data); let res = client.send_msg(payload).await; decode_res(res) } pub async fn new_async(addr: &String) -> io::Result> { let client = tcp::client::Client::connect(addr).await?; Ok(Arc::new(RPCClient { server_id: client.server_id, client, address: addr.clone(), })) } } impl ClientPool { pub fn new() -> ClientPool { ClientPool { clients: PtrHashMap::with_capacity(16), } } pub async fn get(&self, addr: &String) -> io::Result> { let addr_clone = addr.clone(); let server_id = hash_str(addr); self.get_by_id(server_id, move |_| addr_clone).await } pub async fn get_by_id(&self, server_id: u64, addr_fn: F) -> io::Result> where F: FnOnce(u64) -> String, { let clients = &self.clients; if let Some(client) = clients.get(&server_id) { Ok(client.clone()) } else { let client = timeout( Duration::from_secs(5), RPCClient::new_async(&addr_fn(server_id)), ) .await??; clients.insert(server_id, client.clone()); Ok(client) } } } pub trait ServiceClient: Send + Sync { fn new_instance_with_service_id(server_id: u64, client: &Arc) -> Self; fn server_id(&self) -> u64; fn new_with_service_id(server_id: u64, client: &Arc) -> Arc where Self: Sized, { Arc::new(Self::new_instance_with_service_id(server_id, client)) } } pub trait ServiceClientWithId: ServiceClient { const SERVICE_ID: u64; fn new(client: &Arc) -> Arc where Self: Sized, { Self::new_with_service_id(Self::SERVICE_ID, client) } } pub trait RPCServiceWithId: RPCService { const SERVICE_ID: u64; } #[cfg(test)] mod test { use futures::future::BoxFuture; use serde::{Deserialize, Serialize}; use std::sync::Arc; use std::time::Duration; use tokio::time::sleep; pub mod simple_service { use super::*; service! { rpc hello(name: String) -> String; rpc error(message: String) -> Result<(), String>; } struct HelloServer; impl Service for HelloServer { fn hello(&self, name: String) -> BoxFuture { future::ready(format!("Hello, {}!", name)).boxed() } fn error(&self, message: String) -> BoxFuture> { future::ready(Err(message.clone())).boxed() } } dispatch_rpc_service_functions!(HelloServer); #[tokio::test(flavor = "multi_thread")] pub async fn simple_rpc() { let _ = env_logger::try_init(); let addr = String::from("127.0.0.1:1300"); { let addr = addr.clone(); let server = Server::new(&addr); server .register_service_with_id(0, &Arc::new(HelloServer)) .await; Server::listen_and_resume(&server).await; } sleep(Duration::from_millis(1000)).await; let client = RPCClient::new_async(&addr).await.unwrap(); let service_client = AsyncServiceClient::new_with_service_id(0, &client); let response = service_client.hello(String::from("Jack")).await; let greeting_str = response.unwrap(); info!("SERVER RESPONDED: {}", greeting_str); assert_eq!(greeting_str, String::from("Hello, Jack!")); let expected_err_msg = String::from("This error is a good one"); let response = service_client.error(expected_err_msg.clone()); let error_msg = response.await.unwrap().err().unwrap(); assert_eq!(error_msg, expected_err_msg); } } pub mod struct_service { use super::*; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Greeting { pub name: String, pub time: u32, } #[derive(Serialize, Deserialize, Debug)] pub struct Respond { pub text: String, pub owner: u32, } service! { rpc hello(gret: Greeting) -> Respond; } pub struct HelloServer; impl Service for HelloServer { fn hello(&self, gret: Greeting) -> BoxFuture { future::ready(Respond { text: format!("Hello, {}. It is {} now!", gret.name, gret.time), owner: 42, }) .boxed() } } dispatch_rpc_service_functions!(HelloServer); #[tokio::test(flavor = "multi_thread")] pub async fn struct_rpc() { let _ = env_logger::try_init(); let addr = String::from("127.0.0.1:1400"); { let addr = addr.clone(); let server = Server::new(&addr); // 0 is service id server .register_service_with_id(0, &Arc::new(HelloServer)) .await; Server::listen_and_resume(&server).await; } sleep(Duration::from_millis(1000)).await; let client = RPCClient::new_async(&addr).await.unwrap(); let service_client = AsyncServiceClient::new_with_service_id(0, &client); let response = service_client.hello(Greeting { name: String::from("Jack"), time: 12, }); let res = response.await.unwrap(); let greeting_str = res.text; info!("SERVER RESPONDED: {}", greeting_str); assert_eq!(greeting_str, String::from("Hello, Jack. It is 12 now!")); assert_eq!(42, res.owner); } } mod multi_server { use super::*; #[derive(Serialize, Deserialize, Clone)] pub struct ComplexAnswer { name: String, id: u64, req: Option, } service! { rpc query_server_id() -> u64; rpc query_answer(req: Option) -> ComplexAnswer; rpc large_query(req: Option) -> Vec; rpc large_req(req: Vec, req2: Vec) -> Vec; } struct IdServer { id: u64, } impl Service for IdServer { fn query_server_id(&self) -> BoxFuture { future::ready(self.id).boxed() } fn query_answer(&self, req: Option) -> BoxFuture { future::ready(ComplexAnswer { name: format!("Server for {:?}", req), id: self.id, req, }) .boxed() } fn large_query(&self, req: Option) -> BoxFuture> { let mut res = vec![]; for i in 0..1024 { res.push(ComplexAnswer { name: format!("Server for {:?}", &req), id: i, req: req.clone(), }) } future::ready(res).boxed() } fn large_req( &self, mut req: Vec, mut req2: Vec, ) -> BoxFuture> { req.append(&mut req2); future::ready(req).boxed() } } dispatch_rpc_service_functions!(IdServer); #[tokio::test(flavor = "multi_thread")] async fn multi_server_rpc() { let addrs = vec![ String::from("127.0.0.1:1500"), String::from("127.0.0.1:1600"), String::from("127.0.0.1:1700"), String::from("127.0.0.1:1800"), ]; let mut id = 0; for addr in &addrs { { let addr = addr.clone(); let server = Server::new(&addr); // 0 is service id server .register_service_with_id(id, &Arc::new(IdServer { id: id })) .await; Server::listen_and_resume(&server).await; id += 1; } } id = 0; sleep(Duration::from_millis(1000)).await; for addr in &addrs { let client = RPCClient::new_async(addr).await.unwrap(); let service_client = AsyncServiceClient::new_with_service_id(id, &client); let id_res = service_client.query_server_id().await; let id_un = id_res.unwrap(); assert_eq!(id_un, id); let user_str = format!("User {}", id); let complex = service_client .query_answer(Some(user_str.to_string())) .await .unwrap(); let large = service_client .large_query(Some(user_str.to_string())) .await .unwrap(); assert_eq!(large.len(), 1024); assert_eq!(complex.req, Some(user_str)); let large_req = service_client .large_req(large.clone(), large) .await .unwrap(); assert_eq!(large_req.len(), 1024 * 2); id += 1; } } } mod parallel { use super::struct_service::*; use super::*; use crate::rpc::{RPCClient, Server, ServiceClient, DEFAULT_CLIENT_POOL}; use bifrost_hasher::hash_str; use futures::prelude::stream::*; use futures::FutureExt; #[tokio::test(flavor = "multi_thread")] pub async fn lots_of_reqs() { let _ = env_logger::try_init(); let addr = String::from("127.0.0.1:1411"); { let addr = addr.clone(); let server = Server::new(&addr); // 0 is service id server .register_service_with_id(0, &Arc::new(HelloServer)) .await; Server::listen_and_resume(&server).await; } sleep(Duration::from_millis(1000)).await; let client = RPCClient::new_async(&addr).await.unwrap(); let service_client = AsyncServiceClient::new_with_service_id(0, &client); info!("Testing parallel RPC reqs"); let mut futs = (0..100) .map(|i| { let service_client = service_client.clone(); tokio::spawn(async move { let response = service_client.hello(Greeting { name: String::from("John"), time: i, }); let res = response.await.unwrap(); let greeting_str = res.text; info!("SERVER RESPONDED: {}", greeting_str); assert_eq!(greeting_str, format!("Hello, John. It is {} now!", i)); assert_eq!(42, res.owner); }) .boxed() }) .collect::>(); while futs.next().await.is_some() {} // test pool let server_id = hash_str(&addr); let mut futs = (0..100) .map(|i| { let addr = (&addr).clone(); tokio::spawn(async move { let client = DEFAULT_CLIENT_POOL .get_by_id(server_id, move |_| addr) .await .unwrap(); let service_client = AsyncServiceClient::new_with_service_id(0, &client); let response = service_client.hello(Greeting { name: String::from("John"), time: i, }); let res = response.await.unwrap(); let greeting_str = res.text; info!("SERVER RESPONDED: {}", greeting_str); assert_eq!(greeting_str, format!("Hello, John. It is {} now!", i)); assert_eq!(42, res.owner); }) .boxed() }) .collect::>(); while futs.next().await.is_some() {} } } } ================================================ FILE: src/rpc/proto.rs ================================================ #[macro_export] macro_rules! dispatch_rpc_service_functions { ($s:ty) => { use $crate::bytes::BytesMut; impl $crate::rpc::RPCService for $s { fn dispatch<'a>( &'a self, data: BytesMut, ) -> ::std::pin::Pin< Box< dyn Future< Output = Result<$crate::bytes::BytesMut, $crate::rpc::RPCRequestError>, > + Send + 'a, >, > where Self: Sized, { self.inner_dispatch(data) } fn register_shortcut_service( &self, service_ptr: usize, server_id: u64, service_id: u64, ) -> ::std::pin::Pin + Send>> { async move { let mut cbs = RPC_SVRS.write().await; let service = unsafe { Arc::from_raw(service_ptr as *const $s) }; cbs.insert((server_id, service_id), service); } .boxed() } fn service_symbol(&self) -> &'static str { stringify!($s) } } }; } // this macro expansion design took credits from tarpc by Google Inc. #[macro_export] macro_rules! service { ( $( $(#[$attr:meta])* rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) $(-> $out:ty)*; )* ) => { service! {{ $( $(#[$attr])* rpc $fn_name( $( $arg : $in_ ),* ) $(-> $out)*; )* }} }; ( { $(#[$attr:meta])* rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ); // No return, no error $( $unexpanded:tt )* } $( $expanded:tt )* ) => { service! { { $( $unexpanded )* } $( $expanded )* $(#[$attr])* rpc $fn_name( $( $arg : $in_ ),* ) -> (); } }; ( { $(#[$attr:meta])* rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; $( $unexpanded:tt )* } $( $expanded:tt )* ) => { service! { { $( $unexpanded )* } $( $expanded )* $(#[$attr])* rpc $fn_name( $( $arg : $in_ ),* ) -> $out; } }; ( {} // all expanded $( $(#[$attr:meta])* rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty; )* ) => { use std::sync::Arc; use $crate::rpc::*; #[allow(unused_imports)] use futures::prelude::*; use std::pin::Pin; use bifrost_proc_macro::{deref_tuple_types, adjust_caller_identifiers, adjust_function_signature}; lazy_static! { pub static ref RPC_SVRS: async_std::sync::RwLock<::std::collections::BTreeMap<(u64, u64), Arc>> = async_std::sync::RwLock::new(::std::collections::BTreeMap::new()); } pub trait Service : RPCService { $( $(#[$attr])* adjust_function_signature!{ fn $fn_name<'a>(&self, $($arg:$in_),*) -> ::futures::future::BoxFuture<'a, $out>; } )* fn inner_dispatch<'a>(&'a self, data: $crate::bytes::BytesMut) -> Pin> + Send + 'a>> { let (func_id, body) = read_u64_head(data); async move { match func_id as usize { $(::bifrost_plugins::hash_ident!($fn_name) => { if let Some(data) = $crate::utils::serde::deserialize(body.as_ref()) { #[allow(unused_parens)] let tuple : deref_tuple_types!(($($in_,)*)) = data; let adjust_caller_identifiers!($($arg: $in_),*) = tuple; let f_result = self.$fn_name($($arg,)*).await; let res_data = $crate::bytes::BytesMut::from($crate::utils::serde::serialize(&f_result).as_slice()); Ok(res_data) } else { Err(RPCRequestError::BadRequest) } }),* _ => { Err(RPCRequestError::FunctionIdNotFound) } } }.boxed() } } #[allow(dead_code)] pub async fn get_local(server_id: u64, service_id: u64) -> Option> { let svrs = RPC_SVRS.read().await; match svrs.get(&(server_id, service_id)) { Some(s) => Some(s.clone()), _ => None } } #[allow(dead_code)] pub struct AsyncServiceClient { pub service_id: u64, pub client: Arc, } #[allow(dead_code)] impl AsyncServiceClient { $( #[allow(non_camel_case_types)] $(#[$attr])* pub async fn $fn_name(&self, $($arg:$in_),*) -> Result<$out, RPCError> { ImmeServiceClient::$fn_name(self.service_id, &self.client, $($arg),*).await } )* } impl ServiceClient for AsyncServiceClient { fn new_instance_with_service_id(service_id: u64, client: &Arc) -> Self { AsyncServiceClient{ service_id: service_id, client: client.clone() } } fn server_id(&self) -> u64 { self.client.server_id } } pub struct ImmeServiceClient; impl ImmeServiceClient { $( $(#[$attr])* /// Judgement: Use data ownership transfer instead of borrowing. /// Some applications highly depend on RPC shortcut to achieve performance advantages. /// Cloning for shortcut will significantly increase overhead. Eg. Hivemind immutable queue pub async fn $fn_name(service_id: u64, client: &Arc, $($arg:$in_),*) -> Result<$out, RPCError> { if let Some(ref local) = get_local(client.server_id, service_id).await { Ok(local.$fn_name($($arg),*).await) } else { let req_data = ($($arg,)*); let req_data_bytes = $crate::bytes::BytesMut::from($crate::utils::serde::serialize(&req_data).as_slice()); let req_bytes = prepend_u64(::bifrost_plugins::hash_ident!($fn_name) as u64, req_data_bytes); let res_bytes = RPCClient::send_async(Pin::new(&*client), service_id, req_bytes).await; if let Ok(res_bytes) = res_bytes { if let Some(data) = $crate::utils::serde::deserialize(&res_bytes) { Ok(data) } else { Err(RPCError::ClientCannotDecodeResponse) } } else { Err(res_bytes.err().unwrap()) } } } )* } } } #[macro_export] macro_rules! service_with_id { ($s:ty, $id:expr) => { impl $crate::rpc::RPCServiceWithId for $s { const SERVICE_ID: u64 = $id; } impl $crate::rpc::ServiceClientWithId for AsyncServiceClient { const SERVICE_ID: u64 = $id; } }; } mod syntax_test { service! { rpc test(a: u32, b: u32) -> bool; rpc test2(a: u32); rpc test3(a: u32, b: u32, c: u32, d: u32); rpc test4(a: u32, b: Vec, c: &Vec, d: u32); } } #[cfg(test)] mod struct_test { use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct A { b: u32, d: u64, e: String, f: f32, } service! { rpc test(a: A, b: u32) -> bool; } } ================================================ FILE: src/tcp/client.rs ================================================ use std::sync::Arc; use std::time::Duration; use crate::tcp::{shortcut, STANDALONE_ADDRESS}; use crate::DISABLE_SHORTCUT; use bifrost_hasher::hash_str; use crate::tcp::server::TcpReq; use async_std::sync::Mutex; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::prelude::*; use futures::stream::SplitSink; use futures::SinkExt; use parking_lot::Mutex as SyncMutex; use std::collections::HashMap; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering::Relaxed; use tokio::io; use tokio::net::TcpStream; use tokio::sync::oneshot; use tokio::time; use tokio_util::codec::{Framed, LengthDelimitedCodec}; pub struct Client { //client: Option>, client: Option, Bytes>>>, msg_counter: AtomicU64, senders: Arc>>>, timeout: Duration, pub server_id: u64, } impl Client { pub async fn connect_with_timeout(address: &String, timeout: Duration) -> io::Result { let server_id = hash_str(address); let senders = Arc::new(SyncMutex::new( HashMap::>::new(), )); debug!( "TCP connect to {}, server id {}, timeout {}ms", address, server_id, timeout.as_millis() ); let client = { if !DISABLE_SHORTCUT && shortcut::is_local(server_id).await { debug!("Local connection, using shortcut"); None } else { if address.eq(&STANDALONE_ADDRESS) { return Err(io::Error::new( io::ErrorKind::Other, "STANDALONE server is not found", )); } debug!("Create socket on {}", address); let socket = time::timeout(timeout, TcpStream::connect(address)).await??; let transport = Framed::new(socket, LengthDelimitedCodec::new()); let (writer, mut reader) = transport.split(); let cloned_senders = senders.clone(); debug!("Streaming messages for {}", address); let address = address.clone(); tokio::spawn(async move { while let Some(res) = reader.next().await { if let Ok(mut data) = res { let res_msg_id = data.get_u64_le(); trace!("Received msg for {}, size {}", res_msg_id, data.len()); let mut senders = cloned_senders.lock(); if let Some(sender) = senders.remove(&res_msg_id) { if let Err(e) = sender.send(data) { error!( "Failed to send response for msg {}: {:?}", res_msg_id, e ); } } else { error!("No sender found for response msg {}", res_msg_id); } } } debug!("Stream from TCP server {} broken", address); }); Some(Mutex::new(writer)) } }; Ok(Client { client, server_id, senders, timeout, msg_counter: AtomicU64::new(0), }) } pub async fn connect(address: &String) -> io::Result { Client::connect_with_timeout(address, Duration::from_secs(2)).await } pub async fn send_msg(&self, msg: TcpReq) -> io::Result { if let Some(ref transport) = self.client { let msg_id = self.msg_counter.fetch_add(1, Relaxed); let mut frame = BytesMut::with_capacity(8 + msg.len()); let rx = { frame.put_u64_le(msg_id); frame.extend_from_slice(msg.as_ref()); let (tx, rx) = oneshot::channel(); let mut senders = self.senders.lock(); senders.insert(msg_id, tx); rx }; trace!("Sending msg {}, size {}", msg_id, frame.len()); time::timeout(self.timeout, transport.lock().await.send(frame.freeze())).await??; trace!("Sent msg {}", msg_id); match time::timeout(self.timeout, rx).await? { Ok(response) => Ok(response), Err(e) => { error!("Failed to receive response for msg {}: {:?}", msg_id, e); Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "Response channel closed", )) } } } else { Ok(shortcut::call(self.server_id, msg).await?) } } } unsafe impl Send for Client {} #[cfg(test)] mod tests { use super::*; use bytes::{BufMut, BytesMut}; use std::time::Duration; #[tokio::test(flavor = "multi_thread")] async fn test_client_connect_timeout() { let _ = env_logger::builder().format_timestamp(None).try_init(); // Try to connect to a non-existent server with short timeout let addr = String::from("127.0.0.1:9999"); let timeout = Duration::from_millis(100); let result = Client::connect_with_timeout(&addr, timeout).await; // This should fail since there's no server assert!( result.is_err(), "Connection to non-existent server should fail" ); } #[tokio::test(flavor = "multi_thread")] async fn test_client_standalone_address() { let _ = env_logger::builder().format_timestamp(None).try_init(); // Try to connect to STANDALONE address let standalone_addr = STANDALONE_ADDRESS.to_string(); let result = Client::connect(&standalone_addr).await; assert!(result.is_err(), "Connection to STANDALONE should fail"); if let Err(e) = result { assert_eq!(e.kind(), io::ErrorKind::Other); assert!(e.to_string().contains("STANDALONE")); } } #[tokio::test(flavor = "multi_thread")] async fn test_client_server_id() { let addr = String::from("127.0.0.1:9876"); let expected_id = hash_str(&addr); // Even if connection fails, we can test server_id calculation let timeout = Duration::from_millis(50); let _ = Client::connect_with_timeout(&addr, timeout).await; // Verify hash_str produces consistent results assert_eq!(hash_str(&addr), expected_id); } } ================================================ FILE: src/tcp/mod.rs ================================================ use bifrost_hasher::hash_str; pub mod client; pub mod server; pub mod shortcut; pub static STANDALONE_ADDRESS: &'static str = "STANDALONE"; lazy_static! { pub static ref STANDALONE_ADDRESS_STRING: String = String::from(STANDALONE_ADDRESS); pub static ref STANDALONE_SERVER_ID: u64 = hash_str(&STANDALONE_ADDRESS_STRING); } ================================================ FILE: src/tcp/server.rs ================================================ use super::STANDALONE_ADDRESS; use crate::tcp::shortcut; use bytes::{Buf, BufMut, BytesMut}; use futures::SinkExt; use std::error::Error; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use tokio::net::TcpListener; use tokio::sync::broadcast; use tokio_stream::StreamExt; use tokio_util::codec::{Framed, LengthDelimitedCodec}; pub type RPCFuture = dyn Future; pub type BoxedRPCFuture = Box; pub type TcpReq = BytesMut; pub type TcpRes = Pin + Send>>; pub struct Server { shutdown_tx: broadcast::Sender<()>, } impl Server { pub fn new() -> Server { let (shutdown_tx, _) = broadcast::channel(1); Server { shutdown_tx } } pub fn shutdown_handle(&self) -> broadcast::Sender<()> { self.shutdown_tx.clone() } pub async fn listen( &self, addr: &String, callback: Arc TcpRes + Send + Sync>, ) -> Result<(), Box> { shortcut::register_server(addr, &callback).await; if !addr.eq(&STANDALONE_ADDRESS) { let listener = TcpListener::bind(&addr).await?; let mut shutdown_rx = self.shutdown_tx.subscribe(); info!("TCP server listening on {}", addr); loop { tokio::select! { accept_result = listener.accept() => { match accept_result { Ok((socket, addr)) => { debug!("Accepted connection from {}", addr); let callback = callback.clone(); let mut conn_shutdown_rx = self.shutdown_tx.subscribe(); tokio::spawn(async move { let mut transport = Framed::new(socket, LengthDelimitedCodec::new()); loop { tokio::select! { result = transport.next() => { match result { Some(Ok(mut data)) => { let msg_id = data.get_u64_le(); let call_back_data = callback(data).await; let mut res = BytesMut::with_capacity(8 + call_back_data.len()); res.put_u64_le(msg_id); res.extend_from_slice(call_back_data.as_ref()); if let Err(e) = transport.send(res.freeze()).await { error!("Error on TCP callback {:?}", e); break; } } Some(Err(e)) => { error!("error on decoding from socket; error = {:?}", e); break; } None => { debug!("Connection closed by client"); break; } } } _ = conn_shutdown_rx.recv() => { info!("Connection handler received shutdown signal"); break; } } } // The connection will be closed at this point }); } Err(e) => error!("error accepting socket; error = {:?}", e), } } _ = shutdown_rx.recv() => { info!("TCP server on {} received shutdown signal, stopping accept loop", addr); break; } } } } info!("TCP server on {} shut down gracefully", addr); Ok(()) } pub fn shutdown(&self) { info!("Initiating TCP server shutdown"); let _ = self.shutdown_tx.send(()); } } #[cfg(test)] mod tests { use super::*; use bytes::{BufMut, BytesMut}; use futures::future::FutureExt; use std::sync::Arc; use tokio::time::{sleep, Duration}; #[tokio::test(flavor = "multi_thread")] async fn test_server_creation() { let server = Server::new(); assert!( server.shutdown_tx.receiver_count() == 0, "Should start with no subscribers" ); } #[tokio::test(flavor = "multi_thread")] async fn test_shutdown_handle() { let server = Server::new(); let handle = server.shutdown_handle(); // Subscribe a receiver so the send won't fail let mut _rx = handle.subscribe(); // Verify we can send shutdown signal let result = handle.send(()); assert!(result.is_ok(), "Should be able to send shutdown signal"); } #[tokio::test(flavor = "multi_thread")] async fn test_server_listen_and_shutdown() { let _ = env_logger::builder().format_timestamp(None).try_init(); let addr = String::from("127.0.0.1:9100"); let server = Arc::new(Server::new()); let callback = Arc::new(|_data: TcpReq| -> TcpRes { async move { let mut response = BytesMut::new(); response.put_slice(b"pong"); response } .boxed() }); let server_clone = server.clone(); let addr_clone = addr.clone(); let callback_clone = callback.clone(); tokio::spawn(async move { let _ = server_clone.listen(&addr_clone, callback_clone).await; }); sleep(Duration::from_millis(500)).await; // Test that server is listening let connect_result = tokio::net::TcpStream::connect(&addr).await; assert!(connect_result.is_ok(), "Server should be listening"); // Shutdown the server server.shutdown(); sleep(Duration::from_millis(200)).await; } #[tokio::test(flavor = "multi_thread")] async fn test_server_callback_invocation() { let _ = env_logger::builder().format_timestamp(None).try_init(); let addr = String::from("127.0.0.1:9200"); let server = Arc::new(Server::new()); let callback = Arc::new(|data: TcpReq| -> TcpRes { async move { // Echo back the data let mut response = BytesMut::new(); response.put_slice(&data); response } .boxed() }); let server_clone = server.clone(); let addr_clone = addr.clone(); let callback_clone = callback.clone(); tokio::spawn(async move { let _ = server_clone.listen(&addr_clone, callback_clone).await; }); sleep(Duration::from_millis(500)).await; // Shutdown after test server.shutdown(); sleep(Duration::from_millis(200)).await; } } ================================================ FILE: src/tcp/shortcut.rs ================================================ use crate::tcp::server::{TcpReq, TcpRes}; use async_std::sync::*; use bifrost_hasher::hash_str; use bytes::BytesMut; use std::collections::BTreeMap; use std::io::{Error, ErrorKind, Result}; use std::sync::Arc; trait TcpCallbackFunc = Fn(TcpReq) -> TcpRes; trait TcpCallbackFuncShareable = TcpCallbackFunc + Send + Sync; lazy_static! { pub static ref TCP_CALLBACKS: RwLock>> = RwLock::new(BTreeMap::new()); } pub async fn register_server( server_address: &String, callback: &Arc, ) { let server_id = hash_str(server_address); let mut servers_cbs = TCP_CALLBACKS.write().await; servers_cbs.insert(server_id, callback.clone()); } pub async fn call(server_id: u64, data: TcpReq) -> Result { let server_cbs = TCP_CALLBACKS.read().await; match server_cbs.get(&server_id) { Some(c) => Ok(c(data).await), _ => Err(Error::new( ErrorKind::Other, "Cannot found callback for shortcut", )), } } pub async fn is_local(server_id: u64) -> bool { let cbs = TCP_CALLBACKS.read().await; cbs.contains_key(&server_id) } ================================================ FILE: src/utils/bindings.rs ================================================ use parking_lot::RwLock; use std::collections::HashMap; use std::sync::Arc; use thread_id; pub struct Binding where T: Clone, { default: T, thread_vals: RwLock>, } impl Binding where T: Clone, { pub fn new(default: T) -> Binding { Binding { default, thread_vals: RwLock::new(HashMap::new()), } } pub fn get(&self) -> T { let tid = thread_id::get(); let thread_map = self.thread_vals.read(); match thread_map.get(&tid) { Some(v) => v.clone(), None => self.default.clone(), } } pub fn set(&self, val: T) { let tid = thread_id::get(); let mut thread_map = self.thread_vals.write(); thread_map.insert(tid, val); } pub fn del(&self) { let tid = thread_id::get(); let mut thread_map = self.thread_vals.write(); thread_map.remove(&tid); } } pub struct RefBinding { bind: Binding>, } impl RefBinding { pub fn new(default: T) -> RefBinding { RefBinding { bind: Binding::new(Arc::new(default)), } } pub fn get(&self) -> Arc { self.bind.get() } pub fn set(&self, val: T) { self.bind.set(Arc::new(val)) } pub fn del(&self) { self.bind.del() } } #[macro_export] macro_rules! def_bindings { ($( bind $bt:ident $name:ident : $t:ty = $def_val:expr; )*) => { def_bindings! {{$( bind $bt $name : $t = $def_val; )*}} }; ( { bind val $name:ident : $t:ty = $def_val:expr; $( $unexpanded:tt )* } $( $expanded:tt )* ) => { def_bindings! { { $( $unexpanded )* } $( $expanded )* bind Binding $name : $t = $def_val; } }; ( { bind ref $name:ident : $t:ty = $def_val:expr; $( $unexpanded:tt )* } $( $expanded:tt )* ) => { def_bindings! { { $( $unexpanded )* } $( $expanded )* bind RefBinding $name : $t = $def_val; } }; ({}$( bind $bt:ident $name:ident : $t:ty = $def_val:expr; )*) => { lazy_static! { $( pub static ref $name : $crate::utils::bindings::$bt<$t> = $crate::utils::bindings::$bt::new($def_val); )* } }; } #[macro_export] macro_rules! with_bindings { ( $( $bind:path : $val:expr ),* => $stat:block ) => { { $( $bind.set($val); )* let r = $stat; $( $bind.del(); )* r } }; } #[cfg(test)] mod struct_test { def_bindings! { bind val TEST_VAL: u64 = 0; bind ref TEST_REF: String = String::from("Hello"); } } ================================================ FILE: src/utils/math.rs ================================================ pub fn min(nums: &Vec) -> Option where T: Ord + Copy, { nums.iter().fold(None, |min, x| match min { None => Some(*x), Some(y) => Some(if *x < y { *x } else { y }), }) } pub fn max(nums: &Vec) -> Option where T: Ord + Copy, { nums.iter().fold(None, |max, x| match max { None => Some(*x), Some(y) => Some(if *x > y { *x } else { y }), }) } pub fn avg_scale(nums: &Vec) -> Option { if nums.len() > 0 { let count = nums.len() as u64; //let max_num = max(nums).unwrap(); let min_num = min(nums).unwrap(); let sum: u64 = nums.iter().sum(); let mid_abs = (sum - (min_num * count)) / count; return Some(min_num + mid_abs); } return None; } #[cfg(test)] mod test { use crate::utils::math; #[test] fn max() { assert_eq!(math::max(&vec!(1, 2, 3, 4, 5)).unwrap(), 5); assert_eq!(math::max(&vec!(1, 2, 9, 4, 5)).unwrap(), 9); assert_eq!(math::max(&Vec::::new()), None); } #[test] fn min() { assert_eq!(math::min(&vec!(1, 2, 3, 4, 5)).unwrap(), 1); assert_eq!(math::min(&vec!(1, 2, -10, 4, 5)).unwrap(), -10); assert_eq!(math::min(&Vec::::new()), None); } } ================================================ FILE: src/utils/mod.rs ================================================ pub mod time; #[macro_use] pub mod bindings; pub mod math; pub mod serde; ================================================ FILE: src/utils/serde.rs ================================================ use bifrost_hasher::hash_bytes; use serde; #[cfg(not(debug_assertions))] pub fn serialize(obj: &T) -> Vec where T: serde::Serialize, { match serde_cbor::to_vec(obj) { Ok(data) => data, Err(e) => panic!("Cannot serialize: {:?}", e), } } #[cfg(not(debug_assertions))] pub fn deserialize<'a, T>(data: &'a [u8]) -> Option where T: serde::Deserialize<'a>, { match serde_cbor::from_slice(data) { Ok(obj) => Some(obj), Err(e) => { warn!( "Error on decoding data for type '{}', {}", std::any::type_name::(), e ); None } } } #[cfg(debug_assertions)] pub fn serialize(obj: &T) -> Vec where T: serde::Serialize, { match serde_json::to_vec(obj) { Ok(data) => data, Err(e) => panic!("Cannot serialize: {:?}", e), } } #[cfg(debug_assertions)] pub fn deserialize<'a, T>(data: &'a [u8]) -> Option where T: serde::Deserialize<'a>, { let type_name = std::any::type_name::(); match serde_json::from_slice(data) { Ok(obj) => Some(obj), Err(e) => { warn!( "Error on decoding data for type '{}', {}, json: {}", type_name, e, String::from_utf8_lossy(data) ); None } } } pub fn hash(obj: &T) -> u64 where T: serde::Serialize, { let data = serialize(obj); hash_bytes(data.as_slice()) } ================================================ FILE: src/utils/time.rs ================================================ use std::time::Duration; use std::time::SystemTime; use tokio::time::sleep; pub fn get_time() -> i64 { //Get current time let current_time = SystemTime::now(); let duration = current_time.duration_since(SystemTime::UNIX_EPOCH).unwrap(); //Calculate milliseconds return duration_to_ms(duration) as i64; } pub fn duration_to_ms(duration: Duration) -> u64 { let nanos = duration.subsec_nanos() as u64; (1000 * 1000 * 1000 * duration.as_secs() + nanos) / (1000 * 1000) } pub async fn async_wait(duration: Duration) { sleep(duration).await; } pub async fn async_wait_secs() { async_wait(Duration::from_secs(2)).await; } ================================================ FILE: src/vector_clock/mod.rs ================================================ use bifrost_hasher::hash_str; use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq, Hash)] pub enum Relation { Equal, Before, After, Concurrent, } #[derive(Serialize, Deserialize, Debug, Clone, Eq, Hash)] pub struct VectorClock { map: Vec<(S, u64)>, } impl PartialOrd for VectorClock { fn partial_cmp(&self, other: &VectorClock) -> Option { let rel = self.relation(other); match rel { Relation::Before => Some(Ordering::Less), Relation::After => Some(Ordering::Greater), Relation::Equal => Some(Ordering::Equal), Relation::Concurrent => None, } } } impl Ord for VectorClock { fn cmp(&self, other: &Self) -> Ordering { let rel = self.relation(other); match rel { Relation::Before => Ordering::Less, Relation::After => Ordering::Greater, _ => Ordering::Equal, // not justified, but sufficient for BTreeSet data structure } } } impl PartialEq for VectorClock { fn eq(&self, other: &VectorClock) -> bool { let rel = self.relation(other); rel == Relation::Equal } } impl VectorClock { pub fn new() -> VectorClock { VectorClock { map: vec![] } } pub fn from_vec(vec: Vec<(S, u64)>) -> Self { Self { map: vec } } pub fn inc(&mut self, server: S) { let idx = self.map.binary_search_by_key(&server, |(k, _)| *k); match idx { Ok(idx) => { *(&mut self.map[idx].1) += 1; } Err(idx) => { self.map.insert(idx, (server, 1)); } } } pub fn happened_before(&self, clock_b: &VectorClock) -> bool { let mut ai = 0; let mut bi = 0; let al = self.map.len(); let bl = clock_b.map.len(); if al == 0 { return clock_b.map.iter().any(|(_, n)| *n > 0); } if bl == 0 { return false; } let mut a_lt_b = false; let mut b_lt_a = false; while ai < al && bi < bl { let (ak, an) = &self.map[ai]; let (bk, bn) = &clock_b.map[bi]; if ak == bk { // Two vector have the same key, compare their values ai += 1; bi += 1; if *an < *bn { a_lt_b = true; } else if *an > *bn { b_lt_a = true; } } else if ak > bk { // Clock b have a server that a does not have bi += 1; } else if ak < bk { // Clock a have a server that b does not have ai += 1; } else { unreachable!(); } } return a_lt_b && (!b_lt_a); } pub fn equals(&self, clock_b: &VectorClock) -> bool { let al = self.map.len(); let bl = clock_b.map.len(); if al == 0 && al == bl { return true; } if al != bl { if al == 0 { return clock_b.map.iter().all(|(_, n)| *n == 0); } if bl == 0 { return self.map.iter().all(|(_, n)| *n == 0); } } let mut ai = 0; let mut bi = 0; let mut a_eq_b = false; while ai < al && bi < bl { let (ak, an) = &self.map[ai]; let (bk, bn) = &clock_b.map[bi]; if ak == bk { // Two vector have the same key, compare their values if an != bn { return false; } a_eq_b = true; ai += 1; bi += 1; } else if ak > bk { // Clock b have a server that a does not have // b should either equal or happend after a bi += 1; } else if ak < bk { // Clock a have a server that b does not have ai += 1; } else { unreachable!(); } } return a_eq_b; } pub fn relation(&self, clock_b: &VectorClock) -> Relation { if self.equals(clock_b) { return Relation::Equal; } if self.happened_before(clock_b) { return Relation::Before; } if clock_b.happened_before(self) { return Relation::After; } return Relation::Concurrent; } pub fn merge_with(&mut self, clock_b: &VectorClock) { // merge_with is used to update counter for other servers (also learn from it) let mut ai = 0; let mut bi = 0; let al = self.map.len(); let bl = clock_b.map.len(); if bl == 0 { return; } if al == 0 { self.map = clock_b.map.clone(); return; } let mut new_map = Vec::with_capacity(self.map.len() + clock_b.map.len()); while ai < al || bi < bl { if ai >= al { ai = al - 1; } if bi >= bl { bi = bl - 1; } let (ak, an) = &self.map[ai]; let (bk, bn) = &clock_b.map[bi]; if ak == bk { // Two vector have the same key, compare their values if an < bn { new_map.push((*ak, *bn)); } else { new_map.push((*ak, *an)); } ai += 1; bi += 1; } else if ak > bk { // Clock b have a server that a does not have new_map.push((*bk, *bn)); bi += 1; } else if ak < bk { // Clock a have a server that b does not have new_map.push((*ak, *an)); ai += 1; } else { unreachable!(); } } self.map = new_map; } pub fn learn_from(&mut self, clock_b: &VectorClock) { // learn_from only insert missing servers into the clock let mut ai = 0; let mut bi = 0; let al = self.map.len(); let bl = clock_b.map.len(); if bl == 0 { return; } if al == 0 { self.map = clock_b.map.clone(); return; } let mut new_map = Vec::with_capacity(self.map.len() + clock_b.map.len()); while ai < al || bi < bl { if ai >= al { ai = al - 1; } if bi >= bl { bi = bl - 1; } let (ak, an) = &self.map[ai]; let (bk, bn) = &clock_b.map[bi]; if ak == bk { // Two vector have the same key, compare their values ai += 1; bi += 1; new_map.push((*ak, *an)); } else if ak > bk { // Clock b have a server that a does not have new_map.push((*bk, *bn)); bi += 1; } else if ak < bk { // Clock a have a server that b does not have new_map.push((*ak, *an)); ai += 1; } else { unreachable!(); } } self.map = new_map; } } pub struct ServerVectorClock { server: u64, clock: RwLock, } impl ServerVectorClock { pub fn new(server_address: &String) -> ServerVectorClock { ServerVectorClock { server: hash_str(server_address), clock: RwLock::new(VectorClock::new()), } } pub fn inc(&self) -> StandardVectorClock { let mut clock = self.clock.write(); clock.inc(self.server); clock.clone() } pub fn happened_before(&self, clock_b: &StandardVectorClock) -> bool { let clock = self.clock.read(); clock.happened_before(clock_b) } pub fn equals(&self, clock_b: &StandardVectorClock) -> bool { let clock = self.clock.read(); clock.equals(clock_b) } pub fn relation(&self, clock_b: &StandardVectorClock) -> Relation { let clock = self.clock.read(); clock.relation(clock_b) } pub fn merge_with(&self, clock_b: &StandardVectorClock) { let mut clock = self.clock.write(); clock.merge_with(clock_b) } pub fn learn_from(&self, clock_b: &StandardVectorClock) { let mut clock = self.clock.write(); clock.learn_from(clock_b) } pub fn to_clock(&self) -> StandardVectorClock { let clock = self.clock.read(); clock.clone() } } pub type StandardVectorClock = VectorClock; #[cfg(test)] mod test { use crate::vector_clock::{Relation, StandardVectorClock}; #[test] fn general() { let _ = env_logger::try_init(); let mut clock = StandardVectorClock::new(); let blank_clock = StandardVectorClock::new(); clock.inc(1); clock.inc(3); let old_clock = clock.clone(); clock.inc(1); clock.inc(2); info!("{:?}", clock.relation(&blank_clock)); assert!(clock > blank_clock); assert!(blank_clock < clock); assert!(blank_clock != clock); assert!( old_clock.happened_before(&clock), "old {:?}, new {:?}", old_clock, clock ); assert!( !clock.happened_before(&old_clock), "old {:?}, new {:?}", old_clock, clock ); assert!( !clock.equals(&old_clock), "old {:?}, new {:?}", old_clock, clock ); assert_eq!( clock.relation(&old_clock), Relation::After, "old {:?}, new {:?}", old_clock, clock ); assert_eq!( old_clock.relation(&clock), Relation::Before, "old {:?}, new {:?}", old_clock, clock ); let blank_clock_2 = StandardVectorClock::new(); assert!(blank_clock == blank_clock_2); } #[test] fn unaligned_clock_eq() { let _ = env_logger::try_init(); let clock_a = StandardVectorClock::from_vec(vec![(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]); let clock_b = StandardVectorClock::from_vec(vec![(2, 3), (4, 5)]); assert!(clock_a.equals(&clock_b)); assert!(clock_b.equals(&clock_a)); assert!(!clock_a.happened_before(&clock_b)); assert!(!clock_b.happened_before(&clock_a)); assert_eq!(clock_a.relation(&clock_b), Relation::Equal); } #[test] fn unaligned_clock_rel_disjoint_concurrent() { let _ = env_logger::try_init(); let clock_a = StandardVectorClock::from_vec(vec![(1, 2), (3, 4), (5, 6)]); let clock_b = StandardVectorClock::from_vec(vec![(0, 1), (2, 3), (7, 8), (9, 10)]); assert!(!clock_a.equals(&clock_b)); assert!(!clock_b.equals(&clock_a)); assert!(!clock_a.happened_before(&clock_b)); assert!(!clock_b.happened_before(&clock_a)); assert_eq!(clock_a.relation(&clock_b), Relation::Concurrent); } #[test] fn unaligned_clock_rel_joint_concurrent() { let _ = env_logger::try_init(); let clock_a = StandardVectorClock::from_vec(vec![(1, 2), (3, 4)]); let clock_b = StandardVectorClock::from_vec(vec![(1, 3), (3, 3)]); assert!(!clock_a.equals(&clock_b)); assert!(!clock_b.equals(&clock_a)); assert!(!clock_a.happened_before(&clock_b)); assert!(!clock_b.happened_before(&clock_a)); assert_eq!(clock_a.relation(&clock_b), Relation::Concurrent); } #[test] fn test_merge_with() { let _ = env_logger::try_init(); let mut clock_a = StandardVectorClock::from_vec(vec![(1, 2), (3, 4)]); let clock_b = StandardVectorClock::from_vec(vec![(1, 5), (2, 3), (3, 1)]); clock_a.merge_with(&clock_b); // After merge, clock_a should have max values assert_eq!( clock_a, StandardVectorClock::from_vec(vec![(1, 5), (2, 3), (3, 4)]) ); } #[test] fn test_merge_with_empty() { let mut clock_a = StandardVectorClock::from_vec(vec![(1, 2), (3, 4)]); let clock_b = StandardVectorClock::new(); let expected = clock_a.clone(); clock_a.merge_with(&clock_b); assert_eq!(clock_a, expected); } #[test] fn test_merge_with_into_empty() { let mut clock_a = StandardVectorClock::new(); let clock_b = StandardVectorClock::from_vec(vec![(1, 2), (3, 4)]); clock_a.merge_with(&clock_b); assert_eq!(clock_a, clock_b); } #[test] fn test_learn_from() { let mut clock_a = StandardVectorClock::from_vec(vec![(1, 5)]); let clock_b = StandardVectorClock::from_vec(vec![(1, 2)]); clock_a.learn_from(&clock_b); // After learn_from, clock_a keeps its own values for existing keys assert_eq!(clock_a, StandardVectorClock::from_vec(vec![(1, 5)])); } #[test] fn test_learn_from_empty() { let mut clock_a = StandardVectorClock::from_vec(vec![(1, 2), (3, 4)]); let clock_b = StandardVectorClock::new(); let expected = clock_a.clone(); clock_a.learn_from(&clock_b); assert_eq!(clock_a, expected); } #[test] fn test_learn_from_into_empty() { let mut clock_a = StandardVectorClock::new(); let clock_b = StandardVectorClock::from_vec(vec![(1, 2), (3, 4)]); clock_a.learn_from(&clock_b); assert_eq!(clock_a, clock_b); } #[test] fn test_server_vector_clock() { use super::ServerVectorClock; let addr = String::from("127.0.0.1:8080"); let svc = ServerVectorClock::new(&addr); // Test inc let clock1 = svc.inc(); let clock2 = svc.inc(); assert!(clock1.happened_before(&clock2)); assert_eq!(clock1.relation(&clock2), Relation::Before); // Test happened_before with clock1 (svc is now at clock2, so clock1 is before svc) assert!(clock1.happened_before(&svc.to_clock())); // Test equals - svc should equal clock2 assert!(svc.equals(&clock2)); // Test to_clock let current = svc.to_clock(); assert!(!current.map.is_empty()); } #[test] fn test_server_vector_clock_relation() { use super::ServerVectorClock; let addr = String::from("127.0.0.1:9090"); let svc = ServerVectorClock::new(&addr); let clock1 = svc.inc(); let external_clock = StandardVectorClock::new(); let rel = svc.relation(&external_clock); assert_eq!(rel, Relation::After); } #[test] fn test_inc_new_server() { let mut clock = StandardVectorClock::new(); clock.inc(10); clock.inc(5); clock.inc(10); // Should be sorted and properly counted assert_eq!(clock, StandardVectorClock::from_vec(vec![(5, 1), (10, 2)])); } #[test] fn test_partial_ord() { let clock_a = StandardVectorClock::from_vec(vec![(1, 1)]); let clock_b = StandardVectorClock::from_vec(vec![(1, 2)]); let clock_c = StandardVectorClock::from_vec(vec![(2, 1)]); assert!(clock_a < clock_b); assert!(clock_b > clock_a); assert!(clock_a.partial_cmp(&clock_c).is_none()); // Concurrent } #[test] fn test_ord() { use std::cmp::Ordering; let clock_a = StandardVectorClock::from_vec(vec![(1, 1)]); let clock_b = StandardVectorClock::from_vec(vec![(1, 2)]); let clock_c = StandardVectorClock::from_vec(vec![(2, 1)]); assert_eq!(clock_a.cmp(&clock_b), Ordering::Less); assert_eq!(clock_b.cmp(&clock_a), Ordering::Greater); assert_eq!(clock_a.cmp(&clock_c), Ordering::Equal); // Concurrent treated as equal } #[test] fn test_happened_before_empty_clocks() { let clock_a = StandardVectorClock::new(); let mut clock_b = StandardVectorClock::new(); clock_b.inc(1); assert!(clock_a.happened_before(&clock_b)); assert!(!clock_b.happened_before(&clock_a)); } #[test] fn test_equals_empty_clocks() { let clock_a = StandardVectorClock::new(); let clock_b = StandardVectorClock::new(); assert!(clock_a.equals(&clock_b)); } #[test] fn test_equals_with_zero_values() { let clock_a = StandardVectorClock::new(); let clock_b = StandardVectorClock::from_vec(vec![(1, 0), (2, 0)]); assert!(clock_a.equals(&clock_b)); assert!(clock_b.equals(&clock_a)); } } ================================================ FILE: tests/graceful_shutdown_tests.rs ================================================ /// Tests for graceful shutdown functionality /// /// These tests verify that: /// 1. Servers actually shut down when shutdown() is called /// 2. Ports/addresses are released and can be reused /// 3. Background tasks stop cleanly use bifrost::raft::{Options, RaftService, Storage, DEFAULT_SERVICE_ID}; use bifrost::rpc::Server; use bifrost::tcp; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; use tokio::time::{sleep, timeout}; /// Test that TCP server releases the port after shutdown #[tokio::test(flavor = "multi_thread")] async fn test_tcp_server_shutdown_releases_port() { let _ = env_logger::try_init(); let address = "127.0.0.1:19001".to_string(); // Start first TCP server let tcp_server = Arc::new(tcp::server::Server::new()); let tcp_server_clone = tcp_server.clone(); let addr_clone = address.clone(); let handle = tokio::spawn(async move { tcp_server_clone .listen(&addr_clone, Arc::new(|data| Box::pin(async move { data }))) .await .unwrap(); }); // Give it time to bind sleep(Duration::from_millis(500)).await; // Verify server is listening by connecting to it let connect_result = timeout(Duration::from_secs(2), TcpStream::connect(&address)).await; assert!( connect_result.is_ok(), "Should be able to connect to server" ); // Shutdown the server tcp_server.shutdown(); // Wait for shutdown to complete let shutdown_result = timeout(Duration::from_secs(5), handle).await; assert!( shutdown_result.is_ok(), "Server should shut down within 5 seconds" ); // Give a moment for the OS to release the port sleep(Duration::from_millis(500)).await; // Verify we can start a new server on the same port let tcp_server2 = Arc::new(tcp::server::Server::new()); let tcp_server2_clone = tcp_server2.clone(); let addr_clone2 = address.clone(); let handle2 = tokio::spawn(async move { let result = tcp_server2_clone .listen(&addr_clone2, Arc::new(|data| Box::pin(async move { data }))) .await; assert!( result.is_ok(), "Should be able to bind to the same port after shutdown" ); }); // Give it time to bind sleep(Duration::from_millis(500)).await; // Verify second server is listening let connect_result2 = timeout(Duration::from_secs(2), TcpStream::connect(&address)).await; assert!( connect_result2.is_ok(), "Should be able to connect to new server on same port" ); // Cleanup tcp_server2.shutdown(); let _ = timeout(Duration::from_secs(5), handle2).await; } /// Test that RPC server releases the port after shutdown #[tokio::test(flavor = "multi_thread")] async fn test_rpc_server_shutdown_releases_port() { let _ = env_logger::try_init(); let address = "127.0.0.1:19002".to_string(); // Start first RPC server let server1 = Server::new(&address); Server::listen_and_resume(&server1).await; // Verify server is listening let connect_result = timeout(Duration::from_secs(2), TcpStream::connect(&address)).await; assert!( connect_result.is_ok(), "Should be able to connect to RPC server" ); // Shutdown the server server1.shutdown().await; // Give time for shutdown to complete and port to be released sleep(Duration::from_millis(1000)).await; // Verify we can start a new server on the same port let server2 = Server::new(&address); Server::listen_and_resume(&server2).await; // Verify second server is listening let connect_result2 = timeout(Duration::from_secs(2), TcpStream::connect(&address)).await; assert!( connect_result2.is_ok(), "Should be able to connect to new RPC server on same port" ); // Cleanup server2.shutdown().await; sleep(Duration::from_millis(500)).await; } /// Test that Raft service stops its background tasks after shutdown /// Note: Ignored because RaftService contains a nested tokio runtime which cannot /// be safely dropped within another tokio test runtime context. /// The full_stack_shutdown test covers Raft shutdown in a working configuration. #[tokio::test(flavor = "multi_thread")] #[ignore = "RaftService nested runtime causes drop issues in test context"] async fn test_raft_service_shutdown_stops_tasks() { let _ = env_logger::try_init(); let address = "127.0.0.1:19003".to_string(); // Use scope to ensure proper cleanup { // Create and start Raft service let raft_service = RaftService::new(Options { storage: Storage::MEMORY, address: address.clone(), service_id: DEFAULT_SERVICE_ID, }); // Give initialization more time sleep(Duration::from_millis(100)).await; let started = RaftService::start(&raft_service, false).await; if !started { println!("Warning: Raft service failed to start, skipping test"); return; // Skip this test if it fails to start } // Bootstrap the cluster raft_service.bootstrap().await; // Give it time to run and stabilize sleep(Duration::from_millis(1000)).await; // Verify service is running by checking leader status assert!(raft_service.is_leader(), "Should be leader after bootstrap"); // Shutdown the service let shutdown_start = std::time::Instant::now(); raft_service.shutdown().await; let shutdown_duration = shutdown_start.elapsed(); // Verify shutdown completed in reasonable time (< 5 seconds) assert!( shutdown_duration < Duration::from_secs(5), "Shutdown should complete within 5 seconds, took {:?}", shutdown_duration ); // Verify service is no longer leader (membership should be Offline) assert!( !raft_service.is_leader(), "Should not be leader after shutdown" ); } // raft_service drops here println!("Test completed successfully"); } /// Full integration test: Start everything, shutdown, verify port is released /// Note: This test uses scoped drops to avoid runtime drop issues #[tokio::test(flavor = "multi_thread")] async fn test_full_stack_shutdown_releases_port() { let _ = env_logger::try_init(); let address = "127.0.0.1:19004".to_string(); // Scope 1: Create and start full stack, then shut it down { let raft_service = RaftService::new(Options { storage: Storage::MEMORY, address: address.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&address); Server::listen_and_resume(&server).await; server.register_service(&raft_service).await; let started = RaftService::start(&raft_service, false).await; assert!(started, "Raft service should start"); raft_service.bootstrap().await; // Verify everything is running sleep(Duration::from_millis(500)).await; let connect_result = timeout(Duration::from_secs(2), TcpStream::connect(&address)).await; assert!( connect_result.is_ok(), "Should be able to connect to server" ); assert!(raft_service.is_leader(), "Should be leader"); // Shutdown in reverse order (service first, then server) println!("Shutting down Raft service..."); raft_service.shutdown().await; println!("Shutting down RPC server..."); server.shutdown().await; // Give time for everything to shut down sleep(Duration::from_millis(1000)).await; } // raft_service and server drop here // Give OS time to fully release the port sleep(Duration::from_millis(500)).await; // Scope 2: Start new server on same port to verify it's released { println!("Starting new server on same port..."); let server2 = Server::new(&address); Server::listen_and_resume(&server2).await; // Verify new server is listening sleep(Duration::from_millis(500)).await; let connect_result2 = timeout(Duration::from_secs(2), TcpStream::connect(&address)).await; assert!( connect_result2.is_ok(), "Should be able to connect to new server on same port" ); // Cleanup server2.shutdown().await; sleep(Duration::from_millis(500)).await; } // server2 drops here println!("Test completed successfully"); } /// Test multiple rapid shutdown/restart cycles #[tokio::test(flavor = "multi_thread")] async fn test_rapid_shutdown_restart_cycles() { let _ = env_logger::try_init(); let address = "127.0.0.1:19005".to_string(); for i in 0..3 { println!("Cycle {}", i + 1); // Start server let server = Server::new(&address); Server::listen_and_resume(&server).await; // Verify it's listening sleep(Duration::from_millis(300)).await; let connect_result = timeout(Duration::from_secs(2), TcpStream::connect(&address)).await; assert!( connect_result.is_ok(), "Cycle {}: Should be able to connect", i + 1 ); // Shutdown server.shutdown().await; sleep(Duration::from_millis(500)).await; } println!("All cycles completed successfully"); } /// Test that connections are closed cleanly during shutdown #[tokio::test(flavor = "multi_thread")] async fn test_active_connections_close_on_shutdown() { let _ = env_logger::try_init(); let address = "127.0.0.1:19006".to_string(); // Start server let server = Server::new(&address); Server::listen_and_resume(&server).await; sleep(Duration::from_millis(300)).await; // Open multiple connections let mut connections = Vec::new(); for _ in 0..5 { let stream = TcpStream::connect(&address).await; assert!(stream.is_ok(), "Should be able to connect"); connections.push(stream.unwrap()); } println!("Opened {} connections", connections.len()); // Shutdown server server.shutdown().await; // Give a moment for shutdown to propagate sleep(Duration::from_millis(500)).await; // Verify we cannot open new connections let new_connect = timeout(Duration::from_secs(1), TcpStream::connect(&address)).await; assert!( new_connect.is_err() || new_connect.unwrap().is_err(), "Should not be able to connect after shutdown" ); println!("Verified server is no longer accepting connections"); } /// Test shutdown timeout behavior #[tokio::test(flavor = "multi_thread")] async fn test_shutdown_completes_within_timeout() { let _ = env_logger::try_init(); let address = "127.0.0.1:19007".to_string(); // Create full stack let raft_service = RaftService::new(Options { storage: Storage::MEMORY, address: address.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&address); Server::listen_and_resume(&server).await; server.register_service(&raft_service).await; RaftService::start(&raft_service, false).await; raft_service.bootstrap().await; sleep(Duration::from_millis(500)).await; // Shutdown with timeout let shutdown_result = timeout(Duration::from_secs(10), async { raft_service.shutdown().await; server.shutdown().await; }) .await; assert!( shutdown_result.is_ok(), "Shutdown should complete within 10 seconds" ); } ================================================ FILE: tests/single_node_recovery_test.rs ================================================ use bifrost::raft::disk::DiskOptions; /// Test for single-node Raft cluster recovery from disk /// /// This test verifies the fix for the bug where single-node clusters /// failed to elect themselves as leader after recovering from persistent storage. use bifrost::raft::{client::RaftClient, Options, RaftService, Storage, DEFAULT_SERVICE_ID}; use bifrost::rpc::Server; use std::time::Duration; use tokio::time::sleep; #[tokio::test(flavor = "multi_thread")] async fn test_single_node_cluster_recovery_becomes_leader() { let _ = env_logger::try_init(); // Use test-specific directory let data_path = "/tmp/bifrost_test_single_node_recovery_18000".to_string(); // Clean up any existing data from previous runs let _ = std::fs::remove_dir_all(&data_path); std::fs::create_dir_all(&data_path).unwrap(); let address = "127.0.0.1:18000".to_string(); println!("=== Phase 1: Create initial single-node cluster ==="); let initial_leader_id; // Phase 1: Create and run a single-node cluster with disk storage { let raft_service = RaftService::new(Options { storage: Storage::DISK(DiskOptions { path: data_path.clone(), take_snapshots: true, append_logs: true, trim_logs: false, snapshot_log_threshold: 5, log_compaction_threshold: 10, }), address: address.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&address); Server::listen_and_resume(&server).await; server.register_service(&raft_service).await; // Start and bootstrap let started = RaftService::start(&raft_service, true).await; assert!(started, "Phase 1: Should start successfully"); raft_service.bootstrap().await; sleep(Duration::from_millis(500)).await; // Verify it's a leader assert!(raft_service.is_leader(), "Phase 1: Should be leader"); initial_leader_id = raft_service.leader_id().await; assert!( initial_leader_id != 0, "Phase 1: Should have valid leader ID" ); println!( "Phase 1: Leader ID = {}, Server ID = {}", initial_leader_id, raft_service.id ); // Perform some operations to generate logs let client = RaftClient::new(&vec![address.clone()], DEFAULT_SERVICE_ID) .await .unwrap(); println!("Phase 1: Created Raft client"); // Generate some logs by performing state machine operations // Use the config state machine to add/remove a dummy member use bifrost::raft::state_machine::configs::commands; // Add a dummy member (will generate a log entry) let result1 = client .execute( bifrost::raft::state_machine::configs::CONFIG_SM_ID, commands::new_member_::new(&"dummy:9999".to_string()), ) .await; println!("Phase 1: Added dummy member: {:?}", result1.is_ok()); // Remove the dummy member (another log entry) let result2 = client .execute( bifrost::raft::state_machine::configs::CONFIG_SM_ID, commands::del_member_::new(&"dummy:9999".to_string()), ) .await; println!("Phase 1: Removed dummy member: {:?}", result2.is_ok()); // Wait for persistence sleep(Duration::from_millis(1000)).await; // Check that logs exist let num_logs = raft_service.num_logs().await; println!("Phase 1: Number of logs: {}", num_logs); assert!(num_logs > 0, "Phase 1: Should have generated some logs"); // Shutdown gracefully println!("Phase 1: Shutting down..."); drop(client); // Drop client first raft_service.shutdown().await; server.shutdown().await; sleep(Duration::from_secs(2)).await; // Prevent runtime drop panic in test context std::mem::forget(raft_service); std::mem::forget(server); } // Give OS time to release resources sleep(Duration::from_millis(500)).await; println!("\n=== Phase 2: Restart and recover from disk ==="); // Phase 2: Restart the server with same storage - THIS IS THE BUG FIX TEST { let raft_service2 = RaftService::new(Options { storage: Storage::DISK(DiskOptions { path: data_path.clone(), // Same path - will recover state take_snapshots: true, append_logs: true, trim_logs: false, snapshot_log_threshold: 5, log_compaction_threshold: 10, }), address: address.clone(), // Same address service_id: DEFAULT_SERVICE_ID, }); let server2 = Server::new(&address); Server::listen_and_resume(&server2).await; server2.register_service(&raft_service2).await; // Start - should recover and immediately become leader let started2 = RaftService::start(&raft_service2, true).await; assert!(started2, "Phase 2: Should start successfully"); // Give it a moment to stabilize sleep(Duration::from_millis(1000)).await; // THE KEY TEST: Should be leader immediately after recovery let is_leader = raft_service2.is_leader(); println!("Phase 2: Is leader? {}", is_leader); let leader_id2 = raft_service2.leader_id().await; println!( "Phase 2: Leader ID = {}, Server ID = {}", leader_id2, raft_service2.id ); assert!( is_leader, "Phase 2: CRITICAL - Should be leader after recovery (single-node cluster)" ); assert!( leader_id2 != 0, "Phase 2: Should have valid leader ID after recovery" ); assert_eq!( leader_id2, raft_service2.id, "Phase 2: Should be its own leader" ); // Verify client can connect successfully let client2_result = RaftClient::new(&vec![address.clone()], DEFAULT_SERVICE_ID).await; assert!( client2_result.is_ok(), "Phase 2: Should be able to create RaftClient (leader should be elected)" ); println!("Phase 2: ✅ Single-node cluster successfully recovered and became leader!"); // Cleanup raft_service2.shutdown().await; server2.shutdown().await; sleep(Duration::from_millis(500)).await; // Prevent runtime drop panic in test context std::mem::forget(raft_service2); std::mem::forget(server2); } // Cleanup test directory let _ = std::fs::remove_dir_all(&data_path); println!("\n✅ TEST PASSED: Single-node cluster recovery works correctly!"); } #[tokio::test(flavor = "multi_thread")] async fn test_single_node_multiple_restart_cycles() { let _ = env_logger::try_init(); let data_path = "/tmp/bifrost_test_single_node_cycles_18001".to_string(); // Clean up any existing data from previous runs let _ = std::fs::remove_dir_all(&data_path); std::fs::create_dir_all(&data_path).unwrap(); let address = "127.0.0.1:18001".to_string(); // Perform 3 restart cycles for cycle in 1..=3 { println!("\n=== Cycle {} ===", cycle); let raft_service = RaftService::new(Options { storage: Storage::DISK(DiskOptions { path: data_path.clone(), take_snapshots: true, append_logs: true, trim_logs: false, snapshot_log_threshold: 5, log_compaction_threshold: 10, }), address: address.clone(), service_id: DEFAULT_SERVICE_ID, }); let server = Server::new(&address); Server::listen_and_resume(&server).await; server.register_service(&raft_service).await; let started = RaftService::start(&raft_service, true).await; assert!(started, "Cycle {}: Should start", cycle); if cycle == 1 { // First cycle: bootstrap raft_service.bootstrap().await; } sleep(Duration::from_millis(1000)).await; // Should be leader in all cycles assert!( raft_service.is_leader(), "Cycle {}: Should be leader", cycle ); let leader_id = raft_service.leader_id().await; assert!( leader_id != 0, "Cycle {}: Should have valid leader ID", cycle ); // Verify client works and generate some logs let client = RaftClient::new(&vec![address.clone()], DEFAULT_SERVICE_ID).await; assert!(client.is_ok(), "Cycle {}: RaftClient should connect", cycle); // Generate logs to ensure persistence if let Ok(ref client) = client { use bifrost::raft::state_machine::configs::commands; let dummy_addr = format!("dummy{}:9999", cycle); let _ = client .execute( bifrost::raft::state_machine::configs::CONFIG_SM_ID, commands::new_member_::new(&dummy_addr), ) .await; let _ = client .execute( bifrost::raft::state_machine::configs::CONFIG_SM_ID, commands::del_member_::new(&dummy_addr), ) .await; sleep(Duration::from_millis(500)).await; // Wait for persistence } let num_logs = raft_service.num_logs().await; println!( "Cycle {}: ✅ Leader elected successfully, {} logs", cycle, num_logs ); // Shutdown raft_service.shutdown().await; server.shutdown().await; sleep(Duration::from_secs(2)).await; // Prevent runtime drop panic in test context std::mem::forget(raft_service); std::mem::forget(server); } // Cleanup test directory let _ = std::fs::remove_dir_all(&data_path); println!("\n✅ All 3 restart cycles passed!"); }