Repository: krojew/cdrs-tokio Branch: master Commit: 27d38d2ad9bb Files: 155 Total size: 988.6 KB Directory structure: gitextract_xk4fwcia/ ├── .github/ │ ├── FUNDING.yml │ ├── dependabot.yml │ ├── stale.yml │ └── workflows/ │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── cassandra-ports.txt ├── cassandra-protocol/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ ├── authenticators.rs │ ├── compression.rs │ ├── consistency.rs │ ├── crc.rs │ ├── error.rs │ ├── events.rs │ ├── frame/ │ │ ├── events.rs │ │ ├── frame_decoder.rs │ │ ├── frame_encoder.rs │ │ ├── message_auth_challenge.rs │ │ ├── message_auth_response.rs │ │ ├── message_auth_success.rs │ │ ├── message_authenticate.rs │ │ ├── message_batch.rs │ │ ├── message_error.rs │ │ ├── message_event.rs │ │ ├── message_execute.rs │ │ ├── message_options.rs │ │ ├── message_prepare.rs │ │ ├── message_query.rs │ │ ├── message_ready.rs │ │ ├── message_register.rs │ │ ├── message_request.rs │ │ ├── message_response.rs │ │ ├── message_result.rs │ │ ├── message_startup.rs │ │ ├── message_supported.rs │ │ └── traits.rs │ ├── frame.rs │ ├── lib.rs │ ├── macros.rs │ ├── query/ │ │ ├── batch_query_builder.rs │ │ ├── prepare_flags.rs │ │ ├── prepared_query.rs │ │ ├── query_flags.rs │ │ ├── query_params.rs │ │ ├── query_params_builder.rs │ │ ├── query_values.rs │ │ └── utils.rs │ ├── query.rs │ ├── token.rs │ ├── types/ │ │ ├── blob.rs │ │ ├── cassandra_type.rs │ │ ├── data_serialization_types.rs │ │ ├── decimal.rs │ │ ├── duration.rs │ │ ├── from_cdrs.rs │ │ ├── list.rs │ │ ├── map.rs │ │ ├── rows.rs │ │ ├── tuple.rs │ │ ├── udt.rs │ │ ├── value.rs │ │ └── vector.rs │ └── types.rs ├── cdrs-tokio/ │ ├── Cargo.toml │ ├── examples/ │ │ ├── README.md │ │ ├── crud_operations.rs │ │ ├── generic_connection.rs │ │ ├── insert_collection.rs │ │ ├── multiple_thread.rs │ │ ├── paged_query.rs │ │ └── prepare_batch_execute.rs │ ├── src/ │ │ ├── cluster/ │ │ │ ├── cluster_metadata_manager.rs │ │ │ ├── config_proxy.rs │ │ │ ├── config_rustls.rs │ │ │ ├── config_tcp.rs │ │ │ ├── connection_manager.rs │ │ │ ├── connection_pool.rs │ │ │ ├── control_connection.rs │ │ │ ├── keyspace_holder.rs │ │ │ ├── metadata_builder.rs │ │ │ ├── node_address.rs │ │ │ ├── node_info.rs │ │ │ ├── pager.rs │ │ │ ├── rustls_connection_manager.rs │ │ │ ├── send_envelope.rs │ │ │ ├── session.rs │ │ │ ├── session_context.rs │ │ │ ├── tcp_connection_manager.rs │ │ │ ├── token_map.rs │ │ │ ├── topology/ │ │ │ │ ├── cluster_metadata.rs │ │ │ │ ├── datacenter_metadata.rs │ │ │ │ ├── keyspace_metadata.rs │ │ │ │ ├── node.rs │ │ │ │ ├── node_distance.rs │ │ │ │ ├── node_state.rs │ │ │ │ └── replication_strategy.rs │ │ │ └── topology.rs │ │ ├── cluster.rs │ │ ├── envelope_parser.rs │ │ ├── frame_encoding.rs │ │ ├── future.rs │ │ ├── lib.rs │ │ ├── load_balancing/ │ │ │ ├── initializing_wrapper.rs │ │ │ ├── node_distance_evaluator.rs │ │ │ ├── random.rs │ │ │ ├── request.rs │ │ │ ├── round_robin.rs │ │ │ └── topology_aware.rs │ │ ├── load_balancing.rs │ │ ├── macros.rs │ │ ├── retry/ │ │ │ ├── reconnection_policy.rs │ │ │ └── retry_policy.rs │ │ ├── retry.rs │ │ ├── speculative_execution.rs │ │ ├── statement/ │ │ │ ├── statement_params.rs │ │ │ └── statement_params_builder.rs │ │ ├── statement.rs │ │ └── transport.rs │ └── tests/ │ ├── collection_types.rs │ ├── common.rs │ ├── compression.rs │ ├── derive_traits.rs │ ├── keyspace.rs │ ├── multi_node_speculative_execution.rs │ ├── multithread.rs │ ├── native_types.rs │ ├── paged_query.rs │ ├── query_values.rs │ ├── single_node_speculative_execution.rs │ ├── topology_aware.rs │ ├── tuple_types.rs │ └── user_defined_types.rs ├── cdrs-tokio-helpers-derive/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ ├── common.rs │ ├── db_mirror.rs │ ├── into_cdrs_value.rs │ ├── lib.rs │ ├── try_from_row.rs │ └── try_from_udt.rs ├── changelog.md ├── clippy.toml ├── documentation/ │ ├── README.md │ ├── batching-multiple-queries.md │ ├── cdrs-session.md │ ├── cluster-configuration.md │ ├── deserialization.md │ ├── preparing-and-executing-queries.md │ ├── query-values.md │ └── type-mapping.md └── rustfmt.toml ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: krojew ================================================ FILE: .github/dependabot.yml ================================================ # To get started with Dependabot version updates, you'll need to specify which # package ecosystems to update and where the package manifests are located. # Please see the documentation for all configuration options: # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates version: 2 updates: - package-ecosystem: "cargo" directory: "/" schedule: interval: "daily" ================================================ FILE: .github/stale.yml ================================================ # Number of days of inactivity before an issue becomes stale daysUntilStale: 60 # Number of days of inactivity before a stale issue is closed daysUntilClose: 7 # Issues with these labels will never be considered stale exemptLabels: - pinned - security # Label to use when marking an issue as stale staleLabel: wontfix # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable closeComment: false ================================================ FILE: .github/workflows/rust.yml ================================================ name: Continuous integration on: [ push, pull_request ] env: CARGO_TERM_COLOR: always jobs: test: name: Test Suite runs-on: ubuntu-latest services: cassandra: image: cassandra ports: - 9042:9042 steps: - uses: actions/checkout@v2 - name: Install minimal toolchain with clippy and rustfmt uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable components: rustfmt, clippy - name: Run tests # test threads must be one because else database tests will run in parallel and will result in flaky tests run: cargo test --all-features --verbose -- --test-threads=1 - name: Format check run: cargo fmt --all -- --check # Ensure that all targets compile and pass clippy checks under every possible combination of features - name: Clippy check run: cargo install cargo-hack && cargo hack --feature-powerset clippy --locked --release ================================================ FILE: .gitignore ================================================ Cargo.lock target *.bk .idea/ cdrs.iml .vscode/ ================================================ FILE: Cargo.toml ================================================ [workspace] members = [ "cassandra-protocol", "cdrs-tokio", "cdrs-tokio-helpers-derive" ] [workspace.dependencies] arc-swap = "1.7.1" uuid = "1.19.0" derivative = "2.2.0" derive_more = { version = "2.1.0", features = ["constructor", "display"] } itertools = "0.14.0" thiserror = "2.0.17" ================================================ FILE: LICENSE-APACHE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2017 CDRS Project Developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: LICENSE-MIT ================================================ The MIT License (MIT) Copyright (c) 2016 CDRS Project Developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # CDRS tokio [![crates.io version](https://img.shields.io/crates/v/cdrs-tokio.svg)](https://crates.io/crates/cdrs-tokio) ![build status](https://github.com/krojew/cdrs-tokio/actions/workflows/rust.yml/badge.svg) ![CDRS tokio - async Apache Cassandra driver using tokio](./cdrs-logo.png) CDRS is production-ready Apache **C**assandra **d**river written in pure **R**u* *s**t. Focuses on providing high level of configurability to suit most use cases at any scale, as its Java counterpart, while also leveraging the safety and performance of Rust. ## Features - Asynchronous API; - TCP/TLS connection (rustls); - Topology-aware dynamic and configurable load balancing; - Configurable connection strategies and pools; - Configurable speculative execution; - LZ4, Snappy compression; - Cassandra-to-Rust data serialization/deserialization with custom type support; - Pluggable authentication strategies; - [ScyllaDB](https://www.scylladb.com/) support; - Server events listening; - Multiple CQL version support (3, 4, 5), full spec implementation; - Query tracing information; - Prepared statements; - Query paging; - Batch statements; - Configurable retry and reconnection policy; - Support for interleaved queries; - Support for Yugabyte YCQL JSONB; - Support for beta protocol usage; ## Performance Due to high configurability of **CDRS**, the performance will vary depending on use case. The following benchmarks have been made against the latest (master as of 03-12-2021) versions of respective libraries (except cassandra-cpp: 2.16.0) and protocol version 4. - `cdrs-tokio-large-pool` - **CDRS** with node connection pool equal to double of physical CPU cores - `cdrs-tokio-small-pool` - **CDRS** with a single connection per node - `scylladb-rust-large-pool` - `scylla` crate with node connection pool equal to double of physical CPU cores - `scylladb-rust-small-pool` - `scylla` crate with a single connection per node - `cassandra-cpp` - Rust bindings for Datastax C++ Driver, running on multiple threads using Tokio - `gocql` - a driver written in Go insert benchmark select benchmark mixed benchmark Knowing given use case, CDRS can be optimized for peak performance. ## Documentation and examples - [User guide](./documentation). - [Examples](./cdrs-tokio/examples). - [API docs](https://docs.rs/cdrs-tokio/latest/cdrs_tokio/). - Using ScyllaDB with RUST [lesson](https://university.scylladb.com/courses/using-scylla-drivers/lessons/rust-and-scylla/). ## Getting started This example configures a cluster consisting of a single node without authentication, and uses round-robin load balancing. Other options are kept as default. ```rust use cdrs_tokio::cluster::session::{TcpSessionBuilder, SessionBuilder}; use cdrs_tokio::cluster::NodeTcpConfigBuilder; use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; use cdrs_tokio::query::*; #[tokio::main] async fn main() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .build() .await .unwrap(); let session = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), cluster_config) .build() .await .unwrap(); let create_ks = "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; session .query(create_ks) .await .expect("Keyspace create error"); } ``` ## License This project is licensed under either of - Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)) - MIT license ([LICENSE-MIT](LICENSE-MIT) or [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)) at your option. ================================================ FILE: cassandra-ports.txt ================================================ Cassandra ports: * 7199 - JMX (was 8080 pre Cassandra 0.8.xx) * 7000 - Internode communication (not used if TLS enabled) * 7001 - TLS Internode communication (used if TLS enabled) * 9160 - Thrift client API * 9042 - CQL native transport port ================================================ FILE: cassandra-protocol/Cargo.toml ================================================ [package] name = "cassandra-protocol" version = "4.0.0" authors = ["Alex Pikalov ", "Kamil Rojewski "] edition = "2018" description = "Cassandra protocol implementation" documentation = "https://docs.rs/cassandra-protocol" homepage = "https://github.com/krojew/cdrs-tokio" repository = "https://github.com/krojew/cdrs-tokio" keywords = ["cassandra", "client", "cassandradb"] license = "MIT/Apache-2.0" categories = ["asynchronous", "database"] rust-version = "1.74" [features] e2e-tests = [] [dependencies] arc-swap.workspace = true bitflags = "2.10.0" bytes = "1.11.0" chrono = { version = "0.4.31", default-features = false, features = ["std"] } crc32fast = "1.5.0" derivative.workspace = true derive_more.workspace = true float_eq = "1.0.1" integer-encoding = "4.1.0" itertools.workspace = true num-bigint = "0.4.1" lz4_flex = "0.13.0" snap = "1.1.0" thiserror.workspace = true time = { version = "0.3.29", features = ["macros"] } uuid.workspace = true ================================================ FILE: cassandra-protocol/README.md ================================================ # Cassandra protocol [![crates.io version](https://img.shields.io/crates/v/cassandra-protocol.svg)](https://crates.io/crates/cassandra-protocol) ![build status](https://github.com/krojew/cdrs-tokio/actions/workflows/rust.yml/badge.svg) **Cassandra** low-level protocol implementation, written in Rust. If you wish to use **Cassandra** without dealing with protocol-level details, consider a high-level crate like **[cdrs-tokio](https://crates.io/crates/cdrs-tokio)**. ================================================ FILE: cassandra-protocol/src/authenticators.rs ================================================ use crate::error::Result; use crate::types::CBytes; /// Handles SASL authentication. /// /// The lifecycle of an authenticator consists of: /// - The `initial_response` function will be called. The initial return value will be sent to the /// server to initiate the handshake. /// - The server will respond to each client response by either issuing a challenge or indicating /// that the authentication is complete (successfully or not). If a new challenge is issued, /// the authenticator's `evaluate_challenge` function will be called to produce a response /// that will be sent to the server. This challenge/response negotiation will continue until /// the server responds that authentication is successful or an error is raised. /// - On success, the `handle_success` will be called with data returned by the server. pub trait SaslAuthenticator { fn initial_response(&self) -> CBytes; fn evaluate_challenge(&self, challenge: CBytes) -> Result; fn handle_success(&self, data: CBytes) -> Result<()>; } /// Provides authenticators per new connection. pub trait SaslAuthenticatorProvider { fn name(&self) -> Option<&str>; fn create_authenticator(&self) -> Box; } #[derive(Debug, Clone)] pub struct StaticPasswordAuthenticator { username: String, password: String, } impl StaticPasswordAuthenticator { pub fn new(username: S, password: S) -> StaticPasswordAuthenticator { StaticPasswordAuthenticator { username: username.to_string(), password: password.to_string(), } } } impl SaslAuthenticator for StaticPasswordAuthenticator { fn initial_response(&self) -> CBytes { let mut token = vec![0]; token.extend_from_slice(self.username.as_bytes()); token.push(0); token.extend_from_slice(self.password.as_bytes()); CBytes::new(token) } fn evaluate_challenge(&self, _challenge: CBytes) -> Result { Err("Server challenge is not supported for StaticPasswordAuthenticator!".into()) } fn handle_success(&self, _data: CBytes) -> Result<()> { Ok(()) } } /// Authentication provider with a username and password. #[derive(Debug, Clone)] pub struct StaticPasswordAuthenticatorProvider { username: String, password: String, } impl SaslAuthenticatorProvider for StaticPasswordAuthenticatorProvider { fn name(&self) -> Option<&str> { Some("org.apache.cassandra.auth.PasswordAuthenticator") } fn create_authenticator(&self) -> Box { Box::new(StaticPasswordAuthenticator::new( self.username.clone(), self.password.clone(), )) } } impl StaticPasswordAuthenticatorProvider { pub fn new(username: S, password: S) -> Self { StaticPasswordAuthenticatorProvider { username: username.to_string(), password: password.to_string(), } } } #[derive(Debug, Clone)] pub struct NoneAuthenticator; impl SaslAuthenticator for NoneAuthenticator { fn initial_response(&self) -> CBytes { CBytes::new(vec![0]) } fn evaluate_challenge(&self, _challenge: CBytes) -> Result { Err("Server challenge is not supported for NoneAuthenticator!".into()) } fn handle_success(&self, _data: CBytes) -> Result<()> { Ok(()) } } /// Provider for no authentication. #[derive(Debug, Clone)] pub struct NoneAuthenticatorProvider; impl SaslAuthenticatorProvider for NoneAuthenticatorProvider { fn name(&self) -> Option<&str> { None } fn create_authenticator(&self) -> Box { Box::new(NoneAuthenticator) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_password_authenticator_new() { StaticPasswordAuthenticator::new("foo", "bar"); } #[test] fn test_static_password_authenticator_cassandra_name() { let auth = StaticPasswordAuthenticatorProvider::new("foo", "bar"); assert_eq!( auth.name(), Some("org.apache.cassandra.auth.PasswordAuthenticator") ); } #[test] fn test_authenticator_none_cassandra_name() { let auth = NoneAuthenticator; let provider = NoneAuthenticatorProvider; assert_eq!(provider.name(), None); assert_eq!(auth.initial_response().into_bytes().unwrap(), vec![0]); } } ================================================ FILE: cassandra-protocol/src/compression.rs ================================================ /// CDRS support traffic compression as it is described in [Apache /// Cassandra protocol]( /// https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L790) /// /// Before being used, client and server must agree on a compression algorithm to /// use, which is done in the STARTUP message. As a consequence, a STARTUP message /// must never be compressed. However, once the STARTUP envelope has been received /// by the server, messages can be compressed (including the response to the STARTUP /// request). use derive_more::Display; use snap::raw::{Decoder, Encoder}; use std::convert::{From, TryInto}; use std::error::Error; use std::fmt; use std::io; use std::result; type Result = result::Result; pub const LZ4: &str = "lz4"; pub const SNAPPY: &str = "snappy"; /// An error which may occur during encoding or decoding frame body. As there are only two types /// of compressors it contains two related enum options. #[derive(Debug)] pub enum CompressionError { /// Snappy error. Snappy(snap::Error), /// Lz4 error. Lz4(io::Error), } impl fmt::Display for CompressionError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { CompressionError::Snappy(ref err) => write!(f, "Snappy Error: {err:?}"), CompressionError::Lz4(ref err) => write!(f, "Lz4 Error: {err:?}"), } } } impl Error for CompressionError { fn source(&self) -> Option<&(dyn Error + 'static)> { match *self { CompressionError::Snappy(ref err) => Some(err), CompressionError::Lz4(ref err) => Some(err), } } } impl Clone for CompressionError { fn clone(&self) -> Self { match self { CompressionError::Snappy(error) => CompressionError::Snappy(error.clone()), CompressionError::Lz4(error) => CompressionError::Lz4(io::Error::new( error.kind(), error .get_ref() .map(|error| error.to_string()) .unwrap_or_default(), )), } } } /// Enum which represents a type of compression. Only non-startup frame's body can be compressed. #[derive(Debug, PartialEq, Clone, Copy, Eq, Ord, PartialOrd, Hash, Display)] pub enum Compression { /// [lz4](https://code.google.com/p/lz4/) compression Lz4, /// [snappy](https://code.google.com/p/snappy/) compression Snappy, /// No compression None, } impl Compression { /// It encodes `bytes` basing on type of `Compression`.. /// /// # Examples /// /// ``` /// use cassandra_protocol::compression::Compression; /// /// let snappy_compression = Compression::Snappy; /// let bytes = String::from("Hello World").into_bytes().to_vec(); /// let encoded = snappy_compression.encode(&bytes).unwrap(); /// assert_eq!(snappy_compression.decode(encoded).unwrap(), bytes); /// /// ``` pub fn encode(&self, bytes: &[u8]) -> Result> { match *self { Compression::Lz4 => Compression::encode_lz4(bytes), Compression::Snappy => Compression::encode_snappy(bytes), Compression::None => Ok(bytes.into()), } } /// Checks if current compression actually compresses data. #[inline] pub fn is_compressed(self) -> bool { self != Compression::None } /// It decodes `bytes` basing on type of compression. pub fn decode(&self, bytes: Vec) -> Result> { match *self { Compression::Lz4 => Compression::decode_lz4(bytes), Compression::Snappy => Compression::decode_snappy(bytes), Compression::None => Ok(bytes), } } /// It transforms compression method into a `&str`. pub fn as_str(&self) -> Option<&'static str> { match *self { Compression::Lz4 => Some(LZ4), Compression::Snappy => Some(SNAPPY), Compression::None => None, } } fn encode_snappy(bytes: &[u8]) -> Result> { let mut encoder = Encoder::new(); encoder .compress_vec(bytes) .map_err(CompressionError::Snappy) } fn decode_snappy(bytes: Vec) -> Result> { let mut decoder = Decoder::new(); decoder .decompress_vec(bytes.as_slice()) .map_err(CompressionError::Snappy) } fn encode_lz4(bytes: &[u8]) -> Result> { let len = 4 + lz4_flex::block::get_maximum_output_size(bytes.len()); assert!(len <= i32::MAX as usize); let mut result = vec![0; len]; let len = bytes.len() as i32; result[..4].copy_from_slice(&len.to_be_bytes()); let compressed_len = lz4_flex::compress_into(bytes, &mut result[4..]) .map_err(|error| CompressionError::Lz4(io::Error::other(error)))?; result.truncate(4 + compressed_len); Ok(result) } fn decode_lz4(bytes: Vec) -> Result> { // lz4 wire format prepends a 4-byte big-endian uncompressed length so // the decoder knows how much memory to allocate. Validate length before // slicing to avoid panics on truncated input. if bytes.len() < 4 { return Err(CompressionError::Lz4(io::Error::new( io::ErrorKind::UnexpectedEof, "lz4 payload missing 4-byte uncompressed length header", ))); } let uncompressed_size = i32::from_be_bytes( bytes[..4] .try_into() .map_err(|error| CompressionError::Lz4(io::Error::other(error)))?, ); // a negative size is impossible for a real payload; without this check // the `as usize` cast would silently turn it into ~2 GB+ and ask // lz4_flex to allocate a buffer that size before any decoding begins. if uncompressed_size < 0 { return Err(CompressionError::Lz4(io::Error::new( io::ErrorKind::InvalidData, format!("negative uncompressed size {uncompressed_size}"), ))); } lz4_flex::decompress(&bytes[4..], uncompressed_size as usize) .map_err(|error| CompressionError::Lz4(io::Error::other(error))) } } impl From for Compression { /// It converts `String` into `Compression`. If string is neither `lz4` nor `snappy` then /// `Compression::None` will be returned fn from(compression_string: String) -> Compression { Compression::from(compression_string.as_str()) } } impl Compression { /// It converts `Compression` into `String`. If compression is `None` then empty string will be /// returned pub fn to_protocol_string(self) -> String { match self { Compression::Lz4 => "LZ4".to_string(), Compression::Snappy => "SNAPPY".to_string(), Compression::None => "NONE".to_string(), } } pub fn from_protocol_string(protocol_string: &str) -> std::result::Result { match protocol_string { "lz4" | "LZ4" => Ok(Compression::Lz4), "snappy" | "SNAPPY" => Ok(Compression::Snappy), "none" | "NONE" => Ok(Compression::None), _ => Err("Unknown compression".to_string()), } } } impl<'a> From<&'a str> for Compression { /// It converts `str` into `Compression`. If string is neither `lz4` nor `snappy` then /// `Compression::None` will be returned fn from(compression_str: &'a str) -> Compression { match compression_str { LZ4 => Compression::Lz4, SNAPPY => Compression::Snappy, _ => Compression::None, } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compression_to_protocol_string() { let lz4 = Compression::Lz4; assert_eq!("LZ4", lz4.to_protocol_string()); let snappy = Compression::Snappy; assert_eq!("SNAPPY", snappy.to_protocol_string()); let none = Compression::None; assert_eq!("NONE", none.to_protocol_string()); } #[test] fn test_compression_from_protocol_str() { let lz4 = "lz4"; assert_eq!( Compression::from_protocol_string(lz4).unwrap(), Compression::Lz4 ); let lz4 = "LZ4"; assert_eq!( Compression::from_protocol_string(lz4).unwrap(), Compression::Lz4 ); let snappy = "snappy"; assert_eq!( Compression::from_protocol_string(snappy).unwrap(), Compression::Snappy ); let snappy = "SNAPPY"; assert_eq!( Compression::from_protocol_string(snappy).unwrap(), Compression::Snappy ); let none = "none"; assert_eq!( Compression::from_protocol_string(none).unwrap(), Compression::None ); let none = "NONE"; assert_eq!( Compression::from_protocol_string(none).unwrap(), Compression::None ); } #[test] fn test_compression_from_string() { let lz4 = "lz4".to_string(); assert_eq!(Compression::from(lz4), Compression::Lz4); let snappy = "snappy".to_string(); assert_eq!(Compression::from(snappy), Compression::Snappy); let none = "x".to_string(); assert_eq!(Compression::from(none), Compression::None); } #[test] fn test_compression_encode_snappy() { let snappy_compression = Compression::Snappy; let bytes = String::from("Hello World").into_bytes().to_vec(); snappy_compression .encode(&bytes) .expect("Should work without exceptions"); } #[test] fn test_compression_decode_snappy() { let snappy_compression = Compression::Snappy; let bytes = String::from("Hello World").into_bytes().to_vec(); let encoded = snappy_compression.encode(&bytes).unwrap(); assert_eq!(snappy_compression.decode(encoded).unwrap(), bytes); } #[test] fn test_compression_encode_lz4() { let snappy_compression = Compression::Lz4; let bytes = String::from("Hello World").into_bytes().to_vec(); snappy_compression .encode(&bytes) .expect("Should work without exceptions"); } #[test] fn test_compression_decode_lz4() { let lz4_compression = Compression::Lz4; let bytes = String::from("Hello World").into_bytes().to_vec(); let encoded = lz4_compression.encode(&bytes).unwrap(); assert_eq!(lz4_compression.decode(encoded).unwrap(), bytes); } #[test] fn test_compression_encode_none() { let none_compression = Compression::None; let bytes = String::from("Hello World").into_bytes().to_vec(); none_compression .encode(&bytes) .expect("Should work without exceptions"); } #[test] fn test_compression_decode_none() { let none_compression = Compression::None; let bytes = String::from("Hello World").into_bytes().to_vec(); let encoded = none_compression.encode(&bytes).unwrap(); assert_eq!(none_compression.decode(encoded).unwrap(), bytes); } #[test] fn test_compression_decode_lz4_with_invalid_input() { let lz4_compression = Compression::Lz4; let decode = lz4_compression.decode(vec![0, 0, 0, 0x7f]); assert!(decode.is_err()); } #[test] fn test_compression_decode_lz4_short_input_is_error_not_panic() { // the lz4 wire format prepends a 4-byte big-endian uncompressed size; // a payload shorter than that header must surface as an error rather // than panicking on `bytes[..4]` slicing. let lz4_compression = Compression::Lz4; assert!(lz4_compression.decode(vec![]).is_err()); assert!(lz4_compression.decode(vec![1, 2, 3]).is_err()); } #[test] fn test_compression_decode_lz4_negative_size_is_error_not_oom() { // a negative i32 uncompressed length cast through `as usize` becomes // a huge value (~2 GB+) and would otherwise hand lz4_flex an absurd // allocation request - guard against that. let lz4_compression = Compression::Lz4; // -1 in big-endian i32 followed by a dummy compressed byte let bytes = vec![0xff, 0xff, 0xff, 0xff, 0]; assert!(lz4_compression.decode(bytes).is_err()); } #[test] fn test_compression_encode_snappy_with_non_utf8() { let snappy_compression = Compression::Snappy; let v = vec![0xff, 0xff]; let encoded = snappy_compression .encode(&v) .expect("Should work without exceptions"); assert_eq!(snappy_compression.decode(encoded).unwrap(), v); } } ================================================ FILE: cassandra-protocol/src/consistency.rs ================================================ #![warn(missing_docs)] //! The module contains Rust representation of Cassandra consistency levels. use crate::error; use crate::frame::{FromBytes, FromCursor, Serialize, Version}; use crate::types::*; use derive_more::Display; use std::convert::{From, TryFrom, TryInto}; use std::default::Default; use std::io; use std::str::FromStr; /// `Consistency` is an enum which represents Cassandra's consistency levels. /// To find more details about each consistency level please refer to the following documentation: /// #[derive(Debug, PartialEq, Clone, Copy, Display, Ord, PartialOrd, Eq, Hash, Default)] #[non_exhaustive] pub enum Consistency { /// Closest replica, as determined by the snitch. /// If all replica nodes are down, write succeeds after a hinted handoff. /// Provides low latency, guarantees writes never fail. /// Note: this consistency level can only be used for writes. /// It provides the lowest consistency and the highest availability. Any, /// /// A write must be written to the commit log and memtable of at least one replica node. /// Satisfies the needs of most users because consistency requirements are not stringent. #[default] One, /// A write must be written to the commit log and memtable of at least two replica nodes. /// Similar to ONE. Two, /// A write must be written to the commit log and memtable of at least three replica nodes. /// Similar to TWO. Three, /// A write must be written to the commit log and memtable on a quorum of replica nodes. /// Provides strong consistency if you can tolerate some level of failure. Quorum, /// A write must be written to the commit log and memtable on all replica nodes in the cluster /// for that partition key. /// Provides the highest consistency and the lowest availability of any other level. All, /// Strong consistency. A write must be written to the commit log and memtable on a quorum /// of replica nodes in the same data center as thecoordinator node. /// Avoids latency of inter-data center communication. /// Used in multiple data center clusters with a rack-aware replica placement strategy, /// such as NetworkTopologyStrategy, and a properly configured snitch. /// Use to maintain consistency locally (within the single data center). /// Can be used with SimpleStrategy. LocalQuorum, /// Strong consistency. A write must be written to the commit log and memtable on a quorum of /// replica nodes in all data center. /// Used in multiple data center clusters to strictly maintain consistency at the same level /// in each data center. For example, choose this level /// if you want a read to fail when a data center is down and the QUORUM /// cannot be reached on that data center. EachQuorum, /// Achieves linearizable consistency for lightweight transactions by preventing unconditional /// updates. You cannot configure this level as a normal consistency level, /// configured at the driver level using the consistency level field. /// You configure this level using the serial consistency field /// as part of the native protocol operation. See failure scenarios. Serial, /// Same as SERIAL but confined to the data center. A write must be written conditionally /// to the commit log and memtable on a quorum of replica nodes in the same data center. /// Same as SERIAL. Used for disaster recovery. See failure scenarios. LocalSerial, /// A write must be sent to, and successfully acknowledged by, /// at least one replica node in the local data center. /// In a multiple data center clusters, a consistency level of ONE is often desirable, /// but cross-DC traffic is not. LOCAL_ONE accomplishes this. /// For security and quality reasons, you can use this consistency level /// in an offline datacenter to prevent automatic connection /// to online nodes in other data centers if an offline node goes down. LocalOne, } impl FromStr for Consistency { type Err = error::Error; fn from_str(s: &str) -> Result { let consistency = match s { "Any" => Consistency::Any, "One" => Consistency::One, "Two" => Consistency::Two, "Three" => Consistency::Three, "Quorum" => Consistency::Quorum, "All" => Consistency::All, "LocalQuorum" => Consistency::LocalQuorum, "EachQuorum" => Consistency::EachQuorum, "Serial" => Consistency::Serial, "LocalSerial" => Consistency::LocalSerial, "LocalOne" => Consistency::LocalOne, _ => { return Err(error::Error::General(format!( "Invalid consistency provided: {s}" ))) } }; Ok(consistency) } } impl Serialize for Consistency { fn serialize(&self, cursor: &mut io::Cursor<&mut Vec>, version: Version) { let value: i16 = (*self).into(); value.serialize(cursor, version) } } impl TryFrom for Consistency { type Error = error::Error; fn try_from(value: CIntShort) -> Result { match value { 0x0000 => Ok(Consistency::Any), 0x0001 => Ok(Consistency::One), 0x0002 => Ok(Consistency::Two), 0x0003 => Ok(Consistency::Three), 0x0004 => Ok(Consistency::Quorum), 0x0005 => Ok(Consistency::All), 0x0006 => Ok(Consistency::LocalQuorum), 0x0007 => Ok(Consistency::EachQuorum), 0x0008 => Ok(Consistency::Serial), 0x0009 => Ok(Consistency::LocalSerial), 0x000A => Ok(Consistency::LocalOne), _ => Err(Self::Error::UnknownConsistency(value)), } } } impl From for CIntShort { fn from(value: Consistency) -> Self { match value { Consistency::Any => 0x0000, Consistency::One => 0x0001, Consistency::Two => 0x0002, Consistency::Three => 0x0003, Consistency::Quorum => 0x0004, Consistency::All => 0x0005, Consistency::LocalQuorum => 0x0006, Consistency::EachQuorum => 0x0007, Consistency::Serial => 0x0008, Consistency::LocalSerial => 0x0009, Consistency::LocalOne => 0x000A, } } } impl FromBytes for Consistency { fn from_bytes(bytes: &[u8]) -> error::Result { try_i16_from_bytes(bytes) .map_err(Into::into) .and_then(TryInto::try_into) } } impl FromCursor for Consistency { fn from_cursor(cursor: &mut io::Cursor<&[u8]>, version: Version) -> error::Result { CIntShort::from_cursor(cursor, version).and_then(TryInto::try_into) } } impl Consistency { /// Does this consistency require local dc. #[inline] pub fn is_dc_local(self) -> bool { matches!( self, Consistency::LocalOne | Consistency::LocalQuorum | Consistency::LocalSerial ) } } #[cfg(test)] mod tests { use super::*; use crate::frame::traits::{FromBytes, FromCursor}; use std::io::Cursor; #[test] fn test_consistency_serialize() { assert_eq!(Consistency::Any.serialize_to_vec(Version::V4), &[0, 0]); assert_eq!(Consistency::One.serialize_to_vec(Version::V4), &[0, 1]); assert_eq!(Consistency::Two.serialize_to_vec(Version::V4), &[0, 2]); assert_eq!(Consistency::Three.serialize_to_vec(Version::V4), &[0, 3]); assert_eq!(Consistency::Quorum.serialize_to_vec(Version::V4), &[0, 4]); assert_eq!(Consistency::All.serialize_to_vec(Version::V4), &[0, 5]); assert_eq!( Consistency::LocalQuorum.serialize_to_vec(Version::V4), &[0, 6] ); assert_eq!( Consistency::EachQuorum.serialize_to_vec(Version::V4), &[0, 7] ); assert_eq!(Consistency::Serial.serialize_to_vec(Version::V4), &[0, 8]); assert_eq!( Consistency::LocalSerial.serialize_to_vec(Version::V4), &[0, 9] ); assert_eq!( Consistency::LocalOne.serialize_to_vec(Version::V4), &[0, 10] ); } #[test] fn test_consistency_from() { assert_eq!(Consistency::try_from(0).unwrap(), Consistency::Any); assert_eq!(Consistency::try_from(1).unwrap(), Consistency::One); assert_eq!(Consistency::try_from(2).unwrap(), Consistency::Two); assert_eq!(Consistency::try_from(3).unwrap(), Consistency::Three); assert_eq!(Consistency::try_from(4).unwrap(), Consistency::Quorum); assert_eq!(Consistency::try_from(5).unwrap(), Consistency::All); assert_eq!(Consistency::try_from(6).unwrap(), Consistency::LocalQuorum); assert_eq!(Consistency::try_from(7).unwrap(), Consistency::EachQuorum); assert_eq!(Consistency::try_from(8).unwrap(), Consistency::Serial); assert_eq!(Consistency::try_from(9).unwrap(), Consistency::LocalSerial); assert_eq!(Consistency::try_from(10).unwrap(), Consistency::LocalOne); } #[test] fn test_consistency_from_bytes() { assert_eq!(Consistency::from_bytes(&[0, 0]).unwrap(), Consistency::Any); assert_eq!(Consistency::from_bytes(&[0, 1]).unwrap(), Consistency::One); assert_eq!(Consistency::from_bytes(&[0, 2]).unwrap(), Consistency::Two); assert_eq!( Consistency::from_bytes(&[0, 3]).unwrap(), Consistency::Three ); assert_eq!( Consistency::from_bytes(&[0, 4]).unwrap(), Consistency::Quorum ); assert_eq!(Consistency::from_bytes(&[0, 5]).unwrap(), Consistency::All); assert_eq!( Consistency::from_bytes(&[0, 6]).unwrap(), Consistency::LocalQuorum ); assert_eq!( Consistency::from_bytes(&[0, 7]).unwrap(), Consistency::EachQuorum ); assert_eq!( Consistency::from_bytes(&[0, 8]).unwrap(), Consistency::Serial ); assert_eq!( Consistency::from_bytes(&[0, 9]).unwrap(), Consistency::LocalSerial ); assert_eq!( Consistency::from_bytes(&[0, 10]).unwrap(), Consistency::LocalOne ); assert!(Consistency::from_bytes(&[0, 11]).is_err()); } #[test] fn test_consistency_from_cursor() { assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 0]), Version::V4).unwrap(), Consistency::Any ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 1]), Version::V4).unwrap(), Consistency::One ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 2]), Version::V4).unwrap(), Consistency::Two ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 3]), Version::V4).unwrap(), Consistency::Three ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 4]), Version::V4).unwrap(), Consistency::Quorum ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 5]), Version::V4).unwrap(), Consistency::All ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 6]), Version::V4).unwrap(), Consistency::LocalQuorum ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 7]), Version::V4).unwrap(), Consistency::EachQuorum ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 8]), Version::V4).unwrap(), Consistency::Serial ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 9]), Version::V4).unwrap(), Consistency::LocalSerial ); assert_eq!( Consistency::from_cursor(&mut Cursor::new(&[0, 10]), Version::V4).unwrap(), Consistency::LocalOne ); } } ================================================ FILE: cassandra-protocol/src/crc.rs ================================================ use crc32fast::Hasher; const CRC24_POLY: i32 = 0x1974f0b; const CRC24_INIT: i32 = 0x875060; /// Computes crc24 value of `bytes`. pub fn crc24(bytes: &[u8]) -> i32 { bytes.iter().fold(CRC24_INIT, |mut crc, byte| { crc ^= (*byte as i32) << 16; for _ in 0..8 { crc <<= 1; if (crc & 0x1000000) != 0 { crc ^= CRC24_POLY; } } crc }) } /// Computes crc32 value of `bytes`. pub fn crc32(bytes: &[u8]) -> u32 { let mut hasher = Hasher::new(); hasher.update(&[0xfa, 0x2d, 0x55, 0xca]); // Cassandra appends a few bytes and forgets to mention it in the spec... hasher.update(bytes); hasher.finalize() } ================================================ FILE: cassandra-protocol/src/error.rs ================================================ use crate::compression::CompressionError; use crate::frame::message_error::ErrorBody; use crate::frame::Opcode; use crate::types::{CInt, CIntShort}; use std::fmt::{Debug, Display}; use std::io; use std::net::SocketAddr; use std::result; use std::str::Utf8Error; use std::string::FromUtf8Error; use thiserror::Error as ThisError; use uuid::Error as UuidError; pub type Result = result::Result; /// CDRS custom error type. CDRS expects two types of error - errors returned by Server /// and internal errors occurred within the driver itself. Occasionally `io::Error` /// is a type that represent internal error because due to implementation IO errors only /// can be raised by CDRS driver. `Server` error is an error which are ones returned by /// a Server via result error frames. #[derive(Debug, ThisError)] #[non_exhaustive] pub enum Error { /// Internal IO error. #[error("IO error: {0}")] Io(#[from] io::Error), /// Internal error that may be raised during `uuid::Uuid::from_bytes` #[error("Uuid parse error: {0}")] UuidParse(#[from] UuidError), /// General error #[error("General error: {0}")] General(String), /// Internal error that may be raised during `String::from_utf8` #[error("FromUtf8 error: {0}")] FromUtf8(#[from] FromUtf8Error), /// Internal error that may be raised during `str::from_utf8` #[error("Utf8 error: {0}")] Utf8(#[from] Utf8Error), /// Internal Compression/Decompression error. #[error("Compressor error: {0}")] Compression(#[from] CompressionError), /// Server error. #[error("Server {addr} error: {body:?}")] Server { body: ErrorBody, addr: SocketAddr }, /// Timed out waiting for an operation to complete. #[error("Timeout: {0}")] Timeout(String), /// Unknown consistency. #[error("Unknown consistency: {0}")] UnknownConsistency(CIntShort), /// Unknown server event. #[error("Unknown server event: {0}")] UnknownServerEvent(String), /// Unexpected topology change event type. #[error("Unexpected topology change type: {0}")] UnexpectedTopologyChangeType(String), /// Unexpected status change event type. #[error("Unexpected status change type: {0}")] UnexpectedStatusChangeType(String), /// Unexpected schema change event type. #[error("Unexpected schema change type: {0}")] UnexpectedSchemaChangeType(String), /// Unexpected schema change event target. #[error("Unexpected schema change target: {0}")] UnexpectedSchemaChangeTarget(String), /// Unexpected additional error info. #[error("Unexpected error code: {0}")] UnexpectedErrorCode(CInt), /// Unexpected write type. #[error("Unexpected write type: {0}")] UnexpectedWriteType(String), /// Expected a request opcode, got something else. #[error("Opcode is not a request: {0}")] NonRequestOpcode(Opcode), /// Expected a response opcode, got something else. #[error("Opcode is not a response: {0}")] NonResponseOpcode(Opcode), /// Unexpected result kind. #[error("Unexpected result kind: {0}")] UnexpectedResultKind(CInt), /// Unexpected column type. #[error("Unexpected column type: {0}")] UnexpectedColumnType(CIntShort), /// Invalid format found for given keyspace replication strategy. #[error("Invalid replication format for: {keyspace}")] InvalidReplicationFormat { keyspace: String }, /// Unexpected response to auth message. #[error("Unexpected auth response: {0}")] UnexpectedAuthResponse(Opcode), /// Unexpected startup response. #[error("Unexpected startup response: {0}")] UnexpectedStartupResponse(Opcode), /// Special error for cases when starting up a connection and protocol negotiation fails. There /// currently is no explicit server-side code for this, so the information must be inferred from /// returned error response. #[error("Invalid protocol used when communicating with a node: {0}")] InvalidProtocol(SocketAddr), } pub fn column_is_empty_err(column_name: T) -> Error { Error::General(format!("Column or Udt property '{column_name}' is empty")) } impl From for Error { fn from(err: String) -> Error { Error::General(err) } } impl From<&str> for Error { fn from(err: &str) -> Error { Error::General(err.to_string()) } } impl Clone for Error { fn clone(&self) -> Self { match self { Error::Io(error) => Error::Io(io::Error::new( error.kind(), error .get_ref() .map(|error| error.to_string()) .unwrap_or_default(), )), Error::UuidParse(error) => Error::UuidParse(error.clone()), Error::General(error) => Error::General(error.clone()), Error::FromUtf8(error) => Error::FromUtf8(error.clone()), Error::Utf8(error) => Error::Utf8(*error), Error::Compression(error) => Error::Compression(error.clone()), Error::Server { body, addr } => Error::Server { body: body.clone(), addr: *addr, }, Error::Timeout(error) => Error::Timeout(error.clone()), Error::UnknownConsistency(value) => Error::UnknownConsistency(*value), Error::UnknownServerEvent(value) => Error::UnknownServerEvent(value.clone()), Error::UnexpectedTopologyChangeType(value) => { Error::UnexpectedTopologyChangeType(value.clone()) } Error::UnexpectedStatusChangeType(value) => { Error::UnexpectedStatusChangeType(value.clone()) } Error::UnexpectedSchemaChangeType(value) => { Error::UnexpectedSchemaChangeType(value.clone()) } Error::UnexpectedSchemaChangeTarget(value) => { Error::UnexpectedSchemaChangeTarget(value.clone()) } Error::UnexpectedErrorCode(value) => Error::UnexpectedErrorCode(*value), Error::UnexpectedWriteType(value) => Error::UnexpectedWriteType(value.clone()), Error::NonRequestOpcode(value) => Error::NonRequestOpcode(*value), Error::NonResponseOpcode(value) => Error::NonResponseOpcode(*value), Error::UnexpectedResultKind(value) => Error::UnexpectedResultKind(*value), Error::UnexpectedColumnType(value) => Error::UnexpectedColumnType(*value), Error::InvalidReplicationFormat { keyspace } => Error::InvalidReplicationFormat { keyspace: keyspace.clone(), }, Error::UnexpectedAuthResponse(value) => Error::UnexpectedAuthResponse(*value), Error::UnexpectedStartupResponse(value) => Error::UnexpectedStartupResponse(*value), Error::InvalidProtocol(addr) => Error::InvalidProtocol(*addr), } } } ================================================ FILE: cassandra-protocol/src/events.rs ================================================ use crate::frame::events::{ SchemaChange as MessageSchemaChange, ServerEvent as MessageServerEvent, SimpleServerEvent as MessageSimpleServerEvent, }; /// Full Server Event which includes all details about occurred change. pub type ServerEvent = MessageServerEvent; /// Simplified Server event. It should be used to represent an event /// which consumer wants listen to. pub type SimpleServerEvent = MessageSimpleServerEvent; /// Reexport of `MessageSchemaChange`. pub type SchemaChange = MessageSchemaChange; ================================================ FILE: cassandra-protocol/src/frame/events.rs ================================================ use crate::frame::traits::FromCursor; use crate::frame::{Serialize, Version}; use crate::types::{from_cursor_str, from_cursor_string_list, serialize_str, CIntShort}; use crate::{error, Error}; use derive_more::Display; use std::cmp::PartialEq; use std::convert::TryFrom; use std::io::Cursor; use std::net::SocketAddr; // Event types const TOPOLOGY_CHANGE: &str = "TOPOLOGY_CHANGE"; const STATUS_CHANGE: &str = "STATUS_CHANGE"; const SCHEMA_CHANGE: &str = "SCHEMA_CHANGE"; // Topology changes const NEW_NODE: &str = "NEW_NODE"; const REMOVED_NODE: &str = "REMOVED_NODE"; // Status changes const UP: &str = "UP"; const DOWN: &str = "DOWN"; // Schema changes const CREATED: &str = "CREATED"; const UPDATED: &str = "UPDATED"; const DROPPED: &str = "DROPPED"; // Schema change targets const KEYSPACE: &str = "KEYSPACE"; const TABLE: &str = "TABLE"; const TYPE: &str = "TYPE"; const FUNCTION: &str = "FUNCTION"; const AGGREGATE: &str = "AGGREGATE"; /// Simplified `ServerEvent` that does not contain details /// about a concrete change. It may be useful for subscription /// when you need only string representation of an event. #[derive(Debug, PartialEq, Copy, Clone, Ord, PartialOrd, Eq, Hash)] #[non_exhaustive] pub enum SimpleServerEvent { TopologyChange, StatusChange, SchemaChange, } impl SimpleServerEvent { pub fn as_str(&self) -> &'static str { match *self { SimpleServerEvent::TopologyChange => TOPOLOGY_CHANGE, SimpleServerEvent::StatusChange => STATUS_CHANGE, SimpleServerEvent::SchemaChange => SCHEMA_CHANGE, } } } impl From for SimpleServerEvent { fn from(event: ServerEvent) -> SimpleServerEvent { match event { ServerEvent::TopologyChange(_) => SimpleServerEvent::TopologyChange, ServerEvent::StatusChange(_) => SimpleServerEvent::StatusChange, ServerEvent::SchemaChange(_) => SimpleServerEvent::SchemaChange, } } } impl<'a> From<&'a ServerEvent> for SimpleServerEvent { fn from(event: &'a ServerEvent) -> SimpleServerEvent { match *event { ServerEvent::TopologyChange(_) => SimpleServerEvent::TopologyChange, ServerEvent::StatusChange(_) => SimpleServerEvent::StatusChange, ServerEvent::SchemaChange(_) => SimpleServerEvent::SchemaChange, } } } impl TryFrom<&str> for SimpleServerEvent { type Error = error::Error; fn try_from(value: &str) -> Result { match value { TOPOLOGY_CHANGE => Ok(SimpleServerEvent::TopologyChange), STATUS_CHANGE => Ok(SimpleServerEvent::StatusChange), SCHEMA_CHANGE => Ok(SimpleServerEvent::SchemaChange), value => Err(Error::UnknownServerEvent(value.into())), } } } impl PartialEq for SimpleServerEvent { fn eq(&self, full_event: &ServerEvent) -> bool { self == &SimpleServerEvent::from(full_event) } } /// Full server event that contains all details about a concrete change. #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] #[non_exhaustive] pub enum ServerEvent { /// Events related to change in the cluster topology TopologyChange(TopologyChange), /// Events related to change of node status. StatusChange(StatusChange), /// Events related to schema change. SchemaChange(SchemaChange), } impl Serialize for ServerEvent { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match &self { ServerEvent::TopologyChange(t) => { serialize_str(cursor, TOPOLOGY_CHANGE, version); t.serialize(cursor, version); } ServerEvent::StatusChange(s) => { serialize_str(cursor, STATUS_CHANGE, version); s.serialize(cursor, version); } ServerEvent::SchemaChange(s) => { serialize_str(cursor, SCHEMA_CHANGE, version); s.serialize(cursor, version); } } } } impl PartialEq for ServerEvent { fn eq(&self, event: &SimpleServerEvent) -> bool { &SimpleServerEvent::from(self) == event } } impl FromCursor for ServerEvent { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let event_type = from_cursor_str(cursor)?; match event_type { TOPOLOGY_CHANGE => Ok(ServerEvent::TopologyChange(TopologyChange::from_cursor( cursor, version, )?)), STATUS_CHANGE => Ok(ServerEvent::StatusChange(StatusChange::from_cursor( cursor, version, )?)), SCHEMA_CHANGE => Ok(ServerEvent::SchemaChange(SchemaChange::from_cursor( cursor, version, )?)), _ => Err(Error::UnknownServerEvent(event_type.into())), } } } /// Events related to change in the cluster topology #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct TopologyChange { pub change_type: TopologyChangeType, pub addr: SocketAddr, } impl Serialize for TopologyChange { //noinspection DuplicatedCode fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.change_type.serialize(cursor, version); self.addr.serialize(cursor, version); } } impl FromCursor for TopologyChange { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let change_type = TopologyChangeType::from_cursor(cursor, version)?; let addr = SocketAddr::from_cursor(cursor, version)?; Ok(TopologyChange { change_type, addr }) } } #[derive(Debug, Copy, Clone, PartialEq, Ord, PartialOrd, Eq, Hash, Display)] #[non_exhaustive] pub enum TopologyChangeType { NewNode, RemovedNode, } impl Serialize for TopologyChangeType { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match &self { TopologyChangeType::NewNode => serialize_str(cursor, NEW_NODE, version), TopologyChangeType::RemovedNode => serialize_str(cursor, REMOVED_NODE, version), } } } impl FromCursor for TopologyChangeType { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { from_cursor_str(cursor).and_then(|tc| match tc { NEW_NODE => Ok(TopologyChangeType::NewNode), REMOVED_NODE => Ok(TopologyChangeType::RemovedNode), _ => Err(Error::UnexpectedTopologyChangeType(tc.into())), }) } } /// Events related to change of node status. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct StatusChange { pub change_type: StatusChangeType, pub addr: SocketAddr, } impl Serialize for StatusChange { //noinspection DuplicatedCode fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.change_type.serialize(cursor, version); self.addr.serialize(cursor, version); } } impl FromCursor for StatusChange { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let change_type = StatusChangeType::from_cursor(cursor, version)?; let addr = SocketAddr::from_cursor(cursor, version)?; Ok(StatusChange { change_type, addr }) } } #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display)] #[non_exhaustive] pub enum StatusChangeType { Up, Down, } impl Serialize for StatusChangeType { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { StatusChangeType::Up => serialize_str(cursor, UP, version), StatusChangeType::Down => serialize_str(cursor, DOWN, version), } } } impl FromCursor for StatusChangeType { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { from_cursor_str(cursor).and_then(|sct| match sct { UP => Ok(StatusChangeType::Up), DOWN => Ok(StatusChangeType::Down), _ => Err(Error::UnexpectedStatusChangeType(sct.into())), }) } } /// Events related to schema change. #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct SchemaChange { pub change_type: SchemaChangeType, pub target: SchemaChangeTarget, pub options: SchemaChangeOptions, } impl Serialize for SchemaChange { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.change_type.serialize(cursor, version); self.target.serialize(cursor, version); self.options.serialize(cursor, version); } } impl FromCursor for SchemaChange { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let change_type = SchemaChangeType::from_cursor(cursor, version)?; let target = SchemaChangeTarget::from_cursor(cursor, version)?; let options = SchemaChangeOptions::from_cursor_and_target(cursor, &target)?; Ok(SchemaChange { change_type, target, options, }) } } /// Represents type of changes. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display)] #[non_exhaustive] pub enum SchemaChangeType { Created, Updated, Dropped, } impl Serialize for SchemaChangeType { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { SchemaChangeType::Created => serialize_str(cursor, CREATED, version), SchemaChangeType::Updated => serialize_str(cursor, UPDATED, version), SchemaChangeType::Dropped => serialize_str(cursor, DROPPED, version), } } } impl FromCursor for SchemaChangeType { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { from_cursor_str(cursor).and_then(|ct| match ct { CREATED => Ok(SchemaChangeType::Created), UPDATED => Ok(SchemaChangeType::Updated), DROPPED => Ok(SchemaChangeType::Dropped), _ => Err(Error::UnexpectedSchemaChangeType(ct.into())), }) } } /// Refers to a target of changes were made. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display)] #[non_exhaustive] pub enum SchemaChangeTarget { Keyspace, Table, Type, Function, Aggregate, } impl Serialize for SchemaChangeTarget { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { SchemaChangeTarget::Keyspace => serialize_str(cursor, KEYSPACE, version), SchemaChangeTarget::Table => serialize_str(cursor, TABLE, version), SchemaChangeTarget::Type => serialize_str(cursor, TYPE, version), SchemaChangeTarget::Function => serialize_str(cursor, FUNCTION, version), SchemaChangeTarget::Aggregate => serialize_str(cursor, AGGREGATE, version), } } } impl FromCursor for SchemaChangeTarget { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { from_cursor_str(cursor).and_then(|t| match t { KEYSPACE => Ok(SchemaChangeTarget::Keyspace), TABLE => Ok(SchemaChangeTarget::Table), TYPE => Ok(SchemaChangeTarget::Type), FUNCTION => Ok(SchemaChangeTarget::Function), AGGREGATE => Ok(SchemaChangeTarget::Aggregate), _ => Err(Error::UnexpectedSchemaChangeTarget(t.into())), }) } } /// Information about changes made. #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] #[non_exhaustive] pub enum SchemaChangeOptions { /// Changes related to keyspaces. Contains keyspace name. Keyspace(String), /// Changes related to tables. Contains keyspace and table names. TableType(String, String), /// Changes related to functions and aggregations. Contains: /// * keyspace containing the user defined function/aggregate /// * the function/aggregate name /// * list of strings, one string for each argument type (as CQL type) FunctionAggregate(String, String, Vec), } impl Serialize for SchemaChangeOptions { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { SchemaChangeOptions::Keyspace(ks) => { serialize_str(cursor, ks, version); } SchemaChangeOptions::TableType(ks, t) => { serialize_str(cursor, ks, version); serialize_str(cursor, t, version); } SchemaChangeOptions::FunctionAggregate(ks, fa_name, list) => { serialize_str(cursor, ks, version); serialize_str(cursor, fa_name, version); let len = list.len() as CIntShort; len.serialize(cursor, version); list.iter().for_each(|x| serialize_str(cursor, x, version)); } } } } impl SchemaChangeOptions { fn from_cursor_and_target( cursor: &mut Cursor<&[u8]>, target: &SchemaChangeTarget, ) -> error::Result { Ok(match *target { SchemaChangeTarget::Keyspace => SchemaChangeOptions::from_cursor_keyspace(cursor)?, SchemaChangeTarget::Table | SchemaChangeTarget::Type => { SchemaChangeOptions::from_cursor_table_type(cursor)? } SchemaChangeTarget::Function | SchemaChangeTarget::Aggregate => { SchemaChangeOptions::from_cursor_function_aggregate(cursor)? } }) } fn from_cursor_keyspace(cursor: &mut Cursor<&[u8]>) -> error::Result { Ok(SchemaChangeOptions::Keyspace( from_cursor_str(cursor)?.to_string(), )) } fn from_cursor_table_type(cursor: &mut Cursor<&[u8]>) -> error::Result { let keyspace = from_cursor_str(cursor)?.to_string(); let name = from_cursor_str(cursor)?.to_string(); Ok(SchemaChangeOptions::TableType(keyspace, name)) } fn from_cursor_function_aggregate( cursor: &mut Cursor<&[u8]>, ) -> error::Result { let keyspace = from_cursor_str(cursor)?.to_string(); let name = from_cursor_str(cursor)?.to_string(); let types = from_cursor_string_list(cursor)?; Ok(SchemaChangeOptions::FunctionAggregate( keyspace, name, types, )) } } #[cfg(test)] fn test_encode_decode(bytes: &[u8], expected: ServerEvent) { let mut ks: Cursor<&[u8]> = Cursor::new(bytes); let event = ServerEvent::from_cursor(&mut ks, Version::V4).unwrap(); assert_eq!(expected, event); let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } #[cfg(test)] mod topology_change_type_test { use super::*; use crate::frame::traits::FromCursor; use std::io::Cursor; #[test] fn from_cursor() { let a = &[0, 8, 78, 69, 87, 95, 78, 79, 68, 69]; let mut new_node: Cursor<&[u8]> = Cursor::new(a); assert_eq!( TopologyChangeType::from_cursor(&mut new_node, Version::V4).unwrap(), TopologyChangeType::NewNode ); let b = &[0, 12, 82, 69, 77, 79, 86, 69, 68, 95, 78, 79, 68, 69]; let mut removed_node: Cursor<&[u8]> = Cursor::new(b); assert_eq!( TopologyChangeType::from_cursor(&mut removed_node, Version::V4).unwrap(), TopologyChangeType::RemovedNode ); } #[test] fn serialize() { { let a = &[0, 8, 78, 69, 87, 95, 78, 79, 68, 69]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let new_node = TopologyChangeType::NewNode; new_node.serialize(&mut cursor, Version::V4); assert_eq!(buffer, a); } { let b = &[0, 12, 82, 69, 77, 79, 86, 69, 68, 95, 78, 79, 68, 69]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let removed_node = TopologyChangeType::RemovedNode; removed_node.serialize(&mut cursor, Version::V4); assert_eq!(buffer, b); } } #[test] #[should_panic] fn from_cursor_wrong() { let a = &[0, 1, 78]; let mut wrong: Cursor<&[u8]> = Cursor::new(a); let _ = TopologyChangeType::from_cursor(&mut wrong, Version::V4).unwrap(); } } #[cfg(test)] mod status_change_type_test { use super::*; use crate::frame::traits::FromCursor; use std::io::Cursor; #[test] fn from_cursor() { let a = &[0, 2, 85, 80]; let mut up: Cursor<&[u8]> = Cursor::new(a); assert_eq!( StatusChangeType::from_cursor(&mut up, Version::V4).unwrap(), StatusChangeType::Up ); let b = &[0, 4, 68, 79, 87, 78]; let mut down: Cursor<&[u8]> = Cursor::new(b); assert_eq!( StatusChangeType::from_cursor(&mut down, Version::V4).unwrap(), StatusChangeType::Down ); } #[test] fn serialize() { { let a = &[0, 2, 85, 80]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let up = StatusChangeType::Up; up.serialize(&mut cursor, Version::V4); assert_eq!(buffer, a); } { let b = &[0, 4, 68, 79, 87, 78]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let down = StatusChangeType::Down; down.serialize(&mut cursor, Version::V4); assert_eq!(buffer, b); } } #[test] fn from_cursor_wrong() { let a = &[0, 1, 78]; let mut wrong: Cursor<&[u8]> = Cursor::new(a); let err = StatusChangeType::from_cursor(&mut wrong, Version::V4).unwrap_err(); assert!(matches!(err, Error::UnexpectedStatusChangeType(_))); } } #[cfg(test)] mod schema_change_type_test { use super::*; use crate::frame::traits::FromCursor; use std::io::Cursor; #[test] fn from_cursor() { let a = &[0, 7, 67, 82, 69, 65, 84, 69, 68]; let mut created: Cursor<&[u8]> = Cursor::new(a); assert_eq!( SchemaChangeType::from_cursor(&mut created, Version::V4).unwrap(), SchemaChangeType::Created ); let b = &[0, 7, 85, 80, 68, 65, 84, 69, 68]; let mut updated: Cursor<&[u8]> = Cursor::new(b); assert_eq!( SchemaChangeType::from_cursor(&mut updated, Version::V4).unwrap(), SchemaChangeType::Updated ); let c = &[0, 7, 68, 82, 79, 80, 80, 69, 68]; let mut dropped: Cursor<&[u8]> = Cursor::new(c); assert_eq!( SchemaChangeType::from_cursor(&mut dropped, Version::V4).unwrap(), SchemaChangeType::Dropped ); } #[test] fn serialize() { { let a = &[0, 7, 67, 82, 69, 65, 84, 69, 68]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let created = SchemaChangeType::Created; created.serialize(&mut cursor, Version::V4); assert_eq!(buffer, a); } { let b = &[0, 7, 85, 80, 68, 65, 84, 69, 68]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let updated = SchemaChangeType::Updated; updated.serialize(&mut cursor, Version::V4); assert_eq!(buffer, b); } { let c = &[0, 7, 68, 82, 79, 80, 80, 69, 68]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let dropped = SchemaChangeType::Dropped; dropped.serialize(&mut cursor, Version::V4); assert_eq!(buffer, c); } } #[test] #[should_panic] fn from_cursor_wrong() { let a = &[0, 1, 78]; let mut wrong: Cursor<&[u8]> = Cursor::new(a); let _ = SchemaChangeType::from_cursor(&mut wrong, Version::V4).unwrap(); } } #[cfg(test)] mod schema_change_target_test { use super::*; use crate::frame::traits::FromCursor; use std::io::Cursor; #[test] #[allow(clippy::many_single_char_names)] fn schema_change_target() { { let bytes = &[0, 8, 75, 69, 89, 83, 80, 65, 67, 69]; let mut keyspace: Cursor<&[u8]> = Cursor::new(bytes); assert_eq!( SchemaChangeTarget::from_cursor(&mut keyspace, Version::V4).unwrap(), SchemaChangeTarget::Keyspace ); } let b = &[0, 5, 84, 65, 66, 76, 69]; let mut table: Cursor<&[u8]> = Cursor::new(b); assert_eq!( SchemaChangeTarget::from_cursor(&mut table, Version::V4).unwrap(), SchemaChangeTarget::Table ); let c = &[0, 4, 84, 89, 80, 69]; let mut _type: Cursor<&[u8]> = Cursor::new(c); assert_eq!( SchemaChangeTarget::from_cursor(&mut _type, Version::V4).unwrap(), SchemaChangeTarget::Type ); let d = &[0, 8, 70, 85, 78, 67, 84, 73, 79, 78]; let mut function: Cursor<&[u8]> = Cursor::new(d); assert_eq!( SchemaChangeTarget::from_cursor(&mut function, Version::V4).unwrap(), SchemaChangeTarget::Function ); let e = &[0, 9, 65, 71, 71, 82, 69, 71, 65, 84, 69]; let mut aggregate: Cursor<&[u8]> = Cursor::new(e); assert_eq!( SchemaChangeTarget::from_cursor(&mut aggregate, Version::V4).unwrap(), SchemaChangeTarget::Aggregate ); } #[test] fn serialize() { { let a = &[0, 8, 75, 69, 89, 83, 80, 65, 67, 69]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let keyspace = SchemaChangeTarget::Keyspace; keyspace.serialize(&mut cursor, Version::V4); assert_eq!(buffer, a); } { let b = &[0, 5, 84, 65, 66, 76, 69]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let table = SchemaChangeTarget::Table; table.serialize(&mut cursor, Version::V4); assert_eq!(buffer, b); } { let c = &[0, 4, 84, 89, 80, 69]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let target_type = SchemaChangeTarget::Type; target_type.serialize(&mut cursor, Version::V4); assert_eq!(buffer, c); } { let d = &[0, 8, 70, 85, 78, 67, 84, 73, 79, 78]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let function = SchemaChangeTarget::Function; function.serialize(&mut cursor, Version::V4); assert_eq!(buffer, d); } { let e = &[0, 9, 65, 71, 71, 82, 69, 71, 65, 84, 69]; let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); let aggregate = SchemaChangeTarget::Aggregate; aggregate.serialize(&mut cursor, Version::V4); assert_eq!(buffer, e); } } #[test] #[should_panic] fn from_cursor_wrong() { let a = &[0, 1, 78]; let mut wrong: Cursor<&[u8]> = Cursor::new(a); let _ = SchemaChangeTarget::from_cursor(&mut wrong, Version::V4).unwrap(); } } #[cfg(test)] mod server_event { use super::*; #[test] fn topology_change_new_node() { let bytes = &[ // topology change 0, 15, 84, 79, 80, 79, 76, 79, 71, 89, 95, 67, 72, 65, 78, 71, 69, // new node 0, 8, 78, 69, 87, 95, 78, 79, 68, 69, // 4, 127, 0, 0, 1, 0, 0, 0, 1, // 127.0.0.1:1 ]; let expected = ServerEvent::TopologyChange(TopologyChange { change_type: TopologyChangeType::NewNode, addr: "127.0.0.1:1".parse().unwrap(), }); test_encode_decode(bytes, expected); } #[test] fn topology_change_removed_node() { let bytes = &[ // topology change 0, 15, 84, 79, 80, 79, 76, 79, 71, 89, 95, 67, 72, 65, 78, 71, 69, // removed node 0, 12, 82, 69, 77, 79, 86, 69, 68, 95, 78, 79, 68, 69, // 4, 127, 0, 0, 1, 0, 0, 0, 1, // 127.0.0.1:1 ]; let expected = ServerEvent::TopologyChange(TopologyChange { change_type: TopologyChangeType::RemovedNode, addr: "127.0.0.1:1".parse().unwrap(), }); test_encode_decode(bytes, expected); } #[test] fn status_change_up() { let bytes = &[ // status change 0, 13, 83, 84, 65, 84, 85, 83, 95, 67, 72, 65, 78, 71, 69, // up 0, 2, 85, 80, // 4, 127, 0, 0, 1, 0, 0, 0, 1, // 127.0.0.1:1 ]; let expected = ServerEvent::StatusChange(StatusChange { change_type: StatusChangeType::Up, addr: "127.0.0.1:1".parse().unwrap(), }); test_encode_decode(bytes, expected); } #[test] fn status_change_down() { let bytes = &[ // status change 0, 13, 83, 84, 65, 84, 85, 83, 95, 67, 72, 65, 78, 71, 69, // down 0, 4, 68, 79, 87, 78, // 4, 127, 0, 0, 1, 0, 0, 0, 1, // 127.0.0.1:1 ]; let expected = ServerEvent::StatusChange(StatusChange { change_type: StatusChangeType::Down, addr: "127.0.0.1:1".parse().unwrap(), }); test_encode_decode(bytes, expected); } #[test] fn schema_change_created() { // keyspace { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // created 0, 7, 67, 82, 69, 65, 84, 69, 68, // keyspace 0, 8, 75, 69, 89, 83, 80, 65, 67, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Created, target: SchemaChangeTarget::Keyspace, options: SchemaChangeOptions::Keyspace("my_ks".to_string()), }); test_encode_decode(bytes, expected); } // table { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // created 0, 7, 67, 82, 69, 65, 84, 69, 68, // table 0, 5, 84, 65, 66, 76, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // my_table 0, 8, 109, 121, 95, 116, 97, 98, 108, 101, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Created, target: SchemaChangeTarget::Table, options: SchemaChangeOptions::TableType( "my_ks".to_string(), "my_table".to_string(), ), }); test_encode_decode(bytes, expected); } // type { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // created 0, 7, 67, 82, 69, 65, 84, 69, 68, // type 0, 4, 84, 89, 80, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // my_table 0, 8, 109, 121, 95, 116, 97, 98, 108, 101, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Created, target: SchemaChangeTarget::Type, options: SchemaChangeOptions::TableType( "my_ks".to_string(), "my_table".to_string(), ), }); test_encode_decode(bytes, expected); } { // function let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // created 0, 7, 67, 82, 69, 65, 84, 69, 68, // function 0, 8, 70, 85, 78, 67, 84, 73, 79, 78, // my_ks 0, 5, 109, 121, 95, 107, 115, // name 0, 4, 110, 97, 109, 101, // empty list of parameters 0, 0, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Created, target: SchemaChangeTarget::Function, options: SchemaChangeOptions::FunctionAggregate( "my_ks".to_string(), "name".to_string(), Vec::new(), ), }); test_encode_decode(bytes, expected); } { // aggregate let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // created 0, 7, 67, 82, 69, 65, 84, 69, 68, // aggregate 0, 9, 65, 71, 71, 82, 69, 71, 65, 84, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // name 0, 4, 110, 97, 109, 101, // empty list of parameters 0, 0, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Created, target: SchemaChangeTarget::Aggregate, options: SchemaChangeOptions::FunctionAggregate( "my_ks".to_string(), "name".to_string(), Vec::new(), ), }); test_encode_decode(bytes, expected); } } #[test] fn schema_change_updated() { // keyspace { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // updated 0, 7, 85, 80, 68, 65, 84, 69, 68, // keyspace 0, 8, 75, 69, 89, 83, 80, 65, 67, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Updated, target: SchemaChangeTarget::Keyspace, options: SchemaChangeOptions::Keyspace("my_ks".to_string()), }); test_encode_decode(bytes, expected); } // table { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // updated 0, 7, 85, 80, 68, 65, 84, 69, 68, // table 0, 5, 84, 65, 66, 76, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // my_table 0, 8, 109, 121, 95, 116, 97, 98, 108, 101, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Updated, target: SchemaChangeTarget::Table, options: SchemaChangeOptions::TableType( "my_ks".to_string(), "my_table".to_string(), ), }); test_encode_decode(bytes, expected); } // type { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // updated 0, 7, 85, 80, 68, 65, 84, 69, 68, // type 0, 4, 84, 89, 80, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // my_table 0, 8, 109, 121, 95, 116, 97, 98, 108, 101, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Updated, target: SchemaChangeTarget::Type, options: SchemaChangeOptions::TableType( "my_ks".to_string(), "my_table".to_string(), ), }); test_encode_decode(bytes, expected); } // function { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // updated 0, 7, 85, 80, 68, 65, 84, 69, 68, // function 0, 8, 70, 85, 78, 67, 84, 73, 79, 78, // my_ks 0, 5, 109, 121, 95, 107, 115, // name 0, 4, 110, 97, 109, 101, // empty list of parameters 0, 0, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Updated, target: SchemaChangeTarget::Function, options: SchemaChangeOptions::FunctionAggregate( "my_ks".to_string(), "name".to_string(), Vec::new(), ), }); test_encode_decode(bytes, expected); } // aggreate { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // updated 0, 7, 85, 80, 68, 65, 84, 69, 68, // aggregate 0, 9, 65, 71, 71, 82, 69, 71, 65, 84, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // name 0, 4, 110, 97, 109, 101, // empty list of parameters 0, 0, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Updated, target: SchemaChangeTarget::Aggregate, options: SchemaChangeOptions::FunctionAggregate( "my_ks".to_string(), "name".to_string(), Vec::new(), ), }); test_encode_decode(bytes, expected); } } #[test] fn schema_change_dropped() { // keyspace { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // dropped 0, 7, 68, 82, 79, 80, 80, 69, 68, // keyspace 0, 8, 75, 69, 89, 83, 80, 65, 67, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Dropped, target: SchemaChangeTarget::Keyspace, options: SchemaChangeOptions::Keyspace("my_ks".to_string()), }); test_encode_decode(bytes, expected); } // table { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // dropped 0, 7, 68, 82, 79, 80, 80, 69, 68, // table 0, 5, 84, 65, 66, 76, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // my_table 0, 8, 109, 121, 95, 116, 97, 98, 108, 101, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Dropped, target: SchemaChangeTarget::Table, options: SchemaChangeOptions::TableType( "my_ks".to_string(), "my_table".to_string(), ), }); test_encode_decode(bytes, expected); } // type { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // dropped 0, 7, 68, 82, 79, 80, 80, 69, 68, // type 0, 4, 84, 89, 80, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // my_table 0, 8, 109, 121, 95, 116, 97, 98, 108, 101, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Dropped, target: SchemaChangeTarget::Type, options: SchemaChangeOptions::TableType( "my_ks".to_string(), "my_table".to_string(), ), }); test_encode_decode(bytes, expected); } // function { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // dropped 0, 7, 68, 82, 79, 80, 80, 69, 68, // function 0, 8, 70, 85, 78, 67, 84, 73, 79, 78, // my_ks 0, 5, 109, 121, 95, 107, 115, // name 0, 4, 110, 97, 109, 101, // empty list of parameters 0, 0, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Dropped, target: SchemaChangeTarget::Function, options: SchemaChangeOptions::FunctionAggregate( "my_ks".to_string(), "name".to_string(), Vec::new(), ), }); test_encode_decode(bytes, expected); } // function { let bytes = &[ // schema change 0, 13, 83, 67, 72, 69, 77, 65, 95, 67, 72, 65, 78, 71, 69, // dropped 0, 7, 68, 82, 79, 80, 80, 69, 68, // aggregate 0, 9, 65, 71, 71, 82, 69, 71, 65, 84, 69, // my_ks 0, 5, 109, 121, 95, 107, 115, // name 0, 4, 110, 97, 109, 101, // empty list of parameters 0, 0, ]; let expected = ServerEvent::SchemaChange(SchemaChange { change_type: SchemaChangeType::Dropped, target: SchemaChangeTarget::Aggregate, options: SchemaChangeOptions::FunctionAggregate( "my_ks".to_string(), "name".to_string(), Vec::new(), ), }); test_encode_decode(bytes, expected); } } } ================================================ FILE: cassandra-protocol/src/frame/frame_decoder.rs ================================================ use crate::compression::{Compression, CompressionError}; use crate::crc::{crc24, crc32}; use crate::error::{Error, Result}; use crate::frame::{ Envelope, ParseEnvelopeError, COMPRESSED_FRAME_HEADER_LENGTH, ENVELOPE_HEADER_LEN, FRAME_TRAILER_LENGTH, MAX_FRAME_SIZE, PAYLOAD_SIZE_LIMIT, UNCOMPRESSED_FRAME_HEADER_LENGTH, }; use lz4_flex::decompress; use std::convert::TryInto; use std::io; #[inline] fn create_unexpected_self_contained_error() -> Error { "Found self-contained frame while waiting for non self-contained continuation!".into() } #[inline] fn create_header_crc_mismatch_error(computed_crc: i32, header_crc24: i32) -> Error { format!("Header CRC mismatch - expected {header_crc24}, found {computed_crc}.",).into() } #[inline] fn create_payload_crc_mismatch_error(computed_crc: u32, payload_crc32: u32) -> Error { format!("Payload CRC mismatch - read {payload_crc32}, computed {computed_crc}.",).into() } fn extract_envelopes(buffer: &[u8], compression: Compression) -> Result<(usize, Vec)> { let mut current_pos = 0; let mut envelopes = vec![]; loop { match Envelope::from_buffer(&buffer[current_pos..], compression) { Ok(envelope) => { envelopes.push(envelope.envelope); current_pos += envelope.envelope_len; } Err(ParseEnvelopeError::NotEnoughBytes) => break, Err(error) => return Err(error.to_string().into()), } } Ok((current_pos, envelopes)) } fn try_decode_envelopes_with_spare_data( buffer: &mut Vec, compression: Compression, ) -> Result<(Vec, Vec)> { let (current_pos, envelopes) = extract_envelopes(buffer.as_slice(), compression)?; Ok((envelopes, buffer.split_off(current_pos))) } fn try_decode_envelopes_without_spare_data(buffer: &[u8]) -> Result> { let (_, envelopes) = extract_envelopes(buffer, Compression::None)?; Ok(envelopes) } /// A decoder for frames. Since protocol v5, frames became "envelopes" and a frame now can contain /// multiple complete envelopes (self-contained frame) or a part of one bigger envelope. pub trait FrameDecoder { /// Consumes some data and returns decoded envelopes. Decoders can be stateful, so data can be /// buffered until envelopes can be parsed. /// The buffer passed in should be cleared of consumed data by the decoder. fn consume(&mut self, data: &mut Vec, compression: Compression) -> Result>; } /// Pre-V5 frame decoder which simply decodes one envelope directly into a buffer. #[derive(Clone, Debug)] pub struct LegacyFrameDecoder { buffer: Vec, } impl Default for LegacyFrameDecoder { fn default() -> Self { Self { buffer: Vec::with_capacity(MAX_FRAME_SIZE), } } } impl FrameDecoder for LegacyFrameDecoder { fn consume(&mut self, data: &mut Vec, compression: Compression) -> Result> { if self.buffer.is_empty() { // optimistic case let (envelopes, buffer) = try_decode_envelopes_with_spare_data(data, compression)?; self.buffer = buffer; data.clear(); return Ok(envelopes); } self.buffer.append(data); let (envelopes, buffer) = try_decode_envelopes_with_spare_data(&mut self.buffer, compression)?; self.buffer = buffer; Ok(envelopes) } } /// Post-V5 Lz4 decoder with support for envelope frames with CRC checksum. #[derive(Clone, Debug, Default)] pub struct Lz4FrameDecoder { inner_decoder: GenericFrameDecoder, } impl FrameDecoder for Lz4FrameDecoder { //noinspection DuplicatedCode #[inline] fn consume(&mut self, data: &mut Vec, _compression: Compression) -> Result> { self.inner_decoder.consume(data, Self::try_decode_frame) } } impl Lz4FrameDecoder { fn try_decode_frame(buffer: &mut Vec) -> Result)>> { let buffer_len = buffer.len(); if buffer_len < COMPRESSED_FRAME_HEADER_LENGTH { return Ok(None); } let header = i64::from_le_bytes(buffer[..COMPRESSED_FRAME_HEADER_LENGTH].try_into().unwrap()); let header_crc24 = ((header >> 40) & 0xffffff) as i32; let computed_crc = crc24(&header.to_le_bytes()[..5]); if header_crc24 != computed_crc { return Err(create_header_crc_mismatch_error(computed_crc, header_crc24)); } let compressed_length = (header & 0x1ffff) as usize; let compressed_payload_end = compressed_length + COMPRESSED_FRAME_HEADER_LENGTH; let frame_end = compressed_payload_end + FRAME_TRAILER_LENGTH; if buffer_len < frame_end { return Ok(None); } let compressed_payload_crc32 = u32::from_le_bytes( buffer[compressed_payload_end..frame_end] .try_into() .unwrap(), ); let computed_crc = crc32(&buffer[COMPRESSED_FRAME_HEADER_LENGTH..compressed_payload_end]); if compressed_payload_crc32 != computed_crc { return Err(create_payload_crc_mismatch_error( computed_crc, compressed_payload_crc32, )); } let self_contained = (header & (1 << 34)) != 0; let uncompressed_length = ((header >> 17) & 0x1ffff) as usize; if uncompressed_length == 0 { // protocol spec 2.2: // An uncompressed length of 0 signals that the compressed payload should be used as-is // and not decompressed. let payload = buffer[COMPRESSED_FRAME_HEADER_LENGTH..compressed_payload_end].into(); *buffer = buffer.split_off(frame_end); return Ok(Some((self_contained, payload))); } decompress( &buffer[COMPRESSED_FRAME_HEADER_LENGTH..compressed_payload_end], uncompressed_length, ) .map_err(|error| CompressionError::Lz4(io::Error::other(error)).into()) .map(|payload| { *buffer = buffer.split_off(frame_end); Some((self_contained, payload)) }) } } /// Post-V5 decoder with support for envelope frames with CRC checksum. #[derive(Clone, Debug, Default)] pub struct UncompressedFrameDecoder { inner_decoder: GenericFrameDecoder, } impl FrameDecoder for UncompressedFrameDecoder { //noinspection DuplicatedCode #[inline] fn consume(&mut self, data: &mut Vec, _compression: Compression) -> Result> { self.inner_decoder.consume(data, Self::try_decode_frame) } } impl UncompressedFrameDecoder { fn try_decode_frame(buffer: &mut Vec) -> Result)>> { let buffer_len = buffer.len(); if buffer_len < UNCOMPRESSED_FRAME_HEADER_LENGTH { return Ok(None); } let header = if buffer_len >= 8 { i64::from_le_bytes(buffer[..8].try_into().unwrap()) & 0xffffffffffff } else { let mut header = 0; for (i, byte) in buffer[..UNCOMPRESSED_FRAME_HEADER_LENGTH] .iter() .enumerate() { header |= (*byte as i64) << (8 * i as i64); } header }; let header_crc24 = ((header >> 24) & 0xffffff) as i32; let computed_crc = crc24(&header.to_le_bytes()[..3]); if header_crc24 != computed_crc { return Err(create_header_crc_mismatch_error(computed_crc, header_crc24)); } let payload_length = (header & 0x1ffff) as usize; let payload_end = UNCOMPRESSED_FRAME_HEADER_LENGTH + payload_length; let frame_end = payload_end + FRAME_TRAILER_LENGTH; if buffer_len < frame_end { return Ok(None); } let payload_crc32 = u32::from_le_bytes(buffer[payload_end..frame_end].try_into().unwrap()); let computed_crc = crc32(&buffer[UNCOMPRESSED_FRAME_HEADER_LENGTH..payload_end]); if payload_crc32 != computed_crc { return Err(create_payload_crc_mismatch_error( computed_crc, payload_crc32, )); } let self_contained = (header & (1 << 17)) != 0; let payload = buffer[UNCOMPRESSED_FRAME_HEADER_LENGTH..payload_end].into(); *buffer = buffer.split_off(frame_end); Ok(Some((self_contained, payload))) } } #[derive(Clone, Debug)] struct GenericFrameDecoder { frame_buffer: Vec, payload_buffer: Vec, expected_payload_len: Option, } impl Default for GenericFrameDecoder { fn default() -> Self { Self { frame_buffer: Vec::with_capacity(MAX_FRAME_SIZE), payload_buffer: Vec::with_capacity(PAYLOAD_SIZE_LIMIT * 2), expected_payload_len: None, } } } impl GenericFrameDecoder { fn extract_non_self_contained_envelopes(&mut self) -> Result> { if let Some(expected_payload_len) = self.expected_payload_len { // The Cassandra wire format encodes the body length in bytes 5..9 // of the envelope header (after version/flags/stream/opcode). The // FULL envelope on the wire is therefore ENVELOPE_HEADER_LEN bytes // of header plus expected_payload_len bytes of body, so the buffer // must contain at least that many bytes before we can decode it. // Without this header offset we would attempt to decode while the // body was still partial, lose the partial data on the buffer // truncation below, and mis-frame the next envelope. let total_envelope_len = ENVELOPE_HEADER_LEN + expected_payload_len; if self.payload_buffer.len() < total_envelope_len { return Ok(vec![]); } // Use extract_envelopes directly so we know exactly how many bytes // got consumed. drain(..consumed) preserves any trailing bytes // that may belong to the next envelope - they could legitimately // be there if a producer packed bytes from envelope N+1 into the // tail of the non-self-contained sequence for envelope N. The // previous code simply called clear() and silently lost them. let (consumed, envelopes) = extract_envelopes(&self.payload_buffer, Compression::None)?; self.payload_buffer.drain(..consumed); // Reset envelope-tracking state so the next call re-parses the // body length from whatever envelope header remains at the start // of the buffer (or waits for one if the buffer is empty / a // partial header). Without this reset the next sequence would be // gated against the previous envelope's length. self.expected_payload_len = None; return Ok(envelopes); } if let Some(expected_payload_len) = self.extract_expected_payload_len() { self.expected_payload_len = Some(expected_payload_len); self.extract_non_self_contained_envelopes() } else { Ok(vec![]) } } fn extract_expected_payload_len(&self) -> Option { if self.payload_buffer.len() < ENVELOPE_HEADER_LEN { return None; } Some(i32::from_be_bytes(self.payload_buffer[5..9].try_into().unwrap()) as usize) } fn handle_frame( &mut self, envelopes: &mut Vec, self_contained: bool, frame: &mut Vec, ) -> Result<()> { if self_contained { if !self.payload_buffer.is_empty() { return Err(create_unexpected_self_contained_error()); } envelopes.append(&mut try_decode_envelopes_without_spare_data(frame)?); } else { self.payload_buffer.append(frame); envelopes.append(&mut self.extract_non_self_contained_envelopes()?); } Ok(()) } fn consume( &mut self, data: &mut Vec, try_decode_frame: impl Fn(&mut Vec) -> Result)>>, ) -> Result> { let mut envelopes = vec![]; if self.frame_buffer.is_empty() { // optimistic case while !data.is_empty() { if let Some((self_contained, mut frame)) = try_decode_frame(data)? { self.handle_frame(&mut envelopes, self_contained, &mut frame)?; } else { // we have some data, but not a full frame yet self.frame_buffer.append(data); break; } } } else { self.frame_buffer.append(data); while !self.frame_buffer.is_empty() { if let Some((self_contained, mut frame)) = try_decode_frame(&mut self.frame_buffer)? { self.handle_frame(&mut envelopes, self_contained, &mut frame)?; } else { break; } } } Ok(envelopes) } } #[cfg(test)] mod tests { use super::*; use crate::frame::frame_encoder::{FrameEncoder, UncompressedFrameEncoder}; use crate::frame::{Direction, Envelope, Flags, Opcode, Version}; // Build a body of `size` bytes filled with the supplied byte. We pick the // body size to be larger than PAYLOAD_SIZE_LIMIT so the encoder is forced // to emit non-self-contained frames. fn make_envelope(stream_id: i16, fill: u8, body_size: usize) -> Vec { Envelope { version: Version::V5, direction: Direction::Request, flags: Flags::empty(), opcode: Opcode::Query, stream_id, body: vec![fill; body_size], tracing_id: None, warnings: vec![], } .encode_with(Compression::None) .unwrap() } // Encode one envelope (which is too large to fit in a single frame) as a // sequence of non-self-contained frames. Each frame has its own header and // CRC trailer so we can simply concatenate them on the wire. fn encode_as_non_self_contained(envelope: &[u8]) -> Vec { let mut encoder = UncompressedFrameEncoder::default(); let mut wire = vec![]; let mut start = 0; while start < envelope.len() { let (consumed, frame) = encoder.finalize_non_self_contained(&envelope[start..]); wire.extend_from_slice(frame); start += consumed; encoder.reset(); } wire } #[test] fn decoder_recovers_two_consecutive_non_self_contained_envelopes() { // Use a body just over PAYLOAD_SIZE_LIMIT so each envelope spans two // frames; the second envelope is deliberately a different (smaller) // size to expose any stale `expected_payload_len` carryover. let envelope_a = make_envelope(1, 0xAA, PAYLOAD_SIZE_LIMIT + 100); let envelope_b = make_envelope(2, 0xBB, PAYLOAD_SIZE_LIMIT + 50); let mut wire = encode_as_non_self_contained(&envelope_a); wire.extend_from_slice(&encode_as_non_self_contained(&envelope_b)); let mut decoder = UncompressedFrameDecoder::default(); let envelopes = decoder .consume(&mut wire, Compression::None) .expect("decoder must accept two consecutive non-self-contained envelopes"); // we expect to recover both envelopes intact, in order assert_eq!(envelopes.len(), 2, "should decode both envelopes"); assert_eq!(envelopes[0].stream_id, 1); assert_eq!(envelopes[0].body, vec![0xAA; PAYLOAD_SIZE_LIMIT + 100]); assert_eq!(envelopes[1].stream_id, 2); assert_eq!(envelopes[1].body, vec![0xBB; PAYLOAD_SIZE_LIMIT + 50]); } // The reviewer pointed out a defensive gap: when payload_buffer holds a // complete envelope plus the start of the next envelope (because a // hypothetical producer packed bytes across envelope boundaries inside // non-self-contained frames), the previous code called clear() on the // buffer after decoding the first envelope, losing the trailing bytes // that begin the next envelope. // // We construct that exact scenario by hand-packing one non-self-contained // frame whose payload is `envelope_a + envelope_b[..partial]`, followed // by another non-self-contained frame carrying `envelope_b[partial..]`. // A correct decoder must reconstruct both envelopes; a buggy one drops // envelope_b's prefix on `clear()` and then fails to parse the second // envelope from a misaligned start. #[test] fn decoder_preserves_trailing_bytes_across_non_self_contained_frames() { let envelope_a = make_envelope(1, 0xAA, 100); let envelope_b = make_envelope(2, 0xBB, 200); // Frame 1 carries envelope_a in full PLUS the first 100 bytes of // envelope_b. Both fit comfortably under PAYLOAD_SIZE_LIMIT, so // finalize_non_self_contained packs them into a single frame. let half_b = 100usize; let mut packed = envelope_a.clone(); packed.extend_from_slice(&envelope_b[..half_b]); let mut encoder = UncompressedFrameEncoder::default(); let (consumed, frame1_slice) = encoder.finalize_non_self_contained(&packed); assert_eq!( consumed, packed.len(), "test setup: whole packed slice must fit" ); let frame1: Vec = frame1_slice.to_vec(); encoder.reset(); // Frame 2 carries the remaining bytes of envelope_b. let (_, frame2_slice) = encoder.finalize_non_self_contained(&envelope_b[half_b..]); let frame2: Vec = frame2_slice.to_vec(); let mut wire = frame1; wire.extend_from_slice(&frame2); let mut decoder = UncompressedFrameDecoder::default(); let envelopes = decoder .consume(&mut wire, Compression::None) .expect("decoder must accept the cross-boundary packed frames"); assert_eq!(envelopes.len(), 2, "both envelopes must be recovered"); assert_eq!(envelopes[0].stream_id, 1); assert_eq!(envelopes[0].body, vec![0xAA; 100]); assert_eq!(envelopes[1].stream_id, 2); assert_eq!(envelopes[1].body, vec![0xBB; 200]); } } ================================================ FILE: cassandra-protocol/src/frame/frame_encoder.rs ================================================ use crate::crc::{crc24, crc32}; use crate::frame::{ COMPRESSED_FRAME_HEADER_LENGTH, FRAME_TRAILER_LENGTH, PAYLOAD_SIZE_LIMIT, UNCOMPRESSED_FRAME_HEADER_LENGTH, }; use lz4_flex::block::get_maximum_output_size; use lz4_flex::{compress, compress_into}; #[inline] fn put3b(buffer: &mut [u8], value: i32) { let value = value.to_le_bytes(); buffer[0] = value[0]; buffer[1] = value[1]; buffer[2] = value[2]; } #[inline] fn add_trailer(buffer: &mut Vec, payload_start: usize) { buffer.reserve(4); let crc = crc32(&buffer[payload_start..]).to_le_bytes(); buffer.push(crc[0]); buffer.push(crc[1]); buffer.push(crc[2]); buffer.push(crc[3]); } /// An encoder for frames. Since protocol *v5*, frames became "envelopes" and a frame now can contain /// multiple complete envelopes (self-contained frame) or a part of one bigger envelope. /// /// Encoders are stateful and can either: /// 1. Have multiple self-contained envelopes added. /// 2. Have a single non self-contained envelope added. /// /// In either case, the encoder is assumed to have the buffer ready to accept envelopes before /// adding the first one or after calling [`reset_buffer`]. At some point, the frame can become /// finalized (which is the only possible case when adding a non self-contained envelope) and the /// returned buffer is assumed to be immutable and ready to be sent. pub trait FrameEncoder { /// Determines if payload of given size can fit in current frame buffer. fn can_fit(&self, len: usize) -> bool; /// Resets the internal state and prepares it for encoding envelopes. fn reset(&mut self); /// Adds a self-contained envelope to current frame. fn add_envelope(&mut self, envelope: Vec); /// Finalizes a self-contained encoded frame in the buffer. fn finalize_self_contained(&mut self) -> &[u8]; /// Appends a large envelope and finalizes non self-contained encoded frame in the buffer. /// Copies as much envelope data as possible and returns new envelope buffer start. fn finalize_non_self_contained(&mut self, envelope: &[u8]) -> (usize, &[u8]); /// Checks if current frame contains any envelopes. fn has_envelopes(&self) -> bool; } /// Pre-V5 frame encoder which simply encodes one envelope directly in the buffer. #[derive(Clone, Debug, Default)] pub struct LegacyFrameEncoder { buffer: Vec, } impl FrameEncoder for LegacyFrameEncoder { #[inline] fn can_fit(&self, _len: usize) -> bool { // we support only one envelope per frame self.buffer.is_empty() } #[inline] fn reset(&mut self) { self.buffer.clear(); } #[inline] fn add_envelope(&mut self, envelope: Vec) { self.buffer = envelope; } #[inline] fn finalize_self_contained(&mut self) -> &[u8] { &self.buffer } #[inline] fn finalize_non_self_contained(&mut self, envelope: &[u8]) -> (usize, &[u8]) { // attempting to finalize a non self-contained frame via the legacy encoder - while this // will work, the legacy encoder doesn't distinguish such frames and all are considered // self-contained self.buffer.clear(); self.buffer.extend_from_slice(envelope); (envelope.len(), &self.buffer) } #[inline] fn has_envelopes(&self) -> bool { !self.buffer.is_empty() } } /// Post-V5 encoder with support for envelope frames with CRC checksum. #[derive(Clone, Debug)] pub struct UncompressedFrameEncoder { buffer: Vec, } impl FrameEncoder for UncompressedFrameEncoder { #[inline] fn can_fit(&self, len: usize) -> bool { (self.buffer.len() - UNCOMPRESSED_FRAME_HEADER_LENGTH).saturating_add(len) < PAYLOAD_SIZE_LIMIT } #[inline] fn reset(&mut self) { self.buffer.truncate(UNCOMPRESSED_FRAME_HEADER_LENGTH); } #[inline] fn add_envelope(&mut self, mut envelope: Vec) { self.buffer.append(&mut envelope); } fn finalize_self_contained(&mut self) -> &[u8] { self.write_header(true); add_trailer(&mut self.buffer, UNCOMPRESSED_FRAME_HEADER_LENGTH); &self.buffer } fn finalize_non_self_contained(&mut self, envelope: &[u8]) -> (usize, &[u8]) { let max_size = envelope.len().min(PAYLOAD_SIZE_LIMIT - 1); self.buffer.extend_from_slice(&envelope[..max_size]); self.buffer.reserve(FRAME_TRAILER_LENGTH); self.write_header(false); add_trailer(&mut self.buffer, UNCOMPRESSED_FRAME_HEADER_LENGTH); (max_size, &self.buffer) } #[inline] fn has_envelopes(&self) -> bool { self.buffer.len() > UNCOMPRESSED_FRAME_HEADER_LENGTH } } impl Default for UncompressedFrameEncoder { fn default() -> Self { let buffer = vec![0; UNCOMPRESSED_FRAME_HEADER_LENGTH]; Self { buffer } } } impl UncompressedFrameEncoder { fn write_header(&mut self, self_contained: bool) { let len = self.buffer.len(); debug_assert!( len < (PAYLOAD_SIZE_LIMIT + UNCOMPRESSED_FRAME_HEADER_LENGTH), "len: {} max: {}", len, PAYLOAD_SIZE_LIMIT + UNCOMPRESSED_FRAME_HEADER_LENGTH ); let mut len = (len - UNCOMPRESSED_FRAME_HEADER_LENGTH) as u64; if self_contained { len |= 1 << 17; } put3b(self.buffer.as_mut_slice(), len as i32); put3b(&mut self.buffer[3..], crc24(&len.to_le_bytes()[..3])); } } /// Post-V5 Lz4 encoder with support for envelope frames with CRC checksum. #[derive(Clone, Debug)] pub struct Lz4FrameEncoder { buffer: Vec, } impl FrameEncoder for Lz4FrameEncoder { #[inline] fn can_fit(&self, len: usize) -> bool { // we don't know the whole compressed payload size, so we need to be conservative and expect // the worst case get_maximum_output_size( (self.buffer.len() - COMPRESSED_FRAME_HEADER_LENGTH).saturating_add(len), ) < PAYLOAD_SIZE_LIMIT } #[inline] fn reset(&mut self) { self.buffer.truncate(COMPRESSED_FRAME_HEADER_LENGTH); } #[inline] fn add_envelope(&mut self, mut envelope: Vec) { self.buffer.append(&mut envelope); } fn finalize_self_contained(&mut self) -> &[u8] { let uncompressed_size = self.buffer.len() - COMPRESSED_FRAME_HEADER_LENGTH; let mut compressed_payload = compress(&self.buffer[COMPRESSED_FRAME_HEADER_LENGTH..]); self.buffer.truncate(COMPRESSED_FRAME_HEADER_LENGTH); self.buffer.append(&mut compressed_payload); self.write_header(uncompressed_size, true); add_trailer(&mut self.buffer, COMPRESSED_FRAME_HEADER_LENGTH); &self.buffer } fn finalize_non_self_contained(&mut self, envelope: &[u8]) -> (usize, &[u8]) { let mut uncompressed_size = envelope.len().min(PAYLOAD_SIZE_LIMIT - 1); let offset = uncompressed_size; self.buffer.resize( get_maximum_output_size(uncompressed_size) + COMPRESSED_FRAME_HEADER_LENGTH + FRAME_TRAILER_LENGTH, // add space for trailer, so we don't allocate later 0, ); let mut compressed_size = compress_into( &envelope[..uncompressed_size], &mut self.buffer[COMPRESSED_FRAME_HEADER_LENGTH..], ) .unwrap(); // we can safely unwrap, since we have at least the amount of space needed if compressed_size >= PAYLOAD_SIZE_LIMIT { // compressed size can exceed source size, therefore can exceed max payload size // Java driver simply ignores compression at this point, so ¯\_(ツ)_/¯ self.buffer[COMPRESSED_FRAME_HEADER_LENGTH ..(COMPRESSED_FRAME_HEADER_LENGTH + uncompressed_size)] .copy_from_slice(&envelope[..uncompressed_size]); compressed_size = uncompressed_size; uncompressed_size = 0; // compressed size of 0 means no compression } self.buffer .truncate(COMPRESSED_FRAME_HEADER_LENGTH + compressed_size); self.write_header(uncompressed_size, false); add_trailer(&mut self.buffer, COMPRESSED_FRAME_HEADER_LENGTH); (offset, &self.buffer) } #[inline] fn has_envelopes(&self) -> bool { self.buffer.len() > COMPRESSED_FRAME_HEADER_LENGTH } } impl Default for Lz4FrameEncoder { fn default() -> Self { let buffer = vec![0; COMPRESSED_FRAME_HEADER_LENGTH]; Self { buffer } } } impl Lz4FrameEncoder { fn write_header(&mut self, uncompressed_size: usize, self_contained: bool) { let len = self.buffer.len(); debug_assert!(len < (PAYLOAD_SIZE_LIMIT + COMPRESSED_FRAME_HEADER_LENGTH)); let mut header = (len - COMPRESSED_FRAME_HEADER_LENGTH) as u64 | ((uncompressed_size as u64) << 17); if self_contained { header |= 1 << 34; } let crc = crc24(&header.to_le_bytes()[..5]) as u64; let header = header | (crc << 40); self.buffer[..8].copy_from_slice(&header.to_le_bytes()); } } ================================================ FILE: cassandra-protocol/src/frame/message_auth_challenge.rs ================================================ use super::Serialize; use crate::error; use crate::frame::{FromCursor, Version}; use crate::types::CBytes; use std::io::Cursor; /// Server authentication challenge. #[derive(Debug, PartialEq, Ord, PartialOrd, Eq, Hash, Clone)] pub struct BodyResAuthChallenge { pub data: CBytes, } impl Serialize for BodyResAuthChallenge { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.data.serialize(cursor, version); } } impl FromCursor for BodyResAuthChallenge { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { CBytes::from_cursor(cursor, version).map(|data| BodyResAuthChallenge { data }) } } #[cfg(test)] mod tests { use super::*; use crate::frame::traits::FromCursor; use std::io::Cursor; #[test] fn body_res_auth_challenge_from_cursor() { let bytes = &[0, 0, 0, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let expected = BodyResAuthChallenge { data: CBytes::new(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let body = BodyResAuthChallenge::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!( body.data.into_bytes().unwrap(), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } ================================================ FILE: cassandra-protocol/src/frame/message_auth_response.rs ================================================ use crate::error; use crate::frame::{Direction, Envelope, Flags, FromCursor, Opcode, Serialize, Version}; use crate::types::CBytes; use derive_more::Constructor; use std::io::Cursor; #[derive(Debug, Constructor, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct BodyReqAuthResponse { pub data: CBytes, } impl Serialize for BodyReqAuthResponse { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.data.serialize(cursor, version); } } impl FromCursor for BodyReqAuthResponse { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { CBytes::from_cursor(cursor, version).map(BodyReqAuthResponse::new) } } impl Envelope { /// Creates new envelope of type `AuthResponse`. pub fn new_req_auth_response(token_bytes: CBytes, version: Version) -> Envelope { let direction = Direction::Request; let opcode = Opcode::AuthResponse; let body = BodyReqAuthResponse::new(token_bytes); Envelope::new( version, direction, Flags::empty(), opcode, 0, body.serialize_to_vec(version), None, vec![], ) } } #[cfg(test)] mod tests { use super::*; use crate::types::CBytes; #[test] fn body_req_auth_response() { let bytes = CBytes::new(vec![1, 2, 3]); let body = BodyReqAuthResponse::new(bytes); assert_eq!( body.serialize_to_vec(Version::V4), vec![0, 0, 0, 3, 1, 2, 3] ); } #[test] fn frame_body_req_auth_response() { let bytes = vec![1, 2, 3]; let frame = Envelope::new_req_auth_response(CBytes::new(bytes), Version::V4); assert_eq!(frame.version, Version::V4); assert_eq!(frame.opcode, Opcode::AuthResponse); assert_eq!(frame.body, &[0, 0, 0, 3, 1, 2, 3]); assert_eq!(frame.tracing_id, None); assert!(frame.warnings.is_empty()); } } ================================================ FILE: cassandra-protocol/src/frame/message_auth_success.rs ================================================ use super::Serialize; use crate::error; use crate::frame::{FromCursor, Version}; use crate::types::CBytes; use std::io::Cursor; /// `BodyReqAuthSuccess` is a envelope that represents a successful authentication response. #[derive(Debug, PartialEq, Ord, PartialOrd, Eq, Hash, Clone)] pub struct BodyReqAuthSuccess { pub data: CBytes, } impl Serialize for BodyReqAuthSuccess { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.data.serialize(cursor, version); } } impl FromCursor for BodyReqAuthSuccess { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { CBytes::from_cursor(cursor, version).map(|data| BodyReqAuthSuccess { data }) } } #[cfg(test)] mod tests { use super::*; use crate::frame::traits::FromCursor; use std::io::Cursor; #[test] fn body_req_auth_success() { let bytes = &[0, 0, 0, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let expected = BodyReqAuthSuccess { data: CBytes::new(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let body = BodyReqAuthSuccess::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!( body.data.into_bytes().unwrap(), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } ================================================ FILE: cassandra-protocol/src/frame/message_authenticate.rs ================================================ use super::Serialize; use crate::error; use crate::frame::{FromCursor, Version}; use crate::types::{from_cursor_str, serialize_str}; use std::io::Cursor; /// A server authentication challenge. #[derive(Debug, PartialEq, Ord, PartialOrd, Eq, Hash, Clone)] pub struct BodyResAuthenticate { pub data: String, } impl Serialize for BodyResAuthenticate { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { serialize_str(cursor, &self.data, version); } } impl FromCursor for BodyResAuthenticate { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { Ok(BodyResAuthenticate { data: from_cursor_str(cursor)?.to_string(), }) } } #[cfg(test)] mod tests { use super::*; use crate::frame::traits::FromCursor; use crate::frame::Version; use std::io::Cursor; #[test] fn body_res_authenticate() { // string "abcde" let bytes = [0, 5, 97, 98, 99, 100, 101]; let expected = BodyResAuthenticate { data: "abcde".into(), }; { let mut cursor: Cursor<&[u8]> = Cursor::new(&bytes); let auth = BodyResAuthenticate::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(auth, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } ================================================ FILE: cassandra-protocol/src/frame/message_batch.rs ================================================ use crate::consistency::Consistency; use crate::frame::{Direction, Envelope, Flags, FromCursor, Opcode, Serialize, Version}; use crate::query::QueryFlags; use crate::query::QueryValues; use crate::types::value::Value; use crate::types::{ from_cursor_str, from_cursor_str_long, serialize_str, serialize_str_long, CBytesShort, CInt, CIntShort, CLong, }; use crate::{error, Error}; use derive_more::{Constructor, Display}; use std::convert::{TryFrom, TryInto}; use std::io::{Cursor, Read}; #[derive(Debug, Clone, Constructor, PartialEq, Eq)] pub struct BodyReqBatch { pub batch_type: BatchType, pub queries: Vec, pub consistency: Consistency, pub serial_consistency: Option, pub timestamp: Option, pub keyspace: Option, pub now_in_seconds: Option, } impl Serialize for BodyReqBatch { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { let batch_type = u8::from(self.batch_type); batch_type.serialize(cursor, version); let len = self.queries.len() as CIntShort; len.serialize(cursor, version); for query in &self.queries { query.serialize(cursor, version); } let consistency: CIntShort = self.consistency.into(); consistency.serialize(cursor, version); let mut flags = QueryFlags::empty(); if self.serial_consistency.is_some() { flags.insert(QueryFlags::WITH_SERIAL_CONSISTENCY) } if self.timestamp.is_some() { flags.insert(QueryFlags::WITH_DEFAULT_TIMESTAMP) } if self.keyspace.is_some() { flags.insert(QueryFlags::WITH_KEYSPACE) } if self.now_in_seconds.is_some() { flags.insert(QueryFlags::WITH_NOW_IN_SECONDS) } flags.serialize(cursor, version); if let Some(serial_consistency) = self.serial_consistency { let serial_consistency: CIntShort = serial_consistency.into(); serial_consistency.serialize(cursor, version); } if let Some(timestamp) = self.timestamp { timestamp.serialize(cursor, version); } if let Some(keyspace) = &self.keyspace { serialize_str(cursor, keyspace.as_str(), version); } if let Some(now_in_seconds) = self.now_in_seconds { now_in_seconds.serialize(cursor, version); } } } impl FromCursor for BodyReqBatch { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let mut batch_type = [0]; cursor.read_exact(&mut batch_type)?; let batch_type = BatchType::try_from(batch_type[0])?; let len = CIntShort::from_cursor(cursor, version)?; let mut queries = Vec::with_capacity(len as usize); for _ in 0..len { queries.push(BatchQuery::from_cursor(cursor, version)?); } let consistency = CIntShort::from_cursor(cursor, version).and_then(TryInto::try_into)?; let query_flags = QueryFlags::from_cursor(cursor, version)?; let serial_consistency = if query_flags.contains(QueryFlags::WITH_SERIAL_CONSISTENCY) { Some(CIntShort::from_cursor(cursor, version).and_then(TryInto::try_into)?) } else { None }; let timestamp = if query_flags.contains(QueryFlags::WITH_DEFAULT_TIMESTAMP) { Some(CLong::from_cursor(cursor, version)?) } else { None }; let keyspace = if query_flags.contains(QueryFlags::WITH_KEYSPACE) { Some(from_cursor_str(cursor).map(|keyspace| keyspace.to_string())?) } else { None }; let now_in_seconds = if query_flags.contains(QueryFlags::WITH_NOW_IN_SECONDS) { Some(CInt::from_cursor(cursor, version)?) } else { None }; Ok(BodyReqBatch::new( batch_type, queries, consistency, serial_consistency, timestamp, keyspace, now_in_seconds, )) } } /// Batch type #[derive(Debug, Clone, Copy, PartialEq, Ord, PartialOrd, Eq, Hash, Display)] #[non_exhaustive] pub enum BatchType { /// The batch will be "logged". This is equivalent to a /// normal CQL3 batch statement. Logged, /// The batch will be "unlogged". Unlogged, /// The batch will be a "counter" batch (and non-counter /// statements will be rejected). Counter, } impl TryFrom for BatchType { type Error = Error; fn try_from(value: u8) -> Result { match value { 0 => Ok(BatchType::Logged), 1 => Ok(BatchType::Unlogged), 2 => Ok(BatchType::Counter), _ => Err(Error::General(format!("Unknown batch type: {value}"))), } } } impl From for u8 { fn from(value: BatchType) -> Self { match value { BatchType::Logged => 0, BatchType::Unlogged => 1, BatchType::Counter => 2, } } } /// Contains either an id of a prepared query or CQL string. #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] pub enum BatchQuerySubj { PreparedId(CBytesShort), QueryString(String), } /// The structure that represents a query to be batched. #[derive(Debug, Clone, Constructor, PartialEq, Eq)] pub struct BatchQuery { /// Contains either id of a prepared query or a query itself. pub subject: BatchQuerySubj, /// **Important note:** QueryValues::NamedValues does not work and should not be /// used for batches. It is specified in a way that makes it impossible for the server /// to implement. This will be fixed in a future version of the native /// protocol. See for /// more details pub values: QueryValues, } impl Serialize for BatchQuery { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match &self.subject { BatchQuerySubj::PreparedId(id) => { 1u8.serialize(cursor, version); id.serialize(cursor, version); } BatchQuerySubj::QueryString(s) => { 0u8.serialize(cursor, version); serialize_str_long(cursor, s, version); } } let len = self.values.len() as CIntShort; len.serialize(cursor, version); self.values.serialize(cursor, version); } } impl FromCursor for BatchQuery { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let mut is_prepared = [0]; cursor.read_exact(&mut is_prepared)?; let is_prepared = is_prepared[0] != 0; let subject = if is_prepared { BatchQuerySubj::PreparedId(CBytesShort::from_cursor(cursor, version)?) } else { BatchQuerySubj::QueryString(from_cursor_str_long(cursor).map(Into::into)?) }; let len = CIntShort::from_cursor(cursor, version)?; // assuming names are not present due to // https://issues.apache.org/jira/browse/CASSANDRA-10246 let mut values = Vec::with_capacity(len as usize); for _ in 0..len { values.push(Value::from_cursor(cursor, version)?); } Ok(BatchQuery::new(subject, QueryValues::SimpleValues(values))) } } impl Envelope { pub fn new_req_batch(query: BodyReqBatch, flags: Flags, version: Version) -> Envelope { let direction = Direction::Request; let opcode = Opcode::Batch; Envelope::new( version, direction, flags, opcode, 0, query.serialize_to_vec(version), None, vec![], ) } } #[cfg(test)] mod tests { use crate::consistency::Consistency; use crate::frame::message_batch::{BatchQuery, BatchQuerySubj, BatchType, BodyReqBatch}; use crate::frame::traits::Serialize; use crate::frame::{FromCursor, Version}; use crate::query::QueryValues; use crate::types::prelude::Value; use std::io::Cursor; #[test] fn should_deserialize_query() { let data = [0, 0, 0, 0, 1, 65, 0, 1, 0xff, 0xff, 0xff, 0xfe]; let mut cursor = Cursor::new(data.as_slice()); let query = BatchQuery::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(query.subject, BatchQuerySubj::QueryString("A".into())); assert_eq!(query.values, QueryValues::SimpleValues(vec![Value::NotSet])); } #[test] fn should_deserialize_body() { let data = [0, 0, 0, 0, 0, 0x10 | 0x20, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8]; let mut cursor = Cursor::new(data.as_slice()); let body = BodyReqBatch::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(body.batch_type, BatchType::Logged); assert!(body.queries.is_empty()); assert_eq!(body.consistency, Consistency::Any); assert_eq!(body.serial_consistency, Some(Consistency::One)); assert_eq!(body.timestamp, Some(0x0102030405060708)); } #[test] fn should_support_keyspace() { let keyspace = "abc"; let body = BodyReqBatch::new( BatchType::Logged, vec![], Consistency::Any, None, None, Some(keyspace.into()), None, ); let data = body.serialize_to_vec(Version::V5); let body = BodyReqBatch::from_cursor(&mut Cursor::new(data.as_slice()), Version::V5).unwrap(); assert_eq!(body.keyspace, Some(keyspace.to_string())); } #[test] fn should_support_now_in_seconds() { let now_in_seconds = 4; let body = BodyReqBatch::new( BatchType::Logged, vec![], Consistency::Any, None, None, None, Some(now_in_seconds), ); let data = body.serialize_to_vec(Version::V5); let body = BodyReqBatch::from_cursor(&mut Cursor::new(data.as_slice()), Version::V5).unwrap(); assert_eq!(body.now_in_seconds, Some(now_in_seconds)); } } ================================================ FILE: cassandra-protocol/src/frame/message_error.rs ================================================ use super::Serialize; use crate::consistency::Consistency; use crate::frame::traits::FromCursor; use crate::frame::Version; use crate::types::*; use crate::{error, Error}; /// This modules contains [Cassandra's errors]() /// which server could respond to client. use derive_more::Display; use std::collections::HashMap; use std::io::{Cursor, Read}; use std::net::SocketAddr; /// CDRS error which could be returned by Cassandra server as a response. As in the specification, /// it contains an error code and an error message. Apart of those depending of type of error, /// it could contain additional information represented by `additional_info` property. #[derive(Debug, PartialEq, Eq, Clone)] pub struct ErrorBody { /// Error message. pub message: String, /// The type of error, possibly including type specific additional information. pub ty: ErrorType, } impl Serialize for ErrorBody { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.ty.to_error_code().serialize(cursor, version); serialize_str(cursor, &self.message, version); self.ty.serialize(cursor, version); } } impl FromCursor for ErrorBody { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let error_code = CInt::from_cursor(cursor, version)?; let message = from_cursor_str(cursor)?.to_string(); let ty = ErrorType::from_cursor_with_code(cursor, error_code, version)?; Ok(ErrorBody { message, ty }) } } impl ErrorBody { /// Is the error related to bad protocol used. This is a special case which is used in some /// situations to detect when a node should not be contacted. pub fn is_bad_protocol(&self) -> bool { // based on ProtocolInitHandler from the Datastax driver (self.ty == ErrorType::Server || self.ty == ErrorType::Protocol) && (self .message .contains("Invalid or unsupported protocol version") || self.message.contains("Beta version of the protocol used")) } } /// Protocol-dependent failure information. V5 contains a map of endpoint->code entries, while /// previous versions contain only error count. #[derive(Debug, PartialEq, Eq, Clone)] #[non_exhaustive] pub enum FailureInfo { /// Represents the number of nodes that experience a failure while executing the request. NumFailures(CInt), /// Error code map for affected nodes. ReasonMap(HashMap), } impl Serialize for FailureInfo { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { FailureInfo::NumFailures(count) => count.serialize(cursor, version), FailureInfo::ReasonMap(map) => { let num_failures = map.len() as CInt; num_failures.serialize(cursor, version); for (endpoint, error_code) in map { endpoint.serialize(cursor, version); error_code.serialize(cursor, version); } } } } } impl FromCursor for FailureInfo { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { Ok(match version { Version::V3 | Version::V4 => Self::NumFailures(CInt::from_cursor(cursor, version)?), Version::V5 => { let num_failures = CInt::from_cursor(cursor, version)?; let mut map = HashMap::with_capacity(num_failures as usize); for _ in 0..num_failures { let endpoint = SocketAddr::from_cursor(cursor, version)?; let error_code = CIntShort::from_cursor(cursor, version)?; map.insert(endpoint, error_code); } Self::ReasonMap(map) } }) } } /// Additional error info in accordance to /// [Cassandra protocol v4](). #[derive(Debug, PartialEq, Eq, Clone)] #[non_exhaustive] pub enum ErrorType { Server, Protocol, Authentication, Unavailable(UnavailableError), Overloaded, IsBootstrapping, Truncate, WriteTimeout(WriteTimeoutError), ReadTimeout(ReadTimeoutError), ReadFailure(ReadFailureError), FunctionFailure(FunctionFailureError), WriteFailure(WriteFailureError), Syntax, Unauthorized, Invalid, Config, AlreadyExists(AlreadyExistsError), Unprepared(UnpreparedError), } impl Serialize for ErrorType { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { ErrorType::Unavailable(unavailable) => unavailable.serialize(cursor, version), ErrorType::WriteTimeout(write_timeout) => write_timeout.serialize(cursor, version), ErrorType::ReadTimeout(read_timeout) => read_timeout.serialize(cursor, version), ErrorType::ReadFailure(read_failure) => read_failure.serialize(cursor, version), ErrorType::FunctionFailure(function_failure) => { function_failure.serialize(cursor, version) } ErrorType::WriteFailure(write_failure) => write_failure.serialize(cursor, version), ErrorType::AlreadyExists(already_exists) => already_exists.serialize(cursor, version), ErrorType::Unprepared(unprepared) => unprepared.serialize(cursor, version), _ => {} } } } impl ErrorType { pub fn from_cursor_with_code( cursor: &mut Cursor<&[u8]>, error_code: CInt, version: Version, ) -> error::Result { match error_code { 0x0000 => Ok(ErrorType::Server), 0x000A => Ok(ErrorType::Protocol), 0x0100 => Ok(ErrorType::Authentication), 0x1000 => UnavailableError::from_cursor(cursor, version).map(ErrorType::Unavailable), 0x1001 => Ok(ErrorType::Overloaded), 0x1002 => Ok(ErrorType::IsBootstrapping), 0x1003 => Ok(ErrorType::Truncate), 0x1100 => WriteTimeoutError::from_cursor(cursor, version).map(ErrorType::WriteTimeout), 0x1200 => ReadTimeoutError::from_cursor(cursor, version).map(ErrorType::ReadTimeout), 0x1300 => ReadFailureError::from_cursor(cursor, version).map(ErrorType::ReadFailure), 0x1400 => { FunctionFailureError::from_cursor(cursor, version).map(ErrorType::FunctionFailure) } 0x1500 => WriteFailureError::from_cursor(cursor, version).map(ErrorType::WriteFailure), 0x2000 => Ok(ErrorType::Syntax), 0x2100 => Ok(ErrorType::Unauthorized), 0x2200 => Ok(ErrorType::Invalid), 0x2300 => Ok(ErrorType::Config), 0x2400 => { AlreadyExistsError::from_cursor(cursor, version).map(ErrorType::AlreadyExists) } 0x2500 => UnpreparedError::from_cursor(cursor, version).map(ErrorType::Unprepared), _ => Err(Error::UnexpectedErrorCode(error_code)), } } pub fn to_error_code(&self) -> CInt { match self { ErrorType::Server => 0x0000, ErrorType::Protocol => 0x000A, ErrorType::Authentication => 0x0100, ErrorType::Unavailable(_) => 0x1000, ErrorType::Overloaded => 0x1001, ErrorType::IsBootstrapping => 0x1002, ErrorType::Truncate => 0x1003, ErrorType::WriteTimeout(_) => 0x1100, ErrorType::ReadTimeout(_) => 0x1200, ErrorType::ReadFailure(_) => 0x1300, ErrorType::FunctionFailure(_) => 0x1400, ErrorType::WriteFailure(_) => 0x1500, ErrorType::Syntax => 0x2000, ErrorType::Unauthorized => 0x2100, ErrorType::Invalid => 0x2200, ErrorType::Config => 0x2300, ErrorType::AlreadyExists(_) => 0x2400, ErrorType::Unprepared(_) => 0x2500, } } } /// Additional info about /// [unavailable exception]() #[derive(Debug, PartialEq, Ord, PartialOrd, Eq, Copy, Clone, Hash)] pub struct UnavailableError { /// Consistency level of query. pub cl: Consistency, /// Number of nodes that should be available to respect `cl`. pub required: CInt, /// Number of replicas that we were know to be alive. pub alive: CInt, } impl Serialize for UnavailableError { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.cl.serialize(cursor, version); self.required.serialize(cursor, version); self.alive.serialize(cursor, version); } } impl FromCursor for UnavailableError { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let cl = Consistency::from_cursor(cursor, version)?; let required = CInt::from_cursor(cursor, version)?; let alive = CInt::from_cursor(cursor, version)?; Ok(UnavailableError { cl, required, alive, }) } } /// Timeout exception during a write request. #[derive(Debug, PartialEq, Clone, Ord, PartialOrd, Eq, Hash)] pub struct WriteTimeoutError { /// Consistency level of query. pub cl: Consistency, /// `i32` representing the number of nodes having acknowledged the request. pub received: CInt, /// `i32` representing the number of replicas whose acknowledgement is required to achieve `cl`. pub block_for: CInt, /// Describes the type of the write that timed out. pub write_type: WriteType, /// The number of contentions occurred during the CAS operation. The field only presents when /// the `write_type` is `Cas`. pub contentions: Option, } impl Serialize for WriteTimeoutError { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.cl.serialize(cursor, version); self.received.serialize(cursor, version); self.block_for.serialize(cursor, version); self.write_type.serialize(cursor, version); if let Some(contentions) = self.contentions { contentions.serialize(cursor, version); } } } impl FromCursor for WriteTimeoutError { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let cl = Consistency::from_cursor(cursor, version)?; let received = CInt::from_cursor(cursor, version)?; let block_for = CInt::from_cursor(cursor, version)?; let write_type = WriteType::from_cursor(cursor, version)?; let contentions = if write_type == WriteType::Cas { Some(CIntShort::from_cursor(cursor, version)?) } else { None }; Ok(WriteTimeoutError { cl, received, block_for, write_type, contentions, }) } } /// Timeout exception during a read request. #[derive(Debug, PartialEq, Ord, PartialOrd, Eq, Copy, Clone, Hash)] pub struct ReadTimeoutError { /// Consistency level of query. pub cl: Consistency, /// `i32` representing the number of nodes having acknowledged the request. pub received: CInt, /// `i32` representing the number of replicas whose acknowledgement is required to achieve `cl`. pub block_for: CInt, data_present: u8, } impl Serialize for ReadTimeoutError { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.cl.serialize(cursor, version); self.received.serialize(cursor, version); self.block_for.serialize(cursor, version); self.data_present.serialize(cursor, version); } } impl ReadTimeoutError { /// Shows if a replica has responded to a query. #[inline] pub fn replica_has_responded(&self) -> bool { self.data_present != 0 } } impl FromCursor for ReadTimeoutError { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let cl = Consistency::from_cursor(cursor, version)?; let received = CInt::from_cursor(cursor, version)?; let block_for = CInt::from_cursor(cursor, version)?; let mut buff = [0]; cursor.read_exact(&mut buff)?; let data_present = buff[0]; Ok(ReadTimeoutError { cl, received, block_for, data_present, }) } } /// A non-timeout exception during a read request. #[derive(Debug, PartialEq, Eq, Clone)] pub struct ReadFailureError { /// Consistency level of query. pub cl: Consistency, /// The number of nodes having acknowledged the request. pub received: CInt, /// The number of replicas whose acknowledgement is required to achieve `cl`. pub block_for: CInt, /// Failure information. pub failure_info: FailureInfo, data_present: u8, } impl Serialize for ReadFailureError { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.cl.serialize(cursor, version); self.received.serialize(cursor, version); self.block_for.serialize(cursor, version); self.failure_info.serialize(cursor, version); self.data_present.serialize(cursor, version); } } impl ReadFailureError { /// Shows if replica has responded to a query. #[inline] pub fn replica_has_responded(&self) -> bool { self.data_present != 0 } } impl FromCursor for ReadFailureError { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let cl = Consistency::from_cursor(cursor, version)?; let received = CInt::from_cursor(cursor, version)?; let block_for = CInt::from_cursor(cursor, version)?; let failure_info = FailureInfo::from_cursor(cursor, version)?; let mut buff = [0]; cursor.read_exact(&mut buff)?; let data_present = buff[0]; Ok(ReadFailureError { cl, received, block_for, failure_info, data_present, }) } } /// A (user defined) function failed during execution. #[derive(Debug, PartialEq, Ord, PartialOrd, Eq, Hash, Clone)] pub struct FunctionFailureError { /// The keyspace of the failed function. pub keyspace: String, /// The name of the failed function pub function: String, /// One string for each argument type (as CQL type) of the failed function. pub arg_types: Vec, } impl Serialize for FunctionFailureError { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { serialize_str(cursor, &self.keyspace, version); serialize_str(cursor, &self.function, version); serialize_str_list(cursor, self.arg_types.iter().map(|x| x.as_str()), version); } } impl FromCursor for FunctionFailureError { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { let keyspace = from_cursor_str(cursor)?.to_string(); let function = from_cursor_str(cursor)?.to_string(); let arg_types = from_cursor_string_list(cursor)?; Ok(FunctionFailureError { keyspace, function, arg_types, }) } } /// A non-timeout exception during a write request. #[derive(Debug, PartialEq, Eq, Clone)] pub struct WriteFailureError { /// Consistency of the query having triggered the exception. pub cl: Consistency, /// The number of nodes having answered the request. pub received: CInt, /// The number of replicas whose acknowledgement is required to achieve `cl`. pub block_for: CInt, /// Failure information. pub failure_info: FailureInfo, /// describes the type of the write that failed. pub write_type: WriteType, } impl Serialize for WriteFailureError { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.cl.serialize(cursor, version); self.received.serialize(cursor, version); self.block_for.serialize(cursor, version); self.failure_info.serialize(cursor, version); self.write_type.serialize(cursor, version); } } impl FromCursor for WriteFailureError { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let cl = Consistency::from_cursor(cursor, version)?; let received = CInt::from_cursor(cursor, version)?; let block_for = CInt::from_cursor(cursor, version)?; let failure_info = FailureInfo::from_cursor(cursor, version)?; let write_type = WriteType::from_cursor(cursor, version)?; Ok(WriteFailureError { cl, received, block_for, failure_info, write_type, }) } } /// Describes the type of the write that failed. /// [Read more...](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L1118) #[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Hash, Clone, Display)] #[non_exhaustive] pub enum WriteType { /// The write was a non-batched non-counter write. Simple, /// The write was a (logged) batch write. If this type is received, it means the batch log /// has been successfully written. Batch, /// The write was an unlogged batch. No batch log write has been attempted. UnloggedBatch, /// The write was a counter write (batched or not). Counter, /// The failure occurred during the write to the batch log when a (logged) batch /// write was requested. BatchLog, /// The timeout occurred during the Compare And Set write/update. Cas, /// The timeout occurred when a write involves VIEW update and failure to acquire local view(MV) /// lock for key within timeout. View, /// The timeout occurred when cdc_total_space is exceeded when doing a write to data tracked by /// cdc. Cdc, /// Unknown write type. Unknown(String), } impl Serialize for WriteType { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { WriteType::Simple => serialize_str(cursor, "SIMPLE", version), WriteType::Batch => serialize_str(cursor, "BATCH", version), WriteType::UnloggedBatch => serialize_str(cursor, "UNLOGGED_BATCH", version), WriteType::Counter => serialize_str(cursor, "COUNTER", version), WriteType::BatchLog => serialize_str(cursor, "BATCH_LOG", version), WriteType::Cas => serialize_str(cursor, "CAS", version), WriteType::View => serialize_str(cursor, "VIEW", version), WriteType::Cdc => serialize_str(cursor, "CDC", version), WriteType::Unknown(write_type) => serialize_str(cursor, write_type, version), } } } impl FromCursor for WriteType { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> error::Result { match from_cursor_str(cursor)? { "SIMPLE" => Ok(WriteType::Simple), "BATCH" => Ok(WriteType::Batch), "UNLOGGED_BATCH" => Ok(WriteType::UnloggedBatch), "COUNTER" => Ok(WriteType::Counter), "BATCH_LOG" => Ok(WriteType::BatchLog), "CAS" => Ok(WriteType::Cas), "VIEW" => Ok(WriteType::View), "CDC" => Ok(WriteType::Cdc), wt => Ok(WriteType::Unknown(wt.into())), } } } /// The query attempted to create a keyspace or a table that was already existing. /// [Read more...](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L1140) #[derive(Debug, PartialEq, Ord, PartialOrd, Eq, Hash, Clone)] pub struct AlreadyExistsError { /// Represents either the keyspace that already exists, /// or the keyspace in which the table that already exists is. pub ks: String, /// Represents the name of the table that already exists. pub table: String, } impl Serialize for AlreadyExistsError { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { serialize_str(cursor, &self.ks, version); serialize_str(cursor, &self.table, version); } } impl FromCursor for AlreadyExistsError { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { let ks = from_cursor_str(cursor)?.to_string(); let table = from_cursor_str(cursor)?.to_string(); Ok(AlreadyExistsError { ks, table }) } } /// Can be thrown while a prepared statement tries to be /// executed if the provided prepared statement ID is not known by /// this host. [Read more...]() #[derive(Debug, PartialEq, Ord, PartialOrd, Eq, Hash, Clone)] pub struct UnpreparedError { /// Unknown ID. pub id: CBytesShort, } impl Serialize for UnpreparedError { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.id.serialize(cursor, version); } } impl FromCursor for UnpreparedError { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let id = CBytesShort::from_cursor(cursor, version)?; Ok(UnpreparedError { id }) } } //noinspection DuplicatedCode #[cfg(test)] fn test_encode_decode(bytes: &[u8], expected: ErrorBody) { { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let result = ErrorBody::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(expected, result); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } #[cfg(test)] mod error_tests { use super::*; #[test] fn server() { let bytes = &[ 0, 0, 0, 0, // server 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Server, }; test_encode_decode(bytes, expected); } #[test] fn protocol() { let bytes = &[ 0, 0, 0, 10, // protocol 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Protocol, }; test_encode_decode(bytes, expected); } #[test] fn authentication() { let bytes = &[ 0, 0, 1, 0, // authentication error 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Authentication, }; test_encode_decode(bytes, expected); } #[test] fn unavailable() { let bytes = &[ 0, 0, 16, 0, // unavailable 0, 3, 102, 111, 111, // message - foo // // unavailable error 0, 0, // consistency any 0, 0, 0, 1, // required 0, 0, 0, 1, // alive ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Unavailable(UnavailableError { cl: Consistency::Any, required: 1, alive: 1, }), }; test_encode_decode(bytes, expected); } #[test] fn overloaded() { let bytes = &[ 0, 0, 16, 1, // authentication error 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Overloaded, }; test_encode_decode(bytes, expected); } #[test] fn is_bootstrapping() { let bytes = &[ 0, 0, 16, 2, // is bootstrapping 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::IsBootstrapping, }; test_encode_decode(bytes, expected); } #[test] fn truncate() { let bytes = &[ 0, 0, 16, 3, // truncate 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Truncate, }; test_encode_decode(bytes, expected); } #[test] fn write_timeout() { let bytes = &[ 0, 0, 17, 0, // write timeout 0, 3, 102, 111, 111, // message - foo // // timeout error 0, 0, // consistency any 0, 0, 0, 1, // received 0, 0, 0, 1, // block_for 0, 6, 83, 73, 77, 80, 76, 69, // Write type simple ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::WriteTimeout(WriteTimeoutError { cl: Consistency::Any, received: 1, block_for: 1, write_type: WriteType::Simple, contentions: None, }), }; test_encode_decode(bytes, expected); } #[test] fn read_timeout() { let bytes = &[ 0, 0, 18, 0, // read timeout 0, 3, 102, 111, 111, // message - foo // // read timeout 0, 0, // consistency any 0, 0, 0, 1, // received 0, 0, 0, 1, // block_for 0, // data present ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::ReadTimeout(ReadTimeoutError { cl: Consistency::Any, received: 1, block_for: 1, data_present: 0, }), }; test_encode_decode(bytes, expected); } #[test] fn read_failure() { let bytes = &[ 0, 0, 19, 0, // read failure 0, 3, 102, 111, 111, // message - foo // // read timeout 0, 0, // consistency any 0, 0, 0, 1, // received 0, 0, 0, 1, // block_for 0, 0, 0, 1, // num failure 0, // data present ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::ReadFailure(ReadFailureError { cl: Consistency::Any, received: 1, block_for: 1, failure_info: FailureInfo::NumFailures(1), data_present: 0, }), }; test_encode_decode(bytes, expected); } #[test] fn syntax() { let bytes = &[ 0, 0, 32, 0, // syntax 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Syntax, }; test_encode_decode(bytes, expected); } #[test] fn unauthorized() { let bytes = &[ 0, 0, 33, 0, // unauthorized 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Unauthorized, }; test_encode_decode(bytes, expected); } #[test] fn invalid() { let bytes = &[ 0, 0, 34, 0, // invalid 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Invalid, }; test_encode_decode(bytes, expected); } #[test] fn config() { let bytes = &[ 0, 0, 35, 0, // config 0, 3, 102, 111, 111, // message - foo ]; let expected = ErrorBody { message: "foo".into(), ty: ErrorType::Config, }; test_encode_decode(bytes, expected); } } ================================================ FILE: cassandra-protocol/src/frame/message_event.rs ================================================ use crate::error; use crate::frame::events::ServerEvent; use crate::frame::Serialize; use crate::frame::{FromCursor, Version}; use std::io::Cursor; #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct BodyResEvent { pub event: ServerEvent, } impl Serialize for BodyResEvent { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.event.serialize(cursor, version); } } impl FromCursor for BodyResEvent { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let event = ServerEvent::from_cursor(cursor, version)?; Ok(BodyResEvent { event }) } } #[cfg(test)] mod tests { use super::*; use crate::frame::events::*; use crate::frame::traits::FromCursor; use std::io::Cursor; #[test] fn body_res_event() { let bytes = &[ // TOPOLOGY_CHANGE 0, 15, 84, 79, 80, 79, 76, 79, 71, 89, 95, 67, 72, 65, 78, 71, 69, // NEW_NODE 0, 8, 78, 69, 87, 95, 78, 79, 68, 69, // 4, 127, 0, 0, 1, 0, 0, 0, 1, // inet - 127.0.0.1:1 ]; let expected = ServerEvent::TopologyChange(TopologyChange { change_type: TopologyChangeType::NewNode, addr: "127.0.0.1:1".parse().unwrap(), }); { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let event = BodyResEvent::from_cursor(&mut cursor, Version::V4) .unwrap() .event; assert_eq!(event, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } ================================================ FILE: cassandra-protocol/src/frame/message_execute.rs ================================================ use crate::error; use crate::frame::{Direction, Envelope, Flags, FromCursor, Opcode, Serialize, Version}; use crate::query::QueryParams; use crate::types::CBytesShort; use derive_more::Constructor; use std::io::Cursor; /// The structure that represents a body of a envelope of type `execute`. #[derive(Debug, Constructor, Eq, PartialEq, Clone)] pub struct BodyReqExecute<'a> { pub id: &'a CBytesShort, pub result_metadata_id: Option<&'a CBytesShort>, pub query_parameters: &'a QueryParams, } impl<'a> Serialize for BodyReqExecute<'a> { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.id.serialize(cursor, version); if let Some(result_metadata_id) = self.result_metadata_id { result_metadata_id.serialize(cursor, version); } self.query_parameters.serialize(cursor, version); } #[inline] fn serialize_to_vec(&self, version: Version) -> Vec { let mut buf = Vec::with_capacity( self.id.serialized_len() + self .result_metadata_id .map(|id| id.serialized_len()) .unwrap_or(0), ); self.serialize(&mut Cursor::new(&mut buf), version); buf } } /// The structure that represents an owned body of a envelope of type `execute`. #[derive(Debug, Constructor, Clone, Eq, PartialEq, Default)] pub struct BodyReqExecuteOwned { pub id: CBytesShort, pub result_metadata_id: Option, pub query_parameters: QueryParams, } impl FromCursor for BodyReqExecuteOwned { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let id = CBytesShort::from_cursor(cursor, version)?; let result_metadata_id = if version >= Version::V5 { Some(CBytesShort::from_cursor(cursor, version)?) } else { None }; let query_parameters = QueryParams::from_cursor(cursor, version)?; Ok(BodyReqExecuteOwned::new( id, result_metadata_id, query_parameters, )) } } impl Serialize for BodyReqExecuteOwned { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { BodyReqExecute::new( &self.id, self.result_metadata_id.as_ref(), &self.query_parameters, ) .serialize(cursor, version); } } impl Envelope { pub fn new_req_execute( id: &CBytesShort, result_metadata_id: Option<&CBytesShort>, // only required for protocol >= V5 query_parameters: &QueryParams, flags: Flags, version: Version, ) -> Envelope { let direction = Direction::Request; let opcode = Opcode::Execute; let body = BodyReqExecute::new(id, result_metadata_id, query_parameters); Envelope::new( version, direction, flags, opcode, 0, body.serialize_to_vec(version), None, vec![], ) } } #[cfg(test)] mod tests { use crate::consistency::Consistency; use crate::frame::message_execute::BodyReqExecuteOwned; use crate::frame::traits::Serialize; use crate::frame::{FromCursor, Version}; use crate::query::QueryParams; use crate::types::CBytesShort; use std::io::Cursor; #[test] fn should_deserialize_body() { let data = [0, 1, 2, 0, 0, 0]; let mut cursor = Cursor::new(data.as_slice()); let body = BodyReqExecuteOwned::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(body.id, CBytesShort::new(vec![2])); assert_eq!(body.query_parameters.consistency, Consistency::Any); } #[test] fn should_support_result_metadata_id() { let body = BodyReqExecuteOwned::new( CBytesShort::new(vec![1]), Some(CBytesShort::new(vec![2])), QueryParams::default(), ); let data = body.serialize_to_vec(Version::V5); assert_eq!( BodyReqExecuteOwned::from_cursor(&mut Cursor::new(&data), Version::V5).unwrap(), body ); } } ================================================ FILE: cassandra-protocol/src/frame/message_options.rs ================================================ use crate::error; use crate::frame::{Direction, Envelope, Flags, FromCursor, Opcode, Serialize, Version}; use std::io::Cursor; /// The structure which represents a body of a envelope of type `options`. #[derive(Debug, Default, Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] pub struct BodyReqOptions; impl Serialize for BodyReqOptions { #[inline(always)] fn serialize(&self, _cursor: &mut Cursor<&mut Vec>, _version: Version) {} } impl FromCursor for BodyReqOptions { #[inline(always)] fn from_cursor(_cursor: &mut Cursor<&[u8]>, _version: Version) -> error::Result { Ok(BodyReqOptions) } } impl Envelope { /// Creates new envelope of type `options`. pub fn new_req_options(version: Version) -> Envelope { let direction = Direction::Request; let opcode = Opcode::Options; let body: BodyReqOptions = Default::default(); Envelope::new( version, direction, Flags::empty(), opcode, 0, body.serialize_to_vec(version), None, vec![], ) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_frame_options() { let frame = Envelope::new_req_options(Version::V4); assert_eq!(frame.version, Version::V4); assert_eq!(frame.opcode, Opcode::Options); assert!(frame.body.is_empty()); } } ================================================ FILE: cassandra-protocol/src/frame/message_prepare.rs ================================================ use crate::error; use crate::frame::{Direction, Envelope, Flags, FromCursor, Opcode, Serialize, Version}; use crate::query::PrepareFlags; use crate::types::{ from_cursor_str, from_cursor_str_long, serialize_str, serialize_str_long, INT_LEN, SHORT_LEN, }; use std::io::Cursor; /// Struct that represents a body of a envelope of type `prepare` #[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Hash, Clone, Default)] pub struct BodyReqPrepare { pub query: String, pub keyspace: Option, } impl BodyReqPrepare { /// Creates new body of a envelope of type `prepare` that prepares query `query`. #[inline] pub fn new(query: String, keyspace: Option) -> BodyReqPrepare { BodyReqPrepare { query, keyspace } } } impl Serialize for BodyReqPrepare { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { serialize_str_long(cursor, &self.query, version); if version >= Version::V5 { if let Some(keyspace) = &self.keyspace { PrepareFlags::WITH_KEYSPACE.serialize(cursor, version); serialize_str(cursor, keyspace.as_str(), version); } else { PrepareFlags::empty().serialize(cursor, version); } } } #[inline] fn serialize_to_vec(&self, version: Version) -> Vec { let mut buf = if version >= Version::V5 { Vec::with_capacity( INT_LEN * 2 + self.query.len() + self .keyspace .as_ref() .map(|keyspace| SHORT_LEN + keyspace.len()) .unwrap_or(0), ) } else { Vec::with_capacity(INT_LEN + self.query.len()) }; self.serialize(&mut Cursor::new(&mut buf), version); buf } } impl FromCursor for BodyReqPrepare { #[inline] fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { if version >= Version::V5 { from_cursor_str_long(cursor) .and_then(|query| { PrepareFlags::from_cursor(cursor, version).map(|flags| (query, flags)) }) .and_then(|(query, flags)| { if flags.contains(PrepareFlags::WITH_KEYSPACE) { from_cursor_str(cursor).map(|keyspace| { BodyReqPrepare::new(query.into(), Some(keyspace.into())) }) } else { Ok(BodyReqPrepare::new(query.into(), None)) } }) } else { from_cursor_str_long(cursor).map(|query| BodyReqPrepare::new(query.into(), None)) } } } impl Envelope { pub fn new_req_prepare( query: String, keyspace: Option, flags: Flags, version: Version, ) -> Envelope { let direction = Direction::Request; let opcode = Opcode::Prepare; let body = BodyReqPrepare::new(query, keyspace); Envelope::new( version, direction, flags, opcode, 0, body.serialize_to_vec(version), None, vec![], ) } } #[cfg(test)] mod tests { use crate::frame::message_prepare::BodyReqPrepare; use crate::frame::{FromCursor, Serialize, Version}; use std::io::Cursor; #[test] fn should_deserialize_body() { let data = [0, 0, 0, 3, 102, 111, 111, 0]; let mut cursor = Cursor::new(data.as_slice()); let body = BodyReqPrepare::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(body.query, "foo"); } #[test] fn should_support_keyspace() { let keyspace = "abc"; let query = "test"; let body = BodyReqPrepare::new(query.into(), Some(keyspace.into())); let data_v4 = body.serialize_to_vec(Version::V4); let body_v4 = BodyReqPrepare::from_cursor(&mut Cursor::new(data_v4.as_slice()), Version::V4).unwrap(); assert_eq!(body_v4.query, query); assert_eq!(body_v4.keyspace, None); let data_v5 = body.serialize_to_vec(Version::V5); let body_v5 = BodyReqPrepare::from_cursor(&mut Cursor::new(data_v5.as_slice()), Version::V5).unwrap(); assert_eq!(body_v5.query, query); assert_eq!(body_v5.keyspace, Some(keyspace.to_string())); } } ================================================ FILE: cassandra-protocol/src/frame/message_query.rs ================================================ use crate::consistency::Consistency; use crate::error; use crate::frame::traits::FromCursor; use crate::frame::{Direction, Envelope, Flags, Opcode, Serialize, Version}; use crate::query::{QueryParams, QueryValues}; use crate::types::{from_cursor_str_long, serialize_str_long, CBytes, CInt, CLong, INT_LEN}; use std::io::Cursor; /// Structure which represents body of Query request #[derive(Debug, PartialEq, Eq, Clone, Default)] pub struct BodyReqQuery { /// Query string. pub query: String, /// Query parameters. pub query_params: QueryParams, } impl BodyReqQuery { #[allow(clippy::too_many_arguments)] fn new( query: String, consistency: Consistency, values: Option, with_names: bool, page_size: Option, paging_state: Option, serial_consistency: Option, timestamp: Option, keyspace: Option, now_in_seconds: Option, ) -> BodyReqQuery { BodyReqQuery { query, query_params: QueryParams { consistency, with_names, values, page_size, paging_state, serial_consistency, timestamp, keyspace, now_in_seconds, }, } } } impl FromCursor for BodyReqQuery { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let query = from_cursor_str_long(cursor)?.to_string(); let query_params = QueryParams::from_cursor(cursor, version)?; Ok(BodyReqQuery { query, query_params, }) } } impl Serialize for BodyReqQuery { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { serialize_str_long(cursor, &self.query, version); self.query_params.serialize(cursor, version); } #[inline] fn serialize_to_vec(&self, version: Version) -> Vec { let mut buf = Vec::with_capacity(INT_LEN + self.query.len()); self.serialize(&mut Cursor::new(&mut buf), version); buf } } impl Envelope { #[allow(clippy::too_many_arguments)] pub fn new_req_query( query: String, consistency: Consistency, values: Option, with_names: bool, page_size: Option, paging_state: Option, serial_consistency: Option, timestamp: Option, keyspace: Option, now_in_seconds: Option, flags: Flags, version: Version, ) -> Envelope { let direction = Direction::Request; let opcode = Opcode::Query; let body = BodyReqQuery::new( query, consistency, values, with_names, page_size, paging_state, serial_consistency, timestamp, keyspace, now_in_seconds, ); Envelope::new( version, direction, flags, opcode, 0, body.serialize_to_vec(version), None, vec![], ) } #[inline] pub fn new_query(query: BodyReqQuery, flags: Flags, version: Version) -> Envelope { Envelope::new_req_query( query.query, query.query_params.consistency, query.query_params.values, query.query_params.with_names, query.query_params.page_size, query.query_params.paging_state, query.query_params.serial_consistency, query.query_params.timestamp, query.query_params.keyspace, query.query_params.now_in_seconds, flags, version, ) } } ================================================ FILE: cassandra-protocol/src/frame/message_ready.rs ================================================ use crate::error; use crate::frame::{FromCursor, Serialize, Version}; use std::io::Cursor; #[derive(Clone, Debug, PartialEq, Default, Ord, PartialOrd, Eq, Hash, Copy)] pub struct BodyResReady; impl Serialize for BodyResReady { #[inline(always)] fn serialize(&self, _cursor: &mut Cursor<&mut Vec>, _version: Version) {} } impl FromCursor for BodyResReady { #[inline(always)] fn from_cursor(_cursor: &mut Cursor<&[u8]>, _version: Version) -> error::Result { Ok(BodyResReady) } } #[cfg(test)] mod tests { use super::*; #[test] fn body_res_ready_new() { let body: BodyResReady = Default::default(); assert_eq!(body, BodyResReady); } #[test] fn body_res_ready_serialize() { let body = BodyResReady; assert!(body.serialize_to_vec(Version::V4).is_empty()); } } ================================================ FILE: cassandra-protocol/src/frame/message_register.rs ================================================ use crate::error; use crate::frame::events::SimpleServerEvent; use crate::frame::{Direction, Envelope, Flags, FromCursor, Opcode, Serialize, Version}; use crate::types::{from_cursor_string_list, serialize_str_list}; use derive_more::Constructor; use itertools::Itertools; use std::convert::TryFrom; use std::io::Cursor; /// The structure which represents a body of a envelope of type `register`. #[derive(Debug, Constructor, Default, Ord, PartialOrd, Eq, PartialEq, Hash, Clone)] pub struct BodyReqRegister { pub events: Vec, } impl Serialize for BodyReqRegister { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { let events = self.events.iter().map(|event| event.as_str()); serialize_str_list(cursor, events, version); } } impl FromCursor for BodyReqRegister { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> error::Result { let events = from_cursor_string_list(cursor)?; events .iter() .map(|event| SimpleServerEvent::try_from(event.as_str())) .try_collect() .map(BodyReqRegister::new) } } impl Envelope { /// Creates new envelope of type `REGISTER`. pub fn new_req_register(events: Vec, version: Version) -> Envelope { let direction = Direction::Request; let opcode = Opcode::Register; let register_body = BodyReqRegister::new(events); Envelope::new( version, direction, Flags::empty(), opcode, 0, register_body.serialize_to_vec(version), None, vec![], ) } } #[cfg(test)] mod tests { use crate::events::SimpleServerEvent; use crate::frame::message_register::BodyReqRegister; use crate::frame::{FromCursor, Version}; use std::io::Cursor; #[test] fn should_deserialize_body() { let data = [ 0, 1, 0, 15, 0x54, 0x4f, 0x50, 0x4f, 0x4c, 0x4f, 0x47, 0x59, 0x5f, 0x43, 0x48, 0x41, 0x4e, 0x47, 0x45, ]; let mut cursor = Cursor::new(data.as_slice()); let body = BodyReqRegister::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(body.events, vec![SimpleServerEvent::TopologyChange]); } } ================================================ FILE: cassandra-protocol/src/frame/message_request.rs ================================================ use std::io::Cursor; use crate::frame::message_auth_response::BodyReqAuthResponse; use crate::frame::message_batch::BodyReqBatch; use crate::frame::message_execute::BodyReqExecuteOwned; use crate::frame::message_options::BodyReqOptions; use crate::frame::message_prepare::BodyReqPrepare; use crate::frame::message_query::BodyReqQuery; use crate::frame::message_register::BodyReqRegister; use crate::frame::message_startup::BodyReqStartup; use crate::frame::{FromCursor, Opcode, Serialize, Version}; use crate::{error, Error}; #[derive(Debug, PartialEq, Eq, Clone)] #[allow(clippy::large_enum_variant)] #[non_exhaustive] pub enum RequestBody { Startup(BodyReqStartup), Options(BodyReqOptions), Query(BodyReqQuery), Prepare(BodyReqPrepare), Execute(BodyReqExecuteOwned), Register(BodyReqRegister), Batch(BodyReqBatch), AuthResponse(BodyReqAuthResponse), } impl Serialize for RequestBody { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { RequestBody::Query(body) => body.serialize(cursor, version), RequestBody::Startup(body) => body.serialize(cursor, version), RequestBody::Options(body) => body.serialize(cursor, version), RequestBody::Prepare(body) => body.serialize(cursor, version), RequestBody::Execute(body) => body.serialize(cursor, version), RequestBody::Register(body) => body.serialize(cursor, version), RequestBody::Batch(body) => body.serialize(cursor, version), RequestBody::AuthResponse(body) => body.serialize(cursor, version), } } } impl RequestBody { pub fn try_from( bytes: &[u8], response_type: Opcode, version: Version, ) -> error::Result { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); match response_type { Opcode::Startup => { BodyReqStartup::from_cursor(&mut cursor, version).map(RequestBody::Startup) } Opcode::Options => { BodyReqOptions::from_cursor(&mut cursor, version).map(RequestBody::Options) } Opcode::Query => { BodyReqQuery::from_cursor(&mut cursor, version).map(RequestBody::Query) } Opcode::Prepare => { BodyReqPrepare::from_cursor(&mut cursor, version).map(RequestBody::Prepare) } Opcode::Execute => { BodyReqExecuteOwned::from_cursor(&mut cursor, version).map(RequestBody::Execute) } Opcode::Register => { BodyReqRegister::from_cursor(&mut cursor, version).map(RequestBody::Register) } Opcode::Batch => { BodyReqBatch::from_cursor(&mut cursor, version).map(RequestBody::Batch) } Opcode::AuthResponse => BodyReqAuthResponse::from_cursor(&mut cursor, version) .map(RequestBody::AuthResponse), _ => Err(Error::NonRequestOpcode(response_type)), } } } ================================================ FILE: cassandra-protocol/src/frame/message_response.rs ================================================ use std::io::Cursor; use crate::frame::message_auth_challenge::BodyResAuthChallenge; use crate::frame::message_auth_success::BodyReqAuthSuccess; use crate::frame::message_authenticate::BodyResAuthenticate; use crate::frame::message_error::ErrorBody; use crate::frame::message_event::BodyResEvent; use crate::frame::message_result::{ BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ResResultBody, RowsMetadata, }; use crate::frame::message_supported::BodyResSupported; use crate::frame::{FromCursor, Opcode, Version}; use crate::types::rows::Row; use crate::{error, Error}; #[derive(Debug, PartialEq, Eq, Clone)] #[non_exhaustive] pub enum ResponseBody { Error(ErrorBody), Ready, Authenticate(BodyResAuthenticate), Supported(BodyResSupported), Result(ResResultBody), Event(BodyResEvent), AuthChallenge(BodyResAuthChallenge), AuthSuccess(BodyReqAuthSuccess), } // This implementation is incomplete so only enable in tests #[cfg(test)] use crate::frame::Serialize; #[cfg(test)] impl Serialize for ResponseBody { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { ResponseBody::Error(error_body) => { error_body.serialize(cursor, version); } ResponseBody::Ready => {} ResponseBody::Authenticate(auth) => { auth.serialize(cursor, version); } ResponseBody::Supported(supported) => { supported.serialize(cursor, version); } ResponseBody::Result(result) => { result.serialize(cursor, version); } ResponseBody::Event(event) => { event.serialize(cursor, version); } ResponseBody::AuthChallenge(auth_challenge) => { auth_challenge.serialize(cursor, version); } ResponseBody::AuthSuccess(auth_success) => { auth_success.serialize(cursor, version); } } } } impl ResponseBody { pub fn try_from( bytes: &[u8], response_type: Opcode, version: Version, ) -> error::Result { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); match response_type { Opcode::Error => ErrorBody::from_cursor(&mut cursor, version).map(ResponseBody::Error), Opcode::Ready => Ok(ResponseBody::Ready), Opcode::Authenticate => BodyResAuthenticate::from_cursor(&mut cursor, version) .map(ResponseBody::Authenticate), Opcode::Supported => { BodyResSupported::from_cursor(&mut cursor, version).map(ResponseBody::Supported) } Opcode::Result => { ResResultBody::from_cursor(&mut cursor, version).map(ResponseBody::Result) } Opcode::Event => { BodyResEvent::from_cursor(&mut cursor, version).map(ResponseBody::Event) } Opcode::AuthChallenge => BodyResAuthChallenge::from_cursor(&mut cursor, version) .map(ResponseBody::AuthChallenge), Opcode::AuthSuccess => { BodyReqAuthSuccess::from_cursor(&mut cursor, version).map(ResponseBody::AuthSuccess) } _ => Err(Error::NonResponseOpcode(response_type)), } } pub fn into_rows(self) -> Option> { match self { ResponseBody::Result(res) => res.into_rows(), _ => None, } } pub fn as_rows_metadata(&self) -> Option<&RowsMetadata> { match self { ResponseBody::Result(res) => res.as_rows_metadata(), _ => None, } } pub fn as_cols(&self) -> Option<&BodyResResultRows> { match *self { ResponseBody::Result(ResResultBody::Rows(ref rows)) => Some(rows), _ => None, } } /// Unwraps body and returns BodyResResultPrepared which contains an exact result of /// PREPARE query. pub fn into_prepared(self) -> Option { match self { ResponseBody::Result(res) => res.into_prepared(), _ => None, } } /// Unwraps body and returns BodyResResultPrepared which contains an exact result of /// use keyspace query. pub fn into_set_keyspace(self) -> Option { match self { ResponseBody::Result(res) => res.into_set_keyspace(), _ => None, } } /// Unwraps body and returns BodyResEvent. pub fn into_server_event(self) -> Option { match self { ResponseBody::Event(event) => Some(event), _ => None, } } pub fn authenticator(&self) -> Option<&str> { match *self { ResponseBody::Authenticate(ref auth) => Some(auth.data.as_str()), _ => None, } } pub fn into_error(self) -> Option { match self { ResponseBody::Error(err) => Some(err), _ => None, } } } ================================================ FILE: cassandra-protocol/src/frame/message_result.rs ================================================ use crate::error; use crate::error::Error; use crate::frame::events::SchemaChange; use crate::frame::{FromBytes, FromCursor, Serialize, Version}; use crate::types::rows::Row; use crate::types::{ from_cursor_str, serialize_str, try_i16_from_bytes, try_i32_from_bytes, try_u64_from_bytes, CBytes, CBytesShort, CInt, CIntShort, INT_LEN, SHORT_LEN, }; use bitflags::bitflags; use derive_more::{Constructor, Display}; use std::convert::{TryFrom, TryInto}; use std::io::{Cursor, Error as IoError, Read}; /// `ResultKind` is enum which represents types of result. #[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone, Hash, Display)] #[non_exhaustive] pub enum ResultKind { /// Void result. Void, /// Rows result. Rows, /// Set keyspace result. SetKeyspace, /// Prepared result. Prepared, /// Schema change result. SchemaChange, } impl Serialize for ResultKind { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { CInt::from(*self).serialize(cursor, version); } } impl FromBytes for ResultKind { fn from_bytes(bytes: &[u8]) -> error::Result { try_i32_from_bytes(bytes) .map_err(Into::into) .and_then(ResultKind::try_from) } } impl From for CInt { fn from(value: ResultKind) -> Self { match value { ResultKind::Void => 0x0001, ResultKind::Rows => 0x0002, ResultKind::SetKeyspace => 0x0003, ResultKind::Prepared => 0x0004, ResultKind::SchemaChange => 0x0005, } } } impl TryFrom for ResultKind { type Error = Error; fn try_from(value: CInt) -> Result { match value { 0x0001 => Ok(ResultKind::Void), 0x0002 => Ok(ResultKind::Rows), 0x0003 => Ok(ResultKind::SetKeyspace), 0x0004 => Ok(ResultKind::Prepared), 0x0005 => Ok(ResultKind::SchemaChange), _ => Err(Error::UnexpectedResultKind(value)), } } } impl FromCursor for ResultKind { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> error::Result { let mut buff = [0; INT_LEN]; cursor.read_exact(&mut buff)?; let rk = CInt::from_be_bytes(buff); rk.try_into() } } /// `ResponseBody` is a generalized enum that represents all types of responses. Each of enum /// option wraps related body type. #[derive(Debug, PartialEq, Eq, Clone, Hash)] #[non_exhaustive] pub enum ResResultBody { /// Void response body. It's an empty struct. Void, /// Rows response body. It represents a body of response which contains rows. Rows(BodyResResultRows), /// Set keyspace body. It represents a body of set_keyspace query and usually contains /// a name of just set namespace. SetKeyspace(BodyResResultSetKeyspace), /// Prepared response body. Prepared(BodyResResultPrepared), /// Schema change body SchemaChange(SchemaChange), } impl Serialize for ResResultBody { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match &self { ResResultBody::Void => { ResultKind::Void.serialize(cursor, version); } ResResultBody::Rows(rows) => { ResultKind::Rows.serialize(cursor, version); rows.serialize(cursor, version); } ResResultBody::SetKeyspace(set_keyspace) => { ResultKind::SetKeyspace.serialize(cursor, version); set_keyspace.serialize(cursor, version); } ResResultBody::Prepared(prepared) => { ResultKind::Prepared.serialize(cursor, version); prepared.serialize(cursor, version); } ResResultBody::SchemaChange(schema_change) => { ResultKind::SchemaChange.serialize(cursor, version); schema_change.serialize(cursor, version); } } } } impl ResResultBody { fn parse_body_from_cursor( cursor: &mut Cursor<&[u8]>, result_kind: ResultKind, version: Version, ) -> error::Result { Ok(match result_kind { ResultKind::Void => ResResultBody::Void, ResultKind::Rows => { ResResultBody::Rows(BodyResResultRows::from_cursor(cursor, version)?) } ResultKind::SetKeyspace => { ResResultBody::SetKeyspace(BodyResResultSetKeyspace::from_cursor(cursor, version)?) } ResultKind::Prepared => { ResResultBody::Prepared(BodyResResultPrepared::from_cursor(cursor, version)?) } ResultKind::SchemaChange => { ResResultBody::SchemaChange(SchemaChange::from_cursor(cursor, version)?) } }) } /// Converts body into `Vec` if body's type is `Row` and returns `None` otherwise. pub fn into_rows(self) -> Option> { match self { ResResultBody::Rows(rows_body) => Some(Row::from_body(rows_body)), _ => None, } } /// Returns `Some` rows metadata if envelope result is of type rows and `None` otherwise pub fn as_rows_metadata(&self) -> Option<&RowsMetadata> { match self { ResResultBody::Rows(rows_body) => Some(&rows_body.metadata), _ => None, } } /// Unwraps body and returns BodyResResultPrepared which contains an exact result of /// PREPARE query. pub fn into_prepared(self) -> Option { match self { ResResultBody::Prepared(p) => Some(p), _ => None, } } /// Unwraps body and returns BodyResResultSetKeyspace which contains an exact result of /// use keyspace query. pub fn into_set_keyspace(self) -> Option { match self { ResResultBody::SetKeyspace(p) => Some(p), _ => None, } } } impl ResResultBody { pub fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let result_kind = ResultKind::from_cursor(cursor, version)?; ResResultBody::parse_body_from_cursor(cursor, result_kind, version) } } /// It represents set keyspace result body. Body contains keyspace name. #[derive(Debug, Constructor, PartialEq, Ord, PartialOrd, Eq, Clone, Hash)] pub struct BodyResResultSetKeyspace { /// It contains name of keyspace that was set. pub body: String, } impl Serialize for BodyResResultSetKeyspace { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { serialize_str(cursor, &self.body, version); } } impl FromCursor for BodyResResultSetKeyspace { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { from_cursor_str(cursor).map(|x| BodyResResultSetKeyspace::new(x.to_string())) } } /// Structure that represents result of type /// [rows](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L533). #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct BodyResResultRows { /// Rows metadata pub metadata: RowsMetadata, /// Number of rows. pub rows_count: CInt, /// From spec: it is composed of `rows_count` of rows. pub rows_content: Vec>, /// Protocol version. pub protocol_version: Version, } impl Serialize for BodyResResultRows { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.metadata.serialize(cursor, version); self.rows_count.serialize(cursor, version); self.rows_content .iter() .flatten() .for_each(|x| x.serialize(cursor, version)); } } impl BodyResResultRows { fn rows_content( cursor: &mut Cursor<&[u8]>, rows_count: i32, columns_count: i32, version: Version, ) -> error::Result>> { (0..rows_count) .map(|_| { (0..columns_count) .map(|_| CBytes::from_cursor(cursor, version)) .collect::>() }) .collect::>() } } impl FromCursor for BodyResResultRows { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let metadata = RowsMetadata::from_cursor(cursor, version)?; let rows_count = CInt::from_cursor(cursor, version)?; let rows_content = BodyResResultRows::rows_content(cursor, rows_count, metadata.columns_count, version)?; Ok(BodyResResultRows { metadata, rows_count, rows_content, protocol_version: version, }) } } /// Rows metadata. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct RowsMetadata { /// Flags. pub flags: RowsMetadataFlags, /// Number of columns. pub columns_count: i32, /// Paging state. pub paging_state: Option, /// New, changed result set metadata. The new metadata ID must also be used in subsequent /// executions of the corresponding prepared statement, if any. pub new_metadata_id: Option, // In fact by specification Vec should have only two elements representing the // (unique) keyspace name and table name the columns belong to /// `Option` that may contain global table space. pub global_table_spec: Option, /// List of column specifications. pub col_specs: Vec, } impl Serialize for RowsMetadata { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { // First we need assert that the flags match up with the data we were provided. // If they dont match up then it is impossible to encode. assert_eq!( self.flags.contains(RowsMetadataFlags::HAS_MORE_PAGES), self.paging_state.is_some() ); match ( self.flags.contains(RowsMetadataFlags::NO_METADATA), self.flags.contains(RowsMetadataFlags::GLOBAL_TABLE_SPACE), ) { (false, false) => { assert!(self.global_table_spec.is_none()); assert!(!self.col_specs.is_empty()); } (false, true) => { assert!(!self.col_specs.is_empty()); } (true, _) => { assert!(self.global_table_spec.is_none()); assert!(self.col_specs.is_empty()); } } self.flags.serialize(cursor, version); self.columns_count.serialize(cursor, version); if let Some(paging_state) = &self.paging_state { paging_state.serialize(cursor, version); } if let Some(new_metadata_id) = &self.new_metadata_id { new_metadata_id.serialize(cursor, version); } if let Some(global_table_spec) = &self.global_table_spec { global_table_spec.serialize(cursor, version); } self.col_specs .iter() .for_each(|x| x.serialize(cursor, version)); } } impl FromCursor for RowsMetadata { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let flags = RowsMetadataFlags::from_bits_truncate(CInt::from_cursor(cursor, version)?); let columns_count = CInt::from_cursor(cursor, version)?; let paging_state = if flags.contains(RowsMetadataFlags::HAS_MORE_PAGES) { Some(CBytes::from_cursor(cursor, version)?) } else { None }; if flags.contains(RowsMetadataFlags::NO_METADATA) { return Ok(RowsMetadata { flags, columns_count, paging_state, new_metadata_id: None, global_table_spec: None, col_specs: vec![], }); } let new_metadata_id = if flags.contains(RowsMetadataFlags::METADATA_CHANGED) { Some(CBytesShort::from_cursor(cursor, version)?) } else { None }; let has_global_table_space = flags.contains(RowsMetadataFlags::GLOBAL_TABLE_SPACE); let global_table_spec = extract_global_table_space(cursor, has_global_table_space, version)?; let col_specs = ColSpec::parse_colspecs(cursor, columns_count, has_global_table_space, version)?; Ok(RowsMetadata { flags, columns_count, paging_state, new_metadata_id, global_table_spec, col_specs, }) } } bitflags! { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct RowsMetadataFlags: i32 { const GLOBAL_TABLE_SPACE = 0x0001; const HAS_MORE_PAGES = 0x0002; const NO_METADATA = 0x0004; const METADATA_CHANGED = 0x0008; } } impl Serialize for RowsMetadataFlags { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.bits().serialize(cursor, version) } } impl From for i32 { fn from(value: RowsMetadataFlags) -> Self { value.bits() } } impl FromBytes for RowsMetadataFlags { fn from_bytes(bytes: &[u8]) -> error::Result { try_u64_from_bytes(bytes).map_err(Into::into).and_then(|f| { RowsMetadataFlags::from_bits(f as i32) .ok_or_else(|| "Unexpected rows metadata flag".into()) }) } } /// Table specification. #[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub struct TableSpec { pub ks_name: String, pub table_name: String, } impl Serialize for TableSpec { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { serialize_str(cursor, &self.ks_name, version); serialize_str(cursor, &self.table_name, version); } } impl FromCursor for TableSpec { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> error::Result { let ks_name = from_cursor_str(cursor)?.to_string(); let table_name = from_cursor_str(cursor)?.to_string(); Ok(TableSpec { ks_name, table_name, }) } } /// Single column specification. #[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub struct ColSpec { /// The initial and are strings and only present /// if the Global_tables_spec flag is NOT set pub table_spec: Option, /// Column name pub name: String, /// Column type defined in spec in 4.2.5.2 pub col_type: ColTypeOption, } impl Serialize for ColSpec { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { if let Some(table_spec) = &self.table_spec { table_spec.serialize(cursor, version); } serialize_str(cursor, &self.name, version); self.col_type.serialize(cursor, version); } } impl ColSpec { pub fn parse_colspecs( cursor: &mut Cursor<&[u8]>, column_count: i32, has_global_table_space: bool, version: Version, ) -> error::Result> { (0..column_count) .map(|_| { let table_spec = if !has_global_table_space { Some(TableSpec::from_cursor(cursor, version)?) } else { None }; let name = from_cursor_str(cursor)?.to_string(); let col_type = ColTypeOption::from_cursor(cursor, version)?; Ok(ColSpec { table_spec, name, col_type, }) }) .collect::>() } } /// Cassandra data types which could be returned by a server. #[derive(Debug, Clone, Display, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)] #[non_exhaustive] pub enum ColType { Custom, Ascii, Bigint, Blob, Boolean, Counter, Decimal, Double, Float, Int, Timestamp, Uuid, Varchar, Varint, Timeuuid, Inet, Date, Time, Smallint, Tinyint, Duration, List, Map, Set, Udt, Tuple, } impl TryFrom for ColType { type Error = Error; fn try_from(value: CIntShort) -> Result { match value { 0x0000 => Ok(ColType::Custom), 0x0001 => Ok(ColType::Ascii), 0x0002 => Ok(ColType::Bigint), 0x0003 => Ok(ColType::Blob), 0x0004 => Ok(ColType::Boolean), 0x0005 => Ok(ColType::Counter), 0x0006 => Ok(ColType::Decimal), 0x0007 => Ok(ColType::Double), 0x0008 => Ok(ColType::Float), 0x0009 => Ok(ColType::Int), 0x000B => Ok(ColType::Timestamp), 0x000C => Ok(ColType::Uuid), 0x000D => Ok(ColType::Varchar), 0x000E => Ok(ColType::Varint), 0x000F => Ok(ColType::Timeuuid), 0x0010 => Ok(ColType::Inet), 0x0011 => Ok(ColType::Date), 0x0012 => Ok(ColType::Time), 0x0013 => Ok(ColType::Smallint), 0x0014 => Ok(ColType::Tinyint), 0x0015 => Ok(ColType::Duration), 0x0020 => Ok(ColType::List), 0x0021 => Ok(ColType::Map), 0x0022 => Ok(ColType::Set), 0x0030 => Ok(ColType::Udt), 0x0031 => Ok(ColType::Tuple), 0x0080 => Ok(ColType::Varchar), _ => Err(Error::UnexpectedColumnType(value)), } } } impl FromBytes for ColType { fn from_bytes(bytes: &[u8]) -> error::Result { try_i16_from_bytes(bytes) .map_err(Into::into) .and_then(ColType::try_from) } } impl Serialize for ColType { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { (match self { ColType::Custom => 0x0000, ColType::Ascii => 0x0001, ColType::Bigint => 0x0002, ColType::Blob => 0x0003, ColType::Boolean => 0x0004, ColType::Counter => 0x0005, ColType::Decimal => 0x0006, ColType::Double => 0x0007, ColType::Float => 0x0008, ColType::Int => 0x0009, ColType::Timestamp => 0x000B, ColType::Uuid => 0x000C, ColType::Varchar => 0x000D, ColType::Varint => 0x000E, ColType::Timeuuid => 0x000F, ColType::Inet => 0x0010, ColType::Date => 0x0011, ColType::Time => 0x0012, ColType::Smallint => 0x0013, ColType::Tinyint => 0x0014, ColType::Duration => 0x0015, ColType::List => 0x0020, ColType::Map => 0x0021, ColType::Set => 0x0022, ColType::Udt => 0x0030, ColType::Tuple => 0x0031, } as CIntShort) .serialize(cursor, version); } } impl FromCursor for ColType { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> error::Result { let mut buff = [0; SHORT_LEN]; cursor.read_exact(&mut buff)?; let t = CIntShort::from_be_bytes(buff); t.try_into() } } /// Cassandra option that represent column type. #[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub struct ColTypeOption { /// Id refers to `ColType`. pub id: ColType, /// Values depending on column type. pub value: Option, } impl Serialize for ColTypeOption { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.id.serialize(cursor, version); if let Some(value) = &self.value { value.serialize(cursor, version); } } } impl FromCursor for ColTypeOption { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let id = ColType::from_cursor(cursor, version)?; let value = match id { ColType::Custom => Some(ColTypeOptionValue::CString( from_cursor_str(cursor)?.to_string(), )), ColType::Set => { let col_type = ColTypeOption::from_cursor(cursor, version)?; Some(ColTypeOptionValue::CSet(Box::new(col_type))) } ColType::List => { let col_type = ColTypeOption::from_cursor(cursor, version)?; Some(ColTypeOptionValue::CList(Box::new(col_type))) } ColType::Udt => Some(ColTypeOptionValue::UdtType(CUdt::from_cursor( cursor, version, )?)), ColType::Tuple => Some(ColTypeOptionValue::TupleType(CTuple::from_cursor( cursor, version, )?)), ColType::Map => { let name_type = ColTypeOption::from_cursor(cursor, version)?; let value_type = ColTypeOption::from_cursor(cursor, version)?; Some(ColTypeOptionValue::CMap( Box::new(name_type), Box::new(value_type), )) } _ => None, }; Ok(ColTypeOption { id, value }) } } /// Enum that represents all possible types of `value` of `ColTypeOption`. #[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] #[non_exhaustive] pub enum ColTypeOptionValue { CString(String), ColType(ColType), CSet(Box), CList(Box), UdtType(CUdt), TupleType(CTuple), CMap(Box, Box), } impl Serialize for ColTypeOptionValue { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { Self::CString(c) => serialize_str(cursor, c, version), Self::ColType(c) => c.serialize(cursor, version), Self::CSet(c) => c.serialize(cursor, version), Self::CList(c) => c.serialize(cursor, version), Self::UdtType(c) => c.serialize(cursor, version), Self::TupleType(c) => c.serialize(cursor, version), Self::CMap(v1, v2) => { v1.serialize(cursor, version); v2.serialize(cursor, version); } } } } /// User defined type. #[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub struct CUdt { /// Keyspace name. pub ks: String, /// Udt name pub udt_name: String, /// List of pairs `(name, type)` where name is field name and type is type of field. pub descriptions: Vec<(String, ColTypeOption)>, } impl Serialize for CUdt { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { serialize_str(cursor, &self.ks, version); serialize_str(cursor, &self.udt_name, version); (self.descriptions.len() as i16).serialize(cursor, version); self.descriptions.iter().for_each(|(name, col_type)| { serialize_str(cursor, name, version); col_type.serialize(cursor, version); }); } } impl FromCursor for CUdt { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let ks = from_cursor_str(cursor)?.to_string(); let udt_name = from_cursor_str(cursor)?.to_string(); let mut buff = [0; SHORT_LEN]; cursor.read_exact(&mut buff)?; let n = i16::from_be_bytes(buff); let mut descriptions = Vec::with_capacity(n as usize); for _ in 0..n { let name = from_cursor_str(cursor)?.to_string(); let col_type = ColTypeOption::from_cursor(cursor, version)?; descriptions.push((name, col_type)); } Ok(CUdt { ks, udt_name, descriptions, }) } } /// User defined type. /// [Read more...](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L608) #[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub struct CTuple { /// List of types. pub types: Vec, } impl Serialize for CTuple { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { (self.types.len() as i16).serialize(cursor, version); self.types.iter().for_each(|f| f.serialize(cursor, version)); } } impl FromCursor for CTuple { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let mut buff = [0; SHORT_LEN]; cursor.read_exact(&mut buff)?; let n = i16::from_be_bytes(buff); let mut types = Vec::with_capacity(n as usize); for _ in 0..n { let col_type = ColTypeOption::from_cursor(cursor, version)?; types.push(col_type); } Ok(CTuple { types }) } } /// The structure represents a body of a response envelope of type `prepared` #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct BodyResResultPrepared { /// id of prepared request pub id: CBytesShort, /// result metadata id (only available since V5) pub result_metadata_id: Option, /// metadata pub metadata: PreparedMetadata, /// It is defined exactly the same as in the Rows /// documentation. pub result_metadata: RowsMetadata, } impl Serialize for BodyResResultPrepared { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.id.serialize(cursor, version); if let Some(result_metadata_id) = &self.result_metadata_id { result_metadata_id.serialize(cursor, version); } self.metadata.serialize(cursor, version); self.result_metadata.serialize(cursor, version); } } impl BodyResResultPrepared { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let id = CBytesShort::from_cursor(cursor, version)?; let result_metadata_id = if version == Version::V5 { Some(CBytesShort::from_cursor(cursor, version)?) } else { None }; let metadata = PreparedMetadata::from_cursor(cursor, version)?; let result_metadata = RowsMetadata::from_cursor(cursor, version)?; Ok(BodyResResultPrepared { id, result_metadata_id, metadata, result_metadata, }) } } bitflags! { pub struct PreparedMetadataFlags: i32 { const GLOBAL_TABLE_SPACE = 0x0001; } } impl Serialize for PreparedMetadataFlags { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.bits().serialize(cursor, version); } } /// The structure that represents metadata of prepared response. #[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub struct PreparedMetadata { pub pk_indexes: Vec, pub global_table_spec: Option, pub col_specs: Vec, } impl Serialize for PreparedMetadata { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { if self.global_table_spec.is_some() { PreparedMetadataFlags::GLOBAL_TABLE_SPACE } else { PreparedMetadataFlags::empty() } .serialize(cursor, version); let columns_count = self.col_specs.len() as i32; columns_count.serialize(cursor, version); let pk_count = self.pk_indexes.len() as i32; pk_count.serialize(cursor, version); self.pk_indexes .iter() .for_each(|f| f.serialize(cursor, version)); if let Some(global_table_spec) = &self.global_table_spec { global_table_spec.serialize(cursor, version); } self.col_specs .iter() .for_each(|x| x.serialize(cursor, version)); } } impl PreparedMetadata { fn from_cursor( cursor: &mut Cursor<&[u8]>, version: Version, ) -> error::Result { let flags = PreparedMetadataFlags::from_bits_truncate(CInt::from_cursor(cursor, version)?); let columns_count = CInt::from_cursor(cursor, version)?; let pk_count = if let Version::V3 = version { 0 } else { // v4 or v5 CInt::from_cursor(cursor, version)? }; let pk_indexes = (0..pk_count) .map(|_| { let mut buff = [0; SHORT_LEN]; cursor.read_exact(&mut buff)?; Ok(i16::from_be_bytes(buff)) }) .collect::, IoError>>()?; let has_global_table_space = flags.contains(PreparedMetadataFlags::GLOBAL_TABLE_SPACE); let global_table_spec = extract_global_table_space(cursor, has_global_table_space, version)?; let col_specs = ColSpec::parse_colspecs(cursor, columns_count, has_global_table_space, version)?; Ok(PreparedMetadata { pk_indexes, global_table_spec, col_specs, }) } } fn extract_global_table_space( cursor: &mut Cursor<&[u8]>, has_global_table_space: bool, version: Version, ) -> error::Result> { Ok(if has_global_table_space { Some(TableSpec::from_cursor(cursor, version)?) } else { None }) } //noinspection DuplicatedCode #[cfg(test)] fn test_encode_decode(bytes: &[u8], expected: ResResultBody) { { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let result = ResResultBody::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(expected, result); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } #[cfg(test)] mod cudt { use super::*; //noinspection DuplicatedCode #[test] fn cudt() { let bytes = &[ 0, 3, 98, 97, 114, // keyspace name - bar 0, 3, 102, 111, 111, // udt_name - foo 0, 2, // length // pair 1 0, 3, 98, 97, 114, //name - bar 0, 9, // col type int // // // pair 2 0, 3, 102, 111, 111, // name - foo 0, 9, // col type int ]; let expected = CUdt { ks: "bar".into(), udt_name: "foo".into(), descriptions: vec![ ( "bar".into(), ColTypeOption { id: ColType::Int, value: None, }, ), ( "foo".into(), ColTypeOption { id: ColType::Int, value: None, }, ), ], }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let udt = CUdt::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(udt, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } #[cfg(test)] //noinspection DuplicatedCode mod ctuple { use super::*; #[test] fn ctuple() { let bytes = &[0, 3, 0, 9, 0, 9, 0, 9]; let expected = CTuple { types: vec![ ColTypeOption { id: ColType::Int, value: None, }; 3 ], }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let tuple = CTuple::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(tuple, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } #[cfg(test)] //noinspection DuplicatedCode mod col_spec { use super::*; #[test] fn col_spec_with_table_spec() { let bytes = &[ // table spec 0, 3, 98, 97, 114, // bar 0, 3, 102, 111, 111, //foo // 0, 3, 102, 111, 111, //name - foo // 0, 9, // col type - int ]; let expected = vec![ColSpec { table_spec: Some(TableSpec { ks_name: "bar".into(), table_name: "foo".into(), }), name: "foo".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }]; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let col_spec = ColSpec::parse_colspecs(&mut cursor, 1, false, Version::V4).unwrap(); assert_eq!(col_spec, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected[0].serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } #[test] fn col_spec_without_table_spec() { let bytes = &[ 0, 3, 102, 111, 111, //name - foo // 0, 9, // col type - int ]; let expected = vec![ColSpec { table_spec: None, name: "foo".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }]; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let col_spec = ColSpec::parse_colspecs(&mut cursor, 1, true, Version::V4).unwrap(); assert_eq!(col_spec, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected[0].serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } #[cfg(test)] //noinspection DuplicatedCode mod col_type_option { use super::*; #[test] fn col_type_options_int() { let bytes = &[0, 9]; let expected = ColTypeOption { id: ColType::Int, value: None, }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let col_type_option = ColTypeOption::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(col_type_option, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } #[test] fn col_type_options_map() { let bytes = &[0, 33, 0, 9, 0, 9]; let expected = ColTypeOption { id: ColType::Map, value: Some(ColTypeOptionValue::CMap( Box::new(ColTypeOption { id: ColType::Int, value: None, }), Box::new(ColTypeOption { id: ColType::Int, value: None, }), )), }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let col_type_option = ColTypeOption::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(col_type_option, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } #[cfg(test)] //noinspection DuplicatedCode mod table_spec { use super::*; #[test] fn table_spec() { let bytes = &[ 0, 3, 98, 97, 114, // bar 0, 3, 102, 111, 111, //foo ]; let expected = TableSpec { ks_name: "bar".into(), table_name: "foo".into(), }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let table_spec = TableSpec::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(table_spec, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } #[cfg(test)] mod void { use super::*; #[test] fn test_void() { let bytes = &[0, 0, 0, 1]; let expected = ResResultBody::Void; test_encode_decode(bytes, expected); } } #[cfg(test)] //noinspection DuplicatedCode mod rows_metadata { use super::*; #[test] fn rows_metadata() { let bytes = &[ 0, 0, 0, 8, // rows metadata flag 0, 0, 0, 2, // columns count 0, 1, 1, // new metadata id // // Col Spec 1 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 102, 111, 111, // name 0, 9, // col type id // // Col spec 2 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 98, 97, 114, // name 0, 19, // col type ]; let expected = RowsMetadata { flags: RowsMetadataFlags::METADATA_CHANGED, columns_count: 2, paging_state: None, new_metadata_id: Some(CBytesShort::new(vec![1])), global_table_spec: None, col_specs: vec![ ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "foo".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }, ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "bar".into(), col_type: ColTypeOption { id: ColType::Smallint, value: None, }, }, ], }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let metadata = RowsMetadata::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(metadata, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } #[cfg(test)] //noinspection DuplicatedCode mod rows { use super::*; #[test] fn test_rows() { let bytes = &[ 0, 0, 0, 2, // rows flag 0, 0, 0, 0, // rows metadata flag 0, 0, 0, 2, // columns count // // Col Spec 1 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 102, 111, 111, // name 0, 9, // col type id // // Col spec 2 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 98, 97, 114, // name 0, 19, // col type 0, 0, 0, 0, // rows count ]; let expected = ResResultBody::Rows(BodyResResultRows { metadata: RowsMetadata { flags: RowsMetadataFlags::empty(), columns_count: 2, paging_state: None, new_metadata_id: None, global_table_spec: None, col_specs: vec![ ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "foo".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }, ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "bar".into(), col_type: ColTypeOption { id: ColType::Smallint, value: None, }, }, ], }, rows_count: 0, rows_content: vec![], protocol_version: Version::V4, }); test_encode_decode(bytes, expected); } #[test] fn test_rows_no_metadata() { let bytes = &[ 0, 0, 0, 2, // rows flag 0, 0, 0, 4, // rows metadata flag 0, 0, 0, 3, // columns count 0, 0, 0, 0, // rows count ]; let expected = ResResultBody::Rows(BodyResResultRows { metadata: RowsMetadata { flags: RowsMetadataFlags::NO_METADATA, columns_count: 3, paging_state: None, new_metadata_id: None, global_table_spec: None, col_specs: vec![], }, rows_count: 0, rows_content: vec![], protocol_version: Version::V4, }); test_encode_decode(bytes, expected); } } #[cfg(test)] mod keyspace { use super::*; #[test] fn test_set_keyspace() { let bytes = &[ 0, 0, 0, 3, // keyspace flag 0, 4, 98, 108, 97, 104, // blah ]; let expected = ResResultBody::SetKeyspace(BodyResResultSetKeyspace { body: "blah".into(), }); test_encode_decode(bytes, expected); } } #[cfg(test)] //noinspection DuplicatedCode mod prepared_metadata { use super::*; #[test] fn prepared_metadata() { let bytes = &[ 0, 0, 0, 0, // global table space flag 0, 0, 0, 2, // columns counts 0, 0, 0, 1, // pk_count 0, 0, // pk_index // // col specs // col spec 1 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 102, 111, 111, // foo 0, 9, // id // // col spec 2 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 98, 97, 114, // bar 0, 19, // id ]; let expected = PreparedMetadata { pk_indexes: vec![0], global_table_spec: None, col_specs: vec![ ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "foo".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }, ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "bar".into(), col_type: ColTypeOption { id: ColType::Smallint, value: None, }, }, ], }; { let mut cursor: Cursor<&[u8]> = Cursor::new(bytes); let metadata = PreparedMetadata::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(metadata, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } #[cfg(test)] //noinspection DuplicatedCode mod prepared { use super::*; use crate::types::{to_short, CBytesShort}; #[test] fn test_prepared() { let bytes = &[ 0, 0, 0, 4, // prepared 0, 2, 0, 1, // id // // prepared flags 0, 0, 0, 0, // global table space flag 0, 0, 0, 2, // columns counts 0, 0, 0, 1, // pk_count 0, 0, // pk_index // // col specs // col spec 1 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 102, 111, 111, // foo 0, 9, // id // // col spec 2 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 98, 97, 114, // bar 0, 19, // id // // rows metadata 0, 0, 0, 0, // empty flags 0, 0, 0, 2, // columns count 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 102, 111, 111, // foo 0, 9, // int 0, 7, 107, 115, 110, 97, 109, 101, 49, // ksname1 0, 9, 116, 97, 98, 108, 101, 110, 97, 109, 101, // tablename 0, 3, 98, 97, 114, // bar 0, 19, // id ]; let expected = ResResultBody::Prepared(BodyResResultPrepared { id: CBytesShort::new(to_short(1)), result_metadata_id: None, metadata: PreparedMetadata { pk_indexes: vec![0], global_table_spec: None, col_specs: vec![ ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "foo".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }, ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "bar".into(), col_type: ColTypeOption { id: ColType::Smallint, value: None, }, }, ], }, result_metadata: RowsMetadata { flags: RowsMetadataFlags::empty(), columns_count: 2, paging_state: None, new_metadata_id: None, global_table_spec: None, col_specs: vec![ ColSpec { table_spec: Some(TableSpec { ks_name: "ksname1".into(), table_name: "tablename".into(), }), name: "foo".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }, ColSpec { table_spec: Some(TableSpec { table_name: "tablename".into(), ks_name: "ksname1".into(), }), name: "bar".into(), col_type: ColTypeOption { id: ColType::Smallint, value: None, }, }, ], }, }); test_encode_decode(bytes, expected); } } #[cfg(test)] mod schema_change { use super::*; use crate::frame::events::{SchemaChangeOptions, SchemaChangeTarget, SchemaChangeType}; #[test] fn test_schema_change() { let bytes = &[ 0, 0, 0, 5, // schema change 0, 7, 67, 82, 69, 65, 84, 69, 68, // change type - created 0, 8, 75, 69, 89, 83, 80, 65, 67, 69, // target keyspace 0, 4, 98, 108, 97, 104, // options - blah ]; let expected = ResResultBody::SchemaChange(SchemaChange { change_type: SchemaChangeType::Created, target: SchemaChangeTarget::Keyspace, options: SchemaChangeOptions::Keyspace("blah".into()), }); test_encode_decode(bytes, expected); } } ================================================ FILE: cassandra-protocol/src/frame/message_startup.rs ================================================ use crate::error; use crate::frame::{Direction, Envelope, Flags, FromCursor, Opcode, Serialize, Version}; use crate::types::{from_cursor_str, serialize_str, CIntShort}; use std::collections::HashMap; use std::io::Cursor; const CQL_VERSION: &str = "CQL_VERSION"; const CQL_VERSION_VAL: &str = "3.0.0"; const COMPRESSION: &str = "COMPRESSION"; const DRIVER_NAME: &str = "DRIVER_NAME"; const DRIVER_VERSION: &str = "DRIVER_VERSION"; #[derive(Debug, PartialEq, Eq, Default, Clone)] pub struct BodyReqStartup { pub map: HashMap, } impl BodyReqStartup { pub fn new(compression: Option, version: Version) -> BodyReqStartup { let mut map = HashMap::new(); map.insert(CQL_VERSION.into(), CQL_VERSION_VAL.into()); if let Some(c) = compression { map.insert(COMPRESSION.into(), c); } if version >= Version::V5 { map.insert(DRIVER_NAME.into(), "cdrs-tokio".into()); if let Some(version) = option_env!("CARGO_PKG_VERSION") { map.insert(DRIVER_VERSION.into(), version.into()); } } BodyReqStartup { map } } } impl Serialize for BodyReqStartup { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { let num = self.map.len() as CIntShort; num.serialize(cursor, version); for (key, val) in &self.map { serialize_str(cursor, key, version); serialize_str(cursor, val, version); } } } impl FromCursor for BodyReqStartup { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { let num = CIntShort::from_cursor(cursor, version)?; let mut map = HashMap::with_capacity(num as usize); for _ in 0..num { map.insert( from_cursor_str(cursor)?.to_string(), from_cursor_str(cursor)?.to_string(), ); } Ok(BodyReqStartup { map }) } } impl Envelope { /// Creates new envelope of type `startup`. pub fn new_req_startup(compression: Option, version: Version) -> Envelope { let direction = Direction::Request; let opcode = Opcode::Startup; let body = BodyReqStartup::new(compression, version); Envelope::new( version, direction, Flags::empty(), opcode, 0, body.serialize_to_vec(version), None, vec![], ) } } #[cfg(test)] mod test { use super::*; use crate::frame::{Envelope, Flags, Opcode, Version}; #[test] fn new_body_req_startup_some_compression() { let compression = "test_compression"; let body = BodyReqStartup::new(Some(compression.into()), Version::V4); assert_eq!( body.map.get("CQL_VERSION"), Some("3.0.0".to_string()).as_ref() ); assert_eq!( body.map.get("COMPRESSION"), Some(compression.to_string()).as_ref() ); assert_eq!(body.map.len(), 2); } #[test] fn new_body_req_startup_none_compression() { let body = BodyReqStartup::new(None, Version::V4); assert_eq!( body.map.get("CQL_VERSION"), Some("3.0.0".to_string()).as_ref() ); assert_eq!(body.map.len(), 1); } #[test] fn new_req_startup() { let compression = Some("test_compression".to_string()); let frame = Envelope::new_req_startup(compression, Version::V4); assert_eq!(frame.version, Version::V4); assert_eq!(frame.flags, Flags::empty()); assert_eq!(frame.opcode, Opcode::Startup); assert_eq!(frame.tracing_id, None); assert!(frame.warnings.is_empty()); } #[test] fn body_req_startup_from_cursor() { let bytes = vec![ 0, 3, 0, 11, 68, 82, 73, 86, 69, 82, 95, 78, 65, 77, 69, 0, 22, 68, 97, 116, 97, 83, 116, 97, 120, 32, 80, 121, 116, 104, 111, 110, 32, 68, 114, 105, 118, 101, 114, 0, 14, 68, 82, 73, 86, 69, 82, 95, 86, 69, 82, 83, 73, 79, 78, 0, 6, 51, 46, 50, 53, 46, 48, 0, 11, 67, 81, 76, 95, 86, 69, 82, 83, 73, 79, 78, 0, 5, 51, 46, 52, 46, 53, ]; let mut cursor = Cursor::new(bytes.as_slice()); BodyReqStartup::from_cursor(&mut cursor, Version::V4).unwrap(); } } ================================================ FILE: cassandra-protocol/src/frame/message_supported.rs ================================================ use super::Serialize; use crate::error; use crate::frame::{FromCursor, Version}; use crate::types::{from_cursor_str, from_cursor_string_list, serialize_str, CIntShort, SHORT_LEN}; use std::collections::HashMap; use std::io::{Cursor, Read}; #[derive(Debug, PartialEq, Eq, Clone, Default)] pub struct BodyResSupported { pub data: HashMap>, } impl Serialize for BodyResSupported { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { (self.data.len() as CIntShort).serialize(cursor, version); self.data.iter().for_each(|(key, value)| { serialize_str(cursor, key.as_str(), version); (value.len() as CIntShort).serialize(cursor, version); value .iter() .for_each(|s| serialize_str(cursor, s.as_str(), version)); }) } } impl FromCursor for BodyResSupported { fn from_cursor( cursor: &mut Cursor<&[u8]>, _version: Version, ) -> error::Result { let mut buff = [0; SHORT_LEN]; cursor.read_exact(&mut buff)?; let l = i16::from_be_bytes(buff) as usize; let mut data: HashMap> = HashMap::with_capacity(l); for _ in 0..l { let name = from_cursor_str(cursor)?.to_string(); let val = from_cursor_string_list(cursor)?; data.insert(name, val); } Ok(BodyResSupported { data }) } } #[cfg(test)] mod tests { use super::*; use crate::frame::traits::FromCursor; use crate::frame::Version; use std::io::Cursor; #[test] fn body_res_supported() { let bytes = [ 0, 1, // n options // 1-st option 0, 2, 97, 98, // key [string] "ab" 0, 2, 0, 1, 97, 0, 1, 98, /* value ["a", "b"] */ ]; let mut data: HashMap> = HashMap::new(); data.insert("ab".into(), vec!["a".into(), "b".into()]); let expected = BodyResSupported { data }; { let mut cursor: Cursor<&[u8]> = Cursor::new(&bytes); let auth = BodyResSupported::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(auth, expected); } { let mut buffer = Vec::new(); let mut cursor = Cursor::new(&mut buffer); expected.serialize(&mut cursor, Version::V4); assert_eq!(buffer, bytes); } } } ================================================ FILE: cassandra-protocol/src/frame/traits.rs ================================================ use crate::error; use crate::frame::Version; use crate::query; use num_bigint::BigInt; use std::io::{Cursor, Write}; /// Trait that should be implemented by all types that wish to be serialized to a buffer. pub trait Serialize { /// Serializes given value using the cursor. fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version); /// Wrapper for easily starting hierarchical serialization. fn serialize_to_vec(&self, version: Version) -> Vec { let mut buf = vec![]; self.serialize(&mut Cursor::new(&mut buf), version); buf } } /// `FromBytes` should be used to parse an array of bytes into a structure. pub trait FromBytes { /// It gets and array of bytes and should return an implementor struct. fn from_bytes(bytes: &[u8]) -> error::Result where Self: Sized; } /// `FromCursor` should be used to get parsed structure from an `io:Cursor` /// which bound to an array of bytes. pub trait FromCursor { /// Tries to parse Self from a cursor of bytes. fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result where Self: Sized; } /// The trait that allows transformation of `Self` to CDRS query values. pub trait IntoQueryValues { fn into_query_values(self) -> query::QueryValues; } pub trait TryFromRow: Sized { fn try_from_row(row: crate::types::rows::Row) -> error::Result; } pub trait TryFromUdt: Sized { fn try_from_udt(udt: crate::types::udt::Udt) -> error::Result; } impl Serialize for [u8; S] { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, _version: Version) { let _ = cursor.write(self); } } impl Serialize for &[u8] { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, _version: Version) { let _ = cursor.write(self); } } impl Serialize for Vec { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, _version: Version) { let _ = cursor.write(self); } } impl Serialize for BigInt { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, _version: Version) { let _ = cursor.write(&self.to_signed_bytes_be()); } } macro_rules! impl_serialized { ($t:ty) => { impl Serialize for $t { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, _version: Version) { let _ = cursor.write(&self.to_be_bytes()); } } }; } impl_serialized!(i8); impl_serialized!(i16); impl_serialized!(i32); impl_serialized!(i64); impl_serialized!(u8); impl_serialized!(u16); impl_serialized!(u32); impl_serialized!(u64); ================================================ FILE: cassandra-protocol/src/frame.rs ================================================ use crate::compression::{Compression, CompressionError}; use crate::frame::message_request::RequestBody; use crate::frame::message_response::ResponseBody; use crate::types::data_serialization_types::decode_timeuuid; use crate::types::{from_cursor_string_list, try_i16_from_bytes, try_i32_from_bytes, UUID_LEN}; use bitflags::bitflags; use derivative::Derivative; use derive_more::{Constructor, Display}; use std::convert::TryFrom; use std::io::Cursor; use thiserror::Error; use uuid::Uuid; pub use crate::frame::traits::*; /// Number of bytes in the header const ENVELOPE_HEADER_LEN: usize = 9; /// Number of stream bytes in accordance to protocol. pub const STREAM_LEN: usize = 2; /// Number of body length bytes in accordance to protocol. pub const LENGTH_LEN: usize = 4; pub mod events; pub mod frame_decoder; pub mod frame_encoder; pub mod message_auth_challenge; pub mod message_auth_response; pub mod message_auth_success; pub mod message_authenticate; pub mod message_batch; pub mod message_error; pub mod message_event; pub mod message_execute; pub mod message_options; pub mod message_prepare; pub mod message_query; pub mod message_ready; pub mod message_register; pub mod message_request; pub mod message_response; pub mod message_result; pub mod message_startup; pub mod message_supported; pub mod traits; use crate::error; pub const EVENT_STREAM_ID: i16 = -1; const fn const_max(a: usize, b: usize) -> usize { if a < b { a } else { b } } /// Maximum size of frame payloads - aggregated envelopes or a part of a single envelope. pub const PAYLOAD_SIZE_LIMIT: usize = 1 << 17; const UNCOMPRESSED_FRAME_HEADER_LENGTH: usize = 6; const COMPRESSED_FRAME_HEADER_LENGTH: usize = 8; const FRAME_TRAILER_LENGTH: usize = 4; /// Maximum size of an entire frame. pub const MAX_FRAME_SIZE: usize = PAYLOAD_SIZE_LIMIT + const_max( UNCOMPRESSED_FRAME_HEADER_LENGTH, COMPRESSED_FRAME_HEADER_LENGTH, ) + FRAME_TRAILER_LENGTH; /// Cassandra stream identifier. pub type StreamId = i16; #[derive(Debug, Clone, Eq, PartialEq, Hash, Constructor)] pub struct ParsedEnvelope { /// How many bytes from the buffer have been read. pub envelope_len: usize, /// The parsed envelope. pub envelope: Envelope, } #[derive(Derivative, Clone, PartialEq, Eq, Hash)] #[derivative(Debug)] pub struct Envelope { pub version: Version, pub direction: Direction, pub flags: Flags, pub opcode: Opcode, pub stream_id: StreamId, #[derivative(Debug = "ignore")] pub body: Vec, pub tracing_id: Option, pub warnings: Vec, } impl Envelope { #[inline] #[allow(clippy::too_many_arguments)] pub fn new( version: Version, direction: Direction, flags: Flags, opcode: Opcode, stream_id: StreamId, body: Vec, tracing_id: Option, warnings: Vec, ) -> Self { Envelope { version, direction, flags, opcode, stream_id, body, tracing_id, warnings, } } #[inline] pub fn request_body(&self) -> error::Result { RequestBody::try_from(self.body.as_slice(), self.opcode, self.version) } #[inline] pub fn response_body(&self) -> error::Result { ResponseBody::try_from(self.body.as_slice(), self.opcode, self.version) } #[inline] pub fn tracing_id(&self) -> &Option { &self.tracing_id } #[inline] pub fn warnings(&self) -> &[String] { &self.warnings } /// Parses the raw bytes of a cassandra envelope returning a [`ParsedEnvelope`] struct. /// The typical use case is reading from a buffer that may contain 0 or more envelopes and where /// the last envelope may be incomplete. The possible return values are: /// * `Ok(ParsedEnvelope)` - The first envelope in the buffer has been successfully parsed. /// * `Err(ParseEnvelopeError::NotEnoughBytes)` - There are not enough bytes to parse a single envelope, [`Envelope::from_buffer`] should be recalled when it is possible that there are more bytes. /// * `Err(_)` - The envelope is malformed and you should close the connection as this method does not provide a way to tell how many bytes to advance the buffer in this case. pub fn from_buffer( data: &[u8], compression: Compression, ) -> Result { if data.len() < ENVELOPE_HEADER_LEN { return Err(ParseEnvelopeError::NotEnoughBytes); } let body_len = try_i32_from_bytes(&data[5..9]).unwrap() as usize; let envelope_len = ENVELOPE_HEADER_LEN + body_len; if data.len() < envelope_len { return Err(ParseEnvelopeError::NotEnoughBytes); } let version = Version::try_from(data[0]) .map_err(|_| ParseEnvelopeError::UnsupportedVersion(data[0] & 0x7f))?; let direction = Direction::from(data[0]); let flags = Flags::from_bits_truncate(data[1]); let stream_id = try_i16_from_bytes(&data[2..4]).unwrap(); let opcode = Opcode::try_from(data[4]) .map_err(|_| ParseEnvelopeError::UnsupportedOpcode(data[4]))?; let body_bytes = &data[ENVELOPE_HEADER_LEN..envelope_len]; let full_body = if flags.contains(Flags::COMPRESSION) { compression.decode(body_bytes.to_vec()) } else { Compression::None.decode(body_bytes.to_vec()) } .map_err(ParseEnvelopeError::DecompressionError)?; let body_len = full_body.len(); // Use cursor to get tracing id, warnings and actual body let mut body_cursor = Cursor::new(full_body.as_slice()); let tracing_id = if flags.contains(Flags::TRACING) && direction == Direction::Response { let mut tracing_bytes = [0; UUID_LEN]; std::io::Read::read_exact(&mut body_cursor, &mut tracing_bytes).unwrap(); Some(decode_timeuuid(&tracing_bytes).map_err(ParseEnvelopeError::InvalidUuid)?) } else { None }; let warnings = if flags.contains(Flags::WARNING) { from_cursor_string_list(&mut body_cursor) .map_err(ParseEnvelopeError::InvalidWarnings)? } else { vec![] }; let mut body = Vec::with_capacity(body_len - body_cursor.position() as usize); std::io::Read::read_to_end(&mut body_cursor, &mut body) .expect("Read cannot fail because cursor is backed by slice"); Ok(ParsedEnvelope::new( envelope_len, Envelope { version, direction, flags, opcode, stream_id, body, tracing_id, warnings, }, )) } pub fn check_envelope_size(data: &[u8]) -> Result { if data.len() < ENVELOPE_HEADER_LEN { return Err(CheckEnvelopeSizeError::NotEnoughBytes); } let body_len = try_i32_from_bytes(&data[5..9]).unwrap() as usize; let envelope_len = ENVELOPE_HEADER_LEN + body_len; if data.len() < envelope_len { return Err(CheckEnvelopeSizeError::NotEnoughBytes); } let _ = Version::try_from(data[0]) .map_err(|_| CheckEnvelopeSizeError::UnsupportedVersion(data[0] & 0x7f))?; Ok(envelope_len) } pub fn encode_with(&self, compressor: Compression) -> error::Result> { // compression is ignored since v5 let is_compressed = self.version < Version::V5 && compressor.is_compressed(); let combined_version_byte = u8::from(self.version) | u8::from(self.direction); let flag_byte = (if is_compressed { self.flags | Flags::COMPRESSION } else { self.flags.difference(Flags::COMPRESSION) }) .bits(); let opcode_byte = u8::from(self.opcode); let mut v = Vec::with_capacity(9); v.push(combined_version_byte); v.push(flag_byte); v.extend_from_slice(&self.stream_id.to_be_bytes()); v.push(opcode_byte); let mut flags_buffer = vec![]; if self.flags.contains(Flags::TRACING) && self.direction == Direction::Response { let mut tracing_id = self .tracing_id .ok_or_else(|| { error::Error::Io(std::io::Error::other( "Tracing flag was set but Envelope has no tracing_id", )) })? .into_bytes() .to_vec(); flags_buffer.append(&mut tracing_id); }; if self.flags.contains(Flags::WARNING) && self.direction == Direction::Response { let warnings_len = self.warnings.len() as i16; flags_buffer.extend_from_slice(&warnings_len.to_be_bytes()); for warning in &self.warnings { let warning_len = warning.len() as i16; flags_buffer.extend_from_slice(&warning_len.to_be_bytes()); flags_buffer.append(&mut warning.as_bytes().to_vec()); } } if is_compressed { // avoid having to copy the body if there is nothing in flags_buffer let encoded_body = if flags_buffer.is_empty() { compressor.encode(&self.body)? } else { flags_buffer.extend_from_slice(&self.body); compressor.encode(&flags_buffer)? }; let body_len = encoded_body.len() as i32; v.extend_from_slice(&body_len.to_be_bytes()); v.extend_from_slice(&encoded_body); } else { // avoid having to copy the body if there is nothing in flags_buffer if flags_buffer.is_empty() { let body_len = self.body.len() as i32; v.extend_from_slice(&body_len.to_be_bytes()); v.extend_from_slice(&self.body); } else { let body_len = self.body.len() as i32 + flags_buffer.len() as i32; v.extend_from_slice(&body_len.to_be_bytes()); flags_buffer.extend_from_slice(&self.body); v.append(&mut flags_buffer); } } Ok(v) } } #[derive(Debug, Error)] #[non_exhaustive] pub enum CheckEnvelopeSizeError { #[error("Not enough bytes!")] NotEnoughBytes, #[error("Unsupported version: {0}")] UnsupportedVersion(u8), #[error("Unsupported opcode: {0}")] UnsupportedOpcode(u8), } #[derive(Debug, Error)] #[non_exhaustive] pub enum ParseEnvelopeError { /// There are not enough bytes to parse a single envelope, [`Envelope::from_buffer`] should be recalled when it is possible that there are more bytes. #[error("Not enough bytes!")] NotEnoughBytes, /// The version is not supported by cassandra-protocol, a server implementation should handle this by returning a server error with the message "Invalid or unsupported protocol version". #[error("Unsupported version: {0}")] UnsupportedVersion(u8), #[error("Unsupported opcode: {0}")] UnsupportedOpcode(u8), #[error("Decompression error: {0}")] DecompressionError(CompressionError), #[error("Invalid uuid: {0}")] InvalidUuid(uuid::Error), #[error("Invalid warnings: {0}")] InvalidWarnings(error::Error), } /// Protocol version. #[derive(Debug, PartialEq, Copy, Clone, Ord, PartialOrd, Eq, Hash, Display)] #[non_exhaustive] pub enum Version { V3, V4, V5, } impl From for u8 { fn from(value: Version) -> Self { match value { Version::V3 => 3, Version::V4 => 4, Version::V5 => 5, } } } impl TryFrom for Version { type Error = error::Error; fn try_from(version: u8) -> Result { match version & 0x7F { 3 => Ok(Version::V3), 4 => Ok(Version::V4), 5 => Ok(Version::V5), v => Err(error::Error::General(format!( "Unknown cassandra version: {v}" ))), } } } impl Version { /// Number of bytes that represent Cassandra frame's version. pub const BYTE_LENGTH: usize = 1; } #[derive(Debug, PartialEq, Copy, Clone, Ord, PartialOrd, Eq, Hash, Display)] pub enum Direction { Request, Response, } impl From for u8 { fn from(value: Direction) -> u8 { match value { Direction::Request => 0x00, Direction::Response => 0x80, } } } impl From for Direction { fn from(value: u8) -> Self { match value & 0x80 { 0 => Direction::Request, _ => Direction::Response, } } } bitflags! { /// Envelope flags #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct Flags: u8 { const COMPRESSION = 0x01; const TRACING = 0x02; const CUSTOM_PAYLOAD = 0x04; const WARNING = 0x08; const BETA = 0x10; } } impl Default for Flags { #[inline] fn default() -> Self { Flags::empty() } } impl Flags { // Number of opcode bytes in accordance to protocol. pub const BYTE_LENGTH: usize = 1; } #[derive(Debug, PartialEq, Copy, Clone, Ord, PartialOrd, Eq, Hash, Display)] #[non_exhaustive] pub enum Opcode { Error, Startup, Ready, Authenticate, Options, Supported, Query, Result, Prepare, Execute, Register, Event, Batch, AuthChallenge, AuthResponse, AuthSuccess, } impl Opcode { // Number of opcode bytes in accordance to protocol. pub const BYTE_LENGTH: usize = 1; } impl From for u8 { fn from(value: Opcode) -> Self { match value { Opcode::Error => 0x00, Opcode::Startup => 0x01, Opcode::Ready => 0x02, Opcode::Authenticate => 0x03, Opcode::Options => 0x05, Opcode::Supported => 0x06, Opcode::Query => 0x07, Opcode::Result => 0x08, Opcode::Prepare => 0x09, Opcode::Execute => 0x0A, Opcode::Register => 0x0B, Opcode::Event => 0x0C, Opcode::Batch => 0x0D, Opcode::AuthChallenge => 0x0E, Opcode::AuthResponse => 0x0F, Opcode::AuthSuccess => 0x10, } } } impl TryFrom for Opcode { type Error = error::Error; fn try_from(value: u8) -> Result>::Error> { match value { 0x00 => Ok(Opcode::Error), 0x01 => Ok(Opcode::Startup), 0x02 => Ok(Opcode::Ready), 0x03 => Ok(Opcode::Authenticate), 0x05 => Ok(Opcode::Options), 0x06 => Ok(Opcode::Supported), 0x07 => Ok(Opcode::Query), 0x08 => Ok(Opcode::Result), 0x09 => Ok(Opcode::Prepare), 0x0A => Ok(Opcode::Execute), 0x0B => Ok(Opcode::Register), 0x0C => Ok(Opcode::Event), 0x0D => Ok(Opcode::Batch), 0x0E => Ok(Opcode::AuthChallenge), 0x0F => Ok(Opcode::AuthResponse), 0x10 => Ok(Opcode::AuthSuccess), _ => Err(error::Error::General(format!("Unknown opcode: {value}"))), } } } #[cfg(test)] mod helpers { use super::*; pub fn test_encode_decode_roundtrip_response( raw_envelope: &[u8], envelope: Envelope, body: ResponseBody, ) { // test encode let encoded_body = body.serialize_to_vec(Version::V4); assert_eq!( &envelope.body, &encoded_body, "encoded body did not match envelope's body" ); let encoded_envelope = envelope.encode_with(Compression::None).unwrap(); assert_eq!( raw_envelope, &encoded_envelope, "encoded envelope did not match expected raw envelope" ); // test decode let decoded_envelope = Envelope::from_buffer(raw_envelope, Compression::None) .unwrap() .envelope; assert_eq!(decoded_envelope, envelope); let decoded_body = envelope.response_body().unwrap(); assert_eq!( body, decoded_body, "decoded envelope.body did not match body" ) } pub fn test_encode_decode_roundtrip_request( raw_envelope: &[u8], envelope: Envelope, body: RequestBody, ) { // test encode let encoded_body = body.serialize_to_vec(Version::V4); assert_eq!( &envelope.body, &encoded_body, "encoded body did not match envelope's body" ); let encoded_envelope = envelope.encode_with(Compression::None).unwrap(); assert_eq!( raw_envelope, &encoded_envelope, "encoded envelope did not match expected raw envelope" ); // test decode let decoded_envelope = Envelope::from_buffer(raw_envelope, Compression::None) .unwrap() .envelope; assert_eq!(envelope, decoded_envelope); let decoded_body = envelope.request_body().unwrap(); assert_eq!( body, decoded_body, "decoded envelope.body did not match body" ) } /// Use this when the body binary representation is nondeterministic but the body typed representation is deterministic pub fn test_encode_decode_roundtrip_nondeterministic_request( mut envelope: Envelope, body: RequestBody, ) { // test encode envelope.body = body.serialize_to_vec(Version::V4); // test decode let decoded_body = envelope.request_body().unwrap(); assert_eq!( body, decoded_body, "decoded envelope.body did not match body" ) } } //noinspection DuplicatedCode #[cfg(test)] mod tests { use super::*; use crate::consistency::Consistency; use crate::frame::frame_decoder::{ FrameDecoder, LegacyFrameDecoder, Lz4FrameDecoder, UncompressedFrameDecoder, }; use crate::frame::frame_encoder::{ FrameEncoder, LegacyFrameEncoder, Lz4FrameEncoder, UncompressedFrameEncoder, }; use crate::frame::message_query::BodyReqQuery; use crate::query::query_params::QueryParams; use crate::query::query_values::QueryValues; use crate::types::value::Value; use crate::types::CBytes; #[test] fn test_frame_version_as_byte() { assert_eq!(u8::from(Version::V3), 0x03); assert_eq!(u8::from(Version::V4), 0x04); assert_eq!(u8::from(Version::V5), 0x05); assert_eq!(u8::from(Direction::Request), 0x00); assert_eq!(u8::from(Direction::Response), 0x80); } #[test] fn test_frame_version_from() { assert_eq!(Version::try_from(0x03).unwrap(), Version::V3); assert_eq!(Version::try_from(0x83).unwrap(), Version::V3); assert_eq!(Version::try_from(0x04).unwrap(), Version::V4); assert_eq!(Version::try_from(0x84).unwrap(), Version::V4); assert_eq!(Version::try_from(0x05).unwrap(), Version::V5); assert_eq!(Version::try_from(0x85).unwrap(), Version::V5); assert_eq!(Direction::from(0x03), Direction::Request); assert_eq!(Direction::from(0x04), Direction::Request); assert_eq!(Direction::from(0x05), Direction::Request); assert_eq!(Direction::from(0x83), Direction::Response); assert_eq!(Direction::from(0x84), Direction::Response); assert_eq!(Direction::from(0x85), Direction::Response); } #[test] fn test_opcode_as_byte() { assert_eq!(u8::from(Opcode::Error), 0x00); assert_eq!(u8::from(Opcode::Startup), 0x01); assert_eq!(u8::from(Opcode::Ready), 0x02); assert_eq!(u8::from(Opcode::Authenticate), 0x03); assert_eq!(u8::from(Opcode::Options), 0x05); assert_eq!(u8::from(Opcode::Supported), 0x06); assert_eq!(u8::from(Opcode::Query), 0x07); assert_eq!(u8::from(Opcode::Result), 0x08); assert_eq!(u8::from(Opcode::Prepare), 0x09); assert_eq!(u8::from(Opcode::Execute), 0x0A); assert_eq!(u8::from(Opcode::Register), 0x0B); assert_eq!(u8::from(Opcode::Event), 0x0C); assert_eq!(u8::from(Opcode::Batch), 0x0D); assert_eq!(u8::from(Opcode::AuthChallenge), 0x0E); assert_eq!(u8::from(Opcode::AuthResponse), 0x0F); assert_eq!(u8::from(Opcode::AuthSuccess), 0x10); } #[test] fn test_opcode_from() { assert_eq!(Opcode::try_from(0x00).unwrap(), Opcode::Error); assert_eq!(Opcode::try_from(0x01).unwrap(), Opcode::Startup); assert_eq!(Opcode::try_from(0x02).unwrap(), Opcode::Ready); assert_eq!(Opcode::try_from(0x03).unwrap(), Opcode::Authenticate); assert_eq!(Opcode::try_from(0x05).unwrap(), Opcode::Options); assert_eq!(Opcode::try_from(0x06).unwrap(), Opcode::Supported); assert_eq!(Opcode::try_from(0x07).unwrap(), Opcode::Query); assert_eq!(Opcode::try_from(0x08).unwrap(), Opcode::Result); assert_eq!(Opcode::try_from(0x09).unwrap(), Opcode::Prepare); assert_eq!(Opcode::try_from(0x0A).unwrap(), Opcode::Execute); assert_eq!(Opcode::try_from(0x0B).unwrap(), Opcode::Register); assert_eq!(Opcode::try_from(0x0C).unwrap(), Opcode::Event); assert_eq!(Opcode::try_from(0x0D).unwrap(), Opcode::Batch); assert_eq!(Opcode::try_from(0x0E).unwrap(), Opcode::AuthChallenge); assert_eq!(Opcode::try_from(0x0F).unwrap(), Opcode::AuthResponse); assert_eq!(Opcode::try_from(0x10).unwrap(), Opcode::AuthSuccess); } #[test] fn test_ready() { let raw_envelope = vec![4, 0, 0, 0, 2, 0, 0, 0, 0]; let envelope = Envelope { version: Version::V4, direction: Direction::Request, flags: Flags::empty(), opcode: Opcode::Ready, stream_id: 0, body: vec![], tracing_id: None, warnings: vec![], }; let body = ResponseBody::Ready; helpers::test_encode_decode_roundtrip_response(&raw_envelope, envelope, body); } #[test] fn test_query_minimal() { let raw_envelope = [ 4, 0, 0, 0, 7, 0, 0, 0, 11, 0, 0, 0, 4, 98, 108, 97, 104, 0, 0, 64, ]; let envelope = Envelope { version: Version::V4, direction: Direction::Request, flags: Flags::empty(), opcode: Opcode::Query, stream_id: 0, body: vec![0, 0, 0, 4, 98, 108, 97, 104, 0, 0, 64], tracing_id: None, warnings: vec![], }; let body = RequestBody::Query(BodyReqQuery { query: "blah".into(), query_params: QueryParams { consistency: Consistency::Any, with_names: true, values: None, page_size: None, paging_state: None, serial_consistency: None, timestamp: None, keyspace: None, now_in_seconds: None, }, }); helpers::test_encode_decode_roundtrip_request(&raw_envelope, envelope, body); } #[test] fn test_query_simple_values() { let raw_envelope = [ 4, 0, 0, 0, 7, 0, 0, 0, 30, 0, 0, 0, 10, 115, 111, 109, 101, 32, 113, 117, 101, 114, 121, 0, 8, 1, 0, 2, 0, 0, 0, 3, 1, 2, 3, 255, 255, 255, 255, ]; let envelope = Envelope { version: Version::V4, direction: Direction::Request, flags: Flags::empty(), opcode: Opcode::Query, stream_id: 0, body: vec![ 0, 0, 0, 10, 115, 111, 109, 101, 32, 113, 117, 101, 114, 121, 0, 8, 1, 0, 2, 0, 0, 0, 3, 1, 2, 3, 255, 255, 255, 255, ], tracing_id: None, warnings: vec![], }; let body = RequestBody::Query(BodyReqQuery { query: "some query".into(), query_params: QueryParams { consistency: Consistency::Serial, with_names: false, values: Some(QueryValues::SimpleValues(vec![ Value::Some(vec![1, 2, 3]), Value::Null, ])), page_size: None, paging_state: None, serial_consistency: None, timestamp: None, keyspace: None, now_in_seconds: None, }, }); helpers::test_encode_decode_roundtrip_request(&raw_envelope, envelope, body); } #[test] fn test_query_named_values() { let envelope = Envelope { version: Version::V4, direction: Direction::Request, flags: Flags::empty(), opcode: Opcode::Query, stream_id: 0, body: vec![], tracing_id: None, warnings: vec![], }; let body = RequestBody::Query(BodyReqQuery { query: "another query".into(), query_params: QueryParams { consistency: Consistency::Three, with_names: true, values: Some(QueryValues::NamedValues( vec![ ("foo".to_string(), Value::Some(vec![11, 12, 13])), ("bar".to_string(), Value::NotSet), ("baz".to_string(), Value::Some(vec![42, 10, 99, 100, 4])), ] .into_iter() .collect(), )), page_size: Some(4), paging_state: Some(CBytes::new(vec![0, 1, 2, 3])), serial_consistency: Some(Consistency::One), timestamp: Some(2000), keyspace: None, now_in_seconds: None, }, }); helpers::test_encode_decode_roundtrip_nondeterministic_request(envelope, body); } #[test] fn test_result_prepared_statement() { use crate::frame::message_result::{ BodyResResultPrepared, ColSpec, ColType, ColTypeOption, PreparedMetadata, ResResultBody, RowsMetadata, RowsMetadataFlags, TableSpec, }; use crate::types::CBytesShort; let raw_envelope = [ 132, 0, 0, 0, 8, 0, 0, 0, 97, // cassandra header 0, 0, 0, 4, // prepared statement result 0, 16, 195, 165, 42, 38, 120, 170, 232, 144, 214, 187, 158, 200, 160, 226, 27, 73, // id 0, 0, 0, 1, // prepared metadata flags 0, 0, 0, 3, // columns count 0, 0, 0, 1, // pk count 0, 0, // pk index 1 0, 23, 116, 101, 115, 116, 95, 112, 114, 101, 112, 97, 114, 101, 95, 115, 116, 97, 116, 101, 109, 101, 110, 116, 115, // global_table_spec.ks_name = test_prepare_statements 0, 7, 116, 97, 98, 108, 101, 95, 49, // global_table_spec.table_name = table_1 0, 2, 105, 100, // ColSpec.name = "id" 0, 9, // ColSpec.col_type = Int 0, 1, 120, // ColSpec.name = "x" 0, 9, // ColSpec.col_type = Int 0, 4, 110, 97, 109, 101, // ColSpec.name = "name" 0, 13, // ColSpec.col_type = VarChar 0, 0, 0, 4, // row metadata flags 0, 0, 0, 0, // columns count ]; let envelope = Envelope { version: Version::V4, direction: Direction::Response, flags: Flags::empty(), opcode: Opcode::Result, stream_id: 0, body: vec![ 0, 0, 0, 4, // prepared statement result 0, 16, 195, 165, 42, 38, 120, 170, 232, 144, 214, 187, 158, 200, 160, 226, 27, 73, // id 0, 0, 0, 1, // prepared metadata flags 0, 0, 0, 3, // columns count 0, 0, 0, 1, // pk count 0, 0, // pk index 1 0, 23, 116, 101, 115, 116, 95, 112, 114, 101, 112, 97, 114, 101, 95, 115, 116, 97, 116, 101, 109, 101, 110, 116, 115, // global_table_spec.ks_name = test_prepare_statements 0, 7, 116, 97, 98, 108, 101, 95, 49, // global_table_spec.table_name = table_1 0, 2, 105, 100, // ColSpec.name = "id" 0, 9, // ColSpec.col_type = Int 0, 1, 120, // ColSpec.name = "x" 0, 9, // ColSpec.col_type = Int 0, 4, 110, 97, 109, 101, // ColSpec.name = "name" 0, 13, // ColSpec.col_type = VarChar 0, 0, 0, 4, // row metadata flags 0, 0, 0, 0, // columns count ], tracing_id: None, warnings: vec![], }; let body = ResponseBody::Result(ResResultBody::Prepared(BodyResResultPrepared { id: CBytesShort::new(vec![ 195, 165, 42, 38, 120, 170, 232, 144, 214, 187, 158, 200, 160, 226, 27, 73, ]), result_metadata_id: None, metadata: PreparedMetadata { pk_indexes: vec![0], global_table_spec: Some(TableSpec { ks_name: "test_prepare_statements".into(), table_name: "table_1".into(), }), col_specs: vec![ ColSpec { table_spec: None, name: "id".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }, ColSpec { table_spec: None, name: "x".into(), col_type: ColTypeOption { id: ColType::Int, value: None, }, }, ColSpec { table_spec: None, name: "name".into(), col_type: ColTypeOption { id: ColType::Varchar, value: None, }, }, ], }, result_metadata: RowsMetadata { flags: RowsMetadataFlags::NO_METADATA, columns_count: 0, paging_state: None, new_metadata_id: None, global_table_spec: None, col_specs: vec![], }, })); helpers::test_encode_decode_roundtrip_response(&raw_envelope, envelope, body); } fn create_small_envelope_data() -> (Envelope, Vec) { let raw_envelope = vec![ 4, 0, 0, 0, 7, 0, 0, 0, 30, 0, 0, 0, 10, 115, 111, 109, 101, 32, 113, 117, 101, 114, 121, 0, 8, 1, 0, 2, 0, 0, 0, 3, 1, 2, 3, 255, 255, 255, 255, ]; let envelope = Envelope { version: Version::V4, direction: Direction::Request, flags: Flags::empty(), opcode: Opcode::Query, stream_id: 0, body: vec![ 0, 0, 0, 10, 115, 111, 109, 101, 32, 113, 117, 101, 114, 121, 0, 8, 1, 0, 2, 0, 0, 0, 3, 1, 2, 3, 255, 255, 255, 255, ], tracing_id: None, warnings: vec![], }; (envelope, raw_envelope) } fn create_large_envelope_data() -> (Envelope, Vec) { let body: Vec = (0..262144).map(|value| (value % 256) as u8).collect(); let mut raw_envelope = vec![4, 0, 0, 0, 7, 0, 4, 0, 0]; raw_envelope.append(&mut body.clone()); let envelope = Envelope { version: Version::V4, direction: Direction::Request, flags: Flags::empty(), opcode: Opcode::Query, stream_id: 0, body, tracing_id: None, warnings: vec![], }; (envelope, raw_envelope) } #[test] fn should_encode_and_decode_legacy_frames() { let (envelope, raw_envelope) = create_small_envelope_data(); let mut encoder = LegacyFrameEncoder::default(); assert!(encoder.can_fit(raw_envelope.len())); encoder.add_envelope(raw_envelope.clone()); assert!(!encoder.can_fit(1)); let mut frame = encoder.finalize_self_contained().to_vec(); assert_eq!(frame, raw_envelope); let mut decoder = LegacyFrameDecoder::default(); let envelopes = decoder.consume(&mut frame, Compression::None).unwrap(); assert_eq!(envelopes.len(), 1); assert_eq!(envelopes[0], envelope); encoder.reset(); assert!(encoder.can_fit(raw_envelope.len())); } #[test] fn should_encode_and_decode_uncompressed_self_contained_frames() { let (envelope, raw_envelope) = create_small_envelope_data(); let mut encoder = UncompressedFrameEncoder::default(); assert!(encoder.can_fit(raw_envelope.len())); encoder.add_envelope(raw_envelope.clone()); assert!(encoder.can_fit(raw_envelope.len())); encoder.add_envelope(raw_envelope); let mut buffer1 = encoder.finalize_self_contained().to_vec(); let mut buffer2 = buffer1.split_off(5); let mut decoder = UncompressedFrameDecoder::default(); let envelopes = decoder.consume(&mut buffer1, Compression::None).unwrap(); assert!(buffer1.is_empty()); assert!(envelopes.is_empty()); let envelopes = decoder.consume(&mut buffer2, Compression::None).unwrap(); assert!(buffer2.is_empty()); assert_eq!(envelopes.len(), 2); assert_eq!(envelopes[0], envelope); assert_eq!(envelopes[1], envelope); } #[test] fn should_encode_and_decode_uncompressed_non_self_contained_frames() { let (envelope, raw_envelope) = create_large_envelope_data(); let mut encoder = UncompressedFrameEncoder::default(); assert!(!encoder.can_fit(raw_envelope.len())); let data_len = raw_envelope.len(); let mut data_start = 0; let mut buffer1 = vec![]; while data_start < data_len { let (data_start_offset, frame) = encoder.finalize_non_self_contained(&raw_envelope[data_start..]); data_start += data_start_offset; buffer1.extend_from_slice(frame); encoder.reset(); } let mut buffer2 = buffer1.split_off(PAYLOAD_SIZE_LIMIT); let mut decoder = UncompressedFrameDecoder::default(); let envelopes = decoder.consume(&mut buffer1, Compression::None).unwrap(); assert!(buffer1.is_empty()); assert!(envelopes.is_empty()); let envelopes = decoder.consume(&mut buffer2, Compression::None).unwrap(); assert!(buffer2.is_empty()); assert_eq!(envelopes.len(), 1); assert_eq!(envelopes[0], envelope); } #[test] fn should_encode_and_decode_compressed_self_contained_frames() { let (envelope, raw_envelope) = create_small_envelope_data(); let mut encoder = Lz4FrameEncoder::default(); assert!(encoder.can_fit(raw_envelope.len())); encoder.add_envelope(raw_envelope.clone()); assert!(encoder.can_fit(raw_envelope.len())); encoder.add_envelope(raw_envelope); let mut buffer1 = encoder.finalize_self_contained().to_vec(); let mut buffer2 = buffer1.split_off(5); let mut decoder = Lz4FrameDecoder::default(); let envelopes = decoder.consume(&mut buffer1, Compression::None).unwrap(); assert!(buffer1.is_empty()); assert!(envelopes.is_empty()); let envelopes = decoder.consume(&mut buffer2, Compression::None).unwrap(); assert!(buffer2.is_empty()); assert_eq!(envelopes.len(), 2); assert_eq!(envelopes[0], envelope); assert_eq!(envelopes[1], envelope); } #[test] fn should_encode_and_decode_compressed_non_self_contained_frames() { let (envelope, raw_envelope) = create_large_envelope_data(); let mut encoder = Lz4FrameEncoder::default(); assert!(!encoder.can_fit(raw_envelope.len())); let data_len = raw_envelope.len(); let mut data_start = 0; let mut buffer1 = vec![]; while data_start < data_len { let (data_start_offset, frame) = encoder.finalize_non_self_contained(&raw_envelope[data_start..]); data_start += data_start_offset; buffer1.extend_from_slice(frame); encoder.reset(); } let mut buffer2 = buffer1.split_off(1000); let mut decoder = Lz4FrameDecoder::default(); let envelopes = decoder.consume(&mut buffer1, Compression::None).unwrap(); assert!(buffer1.is_empty()); assert!(envelopes.is_empty()); let envelopes = decoder.consume(&mut buffer2, Compression::None).unwrap(); assert!(buffer2.is_empty()); assert_eq!(envelopes.len(), 1); assert_eq!(envelopes[0], envelope); } } #[cfg(test)] mod flags { use super::*; use crate::consistency::Consistency; use crate::frame::message_query::BodyReqQuery; use crate::frame::message_result::ResResultBody; use crate::query::query_params::QueryParams; #[test] fn test_tracing_id_request() { let raw_envelope = [ 4, // version 2, // flags 0, 12, // stream id 7, // opcode 0, 0, 0, 11, //length 0, 0, 0, 4, 98, 108, 97, 104, 0, 0, 64, // body ]; let envelope = Envelope { version: Version::V4, direction: Direction::Request, flags: Flags::TRACING, opcode: Opcode::Query, stream_id: 12, body: vec![0, 0, 0, 4, 98, 108, 97, 104, 0, 0, 64], tracing_id: None, warnings: vec![], }; let body = RequestBody::Query(BodyReqQuery { query: "blah".into(), query_params: QueryParams { consistency: Consistency::Any, with_names: true, values: None, page_size: None, paging_state: None, serial_consistency: None, timestamp: None, keyspace: None, now_in_seconds: None, }, }); helpers::test_encode_decode_roundtrip_request(&raw_envelope, envelope, body); } #[test] fn test_tracing_id_response() { let raw_envelope = [ 132, //version 2, // flags 0, 12, // stream id 8, //opcode 0, 0, 0, 20, // length 4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87, // tracing_id 0, 0, 0, 1, // body ]; let envelope = Envelope { version: Version::V4, direction: Direction::Response, flags: Flags::TRACING, opcode: Opcode::Result, stream_id: 12, body: vec![0, 0, 0, 1], tracing_id: Some(uuid::Uuid::from_bytes([ 4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87, ])), warnings: vec![], }; let body = ResponseBody::Result(ResResultBody::Void); helpers::test_encode_decode_roundtrip_response(&raw_envelope, envelope, body); } #[test] fn test_warnings_response() { let raw_envelope = [ 132, // version 8, // flags 5, 64, // stream id 8, // opcode 0, 0, 0, 19, // length // warnings 0, 1, 0, 11, 72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100, // warnings 0, 0, 0, 1, // body ]; let body = ResponseBody::Result(ResResultBody::Void); let envelope = Envelope { version: Version::V4, opcode: Opcode::Result, flags: Flags::WARNING, direction: Direction::Response, stream_id: 1344, tracing_id: None, body: vec![0, 0, 0, 1], warnings: vec!["Hello World".into()], }; helpers::test_encode_decode_roundtrip_response(&raw_envelope, envelope, body); } } ================================================ FILE: cassandra-protocol/src/lib.rs ================================================ //! A generic cassandra protocol crate. //! Built in coordination with cdrs-tokio but is flexible for many usecases. extern crate core; #[macro_use] mod macros; pub mod frame; pub mod query; pub mod types; pub mod authenticators; pub mod compression; pub mod consistency; pub mod crc; pub mod error; pub mod events; pub mod token; pub type Error = error::Error; pub type Result = error::Result; ================================================ FILE: cassandra-protocol/src/macros.rs ================================================ #[macro_export] /// Transforms arguments to values consumed by queries. macro_rules! query_values { ($($value:expr),*) => { { use cassandra_protocol::types::value::Value; use cassandra_protocol::query::QueryValues; let mut values: Vec = Vec::new(); $( values.push($value.into()); )* QueryValues::SimpleValues(values) } }; ($($name:expr => $value:expr),*) => { { use cassandra_protocol::types::value::Value; use cassandra_protocol::query::QueryValues; use std::collections::HashMap; let mut values: HashMap = HashMap::new(); $( values.insert($name.to_string(), $value.into()); )* QueryValues::NamedValues(values) } }; } macro_rules! vector_as_rust { (f32) => { impl AsRustType> for Vector { fn as_rust_type(&self) -> Result>> { let mut result: Vec = Vec::new(); for data_value in &self.data { let float = decode_float(data_value.as_slice().unwrap_or(Err( Error::General(format!("Failed to convert {:?} into float", data_value)), )?))?; result.push(float); } Ok(Some(result)) } } }; } macro_rules! list_as_rust { (List) => ( impl AsRustType> for List { fn as_rust_type(&self) -> Result>> { match self.metadata.value { Some(ColTypeOptionValue::CList(ref type_option)) | Some(ColTypeOptionValue::CSet(ref type_option)) => { let type_option_ref = type_option.as_ref(); let protocol_version = self.protocol_version; let convert = self .map(|bytes| { as_rust_type!(type_option_ref, bytes, protocol_version, List) .unwrap() // item in a list supposed to be a non-null value. // TODO: check if it's true .unwrap() }); Ok(Some(convert)) }, _ => Err(Error::General(format!("Invalid conversion. \ Cannot convert {:?} into List (valid types: List, Set).", self.metadata.value))) } } } ); (Map) => ( impl AsRustType> for List { fn as_rust_type(&self) -> Result>> { match self.metadata.value { Some(ColTypeOptionValue::CList(ref type_option)) | Some(ColTypeOptionValue::CSet(ref type_option)) => { let type_option_ref = type_option.as_ref(); let protocol_version = self.protocol_version; let convert = self .map(|bytes| { as_rust_type!(type_option_ref, bytes, protocol_version, Map) .unwrap() // item in a list supposed to be a non-null value. // TODO: check if it's true .unwrap() }); Ok(Some(convert)) }, _ => Err(Error::General(format!("Invalid conversion. \ Cannot convert {:?} into List (valid types: List, Set).", self.metadata.value))) } } } ); (Udt) => ( impl AsRustType> for List { fn as_rust_type(&self) -> Result>> { match self.metadata.value { Some(ColTypeOptionValue::CList(ref type_option)) | Some(ColTypeOptionValue::CSet(ref type_option)) => { let type_option_ref = type_option.as_ref(); let protocol_version = self.protocol_version; let convert = self .map(|bytes| { as_rust_type!(type_option_ref, bytes, protocol_version, Udt) .unwrap() // item in a list supposed to be a non-null value. // TODO: check if it's true .unwrap() }); Ok(Some(convert)) }, _ => Err(Error::General(format!("Invalid conversion. \ Cannot convert {:?} into List (valid types: List, Set).", self.metadata.value))) } } } ); (Tuple) => ( impl AsRustType> for List { fn as_rust_type(&self) -> Result>> { match self.metadata.value { Some(ColTypeOptionValue::CList(ref type_option)) | Some(ColTypeOptionValue::CSet(ref type_option)) => { let type_option_ref = type_option.as_ref(); let protocol_version = self.protocol_version; let convert = self .map(|bytes| { as_rust_type!(type_option_ref, bytes, protocol_version, Tuple) .unwrap() // item in a list supposed to be a non-null value. // TODO: check if it's true .unwrap() }); Ok(Some(convert)) }, _ => Err(Error::General(format!("Invalid conversion. \ Cannot convert {:?} into List (valid types: List, Set).", self.metadata.value))) } } } ); ($($into_type:tt)+) => ( impl AsRustType> for List { fn as_rust_type(&self) -> Result>> { match self.metadata.value { Some(ColTypeOptionValue::CList(ref type_option)) | Some(ColTypeOptionValue::CSet(ref type_option)) => { let type_option_ref = type_option.as_ref(); let convert = self .map(|bytes| { as_rust_type!(type_option_ref, bytes, $($into_type)+) .unwrap() // item in a list supposed to be a non-null value. // TODO: check if it's true .unwrap() }); Ok(Some(convert)) }, _ => Err(Error::General(format!("Invalid conversion. \ Cannot convert {:?} into List (valid types: List, Set).", self.metadata.value))) } } } ); } macro_rules! list_as_cassandra_type { () => { impl crate::types::AsCassandraType for List { fn as_cassandra_type( &self, ) -> Result> { use crate::error::Error; use crate::types::cassandra_type::wrapper_fn; use crate::types::cassandra_type::CassandraType; use std::ops::Deref; let protocol_version = self.protocol_version; match self.metadata.value { Some(ColTypeOptionValue::CList(ref type_option)) | Some(ColTypeOptionValue::CSet(ref type_option)) => { let type_option_ref = type_option.deref().clone(); let wrapper = wrapper_fn(&type_option_ref.id); let convert = self .try_map(|bytes| wrapper(bytes, &type_option_ref, protocol_version)); convert.map(|convert| Some(CassandraType::List(convert))) } _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into List (valid types: List, Set).", self.metadata.value ))), } } } }; } macro_rules! vector_as_cassandra_type { () => { impl crate::types::AsCassandraType for Vector { fn as_cassandra_type( &self, ) -> Result> { use crate::error::Error; use crate::types::cassandra_type::wrapper_fn; use crate::types::cassandra_type::CassandraType; let protocol_version = self.protocol_version; match &self.metadata { ColTypeOption { id: ColType::Custom, value, } => { if let Some(value) = value { let VectorInfo { internal_type, .. } = get_vector_type_info(value)?; if internal_type == "FloatType" { let internal_type_option = ColTypeOption { id: ColType::Float, value: None, }; let wrapper = wrapper_fn(&ColType::Float); let convert = self.try_map(|bytes| { wrapper(bytes, &internal_type_option, protocol_version) }); return convert.map(|convert| Some(CassandraType::Vector(convert))); } else { return Err(Error::General(format!( "Invalid conversion. \ Cannot convert Vector<{:?}> into Vector (valid types: Vector", internal_type ))); } } else { return Err(Error::General("Custom type string is none".to_string())); } } _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into Vector (valid types: Custom).", self.metadata.value ))), } } } }; } macro_rules! map_as_cassandra_type { () => { impl crate::types::AsCassandraType for Map { fn as_cassandra_type( &self, ) -> Result> { use crate::types::cassandra_type::wrapper_fn; use crate::types::cassandra_type::CassandraType; use itertools::Itertools; use std::ops::Deref; if let Some(ColTypeOptionValue::CMap( ref key_col_type_option, ref value_col_type_option, )) = self.metadata.value { let key_col_type_option = key_col_type_option.deref().clone(); let value_col_type_option = value_col_type_option.deref().clone(); let key_wrapper = wrapper_fn(&key_col_type_option.id); let value_wrapper = wrapper_fn(&value_col_type_option.id); let protocol_version = self.protocol_version; return self .data .iter() .map(|(key, value)| { key_wrapper(key, &key_col_type_option, protocol_version).and_then( |key| { value_wrapper(value, &value_col_type_option, protocol_version) .map(|value| (key, value)) }, ) }) .try_collect() .map(|map| Some(CassandraType::Map(map))); } else { panic!("not a map") } } } }; } macro_rules! tuple_as_cassandra_type { () => { impl crate::types::AsCassandraType for Tuple { fn as_cassandra_type( &self, ) -> Result> { use crate::types::cassandra_type::wrapper_fn; use crate::types::cassandra_type::CassandraType; use itertools::Itertools; let protocol_version = self.protocol_version; let values = self .data .iter() .map(|(col_type, bytes)| { let wrapper = wrapper_fn(&col_type.id); wrapper(&bytes, col_type, protocol_version) }) .try_collect()?; Ok(Some(CassandraType::Tuple(values))) } } }; } macro_rules! udt_as_cassandra_type { () => { impl crate::types::AsCassandraType for Udt { fn as_cassandra_type( &self, ) -> Result> { use crate::types::cassandra_type::wrapper_fn; use crate::types::cassandra_type::CassandraType; use std::collections::HashMap; let mut map = HashMap::with_capacity(self.data.len()); let protocol_version = self.protocol_version; for (key, (col_type, bytes)) in &self.data { let wrapper = wrapper_fn(&col_type.id); let value = wrapper(&bytes, col_type, protocol_version)?; map.insert(key.clone(), value); } Ok(Some(CassandraType::Udt(map))) } } }; } macro_rules! map_as_rust { ({ Tuple }, { List }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, protocol_version, Tuple)?; let val = as_rust_type!(val_type_option, val, protocol_version, List)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ Tuple }, { Map }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, protocol_version, Tuple)?; let val = as_rust_type!(val_type_option, val, protocol_version, Map)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ Tuple }, { Udt }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, protocol_version, Tuple)?; let val = as_rust_type!(val_type_option, val, protocol_version, Udt)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ Tuple }, { Tuple }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, protocol_version, Tuple)?; let val = as_rust_type!(val_type_option, val, protocol_version, Tuple)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ Tuple }, { $($val_type:tt)+ }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, protocol_version, Tuple)?; let val = as_rust_type!(val_type_option, val, $($val_type)+)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ $($key_type:tt)+ }, { List }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, $($key_type)+)?; let val = as_rust_type!(val_type_option, val, protocol_version, List)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ $($key_type:tt)+ }, { Map }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, $($key_type)+)?; let val = as_rust_type!(val_type_option, val, protocol_version, Map)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ $($key_type:tt)+ }, { Udt }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, $($key_type)+)?; let val = as_rust_type!(val_type_option, val, protocol_version, Udt)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ $($key_type:tt)+ }, { Tuple }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); let protocol_version = self.protocol_version; for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, $($key_type)+)?; let val = as_rust_type!(val_type_option, val, protocol_version, Tuple)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); ({ $($key_type:tt)+ }, { $($val_type:tt)+ }) => ( impl AsRustType> for Map { /// Converts `Map` into `HashMap` for blob values. fn as_rust_type(&self) -> Result>> { if let Some(ColTypeOptionValue::CMap(key_type_option, val_type_option)) = &self.metadata.value { let mut map = HashMap::with_capacity(self.data.len()); let key_type_option = key_type_option.as_ref(); let val_type_option = val_type_option.as_ref(); for (key, val) in self.data.iter() { let key = as_rust_type!(key_type_option, key, $($key_type)+)?; let val = as_rust_type!(val_type_option, val, $($val_type)+)?; if let (Some(key), Some(val)) = (key, val) { map.insert(key, val); } } Ok(Some(map)) } else { Err(format!("Invalid column type for map: {:?}", self.metadata.value).into()) } } } ); } macro_rules! into_rust_by_name { (Row, List) => ( impl IntoRustByName for Row { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_name(name) .ok_or(column_is_empty_err(name)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, List) }) } } ); (Row, Vector) => ( impl IntoRustByName for Row { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_name(name) .ok_or(column_is_empty_err(name)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, Vector) }) } } ); (Row, Map) => ( impl IntoRustByName for Row { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_name(name) .ok_or(column_is_empty_err(name)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, Map) }) } } ); (Row, Udt) => ( impl IntoRustByName for Row { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_name(name) .ok_or(column_is_empty_err(name)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, Udt) }) } } ); (Row, Tuple) => ( impl IntoRustByName for Row { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_name(name) .ok_or(column_is_empty_err(name)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, Tuple) }) } } ); (Row, $($into_type:tt)+) => ( impl IntoRustByName<$($into_type)+> for Row { fn get_by_name(&self, name: &str) -> Result> { self.col_spec_by_name(name) .ok_or(column_is_empty_err(name)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, $($into_type)+) }) } } ); (Udt, List) => ( impl IntoRustByName for Udt { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.data.get(name) .ok_or(column_is_empty_err(name)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, protocol_version, List); converted.map_err(|err| err.into()) }) } } ); (Udt, Map) => ( impl IntoRustByName for Udt { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.data.get(name) .ok_or(column_is_empty_err(name)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, protocol_version, Map); converted.map_err(|err| err.into()) }) } } ); (Udt, Udt) => ( impl IntoRustByName for Udt { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.data.get(name) .ok_or(column_is_empty_err(name)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, protocol_version, Udt); converted.map_err(|err| err.into()) }) } } ); (Udt, Tuple) => ( impl IntoRustByName for Udt { fn get_by_name(&self, name: &str) -> Result> { let protocol_version = self.protocol_version; self.data.get(name) .ok_or(column_is_empty_err(name)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, protocol_version, Tuple); converted.map_err(|err| err.into()) }) } } ); (Udt, $($into_type:tt)+) => ( impl IntoRustByName<$($into_type)+> for Udt { fn get_by_name(&self, name: &str) -> Result> { self.data.get(name) .ok_or(column_is_empty_err(name)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, $($into_type)+); converted.map_err(|err| err.into()) }) } } ); } macro_rules! into_rust_by_index { (Tuple, List) => ( impl IntoRustByIndex for Tuple { fn get_by_index(&self, index: usize) -> Result> { let protocol_version = self.protocol_version; self.data .get(index) .ok_or(column_is_empty_err(index)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, protocol_version, List); converted.map_err(|err| err.into()) }) } } ); (Tuple, Map) => ( impl IntoRustByIndex for Tuple { fn get_by_index(&self, index: usize) -> Result> { let protocol_version = self.protocol_version; self.data .get(index) .ok_or(column_is_empty_err(index)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, protocol_version, Map); converted.map_err(|err| err.into()) }) } } ); (Tuple, Udt) => ( impl IntoRustByIndex for Tuple { fn get_by_index(&self, index: usize) -> Result> { let protocol_version = self.protocol_version; self.data .get(index) .ok_or(column_is_empty_err(index)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, protocol_version, Udt); converted.map_err(|err| err.into()) }) } } ); (Tuple, Tuple) => ( impl IntoRustByIndex for Tuple { fn get_by_index(&self, index: usize) -> Result> { let protocol_version = self.protocol_version; self.data .get(index) .ok_or(column_is_empty_err(index)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, protocol_version, Tuple); converted.map_err(|err| err.into()) }) } } ); (Tuple, $($into_type:tt)+) => ( impl IntoRustByIndex<$($into_type)+> for Tuple { fn get_by_index(&self, index: usize) -> Result> { self.data .get(index) .ok_or(column_is_empty_err(index)) .and_then(|v| { let &(ref col_type, ref bytes) = v; let converted = as_rust_type!(col_type, bytes, $($into_type)+); converted.map_err(|err| err.into()) }) } } ); (Row, List) => ( impl IntoRustByIndex for Row { fn get_by_index(&self, index: usize) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_index(index) .ok_or(column_is_empty_err(index)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, List) }) } } ); (Row, Map) => ( impl IntoRustByIndex for Row { fn get_by_index(&self, index: usize) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_index(index) .ok_or(column_is_empty_err(index)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, Map) }) } } ); (Row, Udt) => ( impl IntoRustByIndex for Row { fn get_by_index(&self, index: usize) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_index(index) .ok_or(column_is_empty_err(index)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, Udt) }) } } ); (Row, Tuple) => ( impl IntoRustByIndex for Row { fn get_by_index(&self, index: usize) -> Result> { let protocol_version = self.protocol_version; self.col_spec_by_index(index) .ok_or(column_is_empty_err(index)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, protocol_version, Tuple) }) } } ); (Row, $($into_type:tt)+) => ( impl IntoRustByIndex<$($into_type)+> for Row { fn get_by_index(&self, index: usize) -> Result> { self.col_spec_by_index(index) .ok_or(column_is_empty_err(index)) .and_then(|(col_spec, cbytes)| { let col_type = &col_spec.col_type; as_rust_type!(col_type, cbytes, $($into_type)+) }) } } ); } macro_rules! as_res_opt { ($data_value:ident, $deserialize:expr) => { match $data_value.as_slice() { Some(ref bytes) => ($deserialize)(bytes).map(Some).map_err(Into::into), None => Ok(None), } }; } /// Decodes any Cassandra data type into the corresponding Rust type, /// given the column type as `ColTypeOption` and the value as `CBytes` /// plus the matching Rust type. macro_rules! as_rust_type { ($data_type_option:ident, $data_value:ident, Blob) => { match $data_type_option.id { ColType::Blob => as_res_opt!($data_value, decode_blob), ColType::Custom => { let unmarshal = || { if let Some(crate::frame::message_result::ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.BytesType" { return as_res_opt!($data_value, decode_blob); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into Vec (valid types: org.apache.cassandra.db.marshal.BytesType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into Vec (valid types: Blob).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, String) => { match $data_type_option.id { ColType::Custom => as_res_opt!($data_value, decode_custom), ColType::Ascii => as_res_opt!($data_value, decode_ascii), ColType::Varchar => as_res_opt!($data_value, decode_varchar), // TODO: clarify when to use decode_text. // it's not mentioned in // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L582 // ColType::XXX => decode_text($data_value)? _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into String (valid types: Custom, Ascii, Varchar).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, bool) => { match $data_type_option.id { ColType::Boolean => as_res_opt!($data_value, decode_boolean), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.BooleanType" { return as_res_opt!($data_value, decode_boolean); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into bool (valid types: org.apache.cassandra.db.marshal.BooleanType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into bool (valid types: Boolean, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, i64) => { match $data_type_option.id { ColType::Bigint => as_res_opt!($data_value, decode_bigint), ColType::Timestamp => as_res_opt!($data_value, decode_timestamp), ColType::Time => as_res_opt!($data_value, decode_time), ColType::Counter => as_res_opt!($data_value, decode_bigint), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { match value.as_str() { "org.apache.cassandra.db.marshal.LongType" | "org.apache.cassandra.db.marshal.CounterColumnType" => return as_res_opt!($data_value, decode_bigint), "org.apache.cassandra.db.marshal.TimestampType" => return as_res_opt!($data_value, decode_timestamp), "org.apache.cassandra.db.marshal.TimeType" => return as_res_opt!($data_value, decode_time), _ => {} } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into i64 (valid types: org.apache.cassandra.db.marshal.{{LongType|IntegerType|CounterColumnType|TimestampType|TimeType}}).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into i64 (valid types: Bigint, Timestamp, Time,\ Counter, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, i32) => { match $data_type_option.id { ColType::Int => as_res_opt!($data_value, decode_int), ColType::Date => as_res_opt!($data_value, decode_date), ColType::Custom => { let unmarshal = || { if let Some(crate::frame::message_result::ColTypeOptionValue::CString(value)) = &$data_type_option.value { match value.as_str() { "org.apache.cassandra.db.marshal.Int32Type" => return as_res_opt!($data_value, decode_int), "org.apache.cassandra.db.marshal.SimpleDateType" => return as_res_opt!($data_value, decode_date), _ => {} } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into i32 (valid types: org.apache.cassandra.db.marshal.Int32Type).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into i32 (valid types: Int, Date, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, i16) => { match $data_type_option.id { ColType::Smallint => as_res_opt!($data_value, decode_smallint), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.ShortType" { return as_res_opt!($data_value, decode_smallint); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into i16 (valid types: org.apache.cassandra.db.marshal.ShortType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into i16 (valid types: Smallint, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, i8) => { match $data_type_option.id { ColType::Tinyint => as_res_opt!($data_value, decode_tinyint), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.ByteType" { return as_res_opt!($data_value, decode_tinyint); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into i8 (valid types: org.apache.cassandra.db.marshal.ByteType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into i8 (valid types: Tinyint, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, NonZeroI64) => { match $data_type_option.id { ColType::Bigint => { as_res_opt!($data_value, decode_bigint).map(|value| value.and_then(NonZeroI64::new)) } ColType::Timestamp => as_res_opt!($data_value, decode_timestamp) .map(|value| value.and_then(NonZeroI64::new)), ColType::Time => { as_res_opt!($data_value, decode_time).map(|value| value.and_then(NonZeroI64::new)) } ColType::Counter => { as_res_opt!($data_value, decode_bigint).map(|value| value.and_then(NonZeroI64::new)) } ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { match value.as_str() { "org.apache.cassandra.db.marshal.LongType" | "org.apache.cassandra.db.marshal.CounterColumnType" => return as_res_opt!($data_value, decode_bigint), "org.apache.cassandra.db.marshal.TimestampType" => return as_res_opt!($data_value, decode_timestamp), "org.apache.cassandra.db.marshal.TimeType" => return as_res_opt!($data_value, decode_time), _ => {} } } Err(Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into i64 (valid types: org.apache.cassandra.db.marshal.{{LongType|IntegerType|CounterColumnType|TimestampType|TimeType}}).", $data_type_option ))) }; unmarshal().map(|value| value.and_then(NonZeroI64::new)) } _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into NonZeroI64 (valid types: Bigint, Timestamp, Time,\ Counter, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, NonZeroI32) => { match $data_type_option.id { ColType::Int => { as_res_opt!($data_value, decode_int).map(|value| value.and_then(NonZeroI32::new)) } ColType::Date => { as_res_opt!($data_value, decode_date).map(|value| value.and_then(NonZeroI32::new)) } ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { match value.as_str() { "org.apache.cassandra.db.marshal.Int32Type" => return as_res_opt!($data_value, decode_int), "org.apache.cassandra.db.marshal.SimpleDateType" => return as_res_opt!($data_value, decode_date), _ => {} } } Err(Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into i32 (valid types: org.apache.cassandra.db.marshal.Int32Type).", $data_type_option ))) }; unmarshal().map(|value| value.and_then(NonZeroI32::new)) } _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into NonZeroI32 (valid types: Int, Date, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, NonZeroI16) => { match $data_type_option.id { ColType::Smallint => as_res_opt!($data_value, decode_smallint) .map(|value| value.and_then(NonZeroI16::new)), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.ShortType" { return as_res_opt!($data_value, decode_smallint); } } Err(Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into NonZeroI16 (valid types: org.apache.cassandra.db.marshal.ShortType).", $data_type_option ))) }; unmarshal().map(|value| value.and_then(NonZeroI16::new)) } _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into NonZeroI16 (valid types: Smallint, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, NonZeroI8) => { match $data_type_option.id { ColType::Tinyint => { as_res_opt!($data_value, decode_tinyint).map(|value| value.and_then(NonZeroI8::new)) } ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.ByteType" { return as_res_opt!($data_value, decode_tinyint); } } Err(Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into NonZeroI8 (valid types: org.apache.cassandra.db.marshal.ByteType).", $data_type_option ))) }; unmarshal().map(|value| value.and_then(NonZeroI8::new)) } _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into NonZeroI8 (valid types: Tinyint, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, f64) => { match $data_type_option.id { ColType::Double => as_res_opt!($data_value, decode_double), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.DoubleType" { return as_res_opt!($data_value, decode_double); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into f64 (valid types: org.apache.cassandra.db.marshal.DoubleType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into f64 (valid types: Double, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, f32) => { match $data_type_option.id { ColType::Float => as_res_opt!($data_value, decode_float), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.FloatType" { return as_res_opt!($data_value, decode_float); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into f32 (valid types: org.apache.cassandra.db.marshal.FloatType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into f32 (valid types: Float, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, IpAddr) => { match $data_type_option.id { ColType::Inet => as_res_opt!($data_value, decode_inet), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.InetAddressType" { return as_res_opt!($data_value, decode_inet); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into IpAddr (valid types: org.apache.cassandra.db.marshal.InetAddressType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into IpAddr (valid types: Inet, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, Uuid) => { match $data_type_option.id { ColType::Uuid | ColType::Timeuuid => as_res_opt!($data_value, decode_timeuuid), ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { match value.as_str() { "org.apache.cassandra.db.marshal.UUIDType" | "org.apache.cassandra.db.marshal.TimeUUIDType" => return as_res_opt!($data_value, decode_timeuuid), _ => {} } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshaled type {:?} into Uuid (valid types: org.apache.cassandra.db.marshal.{{UUIDType|TimeUUIDType}}).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into Uuid (valid types: Uuid, Timeuuid, Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, $version:ident, List) => { match $data_type_option.id { ColType::List | ColType::Set => match $data_value.as_slice() { Some(ref bytes) => decode_list(bytes, $version) .map(|data| Some(List::new($data_type_option.clone(), data, $version))) .map_err(Into::into), None => Ok(None), }, _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into List (valid types: List, Set).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, $version:ident, Vector) => { match $data_type_option.id { ColType::Custom => match $data_value.as_slice() { Some(ref bytes) => { let crate::types::vector::VectorInfo { internal_type: _, count } = crate::types::vector::get_vector_type_info($data_type_option.value.as_ref()?)?; decode_float_vector(bytes, $version, count) .map(|data| Some(Vector::new($data_type_option.clone(), data, $version))) .map_err(Into::into) }, None => Ok(None), }, _ => Err(crate::error::Error::(format!( "Invalid conversion. \ Cannot convert {:?} into Vector (valid types: Custom).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, $version:ident, Map) => { match $data_type_option.id { ColType::Map => match $data_value.as_slice() { Some(ref bytes) => decode_map(bytes, $version) .map(|data| Some(Map::new(data, $data_type_option.clone(), $version))) .map_err(Into::into), None => Ok(None), }, _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into Map (valid types: Map).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, $version:ident, Udt) => { match *$data_type_option { ColTypeOption { id: ColType::Udt, value: Some(ColTypeOptionValue::UdtType(ref list_type_option)), } => match $data_value.as_slice() { Some(ref bytes) => decode_udt(bytes, list_type_option.descriptions.len(), $version) .map(|data| Some(Udt::new(data, list_type_option, $version))) .map_err(Into::into), None => Ok(None), }, _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into Udt (valid types: Udt).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, $version:ident, Tuple) => { match *$data_type_option { ColTypeOption { id: ColType::Tuple, value: Some(ColTypeOptionValue::TupleType(ref list_type_option)), } => match $data_value.as_slice() { Some(ref bytes) => decode_tuple(bytes, list_type_option.types.len(), $version) .map(|data| Some(Tuple::new(data, list_type_option, $version))) .map_err(Into::into), None => Ok(None), }, _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into Tuple (valid types: Tuple).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, PrimitiveDateTime) => { match $data_type_option.id { ColType::Timestamp => match $data_value.as_slice() { Some(ref bytes) => decode_timestamp(bytes) .map(|ts| { let unix_epoch = time::macros::date!(1970 - 01 - 01).midnight(); let tm = unix_epoch + time::Duration::new(ts / 1_000, (ts % 1_000 * 1_000_000) as i32); Some(tm) }) .map_err(Into::into), None => Ok(None), }, _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into PrimitiveDateTime (valid types: Timestamp).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, Decimal) => { match $data_type_option.id { ColType::Decimal => match $data_value.as_slice() { Some(ref bytes) => decode_decimal(bytes).map(Some).map_err(Into::into), None => Ok(None), }, _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into Decimal (valid types: Decimal).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, NaiveDateTime) => { match $data_type_option.id { ColType::Timestamp => match $data_value.as_slice() { Some(ref bytes) => decode_timestamp(bytes) .map(|ts| { DateTime::from_timestamp(ts / 1000, (ts % 1000 * 1_000_000) as u32) .map(|dt| dt.naive_utc()) }) .map_err(Into::into), None => Ok(None), }, _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into NaiveDateTime (valid types: Timestamp).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, DateTime) => { match $data_type_option.id { ColType::Timestamp => match $data_value.as_slice() { Some(ref bytes) => decode_timestamp(bytes) .map(|ts| { DateTime::from_timestamp( ts / 1000, (ts % 1000 * 1_000_000) as u32, ) }) .map_err(Into::into), None => Ok(None), }, _ => Err(Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into DateTime (valid types: Timestamp).", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, BigInt) => { match $data_type_option.id { ColType::Varint => { as_res_opt!($data_value, decode_varint) } ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.IntegerType" { return as_res_opt!($data_value, decode_varint); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshalled type {:?} into BigInt (valid types: org.apache.cassandra.db.marshal.IntegerType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into BigInt (valid types: Varint, Custom)", $data_type_option.id ))), } }; ($data_type_option:ident, $data_value:ident, Duration) => { match $data_type_option.id { ColType::Duration => { as_res_opt!($data_value, decode_duration) } ColType::Custom => { let unmarshal = || { if let Some(ColTypeOptionValue::CString(value)) = &$data_type_option.value { if value.as_str() == "org.apache.cassandra.db.marshal.DurationType" { return as_res_opt!($data_value, decode_duration); } } Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert marshalled type {:?} into Duration (valid types: org.apache.cassandra.db.marshal.DurationType).", $data_type_option ))) }; unmarshal() } _ => Err(crate::error::Error::General(format!( "Invalid conversion. \ Cannot convert {:?} into Duration (valid types: Duration, Custom)", $data_type_option.id ))), } }; } ================================================ FILE: cassandra-protocol/src/query/batch_query_builder.rs ================================================ use crate::consistency::Consistency; use crate::error::{Error as CError, Result as CResult}; use crate::frame::message_batch::{BatchQuery, BatchQuerySubj, BatchType, BodyReqBatch}; use crate::query::{PreparedQuery, QueryValues}; use crate::types::{CBytesShort, CInt, CLong}; use derivative::Derivative; use derive_more::Constructor; use std::collections::HashMap; #[derive(Debug, Clone, PartialEq, Eq, Constructor)] pub struct QueryBatchPreparedStatement { pub query: String, pub keyspace: Option, } #[derive(Debug, Clone, PartialEq, Eq, Constructor, Derivative)] pub struct QueryBatch { pub request: BodyReqBatch, #[derivative(Debug = "ignore")] pub prepared_queries: HashMap, } #[derive(Debug)] pub struct BatchQueryBuilder { batch_type: BatchType, queries: Vec, prepared_queries: HashMap, consistency: Consistency, serial_consistency: Option, timestamp: Option, keyspace: Option, now_in_seconds: Option, } impl Default for BatchQueryBuilder { fn default() -> Self { BatchQueryBuilder { batch_type: BatchType::Logged, queries: vec![], prepared_queries: HashMap::new(), consistency: Consistency::One, serial_consistency: None, timestamp: None, keyspace: None, now_in_seconds: None, } } } impl BatchQueryBuilder { pub fn new() -> BatchQueryBuilder { Default::default() } #[must_use] pub fn with_batch_type(mut self, batch_type: BatchType) -> Self { self.batch_type = batch_type; self } /// Add a query (non-prepared one) #[must_use] pub fn add_query>(mut self, query: T, values: QueryValues) -> Self { self.queries.push(BatchQuery { subject: BatchQuerySubj::QueryString(query.into()), values, }); self } /// Add a query (prepared one) #[must_use] pub fn add_query_prepared(mut self, query: &PreparedQuery, values: QueryValues) -> Self { self.queries.push(BatchQuery { subject: BatchQuerySubj::PreparedId(query.id.clone()), values, }); self.prepared_queries.insert( query.id.clone(), QueryBatchPreparedStatement::new(query.query.clone(), query.keyspace.clone()), ); self } #[must_use] pub fn clear_queries(mut self) -> Self { self.queries = vec![]; self } #[must_use] pub fn with_consistency(mut self, consistency: Consistency) -> Self { self.consistency = consistency; self } #[must_use] pub fn with_serial_consistency(mut self, serial_consistency: Consistency) -> Self { self.serial_consistency = Some(serial_consistency); self } #[must_use] pub fn with_timestamp(mut self, timestamp: CLong) -> Self { self.timestamp = Some(timestamp); self } #[must_use] pub fn with_keyspace(mut self, keyspace: String) -> Self { self.keyspace = Some(keyspace); self } #[must_use] pub fn with_now_in_seconds(mut self, now_in_seconds: CInt) -> Self { self.now_in_seconds = Some(now_in_seconds); self } pub fn build(self) -> CResult { let with_names_for_values = self.queries.iter().all(|q| q.values.has_names()); if !with_names_for_values { let some_names_for_values = self.queries.iter().any(|q| q.values.has_names()); if some_names_for_values { return Err(CError::General(String::from( "Inconsistent query values - mixed with and without names values", ))); } } Ok(QueryBatch::new( BodyReqBatch { batch_type: self.batch_type, queries: self.queries, consistency: self.consistency, serial_consistency: self.serial_consistency, timestamp: self.timestamp, keyspace: self.keyspace, now_in_seconds: self.now_in_seconds, }, self.prepared_queries, )) } } ================================================ FILE: cassandra-protocol/src/query/prepare_flags.rs ================================================ use bitflags::bitflags; use std::io::{Cursor, Read}; use crate::error::Result; use crate::frame::{FromCursor, Serialize, Version}; use crate::types::INT_LEN; bitflags! { pub struct PrepareFlags: u32 { /// The prepare request contains explicit keyspace. const WITH_KEYSPACE = 0x01; } } impl Default for PrepareFlags { #[inline] fn default() -> Self { PrepareFlags::empty() } } impl Serialize for PrepareFlags { #[inline] fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.bits().serialize(cursor, version); } } impl FromCursor for PrepareFlags { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> Result { let mut buff = [0; INT_LEN]; cursor .read_exact(&mut buff) .map(|()| PrepareFlags::from_bits_truncate(u32::from_be_bytes(buff))) .map_err(|error| error.into()) } } ================================================ FILE: cassandra-protocol/src/query/prepared_query.rs ================================================ use arc_swap::ArcSwapOption; use std::cmp::Ordering; use std::hash::{Hash, Hasher}; use crate::types::CBytesShort; #[derive(Debug)] pub struct PreparedQuery { pub id: CBytesShort, pub query: String, pub keyspace: Option, pub pk_indexes: Vec, pub result_metadata_id: ArcSwapOption, } impl Clone for PreparedQuery { fn clone(&self) -> Self { Self { id: self.id.clone(), query: self.query.clone(), keyspace: self.keyspace.clone(), pk_indexes: self.pk_indexes.clone(), result_metadata_id: ArcSwapOption::new(self.result_metadata_id.load().clone()), } } } impl PartialEq for PreparedQuery { #[inline] fn eq(&self, other: &Self) -> bool { self.id == other.id && *self.result_metadata_id.load() == *other.result_metadata_id.load() } } impl Eq for PreparedQuery {} impl PartialOrd for PreparedQuery { #[inline] fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for PreparedQuery { #[inline] fn cmp(&self, other: &Self) -> Ordering { match self.id.cmp(&other.id) { Ordering::Equal => self .result_metadata_id .load() .cmp(&other.result_metadata_id.load()), result => result, } } } impl Hash for PreparedQuery { #[inline] fn hash(&self, state: &mut H) { self.id.hash(state); self.result_metadata_id.load().hash(state); } } ================================================ FILE: cassandra-protocol/src/query/query_flags.rs ================================================ use bitflags::bitflags; use std::io::{Cursor, Read}; use crate::error; use crate::frame::{FromCursor, Serialize, Version}; use crate::types::INT_LEN; bitflags! { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct QueryFlags: u32 { /// Indicates that Query Params contain value. const VALUE = 0x001; /// Indicates that Query Params does not contain metadata. const SKIP_METADATA = 0x002; /// Indicates that Query Params contain page size. const PAGE_SIZE = 0x004; /// Indicates that Query Params contain paging state. const WITH_PAGING_STATE = 0x008; /// Indicates that Query Params contain serial consistency. const WITH_SERIAL_CONSISTENCY = 0x010; /// Indicates that Query Params contain default timestamp. const WITH_DEFAULT_TIMESTAMP = 0x020; /// Indicates that Query Params values are named ones. const WITH_NAMES_FOR_VALUES = 0x040; /// Indicates that Query Params contain keyspace name. const WITH_KEYSPACE = 0x080; /// Indicates that Query Params contain "now" in seconds. const WITH_NOW_IN_SECONDS = 0x100; } } impl Default for QueryFlags { #[inline] fn default() -> Self { QueryFlags::empty() } } impl Serialize for QueryFlags { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { if version >= Version::V5 { self.bits().serialize(cursor, version); } else { (self.bits() as u8).serialize(cursor, version); } } } impl FromCursor for QueryFlags { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result { if version >= Version::V5 { let mut buff = [0; INT_LEN]; cursor .read_exact(&mut buff) .map(|()| QueryFlags::from_bits_truncate(u32::from_be_bytes(buff))) .map_err(|error| error.into()) } else { let mut buff = [0]; cursor.read_exact(&mut buff)?; Ok(QueryFlags::from_bits_truncate(buff[0] as u32)) } } } ================================================ FILE: cassandra-protocol/src/query/query_params.rs ================================================ use std::collections::HashMap; use std::io::Cursor; use crate::consistency::Consistency; use crate::frame::traits::FromCursor; use crate::frame::{Serialize, Version}; use crate::query::query_flags::QueryFlags; use crate::query::query_values::QueryValues; use crate::types::{from_cursor_str, serialize_str, value::Value, CInt, CIntShort}; use crate::types::{CBytes, CLong}; use crate::Error; /// Parameters of Query for query operation. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct QueryParams { /// Cassandra consistency level. pub consistency: Consistency, /// Were values provided with names pub with_names: bool, /// Array of values. pub values: Option, /// Page size. pub page_size: Option, /// Array of bytes which represents paging state. pub paging_state: Option, /// Serial `Consistency`. pub serial_consistency: Option, /// Timestamp. pub timestamp: Option, /// Keyspace indicating the keyspace that the query should be executed in. It supersedes the /// keyspace that the connection is bound to, if any. pub keyspace: Option, /// Represents the current time (now) for the query. Affects TTL cell liveness in read queries /// and local deletion time for tombstones and TTL cells in update requests. pub now_in_seconds: Option, } impl QueryParams { fn flags(&self) -> QueryFlags { let mut flags = QueryFlags::empty(); if self.values.is_some() { flags.insert(QueryFlags::VALUE); } if self.with_names { flags.insert(QueryFlags::WITH_NAMES_FOR_VALUES); } if self.page_size.is_some() { flags.insert(QueryFlags::PAGE_SIZE); } if self.paging_state.is_some() { flags.insert(QueryFlags::WITH_PAGING_STATE); } if self.serial_consistency.is_some() { flags.insert(QueryFlags::WITH_SERIAL_CONSISTENCY); } if self.timestamp.is_some() { flags.insert(QueryFlags::WITH_DEFAULT_TIMESTAMP); } if self.keyspace.is_some() { flags.insert(QueryFlags::WITH_KEYSPACE); } if self.now_in_seconds.is_some() { flags.insert(QueryFlags::WITH_NOW_IN_SECONDS); } flags } } impl Serialize for QueryParams { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { let consistency: CIntShort = self.consistency.into(); consistency.serialize(cursor, version); let flag_bits = self.flags().bits(); if version >= Version::V5 { flag_bits.serialize(cursor, version); } else { (flag_bits as u8).serialize(cursor, version); }; if let Some(values) = &self.values { let len = values.len() as CIntShort; len.serialize(cursor, version); values.serialize(cursor, version); } if let Some(page_size) = self.page_size { page_size.serialize(cursor, version); } if let Some(paging_state) = &self.paging_state { paging_state.serialize(cursor, version); } if let Some(serial_consistency) = self.serial_consistency { let serial_consistency: CIntShort = serial_consistency.into(); serial_consistency.serialize(cursor, version); } if let Some(timestamp) = self.timestamp { timestamp.serialize(cursor, version); } if let Some(keyspace) = &self.keyspace { serialize_str(cursor, keyspace.as_str(), version); } if let Some(now_in_seconds) = self.now_in_seconds { now_in_seconds.serialize(cursor, version); } } } impl FromCursor for QueryParams { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> Result { let consistency = Consistency::from_cursor(cursor, version)?; let flags = QueryFlags::from_cursor(cursor, version)?; let values = if flags.contains(QueryFlags::VALUE) { let number_of_values = CIntShort::from_cursor(cursor, version)?; if flags.contains(QueryFlags::WITH_NAMES_FOR_VALUES) { let mut map = HashMap::with_capacity(number_of_values as usize); for _ in 0..number_of_values { map.insert( from_cursor_str(cursor)?.to_string(), Value::from_cursor(cursor, version)?, ); } Some(QueryValues::NamedValues(map)) } else { let mut vec = Vec::with_capacity(number_of_values as usize); for _ in 0..number_of_values { vec.push(Value::from_cursor(cursor, version)?); } Some(QueryValues::SimpleValues(vec)) } } else { None }; let page_size = if flags.contains(QueryFlags::PAGE_SIZE) { Some(CInt::from_cursor(cursor, version)?) } else { None }; let paging_state = if flags.contains(QueryFlags::WITH_PAGING_STATE) { Some(CBytes::from_cursor(cursor, version)?) } else { None }; let serial_consistency = if flags.contains(QueryFlags::WITH_SERIAL_CONSISTENCY) { Some(Consistency::from_cursor(cursor, version)?) } else { None }; let timestamp = if flags.contains(QueryFlags::WITH_DEFAULT_TIMESTAMP) { Some(CLong::from_cursor(cursor, version)?) } else { None }; let keyspace = if flags.contains(QueryFlags::WITH_KEYSPACE) { Some(from_cursor_str(cursor)?.to_string()) } else { None }; let now_in_seconds = if flags.contains(QueryFlags::WITH_NOW_IN_SECONDS) { Some(CInt::from_cursor(cursor, version)?) } else { None }; let with_names = flags.contains(QueryFlags::WITH_NAMES_FOR_VALUES); Ok(QueryParams { consistency, with_names, values, page_size, paging_state, serial_consistency, timestamp, keyspace, now_in_seconds, }) } } ================================================ FILE: cassandra-protocol/src/query/query_params_builder.rs ================================================ use super::{QueryFlags, QueryParams, QueryValues}; use crate::consistency::Consistency; use crate::types::{CBytes, CInt, CLong}; #[derive(Debug, Default)] pub struct QueryParamsBuilder { consistency: Consistency, flags: Option, values: Option, with_names: bool, page_size: Option, paging_state: Option, serial_consistency: Option, timestamp: Option, keyspace: Option, now_in_seconds: Option, } impl QueryParamsBuilder { /// Factory function that returns new `QueryBuilder`. pub fn new() -> QueryParamsBuilder { Default::default() } /// Sets new query consistency #[must_use] pub fn with_consistency(mut self, consistency: Consistency) -> Self { self.consistency = consistency; self } /// Sets new flags. #[must_use] pub fn with_flags(mut self, flags: QueryFlags) -> Self { self.flags = Some(flags); self } /// Sets new query values. #[must_use] pub fn with_values(mut self, values: QueryValues) -> Self { self.with_names = values.has_names(); self.values = Some(values); self.flags = self.flags.or_else(|| { let mut flags = QueryFlags::VALUE; if self.with_names { flags.insert(QueryFlags::WITH_NAMES_FOR_VALUES); } Some(flags) }); self } /// Sets the "with names for values" flag #[must_use] pub fn with_names(mut self, with_names: bool) -> Self { self.with_names = with_names; self } /// Sets new query consistency. #[must_use] pub fn with_page_size(mut self, size: CInt) -> Self { self.page_size = Some(size); self.flags = self.flags.or(Some(QueryFlags::PAGE_SIZE)); self } /// Sets new query consistency. #[must_use] pub fn with_paging_state(mut self, state: CBytes) -> Self { self.paging_state = Some(state); self.flags = self.flags.or(Some(QueryFlags::WITH_PAGING_STATE)); self } /// Sets new serial consistency. #[must_use] pub fn with_serial_consistency(mut self, serial_consistency: Consistency) -> Self { self.serial_consistency = Some(serial_consistency); self } /// Sets new timestamp. #[must_use] pub fn with_timestamp(mut self, timestamp: CLong) -> Self { self.timestamp = Some(timestamp); self } /// Overrides used keyspace. #[must_use] pub fn with_keyspace(mut self, keyspace: String) -> Self { self.keyspace = Some(keyspace); self } /// Sets "now" in seconds. #[must_use] pub fn with_now_in_seconds(mut self, now_in_seconds: CInt) -> Self { self.now_in_seconds = Some(now_in_seconds); self } /// Finalizes query building process and returns query itself #[must_use] pub fn build(self) -> QueryParams { QueryParams { consistency: self.consistency, values: self.values, with_names: self.with_names, page_size: self.page_size, paging_state: self.paging_state, serial_consistency: self.serial_consistency, timestamp: self.timestamp, keyspace: self.keyspace, now_in_seconds: self.now_in_seconds, } } } ================================================ FILE: cassandra-protocol/src/query/query_values.rs ================================================ use itertools::Itertools; use std::collections::HashMap; use std::io::Cursor; use crate::frame::{Serialize, Version}; use crate::types::serialize_str; use crate::types::value::Value; /// Enum that represents two types of query values: /// * values without name /// * values with names #[derive(Debug, Clone, PartialEq, Eq)] pub enum QueryValues { SimpleValues(Vec), NamedValues(HashMap), } impl QueryValues { /// Returns `true` if query values is with names and `false` otherwise. #[inline] pub fn has_names(&self) -> bool { !matches!(*self, QueryValues::SimpleValues(_)) } /// Returns the number of values. pub fn len(&self) -> usize { match *self { QueryValues::SimpleValues(ref v) => v.len(), QueryValues::NamedValues(ref m) => m.len(), } } #[inline] pub fn is_empty(&self) -> bool { self.len() == 0 } } impl> From> for QueryValues { /// Converts values from `Vec` to query values without names `QueryValues::SimpleValues`. fn from(values: Vec) -> QueryValues { let vals = values.into_iter().map_into(); QueryValues::SimpleValues(vals.collect()) } } impl + Clone> From<&[T]> for QueryValues { /// Converts values from `Vec` to query values without names `QueryValues::SimpleValues`. fn from(values: &[T]) -> QueryValues { let values = values.iter().map(|v| v.clone().into()); QueryValues::SimpleValues(values.collect()) } } impl> From> for QueryValues { /// Converts values from `HashMap` to query values with names `QueryValues::NamedValues`. fn from(values: HashMap) -> QueryValues { let mut map = HashMap::with_capacity(values.len()); for (name, val) in values { map.insert(name.to_string(), val.into()); } QueryValues::NamedValues(map) } } impl Serialize for QueryValues { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { QueryValues::SimpleValues(v) => { for value in v { value.serialize(cursor, version); } } QueryValues::NamedValues(v) => { for (key, value) in v { serialize_str(cursor, key, version); value.serialize(cursor, version); } } } } } ================================================ FILE: cassandra-protocol/src/query/utils.rs ================================================ /// Returns the identifier in a format appropriate for concatenation in a CQL query. #[inline] pub fn quote(text: &str) -> String { format!("\"{}\"", text.replace('"', "\"\"")) } ================================================ FILE: cassandra-protocol/src/query.rs ================================================ pub mod batch_query_builder; pub mod prepare_flags; pub mod prepared_query; pub mod query_flags; pub mod query_params; pub mod query_params_builder; pub mod query_values; pub mod utils; pub use crate::query::batch_query_builder::{BatchQueryBuilder, QueryBatch}; pub use crate::query::prepare_flags::PrepareFlags; pub use crate::query::prepared_query::PreparedQuery; pub use crate::query::query_flags::QueryFlags; pub use crate::query::query_params::QueryParams; pub use crate::query::query_params_builder::QueryParamsBuilder; pub use crate::query::query_values::QueryValues; ================================================ FILE: cassandra-protocol/src/token.rs ================================================ use crate::error::Error; use bytes::Buf; use derive_more::Constructor; use std::cmp::min; use std::convert::TryFrom; use std::num::Wrapping; const C1: Wrapping = Wrapping(0x87c3_7b91_1142_53d5_u64 as i64); const C2: Wrapping = Wrapping(0x4cf5_ad43_2745_937f_u64 as i64); /// A token on the ring. Only Murmur3 tokens are supported for now. #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Default, Debug, Hash, Constructor)] pub struct Murmur3Token { pub value: i64, } impl Murmur3Token { // based on buggy Cassandra implementation pub fn generate(mut routing_key: &[u8]) -> Self { let length = routing_key.len(); let mut h1: Wrapping = Wrapping(0); let mut h2: Wrapping = Wrapping(0); while routing_key.len() >= 16 { let mut k1 = Wrapping(routing_key.get_i64_le()); let mut k2 = Wrapping(routing_key.get_i64_le()); k1 *= C1; k1 = rotl64(k1, 31); k1 *= C2; h1 ^= k1; h1 = rotl64(h1, 27); h1 += h2; h1 = h1 * Wrapping(5) + Wrapping(0x52dce729); k2 *= C2; k2 = rotl64(k2, 33); k2 *= C1; h2 ^= k2; h2 = rotl64(h2, 31); h2 += h1; h2 = h2 * Wrapping(5) + Wrapping(0x38495ab5); } let mut k1 = Wrapping(0_i64); let mut k2 = Wrapping(0_i64); debug_assert!(routing_key.len() < 16); if routing_key.len() > 8 { for i in (8..routing_key.len()).rev() { k2 ^= Wrapping(routing_key[i] as i8 as i64) << ((i - 8) * 8); } k2 *= C2; k2 = rotl64(k2, 33); k2 *= C1; h2 ^= k2; } if !routing_key.is_empty() { for i in (0..min(8, routing_key.len())).rev() { k1 ^= Wrapping(routing_key[i] as i8 as i64) << (i * 8); } k1 *= C1; k1 = rotl64(k1, 31); k1 *= C2; h1 ^= k1; } h1 ^= Wrapping(length as i64); h2 ^= Wrapping(length as i64); h1 += h2; h2 += h1; h1 = fmix(h1); h2 = fmix(h2); h1 += h2; Murmur3Token::new(h1.0) } } impl TryFrom for Murmur3Token { type Error = Error; fn try_from(value: String) -> Result { value .parse() .map_err(|error| format!("Error parsing token: {error}").into()) .map(Murmur3Token::new) } } impl From for Murmur3Token { fn from(value: i64) -> Self { Murmur3Token::new(value) } } #[inline] fn rotl64(v: Wrapping, n: u32) -> Wrapping { Wrapping((v.0 << n) | (v.0 as u64 >> (64 - n)) as i64) } #[inline] fn fmix(mut k: Wrapping) -> Wrapping { k ^= Wrapping((k.0 as u64 >> 33) as i64); k *= Wrapping(0xff51afd7ed558ccd_u64 as i64); k ^= Wrapping((k.0 as u64 >> 33) as i64); k *= Wrapping(0xc4ceb9fe1a85ec53_u64 as i64); k ^= Wrapping((k.0 as u64 >> 33) as i64); k } #[cfg(test)] mod test { use super::*; #[test] fn test_generate_murmur3_token() { for s in [ ("testvalue", 5965290492934326460), ("testvalue123", 1518494936189046133), ("example_key", -7813763279771224608), ("château", 9114062196463836094), ] { let generated_token = Murmur3Token::generate(s.0.as_bytes()); assert_eq!(generated_token.value, s.1); } } } ================================================ FILE: cassandra-protocol/src/types/blob.rs ================================================ use derive_more::Constructor; /// Special type that represents Cassandra blob type. #[derive(PartialEq, Eq, Hash, Debug, Clone, Constructor)] #[repr(transparent)] pub struct Blob(Vec); impl Blob { /// Returns a mutable reference to an underlying slice of bytes. #[inline] pub fn as_mut_slice(&mut self) -> &[u8] { self.0.as_mut_slice() } /// Returns underlying vector of bytes. #[inline] pub fn into_vec(self) -> Vec { self.0 } } impl From> for Blob { #[inline] fn from(vec: Vec) -> Self { Blob::new(vec) } } impl From<&[u8]> for Blob { #[inline] fn from(value: &[u8]) -> Self { Blob::new(value.to_vec()) } } ================================================ FILE: cassandra-protocol/src/types/cassandra_type.rs ================================================ use num_bigint::BigInt; use std::collections::HashMap; use std::net::IpAddr; use super::prelude::{Blob, Decimal, Duration}; use crate::error::Result as CDRSResult; use crate::frame::message_result::{ColType, ColTypeOption}; use crate::frame::Version; use crate::types::CBytes; #[derive(Debug, PartialEq, Clone)] #[non_exhaustive] pub enum CassandraType { Ascii(String), Bigint(i64), Blob(Blob), Boolean(bool), Counter(i64), Decimal(Decimal), Double(f64), Float(f32), Int(i32), Timestamp(i64), Uuid(uuid::Uuid), Varchar(String), Varint(BigInt), Timeuuid(uuid::Uuid), Inet(IpAddr), Date(i32), Time(i64), Smallint(i16), Tinyint(i8), Duration(Duration), List(Vec), Map(Vec<(CassandraType, CassandraType)>), Set(Vec), Udt(HashMap), Tuple(Vec), Vector(Vec), Null, } /// Get a function to convert `CBytes` and `ColTypeOption` into a `CassandraType` pub fn wrapper_fn( col_type: &ColType, ) -> &'static dyn Fn(&CBytes, &ColTypeOption, Version) -> CDRSResult { match col_type { ColType::Blob => &wrappers::blob, ColType::Ascii => &wrappers::ascii, ColType::Int => &wrappers::int, ColType::List => &wrappers::list, ColType::Custom => &wrappers::custom, ColType::Bigint => &wrappers::bigint, ColType::Boolean => &wrappers::bool, ColType::Counter => &wrappers::counter, ColType::Decimal => &wrappers::decimal, ColType::Double => &wrappers::double, ColType::Float => &wrappers::float, ColType::Timestamp => &wrappers::timestamp, ColType::Uuid => &wrappers::uuid, ColType::Varchar => &wrappers::varchar, ColType::Varint => &wrappers::varint, ColType::Timeuuid => &wrappers::timeuuid, ColType::Inet => &wrappers::inet, ColType::Date => &wrappers::date, ColType::Time => &wrappers::time, ColType::Smallint => &wrappers::smallint, ColType::Tinyint => &wrappers::tinyint, ColType::Duration => &wrappers::duration, ColType::Map => &wrappers::map, ColType::Set => &wrappers::set, ColType::Udt => &wrappers::udt, ColType::Tuple => &wrappers::tuple, } } pub mod wrappers { use super::CassandraType; use crate::error::Result as CDRSResult; use crate::frame::message_result::{ColType, ColTypeOption, ColTypeOptionValue}; use crate::frame::Version; use crate::types::data_serialization_types::*; use crate::types::list::List; use crate::types::vector::{get_vector_type_info, Vector, VectorInfo}; use crate::types::AsCassandraType; use crate::types::CBytes; use crate::types::{map::Map, tuple::Tuple, udt::Udt}; pub fn custom( bytes: &CBytes, col_type: &ColTypeOption, version: Version, ) -> CDRSResult { if let ColTypeOption { id: ColType::Custom, value: Some(value), } = col_type { let VectorInfo { internal_type: _, count, } = get_vector_type_info(value)?; if let Some(actual_bytes) = bytes.as_slice() { let vector = decode_float_vector(actual_bytes, version, count) .map(|data| Vector::new(col_type.clone(), data, version))? .as_cassandra_type()? .unwrap_or(CassandraType::Null); return Ok(vector); } } Ok(CassandraType::Null) } pub fn map( bytes: &CBytes, col_type: &ColTypeOption, version: Version, ) -> CDRSResult { if let Some(actual_bytes) = bytes.as_slice() { let decoded_map = decode_map(actual_bytes, version)?; Ok(Map::new(decoded_map, col_type.clone(), version) .as_cassandra_type()? .unwrap_or(CassandraType::Null)) } else { Ok(CassandraType::Null) } } pub fn set( bytes: &CBytes, col_type: &ColTypeOption, version: Version, ) -> CDRSResult { if let Some(actual_bytes) = bytes.as_slice() { let decoded_set = decode_set(actual_bytes, version)?; Ok(List::new(col_type.clone(), decoded_set, version) .as_cassandra_type()? .unwrap_or(CassandraType::Null)) } else { Ok(CassandraType::Null) } } pub fn udt( bytes: &CBytes, col_type: &ColTypeOption, version: Version, ) -> CDRSResult { if let Some(ColTypeOptionValue::UdtType(ref list_type_option)) = col_type.value { if let Some(actual_bytes) = bytes.as_slice() { let len = list_type_option.descriptions.len(); let decoded_udt = decode_udt(actual_bytes, len, version)?; return Ok(Udt::new(decoded_udt, list_type_option, version) .as_cassandra_type()? .unwrap_or(CassandraType::Null)); } } Ok(CassandraType::Null) } pub fn tuple( bytes: &CBytes, col_type: &ColTypeOption, version: Version, ) -> CDRSResult { if let Some(ColTypeOptionValue::TupleType(ref list_type_option)) = col_type.value { if let Some(actual_bytes) = bytes.as_slice() { let len = list_type_option.types.len(); let decoded_tuple = decode_tuple(actual_bytes, len, version)?; return Ok(Tuple::new(decoded_tuple, list_type_option, version) .as_cassandra_type()? .unwrap_or(CassandraType::Null)); } } Ok(CassandraType::Null) } pub fn null( _: &CBytes, _col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { Ok(CassandraType::Null) } pub fn blob( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, Blob)?; Ok(match t { Some(t) => CassandraType::Blob(t), None => CassandraType::Null, }) } pub fn ascii( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, String)?; Ok(match t { Some(t) => CassandraType::Ascii(t), None => CassandraType::Null, }) } pub fn int( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, i32)?; Ok(match t { Some(t) => CassandraType::Int(t), None => CassandraType::Null, }) } pub fn list( bytes: &CBytes, col_type: &ColTypeOption, version: Version, ) -> CDRSResult { let list = as_rust_type!(col_type, bytes, version, List)?; Ok(match list { Some(t) => t.as_cassandra_type()?.unwrap_or(CassandraType::Null), None => CassandraType::Null, }) } pub fn bigint( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, i64)?; Ok(match t { Some(t) => CassandraType::Bigint(t), None => CassandraType::Null, }) } pub fn counter( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, i64)?; Ok(match t { Some(t) => CassandraType::Counter(t), None => CassandraType::Null, }) } pub fn decimal( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, Decimal)?; Ok(match t { Some(t) => CassandraType::Decimal(t), None => CassandraType::Null, }) } pub fn double( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, f64)?; Ok(match t { Some(t) => CassandraType::Double(t), None => CassandraType::Null, }) } pub fn float( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, f32)?; Ok(match t { Some(t) => CassandraType::Float(t), None => CassandraType::Null, }) } pub fn timestamp( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, i64)?; Ok(match t { Some(t) => CassandraType::Timestamp(t), None => CassandraType::Null, }) } pub fn uuid( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, Uuid)?; Ok(match t { Some(t) => CassandraType::Uuid(t), None => CassandraType::Null, }) } pub fn varchar( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, String)?; Ok(match t { Some(t) => CassandraType::Varchar(t), None => CassandraType::Null, }) } pub fn varint( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, BigInt)?; Ok(match t { Some(t) => CassandraType::Varint(t), None => CassandraType::Null, }) } pub fn timeuuid( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, Uuid)?; Ok(match t { Some(t) => CassandraType::Timeuuid(t), None => CassandraType::Null, }) } pub fn inet( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, IpAddr)?; Ok(match t { Some(t) => CassandraType::Inet(t), None => CassandraType::Null, }) } pub fn date( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, i32)?; Ok(match t { Some(t) => CassandraType::Date(t), None => CassandraType::Null, }) } pub fn time( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, i64)?; Ok(match t { Some(t) => CassandraType::Time(t), None => CassandraType::Null, }) } pub fn smallint( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, i16)?; Ok(match t { Some(t) => CassandraType::Smallint(t), None => CassandraType::Null, }) } pub fn tinyint( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, i8)?; Ok(match t { Some(t) => CassandraType::Tinyint(t), None => CassandraType::Null, }) } pub fn bool( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, bool)?; Ok(match t { Some(t) => CassandraType::Boolean(t), None => CassandraType::Null, }) } pub fn duration( bytes: &CBytes, col_type: &ColTypeOption, _version: Version, ) -> CDRSResult { let t = as_rust_type!(col_type, bytes, Duration)?; Ok(match t { Some(t) => CassandraType::Duration(t), None => CassandraType::Null, }) } } ================================================ FILE: cassandra-protocol/src/types/data_serialization_types.rs ================================================ use integer_encoding::VarInt; use num_bigint::BigInt; use std::convert::TryInto; use std::io; use std::net; use std::string::FromUtf8Error; use super::blob::Blob; use super::decimal::Decimal; use super::duration::Duration; use crate::error; use crate::frame::{FromCursor, Version}; use crate::types::{ try_f32_from_bytes, try_f64_from_bytes, try_i16_from_bytes, try_i32_from_bytes, try_i64_from_bytes, CBytes, CInt, INT_LEN, }; // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L813 const FALSE_BYTE: u8 = 0; // Decodes Cassandra `custom` data (bytes) #[inline] pub fn decode_custom(bytes: &[u8]) -> Result { // Use from_utf8 (not from_utf8_lossy) so invalid input surfaces as an // error matching the function's signature, instead of silently // substituting replacement characters. String::from_utf8(bytes.to_vec()) } // Decodes Cassandra `ascii` data (bytes) #[inline] pub fn decode_ascii(bytes: &[u8]) -> Result { // ASCII is a subset of UTF-8; from_utf8 will accept any valid 7-bit // ASCII and reject anything outside it, instead of silently lossy- // converting bytes the server should never have sent. String::from_utf8(bytes.to_vec()) } // Decodes Cassandra `varchar` data (bytes) #[inline] pub fn decode_varchar(bytes: &[u8]) -> Result { // Use from_utf8 (not from_utf8_lossy): the function signature already // promises FromUtf8Error, and lossy conversion would silently corrupt // data instead of letting the caller decide how to handle invalid // input from the server. String::from_utf8(bytes.to_vec()) } // Decodes Cassandra `bigint` data (bytes) #[inline] pub fn decode_bigint(bytes: &[u8]) -> Result { try_i64_from_bytes(bytes) } // Decodes Cassandra `blob` data (bytes) #[inline] pub fn decode_blob(bytes: &[u8]) -> Result { // in fact we just pass it through. Ok(bytes.into()) } // Decodes Cassandra `boolean` data (bytes) #[inline] pub fn decode_boolean(bytes: &[u8]) -> Result { if bytes.is_empty() { Err(io::Error::new( io::ErrorKind::UnexpectedEof, "no bytes were found", )) } else { Ok(bytes[0] != FALSE_BYTE) } } // Decodes Cassandra `int` data (bytes) #[inline] pub fn decode_int(bytes: &[u8]) -> Result { try_i32_from_bytes(bytes) } // Decodes Cassandra `date` data (bytes) // 0: -5877641-06-23 // 2^31: 1970-1-1 // 2^32: 5881580-07-11 #[inline] pub fn decode_date(bytes: &[u8]) -> Result { try_i32_from_bytes(bytes) } // Decodes Cassandra `decimal` data (bytes) pub fn decode_decimal(bytes: &[u8]) -> Result { // wire format: 4-byte int scale followed by a variable-length signed // big-endian integer (no length prefix). At minimum we need the scale. if bytes.len() < INT_LEN { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "decimal requires at least 4 bytes for the scale", )); } let lr = bytes.split_at(INT_LEN); let scale = try_i32_from_bytes(lr.0)?; let unscaled = decode_varint(lr.1)?; Ok(Decimal::new(unscaled, scale)) } // Decodes Cassandra `double` data (bytes) #[inline] pub fn decode_double(bytes: &[u8]) -> Result { try_f64_from_bytes(bytes) } // Decodes Cassandra `float` data (bytes) #[inline] pub fn decode_float(bytes: &[u8]) -> Result { try_f32_from_bytes(bytes) } // Decodes Cassandra `inet` data (bytes) #[allow(clippy::many_single_char_names)] pub fn decode_inet(bytes: &[u8]) -> Result { match bytes.len() { // v4 4 => { let array: [u8; 4] = bytes[0..4].try_into().unwrap(); Ok(net::IpAddr::V4(net::Ipv4Addr::from(array))) } // v6 16 => { let array: [u8; 16] = bytes[0..16].try_into().unwrap(); Ok(net::IpAddr::V6(net::Ipv6Addr::from(array))) } _ => Err(io::Error::other(format!("Invalid Ip address {bytes:?}"))), } } // Decodes Cassandra `timestamp` data (bytes) into Rust's `Result` // `i32` represents a millisecond-precision // offset from the unix epoch (00:00:00, January 1st, 1970). Negative values // represent a negative offset from the epoch. #[inline] pub fn decode_timestamp(bytes: &[u8]) -> Result { try_i64_from_bytes(bytes) } // Decodes Cassandra `list` data (bytes) pub fn decode_list(bytes: &[u8], version: Version) -> Result, io::Error> { let mut cursor = io::Cursor::new(bytes); let l = CInt::from_cursor(&mut cursor, version) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; // Don't pre-allocate based on the wire-stated count. The count is // attacker-controlled and could be near i32::MAX, in which case // Vec::with_capacity would request many gigabytes up-front before // reading a single element. Vec::new + push grows by doubling and // bounds memory at roughly 2x the data we actually receive. let mut list = Vec::new(); for _ in 0..l { let b = CBytes::from_cursor(&mut cursor, version) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; list.push(b); } Ok(list) } pub fn decode_float_vector( bytes: &[u8], _version: Version, count: usize, ) -> Result, io::Error> { let type_size = 4; // validate up front so we can produce a clean error rather than panicking // on out-of-bounds slice indexing when the payload is truncated. We also // use checked_mul to defend against `count * type_size` wrapping on // pathological inputs (e.g. count near usize::MAX). let needed = count.checked_mul(type_size).ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidData, "float vector size overflowed usize", ) })?; if bytes.len() < needed { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, format!( "float vector of {count} elements needs {needed} bytes, got {}", bytes.len() ), )); } let mut vector = Vec::with_capacity(count); for i in (0..needed).step_by(type_size) { vector.push(CBytes::new(bytes[i..i + type_size].to_vec())); } Ok(vector) } // Decodes Cassandra `set` data (bytes) #[inline] pub fn decode_set(bytes: &[u8], version: Version) -> Result, io::Error> { decode_list(bytes, version) } // Decodes Cassandra `map` data (bytes) pub fn decode_map(bytes: &[u8], version: Version) -> Result, io::Error> { let mut cursor = io::Cursor::new(bytes); let l = CInt::from_cursor(&mut cursor, version) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; // See decode_list - skip the wire-stated capacity reservation so a // hostile count cannot make us request gigabytes before reading any // entries. let mut map = Vec::new(); for _ in 0..l { let k = CBytes::from_cursor(&mut cursor, version) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; let v = CBytes::from_cursor(&mut cursor, version) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; map.push((k, v)); } Ok(map) } // Decodes Cassandra `smallint` data (bytes) #[inline] pub fn decode_smallint(bytes: &[u8]) -> Result { try_i16_from_bytes(bytes) } // Decodes Cassandra `tinyint` data (bytes) pub fn decode_tinyint(bytes: &[u8]) -> Result { // a tinyint is a single signed byte; bail with a descriptive error rather // than panicking when the server hands us an empty value bytes .first() .copied() .map(|b| b as i8) .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "tinyint requires 1 byte")) } // Decodes Cassandra `text` data (bytes) #[inline] pub fn decode_text(bytes: &[u8]) -> Result { // Same rationale as decode_varchar - actually return an error on // invalid UTF-8 instead of pretending success with replacement chars. String::from_utf8(bytes.to_vec()) } // Decodes Cassandra `time` data (bytes) #[inline] pub fn decode_time(bytes: &[u8]) -> Result { try_i64_from_bytes(bytes) } // Decodes Cassandra `timeuuid` data (bytes) #[inline] pub fn decode_timeuuid(bytes: &[u8]) -> Result { uuid::Uuid::from_slice(bytes) } // Decodes Cassandra `varint` data (bytes) #[inline] pub fn decode_varint(bytes: &[u8]) -> Result { Ok(BigInt::from_signed_bytes_be(bytes)) } // Decodes Cassandra `duration` data (bytes) #[inline] pub fn decode_duration(bytes: &[u8]) -> Result { let (months, month_bytes_read) = i32::decode_var(bytes).ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?; let (days, day_bytes_read) = i32::decode_var(&bytes[month_bytes_read..]) .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?; let (nanoseconds, _) = i64::decode_var(&bytes[(month_bytes_read + day_bytes_read)..]) .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?; Duration::new(months, days, nanoseconds) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) } // Decodes Cassandra `Udt` data (bytes) /// /// Note: when the input is shorter than `l` would require, fields beyond the /// available bytes are filled with `CBytes::new_null()`. This matches the /// "older driver, newer schema" tolerance described by the protocol spec /// (a server that has added fields can still serve clients that don't know /// about them). Callers that care about strict decoding should validate the /// returned UDT length matches the expected schema field count, since the /// crate cannot distinguish a truncated payload from an intentional schema /// mismatch here. pub fn decode_udt(bytes: &[u8], l: usize, version: Version) -> Result, io::Error> { let mut cursor = io::Cursor::new(bytes); let mut udt = Vec::with_capacity(l); for _ in 0..l { let v = CBytes::from_cursor(&mut cursor, version) .or_else(|err| match err { error::Error::Io(io_err) => { if io_err.kind() == io::ErrorKind::UnexpectedEof { // intentional - see the function-level doc for the // rationale (forward-compat with newer-schema servers) Ok(CBytes::new_null()) } else { Err(io_err.into()) } } _ => Err(err), }) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; udt.push(v); } Ok(udt) } //noinspection DuplicatedCode // Decodes Cassandra `Tuple` data (bytes) pub fn decode_tuple(bytes: &[u8], l: usize, version: Version) -> Result, io::Error> { let mut cursor = io::Cursor::new(bytes); let mut tuple = Vec::with_capacity(l); for _ in 0..l { let v = CBytes::from_cursor(&mut cursor, version) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; tuple.push(v); } Ok(tuple) } //noinspection DuplicatedCode #[cfg(test)] mod tests { use super::super::super::frame::message_result::*; use super::*; use crate::types::{to_float, to_float_big}; use float_eq::*; use std::net::IpAddr; #[test] fn decode_custom_test() { assert_eq!(decode_custom(b"abcd").unwrap(), "abcd".to_string()); } #[test] fn decode_ascii_test() { assert_eq!(decode_ascii(b"abcd").unwrap(), "abcd".to_string()); } #[test] fn decode_varchar_test() { assert_eq!(decode_varchar(b"abcd").unwrap(), "abcd".to_string()); } #[test] fn decode_bigint_test() { assert_eq!(decode_bigint(&[0, 0, 0, 0, 0, 0, 0, 3]).unwrap(), 3); } #[test] fn decode_blob_test() { assert_eq!( decode_blob(&[0, 0, 0, 3]).unwrap().into_vec(), vec![0, 0, 0, 3] ); } #[test] fn decode_boolean_test() { assert!(!decode_boolean(&[0]).unwrap()); assert!(decode_boolean(&[1]).unwrap()); assert!(decode_boolean(&[]).is_err()); } #[test] fn decode_int_test() { assert_eq!(decode_int(&[0, 0, 0, 3]).unwrap(), 3); } #[test] fn decode_date_test() { assert_eq!(decode_date(&[0, 0, 0, 3]).unwrap(), 3); } #[test] fn decode_double_test() { let bytes = to_float_big(0.3); assert_float_eq!( decode_double(bytes.as_slice()).unwrap(), 0.3, abs <= f64::EPSILON ); } #[test] fn decode_float_test() { let bytes = to_float(0.3); assert_float_eq!( decode_float(bytes.as_slice()).unwrap(), 0.3, abs <= f32::EPSILON ); } #[test] fn decode_inet_test() { let bytes_v4 = &[0, 0, 0, 0]; match decode_inet(bytes_v4) { Ok(IpAddr::V4(ref ip)) => assert_eq!(ip.octets(), [0, 0, 0, 0]), _ => panic!("wrong ip v4 address"), } let bytes_v6 = &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; match decode_inet(bytes_v6) { Ok(IpAddr::V6(ref ip)) => assert_eq!(ip.segments(), [0, 0, 0, 0, 0, 0, 0, 0]), _ => panic!("wrong ip v6 address"), }; } #[test] fn decode_timestamp_test() { assert_eq!(decode_timestamp(&[0, 0, 0, 0, 0, 0, 0, 3]).unwrap(), 3); } #[test] fn decode_list_test() { let results = decode_list(&[0, 0, 0, 1, 0, 0, 0, 2, 1, 2], Version::V4).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].as_slice().unwrap(), &[1, 2]); } #[test] fn decode_duration_test() { let result = decode_duration(&[200, 1, 144, 3, 216, 4]).unwrap(); assert_eq!(result, Duration::new(100, 200, 300).unwrap()); } #[test] fn decode_set_test() { let results = decode_set(&[0, 0, 0, 1, 0, 0, 0, 2, 1, 2], Version::V4).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].as_slice().unwrap(), &[1, 2]); } #[test] fn decode_map_test() { let results = decode_map( &[0, 0, 0, 1, 0, 0, 0, 2, 1, 2, 0, 0, 0, 2, 2, 1], Version::V4, ) .unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].0.as_slice().unwrap(), &[1, 2]); assert_eq!(results[0].1.as_slice().unwrap(), &[2, 1]); } #[test] fn decode_smallint_test() { assert_eq!(decode_smallint(&[0, 10]).unwrap(), 10); } #[test] fn decode_tinyint_test() { assert_eq!(decode_tinyint(&[10]).unwrap(), 10); } #[test] fn decode_tinyint_empty_returns_error_not_panic() { // a tinyint requires at least one byte; an empty slice must surface // as an error rather than panicking on `bytes[0]` assert!(decode_tinyint(&[]).is_err()); } #[test] fn decode_decimal_short_input_returns_error_not_panic() { // a decimal needs at least 4 bytes for the scale; less than that must // not panic when split_at(INT_LEN) would otherwise be out of bounds assert!(decode_decimal(&[]).is_err()); assert!(decode_decimal(&[0, 0, 0]).is_err()); } #[test] fn decode_float_vector_short_input_returns_error_not_panic() { // a vector of 4 floats requires 16 bytes; anything less must error // instead of panicking on out-of-bounds slice indexing assert!(decode_float_vector(&[], Version::V5, 4).is_err()); assert!(decode_float_vector(&[0; 8], Version::V5, 4).is_err()); } #[test] fn decode_decimal_test() { assert_eq!( decode_decimal(&[0, 0, 0, 0, 10u8]).unwrap(), Decimal::new(10.into(), 0) ); assert_eq!( decode_decimal(&[0, 0, 0, 0, 0x00, 0x81]).unwrap(), Decimal::new(129.into(), 0) ); assert_eq!( decode_decimal(&[0, 0, 0, 0, 0xFF, 0x7F]).unwrap(), Decimal::new(BigInt::from(-129), 0) ); assert_eq!( decode_decimal(&[0, 0, 0, 1, 0x00, 0x81]).unwrap(), Decimal::new(129.into(), 1) ); assert_eq!( decode_decimal(&[0, 0, 0, 1, 0xFF, 0x7F]).unwrap(), Decimal::new(BigInt::from(-129), 1) ); } #[test] fn decode_text_test() { assert_eq!(decode_text(b"abcba").unwrap(), "abcba"); } // The decode_* string functions advertise a Result // signature but used String::from_utf8_lossy and could therefore never // actually return Err - invalid UTF-8 was silently replaced with U+FFFD // characters, swallowing data corruption. Promote invalid UTF-8 to a // real error so callers can decide how to handle it. #[test] fn decode_string_returns_error_on_invalid_utf8() { // 0xFF is not valid in any UTF-8 byte position let bad: &[u8] = &[0xFF]; assert!( decode_text(bad).is_err(), "decode_text must surface invalid UTF-8 as an error" ); assert!( decode_varchar(bad).is_err(), "decode_varchar must surface invalid UTF-8 as an error" ); assert!( decode_ascii(bad).is_err(), "decode_ascii must surface invalid UTF-8 as an error" ); assert!( decode_custom(bad).is_err(), "decode_custom must surface invalid UTF-8 as an error" ); } #[test] fn decode_time_test() { assert_eq!(decode_time(&[0, 0, 0, 0, 0, 0, 0, 10]).unwrap(), 10); } #[test] fn decode_timeuuid_test() { assert_eq!( decode_timeuuid(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]) .unwrap() .as_bytes(), &[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87] ); } #[test] fn decode_varint_test() { assert_eq!(decode_varint(&[0x00]).unwrap(), 0.into()); assert_eq!(decode_varint(&[0x01]).unwrap(), 1.into()); assert_eq!(decode_varint(&[0x7F]).unwrap(), 127.into()); assert_eq!(decode_varint(&[0x00, 0x80]).unwrap(), 128.into()); assert_eq!(decode_varint(&[0x00, 0x81]).unwrap(), 129.into()); assert_eq!(decode_varint(&[0xFF]).unwrap(), BigInt::from(-1)); assert_eq!(decode_varint(&[0x80]).unwrap(), BigInt::from(-128)); assert_eq!(decode_varint(&[0xFF, 0x7F]).unwrap(), BigInt::from(-129)); } #[test] fn decode_udt_test() { let udt = decode_udt(&[0, 0, 0, 2, 1, 2], 1, Version::V4).unwrap(); assert_eq!(udt.len(), 1); assert_eq!(udt[0].as_slice().unwrap(), &[1, 2]); } #[test] fn as_rust_blob_test() { let d_type = ColTypeOption { id: ColType::Blob, value: None, }; let data = CBytes::new(vec![1, 2, 3]); assert_eq!( as_rust_type!(d_type, data, Blob) .unwrap() .unwrap() .into_vec(), vec![1, 2, 3] ); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, Blob).is_err()); } #[test] fn as_rust_v4_blob_test() { let d_type = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.BytesType".into(), )), }; let data = CBytes::new(vec![1, 2, 3]); assert_eq!( as_rust_type!(d_type, data, Blob) .unwrap() .unwrap() .into_vec(), vec![1, 2, 3] ); } #[test] fn as_rust_string_test() { let type_custom = ColTypeOption { id: ColType::Custom, value: None, }; let type_ascii = ColTypeOption { id: ColType::Ascii, value: None, }; let type_varchar = ColTypeOption { id: ColType::Varchar, value: None, }; let data = CBytes::new(b"abc".to_vec()); assert_eq!( as_rust_type!(type_custom, data, String).unwrap().unwrap(), "abc" ); assert_eq!( as_rust_type!(type_ascii, data, String).unwrap().unwrap(), "abc" ); assert_eq!( as_rust_type!(type_varchar, data, String).unwrap().unwrap(), "abc" ); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, String).is_err()); } #[test] fn as_rust_bool_test() { let type_boolean = ColTypeOption { id: ColType::Boolean, value: None, }; let data_true = CBytes::new(vec![1]); let data_false = CBytes::new(vec![0]); assert!(as_rust_type!(type_boolean, data_true, bool) .unwrap() .unwrap()); assert!(!as_rust_type!(type_boolean, data_false, bool) .unwrap() .unwrap()); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data_false, bool).is_err()); } #[test] fn as_rust_v4_bool_test() { let type_boolean = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.BooleanType".into(), )), }; let data_true = CBytes::new(vec![1]); let data_false = CBytes::new(vec![0]); assert!(as_rust_type!(type_boolean, data_true, bool) .unwrap() .unwrap()); assert!(!as_rust_type!(type_boolean, data_false, bool) .unwrap() .unwrap()); } #[test] fn as_rust_i64_test() { let type_bigint = ColTypeOption { id: ColType::Bigint, value: None, }; let type_timestamp = ColTypeOption { id: ColType::Timestamp, value: None, }; let type_time = ColTypeOption { id: ColType::Time, value: None, }; let data = CBytes::new(vec![0, 0, 0, 0, 0, 0, 0, 100]); assert_eq!(as_rust_type!(type_bigint, data, i64).unwrap().unwrap(), 100); assert_eq!( as_rust_type!(type_timestamp, data, i64).unwrap().unwrap(), 100 ); assert_eq!(as_rust_type!(type_time, data, i64).unwrap().unwrap(), 100); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, i64).is_err()); } #[test] fn as_rust_v4_i64_test() { let type_bigint = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.LongType".into(), )), }; let type_timestamp = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.TimestampType".into(), )), }; let type_time = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.TimeType".into(), )), }; let data = CBytes::new(vec![0, 0, 0, 0, 0, 0, 0, 100]); assert_eq!(as_rust_type!(type_bigint, data, i64).unwrap().unwrap(), 100); assert_eq!( as_rust_type!(type_timestamp, data, i64).unwrap().unwrap(), 100 ); assert_eq!(as_rust_type!(type_time, data, i64).unwrap().unwrap(), 100); } #[test] fn as_rust_i32_test() { let type_int = ColTypeOption { id: ColType::Int, value: None, }; let type_date = ColTypeOption { id: ColType::Date, value: None, }; let data = CBytes::new(vec![0, 0, 0, 100]); assert_eq!(as_rust_type!(type_int, data, i32).unwrap().unwrap(), 100); assert_eq!(as_rust_type!(type_date, data, i32).unwrap().unwrap(), 100); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, i32).is_err()); } #[test] fn as_rust_v4_i32_test() { let type_int = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.Int32Type".into(), )), }; let type_date = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.SimpleDateType".into(), )), }; let data = CBytes::new(vec![0, 0, 0, 100]); assert_eq!(as_rust_type!(type_int, data, i32).unwrap().unwrap(), 100); assert_eq!(as_rust_type!(type_date, data, i32).unwrap().unwrap(), 100); } #[test] fn as_rust_i16_test() { let type_smallint = ColTypeOption { id: ColType::Smallint, value: None, }; let data = CBytes::new(vec![0, 100]); assert_eq!( as_rust_type!(type_smallint, data, i16).unwrap().unwrap(), 100 ); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, i16).is_err()); } #[test] fn as_rust_v4_i16_test() { let type_smallint = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.ShortType".into(), )), }; let data = CBytes::new(vec![0, 100]); assert_eq!( as_rust_type!(type_smallint, data, i16).unwrap().unwrap(), 100 ); } #[test] fn as_rust_i8_test() { let type_tinyint = ColTypeOption { id: ColType::Tinyint, value: None, }; let data = CBytes::new(vec![100]); assert_eq!(as_rust_type!(type_tinyint, data, i8).unwrap().unwrap(), 100); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, i8).is_err()); } #[test] fn as_rust_v4_i8_test() { let type_tinyint = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.ByteType".into(), )), }; let data = CBytes::new(vec![100]); assert_eq!(as_rust_type!(type_tinyint, data, i8).unwrap().unwrap(), 100); } #[test] fn as_rust_f64_test() { let type_double = ColTypeOption { id: ColType::Double, value: None, }; let data = CBytes::new(to_float_big(0.1_f64)); assert_float_eq!( as_rust_type!(type_double, data, f64).unwrap().unwrap(), 0.1, abs <= f64::EPSILON ); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, f64).is_err()); } #[test] fn as_rust_v4_f64_test() { let type_double = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.DoubleType".into(), )), }; let data = CBytes::new(to_float_big(0.1_f64)); assert_float_eq!( as_rust_type!(type_double, data, f64).unwrap().unwrap(), 0.1, abs <= f64::EPSILON ); } #[test] fn as_rust_f32_test() { // let type_decimal = ColTypeOption { id: ColType::Decimal }; let type_float = ColTypeOption { id: ColType::Float, value: None, }; let data = CBytes::new(to_float(0.1_f32)); // assert_eq!(as_rust_type!(type_decimal, data, f32).unwrap(), 100.0); assert_float_eq!( as_rust_type!(type_float, data, f32).unwrap().unwrap(), 0.1, abs <= f32::EPSILON ); let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, f32).is_err()); } #[test] fn as_rust_v4_f32_test() { // let type_decimal = ColTypeOption { id: ColType::Decimal }; let type_float = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.FloatType".into(), )), }; let data = CBytes::new(to_float(0.1_f32)); // assert_eq!(as_rust_type!(type_decimal, data, f32).unwrap(), 100.0); assert_float_eq!( as_rust_type!(type_float, data, f32).unwrap().unwrap(), 0.1, abs <= f32::EPSILON ); } #[test] fn as_rust_inet_test() { let type_inet = ColTypeOption { id: ColType::Inet, value: None, }; let data = CBytes::new(vec![0, 0, 0, 0]); match as_rust_type!(type_inet, data, IpAddr) { Ok(Some(IpAddr::V4(ref ip))) => assert_eq!(ip.octets(), [0, 0, 0, 0]), _ => panic!("wrong ip v4 address"), } let wrong_type = ColTypeOption { id: ColType::Map, value: None, }; assert!(as_rust_type!(wrong_type, data, f32).is_err()); } #[test] fn as_rust_v4_inet_test() { let type_inet = ColTypeOption { id: ColType::Custom, value: Some(ColTypeOptionValue::CString( "org.apache.cassandra.db.marshal.InetAddressType".into(), )), }; let data = CBytes::new(vec![0, 0, 0, 0]); match as_rust_type!(type_inet, data, IpAddr) { Ok(Some(IpAddr::V4(ref ip))) => assert_eq!(ip.octets(), [0, 0, 0, 0]), _ => panic!("wrong ip v4 address"), } } } ================================================ FILE: cassandra-protocol/src/types/decimal.rs ================================================ use derive_more::Constructor; use float_eq::*; use num_bigint::BigInt; use std::io::Cursor; use crate::frame::{Serialize, Version}; /// Cassandra Decimal type #[derive(Debug, Clone, PartialEq, Constructor, Ord, PartialOrd, Eq, Hash)] pub struct Decimal { pub unscaled: BigInt, pub scale: i32, } impl Decimal { /// Method that returns plain `BigInt` value. /// /// Negative scale is handled by multiplying instead of dividing - that /// avoids the previous `scale as u32` cast which made a negative scale /// wrap to a huge value and panic in `10i64.pow`. pub fn as_plain(&self) -> BigInt { if self.scale >= 0 { // dividing by 10^scale; use checked_pow on a u32 exponent so an // out-of-range value yields a clean zero rather than panicking. let exponent = self.scale as u32; match 10i64.checked_pow(exponent) { Some(divisor) => self.unscaled.clone() / divisor, None => BigInt::from(0), } } else { // negative scale means the unscaled value should be multiplied // by 10^|scale| to recover the represented integer let exponent = self.scale.unsigned_abs(); self.unscaled.clone() * BigInt::from(10).pow(exponent) } } } impl Serialize for Decimal { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { self.scale.serialize(cursor, version); self.unscaled .to_signed_bytes_be() .serialize(cursor, version); } } macro_rules! impl_from_for_decimal { ($t:ty) => { impl From<$t> for Decimal { fn from(i: $t) -> Self { Decimal { unscaled: i.into(), scale: 0, } } } }; } impl_from_for_decimal!(i8); impl_from_for_decimal!(i16); impl_from_for_decimal!(i32); impl_from_for_decimal!(i64); impl_from_for_decimal!(u8); impl_from_for_decimal!(u16); impl From for Decimal { fn from(f: f32) -> Decimal { // Cap the loop just below the point where 10i64.pow(scale) overflows // (10^19 > i64::MAX). Without this guard a hostile input could keep // the loop spinning until the pow call panics. In practice f32 // precision causes the equality check to succeed long before this // cap, so existing well-formed inputs are unaffected. const MAX_SCALE: u32 = 18; let mut scale: u32 = 0; loop { let unscaled = f * (10i64.pow(scale) as f32); if float_eq!(unscaled, unscaled.trunc(), abs <= f32::EPSILON) { return Decimal::new((unscaled as i64).into(), scale as i32); } if scale >= MAX_SCALE { // best-effort termination: snap to the truncated value at the // current scale rather than looping forever / panicking return Decimal::new((unscaled.trunc() as i64).into(), scale as i32); } scale += 1; } } } impl From for Decimal { fn from(f: f64) -> Decimal { // Same termination guard as the f32 conversion - bounded just below // i64 overflow on 10i64.pow. const MAX_SCALE: u32 = 18; let mut scale: u32 = 0; loop { let unscaled = f * (10i64.pow(scale) as f64); if float_eq!(unscaled, unscaled.trunc(), abs <= f64::EPSILON) { return Decimal::new((unscaled as i64).into(), scale as i32); } if scale >= MAX_SCALE { return Decimal::new((unscaled.trunc() as i64).into(), scale as i32); } scale += 1; } } } impl From for BigInt { fn from(value: Decimal) -> Self { value.as_plain() } } #[cfg(test)] mod test { use super::*; #[test] fn serialize_test() { assert_eq!( Decimal::new(129.into(), 0).serialize_to_vec(Version::V4), vec![0, 0, 0, 0, 0x00, 0x81] ); assert_eq!( Decimal::new(BigInt::from(-129), 0).serialize_to_vec(Version::V4), vec![0, 0, 0, 0, 0xFF, 0x7F] ); let expected: Vec = vec![0, 0, 0, 1, 0x00, 0x81]; assert_eq!( Decimal::new(129.into(), 1).serialize_to_vec(Version::V4), expected ); let expected: Vec = vec![0, 0, 0, 1, 0xFF, 0x7F]; assert_eq!( Decimal::new(BigInt::from(-129), 1).serialize_to_vec(Version::V4), expected ); } #[test] fn from_f32() { assert_eq!( Decimal::from(12300001_f32), Decimal::new(12300001.into(), 0) ); assert_eq!( Decimal::from(1230000.1_f32), Decimal::new(12300001.into(), 1) ); assert_eq!( Decimal::from(0.12300001_f32), Decimal::new(12300001.into(), 8) ); } #[test] fn from_f64() { assert_eq!( Decimal::from(1230000000000001_f64), Decimal::new(1230000000000001i64.into(), 0) ); assert_eq!( Decimal::from(123000000000000.1f64), Decimal::new(1230000000000001i64.into(), 1) ); assert_eq!( Decimal::from(0.1230000000000001f64), Decimal::new(1230000000000001i64.into(), 16) ); } // 0.1 is not exactly representable in IEEE-754 float, so the previous // implementation kept doubling `scale` looking for an exact match and // eventually panicked on `10i64.pow(scale)` overflow when scale exceeded // the number of significant digits. The conversion must terminate without // panicking and produce a sensible Decimal. #[test] fn from_f32_tolerates_inexact_floats() { let _decimal = Decimal::from(0.1f32); let _decimal = Decimal::from(0.2f32); let _decimal = Decimal::from(1.0f32 / 3.0f32); } #[test] fn from_f64_tolerates_inexact_floats() { let _decimal = Decimal::from(0.1f64); let _decimal = Decimal::from(0.2f64); let _decimal = Decimal::from(1.0f64 / 3.0f64); } // as_plain divides by 10^scale; if scale is negative the previous // `scale as u32` cast wrapped to a huge value and `10i64.pow` panicked. // The function should either reject negative scales or handle them. #[test] fn as_plain_does_not_panic_on_negative_scale() { let decimal = Decimal::new(5.into(), -3); // Just verify it does not panic; we don't assert a specific value // here because the semantics of negative scale aren't part of this // bug fix - we only need to be safe. let _ = decimal.as_plain(); } } ================================================ FILE: cassandra-protocol/src/types/duration.rs ================================================ use integer_encoding::VarInt; use std::io::{Cursor, Write}; use thiserror::Error; use crate::frame::{Serialize, Version}; /// Possible `Duration` creation error. #[derive(Debug, Error, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] pub enum DurationCreationError { #[error( "All values must be either negative or positive, got {months} months, {days} days, {nanoseconds} nanoseconds" )] MixedPositiveAndNegative { months: i32, days: i32, nanoseconds: i64, }, } /// Cassandra Duration type. A duration stores separately months, days, and seconds due to the fact /// that the number of days in a month varies, and a day can have 23 or 25 hours if a daylight /// saving is involved. #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct Duration { months: i32, days: i32, nanoseconds: i64, } impl Duration { pub fn new(months: i32, days: i32, nanoseconds: i64) -> Result { if (months < 0 || days < 0 || nanoseconds < 0) && (months > 0 || days > 0 || nanoseconds > 0) { Err(DurationCreationError::MixedPositiveAndNegative { months, days, nanoseconds, }) } else { Ok(Self { months, days, nanoseconds, }) } } pub fn months(&self) -> i32 { self.months } pub fn days(&self) -> i32 { self.days } pub fn nanoseconds(&self) -> i64 { self.nanoseconds } } impl Serialize for Duration { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, _version: Version) { let month_space = self.months.required_space(); let day_space = self.days.required_space(); let mut buffer = vec![0u8; month_space + day_space + self.nanoseconds.required_space()]; self.months.encode_var(&mut buffer); self.days.encode_var(&mut buffer[month_space..]); self.nanoseconds .encode_var(&mut buffer[(month_space + day_space)..]); let _ = cursor.write(&buffer); } } #[cfg(test)] mod tests { use crate::frame::{Serialize, Version}; use crate::types::duration::Duration; #[test] fn should_serialize_duration() { let duration = Duration::new(100, 200, 300).unwrap(); assert_eq!( duration.serialize_to_vec(Version::V5), vec![200, 1, 144, 3, 216, 4] ); } } ================================================ FILE: cassandra-protocol/src/types/from_cdrs.rs ================================================ use std::net::IpAddr; use std::num::{NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8}; use chrono::prelude::*; use time::PrimitiveDateTime; use uuid::Uuid; use crate::error::Result as CdrsResult; use crate::types::blob::Blob; use crate::types::decimal::Decimal; use crate::types::list::List; use crate::types::map::Map; use crate::types::tuple::Tuple; use crate::types::udt::Udt; use crate::types::{AsRustType, ByName, IntoRustByName}; pub trait FromCdrs { fn from_cdrs(cdrs_type: T) -> CdrsResult> where Self: Sized, T: AsRustType, { cdrs_type.as_rust_type() } fn from_cdrs_r(cdrs_type: T) -> CdrsResult where Self: Sized, T: AsRustType, { cdrs_type.as_r_type() } } impl FromCdrs for Blob {} impl FromCdrs for String {} impl FromCdrs for bool {} impl FromCdrs for i64 {} impl FromCdrs for i32 {} impl FromCdrs for i16 {} impl FromCdrs for i8 {} impl FromCdrs for f64 {} impl FromCdrs for f32 {} impl FromCdrs for IpAddr {} impl FromCdrs for Uuid {} impl FromCdrs for List {} impl FromCdrs for Map {} impl FromCdrs for Udt {} impl FromCdrs for Tuple {} impl FromCdrs for PrimitiveDateTime {} impl FromCdrs for Decimal {} impl FromCdrs for NonZeroI8 {} impl FromCdrs for NonZeroI16 {} impl FromCdrs for NonZeroI32 {} impl FromCdrs for NonZeroI64 {} impl FromCdrs for NaiveDateTime {} impl FromCdrs for DateTime {} pub trait FromCdrsByName { fn from_cdrs_by_name(cdrs_type: &T, name: &str) -> CdrsResult> where Self: Sized, T: ByName + IntoRustByName, { cdrs_type.by_name(name) } fn from_cdrs_r(cdrs_type: &T, name: &str) -> CdrsResult where Self: Sized, T: ByName + IntoRustByName + ::std::fmt::Debug, { cdrs_type.r_by_name(name) } } impl FromCdrsByName for Blob {} impl FromCdrsByName for String {} impl FromCdrsByName for bool {} impl FromCdrsByName for i64 {} impl FromCdrsByName for i32 {} impl FromCdrsByName for i16 {} impl FromCdrsByName for i8 {} impl FromCdrsByName for f64 {} impl FromCdrsByName for f32 {} impl FromCdrsByName for IpAddr {} impl FromCdrsByName for Uuid {} impl FromCdrsByName for List {} impl FromCdrsByName for Map {} impl FromCdrsByName for Udt {} impl FromCdrsByName for Tuple {} impl FromCdrsByName for PrimitiveDateTime {} impl FromCdrsByName for Decimal {} impl FromCdrsByName for NonZeroI8 {} impl FromCdrsByName for NonZeroI16 {} impl FromCdrsByName for NonZeroI32 {} impl FromCdrsByName for NonZeroI64 {} impl FromCdrsByName for NaiveDateTime {} impl FromCdrsByName for DateTime {} ================================================ FILE: cassandra-protocol/src/types/list.rs ================================================ use derive_more::Constructor; use itertools::Itertools; use num_bigint::BigInt; use std::net::IpAddr; use uuid::Uuid; use crate::error::{Error, Result}; use crate::frame::message_result::{ColType, ColTypeOption, ColTypeOptionValue}; use crate::frame::Version; use crate::types::blob::Blob; use crate::types::data_serialization_types::*; use crate::types::decimal::Decimal; use crate::types::map::Map; use crate::types::tuple::Tuple; use crate::types::udt::Udt; use crate::types::{AsRust, AsRustType, CBytes}; // TODO: consider using pointers to ColTypeOption and Vec instead of owning them. #[derive(Debug, Constructor)] pub struct List { /// column spec of the list, i.e. id should be List as it's a list and value should contain /// a type of list items. metadata: ColTypeOption, data: Vec, protocol_version: Version, } impl List { fn map(&self, f: F) -> Vec where F: FnMut(&CBytes) -> T, { self.data.iter().map(f).collect() } fn try_map(&self, f: F) -> Result> where F: FnMut(&CBytes) -> Result, { self.data.iter().map(f).try_collect() } } impl AsRust for List {} list_as_rust!(Blob); list_as_rust!(String); list_as_rust!(bool); list_as_rust!(i64); list_as_rust!(i32); list_as_rust!(i16); list_as_rust!(i8); list_as_rust!(f64); list_as_rust!(f32); list_as_rust!(IpAddr); list_as_rust!(Uuid); list_as_rust!(List); list_as_rust!(Map); list_as_rust!(Udt); list_as_rust!(Tuple); list_as_rust!(Decimal); list_as_rust!(BigInt); list_as_cassandra_type!(); ================================================ FILE: cassandra-protocol/src/types/map.rs ================================================ use std::collections::HashMap; use std::net::IpAddr; use time::PrimitiveDateTime; use uuid::Uuid; use crate::error::{Error, Result}; use crate::frame::message_result::{ColType, ColTypeOption, ColTypeOptionValue}; use crate::frame::Version; use crate::types::blob::Blob; use crate::types::data_serialization_types::*; use crate::types::decimal::Decimal; use crate::types::list::List; use crate::types::tuple::Tuple; use crate::types::udt::Udt; use crate::types::{AsRust, AsRustType, CBytes}; use num_bigint::BigInt; #[derive(Debug)] pub struct Map { metadata: ColTypeOption, data: Vec<(CBytes, CBytes)>, protocol_version: Version, } impl Map { /// Creates new `Map` using the provided data and key and value types. pub fn new(data: Vec<(CBytes, CBytes)>, meta: ColTypeOption, protocol_version: Version) -> Map { Map { metadata: meta, data, protocol_version, } } } impl AsRust for Map {} // Generate `AsRustType` implementations for all kinds of map types. // The macro `map_as_rust!` takes the key and value types as lists of token trees. // This is needed because `as_rust_type!` is called by `map_as_rust!`. // In order to distinguish the key and value types, they are enclosed by curly braces. map_as_rust!({ Blob }, { Blob }); map_as_rust!({ Blob }, { String }); map_as_rust!({ Blob }, { bool }); map_as_rust!({ Blob }, { i64 }); map_as_rust!({ Blob }, { i32 }); map_as_rust!({ Blob }, { i16 }); map_as_rust!({ Blob }, { i8 }); map_as_rust!({ Blob }, { f64 }); map_as_rust!({ Blob }, { f32 }); map_as_rust!({ Blob }, { IpAddr }); map_as_rust!({ Blob }, { Uuid }); map_as_rust!({ Blob }, { PrimitiveDateTime }); map_as_rust!({ Blob }, { List }); map_as_rust!({ Blob }, { Map }); map_as_rust!({ Blob }, { Udt }); map_as_rust!({ Blob }, { Tuple }); map_as_rust!({ Blob }, { Decimal }); map_as_rust!({ Blob }, { BigInt }); map_as_rust!({ String }, { Blob }); map_as_rust!({ String }, { String }); map_as_rust!({ String }, { bool }); map_as_rust!({ String }, { i64 }); map_as_rust!({ String }, { i32 }); map_as_rust!({ String }, { i16 }); map_as_rust!({ String }, { i8 }); map_as_rust!({ String }, { f64 }); map_as_rust!({ String }, { f32 }); map_as_rust!({ String }, { IpAddr }); map_as_rust!({ String }, { Uuid }); map_as_rust!({ String }, { PrimitiveDateTime }); map_as_rust!({ String }, { List }); map_as_rust!({ String }, { Map }); map_as_rust!({ String }, { Udt }); map_as_rust!({ String }, { Tuple }); map_as_rust!({ String }, { Decimal }); map_as_rust!({ String }, { BigInt }); map_as_rust!({ bool }, { Blob }); map_as_rust!({ bool }, { String }); map_as_rust!({ bool }, { bool }); map_as_rust!({ bool }, { i64 }); map_as_rust!({ bool }, { i32 }); map_as_rust!({ bool }, { i16 }); map_as_rust!({ bool }, { i8 }); map_as_rust!({ bool }, { f64 }); map_as_rust!({ bool }, { f32 }); map_as_rust!({ bool }, { IpAddr }); map_as_rust!({ bool }, { Uuid }); map_as_rust!({ bool }, { PrimitiveDateTime }); map_as_rust!({ bool }, { List }); map_as_rust!({ bool }, { Map }); map_as_rust!({ bool }, { Udt }); map_as_rust!({ bool }, { Tuple }); map_as_rust!({ bool }, { Decimal }); map_as_rust!({ bool }, { BigInt }); map_as_rust!({ i64 }, { Blob }); map_as_rust!({ i64 }, { String }); map_as_rust!({ i64 }, { bool }); map_as_rust!({ i64 }, { i64 }); map_as_rust!({ i64 }, { i32 }); map_as_rust!({ i64 }, { i16 }); map_as_rust!({ i64 }, { i8 }); map_as_rust!({ i64 }, { f64 }); map_as_rust!({ i64 }, { f32 }); map_as_rust!({ i64 }, { IpAddr }); map_as_rust!({ i64 }, { Uuid }); map_as_rust!({ i64 }, { PrimitiveDateTime }); map_as_rust!({ i64 }, { List }); map_as_rust!({ i64 }, { Map }); map_as_rust!({ i64 }, { Udt }); map_as_rust!({ i64 }, { Tuple }); map_as_rust!({ i64 }, { Decimal }); map_as_rust!({ i64 }, { BigInt }); map_as_rust!({ i32 }, { Blob }); map_as_rust!({ i32 }, { String }); map_as_rust!({ i32 }, { bool }); map_as_rust!({ i32 }, { i64 }); map_as_rust!({ i32 }, { i32 }); map_as_rust!({ i32 }, { i16 }); map_as_rust!({ i32 }, { i8 }); map_as_rust!({ i32 }, { f64 }); map_as_rust!({ i32 }, { f32 }); map_as_rust!({ i32 }, { IpAddr }); map_as_rust!({ i32 }, { Uuid }); map_as_rust!({ i32 }, { PrimitiveDateTime }); map_as_rust!({ i32 }, { List }); map_as_rust!({ i32 }, { Map }); map_as_rust!({ i32 }, { Udt }); map_as_rust!({ i32 }, { Tuple }); map_as_rust!({ i32 }, { Decimal }); map_as_rust!({ i32 }, { BigInt }); map_as_rust!({ i16 }, { Blob }); map_as_rust!({ i16 }, { String }); map_as_rust!({ i16 }, { bool }); map_as_rust!({ i16 }, { i64 }); map_as_rust!({ i16 }, { i32 }); map_as_rust!({ i16 }, { i16 }); map_as_rust!({ i16 }, { i8 }); map_as_rust!({ i16 }, { f64 }); map_as_rust!({ i16 }, { f32 }); map_as_rust!({ i16 }, { IpAddr }); map_as_rust!({ i16 }, { Uuid }); map_as_rust!({ i16 }, { PrimitiveDateTime }); map_as_rust!({ i16 }, { List }); map_as_rust!({ i16 }, { Map }); map_as_rust!({ i16 }, { Udt }); map_as_rust!({ i16 }, { Tuple }); map_as_rust!({ i16 }, { Decimal }); map_as_rust!({ i16 }, { BigInt }); map_as_rust!({ i8 }, { Blob }); map_as_rust!({ i8 }, { String }); map_as_rust!({ i8 }, { bool }); map_as_rust!({ i8 }, { i64 }); map_as_rust!({ i8 }, { i32 }); map_as_rust!({ i8 }, { i16 }); map_as_rust!({ i8 }, { i8 }); map_as_rust!({ i8 }, { f64 }); map_as_rust!({ i8 }, { f32 }); map_as_rust!({ i8 }, { IpAddr }); map_as_rust!({ i8 }, { Uuid }); map_as_rust!({ i8 }, { PrimitiveDateTime }); map_as_rust!({ i8 }, { List }); map_as_rust!({ i8 }, { Map }); map_as_rust!({ i8 }, { Udt }); map_as_rust!({ i8 }, { Tuple }); map_as_rust!({ i8 }, { Decimal }); map_as_rust!({ i8 }, { BigInt }); map_as_rust!({ IpAddr }, { Blob }); map_as_rust!({ IpAddr }, { String }); map_as_rust!({ IpAddr }, { bool }); map_as_rust!({ IpAddr }, { i64 }); map_as_rust!({ IpAddr }, { i32 }); map_as_rust!({ IpAddr }, { i16 }); map_as_rust!({ IpAddr }, { i8 }); map_as_rust!({ IpAddr }, { f64 }); map_as_rust!({ IpAddr }, { f32 }); map_as_rust!({ IpAddr }, { IpAddr }); map_as_rust!({ IpAddr }, { Uuid }); map_as_rust!({ IpAddr }, { PrimitiveDateTime }); map_as_rust!({ IpAddr }, { List }); map_as_rust!({ IpAddr }, { Map }); map_as_rust!({ IpAddr }, { Udt }); map_as_rust!({ IpAddr }, { Tuple }); map_as_rust!({ IpAddr }, { Decimal }); map_as_rust!({ IpAddr }, { BigInt }); map_as_rust!({ Uuid }, { Blob }); map_as_rust!({ Uuid }, { String }); map_as_rust!({ Uuid }, { bool }); map_as_rust!({ Uuid }, { i64 }); map_as_rust!({ Uuid }, { i32 }); map_as_rust!({ Uuid }, { i16 }); map_as_rust!({ Uuid }, { i8 }); map_as_rust!({ Uuid }, { f64 }); map_as_rust!({ Uuid }, { f32 }); map_as_rust!({ Uuid }, { IpAddr }); map_as_rust!({ Uuid }, { Uuid }); map_as_rust!({ Uuid }, { PrimitiveDateTime }); map_as_rust!({ Uuid }, { List }); map_as_rust!({ Uuid }, { Map }); map_as_rust!({ Uuid }, { Udt }); map_as_rust!({ Uuid }, { Tuple }); map_as_rust!({ Uuid }, { Decimal }); map_as_rust!({ Uuid }, { BigInt }); map_as_rust!({ PrimitiveDateTime }, { Blob }); map_as_rust!({ PrimitiveDateTime }, { String }); map_as_rust!({ PrimitiveDateTime }, { bool }); map_as_rust!({ PrimitiveDateTime }, { i64 }); map_as_rust!({ PrimitiveDateTime }, { i32 }); map_as_rust!({ PrimitiveDateTime }, { i16 }); map_as_rust!({ PrimitiveDateTime }, { i8 }); map_as_rust!({ PrimitiveDateTime }, { f64 }); map_as_rust!({ PrimitiveDateTime }, { f32 }); map_as_rust!({ PrimitiveDateTime }, { IpAddr }); map_as_rust!({ PrimitiveDateTime }, { Uuid }); map_as_rust!({ PrimitiveDateTime }, { PrimitiveDateTime }); map_as_rust!({ PrimitiveDateTime }, { List }); map_as_rust!({ PrimitiveDateTime }, { Map }); map_as_rust!({ PrimitiveDateTime }, { Udt }); map_as_rust!({ PrimitiveDateTime }, { Tuple }); map_as_rust!({ PrimitiveDateTime }, { Decimal }); map_as_rust!({ PrimitiveDateTime }, { BigInt }); map_as_rust!({ Tuple }, { Blob }); map_as_rust!({ Tuple }, { String }); map_as_rust!({ Tuple }, { bool }); map_as_rust!({ Tuple }, { i64 }); map_as_rust!({ Tuple }, { i32 }); map_as_rust!({ Tuple }, { i16 }); map_as_rust!({ Tuple }, { i8 }); map_as_rust!({ Tuple }, { f64 }); map_as_rust!({ Tuple }, { f32 }); map_as_rust!({ Tuple }, { IpAddr }); map_as_rust!({ Tuple }, { Uuid }); map_as_rust!({ Tuple }, { PrimitiveDateTime }); map_as_rust!({ Tuple }, { List }); map_as_rust!({ Tuple }, { Map }); map_as_rust!({ Tuple }, { Udt }); map_as_rust!({ Tuple }, { Tuple }); map_as_rust!({ Tuple }, { Decimal }); map_as_rust!({ Tuple }, { BigInt }); map_as_cassandra_type!(); ================================================ FILE: cassandra-protocol/src/types/rows.rs ================================================ use std::net::IpAddr; use std::num::{NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8}; use std::sync::Arc; use chrono::prelude::*; use time::PrimitiveDateTime; use uuid::Uuid; use crate::error::{column_is_empty_err, Error, Result}; use crate::frame::message_result::{ BodyResResultRows, ColSpec, ColType, ColTypeOption, ColTypeOptionValue, RowsMetadata, }; use crate::frame::Version; use crate::types::blob::Blob; use crate::types::data_serialization_types::*; use crate::types::decimal::Decimal; use crate::types::list::List; use crate::types::map::Map; use crate::types::tuple::Tuple; use crate::types::udt::Udt; use crate::types::{ByIndex, ByName, CBytes, IntoRustByIndex, IntoRustByName}; use num_bigint::BigInt; #[derive(Clone, Debug)] pub struct Row { metadata: Arc, row_content: Vec, protocol_version: Version, } impl Row { pub fn from_body(body: BodyResResultRows) -> Vec { let metadata = Arc::new(body.metadata); let protocol_version = body.protocol_version; body.rows_content .into_iter() .map(|row| Row { metadata: metadata.clone(), row_content: row, protocol_version, }) .collect() } /// Checks if a column is present in the row. pub fn contains_column(&self, name: &str) -> bool { self.metadata .col_specs .iter() .any(|spec| spec.name.as_str() == name) } /// Checks for NULL for a given column. Returns false if given column does not exist. pub fn is_empty(&self, index: usize) -> bool { self.row_content .get(index) .map(|data| data.is_null_or_empty()) .unwrap_or(false) } /// Checks for NULL for a given column. Returns false if given column does not exist. pub fn is_empty_by_name(&self, name: &str) -> bool { self.metadata .col_specs .iter() .position(|spec| spec.name.as_str() == name) .map(|index| self.is_empty(index)) .unwrap_or(false) } fn col_spec_by_name(&self, name: &str) -> Option<(&ColSpec, &CBytes)> { self.metadata .col_specs .iter() .position(|spec| spec.name.as_str() == name) .map(|i| { let col_spec = &self.metadata.col_specs[i]; let data = &self.row_content[i]; (col_spec, data) }) } fn col_spec_by_index(&self, index: usize) -> Option<(&ColSpec, &CBytes)> { let specs = self.metadata.col_specs.iter(); let values = self.row_content.iter(); specs.zip(values).nth(index) } } impl ByName for Row {} into_rust_by_name!(Row, Blob); into_rust_by_name!(Row, String); into_rust_by_name!(Row, bool); into_rust_by_name!(Row, i64); into_rust_by_name!(Row, i32); into_rust_by_name!(Row, i16); into_rust_by_name!(Row, i8); into_rust_by_name!(Row, f64); into_rust_by_name!(Row, f32); into_rust_by_name!(Row, IpAddr); into_rust_by_name!(Row, Uuid); into_rust_by_name!(Row, List); into_rust_by_name!(Row, Map); into_rust_by_name!(Row, Udt); into_rust_by_name!(Row, Tuple); into_rust_by_name!(Row, PrimitiveDateTime); into_rust_by_name!(Row, Decimal); into_rust_by_name!(Row, NonZeroI8); into_rust_by_name!(Row, NonZeroI16); into_rust_by_name!(Row, NonZeroI32); into_rust_by_name!(Row, NonZeroI64); into_rust_by_name!(Row, NaiveDateTime); into_rust_by_name!(Row, DateTime); into_rust_by_name!(Row, BigInt); impl ByIndex for Row {} into_rust_by_index!(Row, Blob); into_rust_by_index!(Row, String); into_rust_by_index!(Row, bool); into_rust_by_index!(Row, i64); into_rust_by_index!(Row, i32); into_rust_by_index!(Row, i16); into_rust_by_index!(Row, i8); into_rust_by_index!(Row, f64); into_rust_by_index!(Row, f32); into_rust_by_index!(Row, IpAddr); into_rust_by_index!(Row, Uuid); into_rust_by_index!(Row, List); into_rust_by_index!(Row, Map); into_rust_by_index!(Row, Udt); into_rust_by_index!(Row, Tuple); into_rust_by_index!(Row, PrimitiveDateTime); into_rust_by_index!(Row, Decimal); into_rust_by_index!(Row, NonZeroI8); into_rust_by_index!(Row, NonZeroI16); into_rust_by_index!(Row, NonZeroI32); into_rust_by_index!(Row, NonZeroI64); into_rust_by_index!(Row, NaiveDateTime); into_rust_by_index!(Row, DateTime); into_rust_by_index!(Row, BigInt); ================================================ FILE: cassandra-protocol/src/types/tuple.rs ================================================ use chrono::prelude::*; use num_bigint::BigInt; use std::hash::{Hash, Hasher}; use std::net::IpAddr; use time::PrimitiveDateTime; use uuid::Uuid; use crate::error::{column_is_empty_err, Error, Result}; use crate::frame::message_result::{CTuple, ColType, ColTypeOption, ColTypeOptionValue}; use crate::frame::Version; use crate::types::blob::Blob; use crate::types::data_serialization_types::*; use crate::types::decimal::Decimal; use crate::types::list::List; use crate::types::map::Map; use crate::types::udt::Udt; use crate::types::{ByIndex, CBytes, IntoRustByIndex}; #[derive(Debug)] pub struct Tuple { data: Vec<(ColTypeOption, CBytes)>, protocol_version: Version, } impl PartialEq for Tuple { fn eq(&self, other: &Tuple) -> bool { if self.data.len() != other.data.len() { return false; } for (s, o) in self.data.iter().zip(other.data.iter()) { if s.1 != o.1 { return false; } } true } } impl Eq for Tuple {} impl Hash for Tuple { fn hash(&self, state: &mut H) { for data in &self.data { data.1.hash(state); } } } impl Tuple { pub fn new(elements: Vec, metadata: &CTuple, protocol_version: Version) -> Tuple { Tuple { data: metadata .types .iter() .zip(elements) .map(|(val_type, val_b)| (val_type.clone(), val_b)) .collect(), protocol_version, } } } impl ByIndex for Tuple {} into_rust_by_index!(Tuple, Blob); into_rust_by_index!(Tuple, String); into_rust_by_index!(Tuple, bool); into_rust_by_index!(Tuple, i64); into_rust_by_index!(Tuple, i32); into_rust_by_index!(Tuple, i16); into_rust_by_index!(Tuple, i8); into_rust_by_index!(Tuple, f64); into_rust_by_index!(Tuple, f32); into_rust_by_index!(Tuple, IpAddr); into_rust_by_index!(Tuple, Uuid); into_rust_by_index!(Tuple, List); into_rust_by_index!(Tuple, Map); into_rust_by_index!(Tuple, Udt); into_rust_by_index!(Tuple, Tuple); into_rust_by_index!(Tuple, PrimitiveDateTime); into_rust_by_index!(Tuple, Decimal); into_rust_by_index!(Tuple, NaiveDateTime); into_rust_by_index!(Tuple, DateTime); into_rust_by_index!(Tuple, BigInt); tuple_as_cassandra_type!(); ================================================ FILE: cassandra-protocol/src/types/udt.rs ================================================ use std::collections::HashMap; use std::net::IpAddr; use std::num::{NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8}; use chrono::prelude::*; use time::PrimitiveDateTime; use uuid::Uuid; use crate::error::{column_is_empty_err, Error, Result}; use crate::frame::message_result::{CUdt, ColType, ColTypeOption, ColTypeOptionValue}; use crate::frame::Version; use crate::types::blob::Blob; use crate::types::data_serialization_types::*; use crate::types::decimal::Decimal; use crate::types::list::List; use crate::types::map::Map; use crate::types::tuple::Tuple; use crate::types::{ByName, CBytes, IntoRustByName}; use num_bigint::BigInt; #[derive(Clone, Debug)] pub struct Udt { data: HashMap, protocol_version: Version, } impl Udt { pub fn new(fields: Vec, metadata: &CUdt, protocol_version: Version) -> Udt { let mut data: HashMap = HashMap::with_capacity(metadata.descriptions.len()); for ((name, val_type), val_b) in metadata.descriptions.iter().zip(fields) { data.insert(name.clone(), (val_type.clone(), val_b)); } Udt { data, protocol_version, } } } impl ByName for Udt {} into_rust_by_name!(Udt, Blob); into_rust_by_name!(Udt, String); into_rust_by_name!(Udt, bool); into_rust_by_name!(Udt, i64); into_rust_by_name!(Udt, i32); into_rust_by_name!(Udt, i16); into_rust_by_name!(Udt, i8); into_rust_by_name!(Udt, f64); into_rust_by_name!(Udt, f32); into_rust_by_name!(Udt, IpAddr); into_rust_by_name!(Udt, Uuid); into_rust_by_name!(Udt, List); into_rust_by_name!(Udt, Map); into_rust_by_name!(Udt, Udt); into_rust_by_name!(Udt, Tuple); into_rust_by_name!(Udt, PrimitiveDateTime); into_rust_by_name!(Udt, Decimal); into_rust_by_name!(Udt, NonZeroI8); into_rust_by_name!(Udt, NonZeroI16); into_rust_by_name!(Udt, NonZeroI32); into_rust_by_name!(Udt, NonZeroI64); into_rust_by_name!(Udt, NaiveDateTime); into_rust_by_name!(Udt, DateTime); into_rust_by_name!(Udt, BigInt); udt_as_cassandra_type!(); ================================================ FILE: cassandra-protocol/src/types/value.rs ================================================ use std::cmp::Eq; use std::collections::{BTreeMap, HashMap}; use std::convert::Into; use std::fmt::Debug; use std::hash::Hash; use std::net::IpAddr; use std::num::{NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8}; use chrono::prelude::*; use num_bigint::BigInt; use time::PrimitiveDateTime; use uuid::Uuid; use super::blob::Blob; use super::decimal::Decimal; use super::duration::Duration; use super::*; use crate::Error; const NULL_INT_VALUE: i32 = -1; const NOT_SET_INT_VALUE: i32 = -2; /// Cassandra value which could be an array of bytes, null and non-set values. #[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub enum Value { Some(Vec), Null, NotSet, } impl Value { pub fn new(v: B) -> Value where B: Into, { Value::Some(v.into().0) } } impl Serialize for Value { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self { Value::Null => NULL_INT_VALUE.serialize(cursor, version), Value::NotSet => NOT_SET_INT_VALUE.serialize(cursor, version), Value::Some(value) => { let len = value.len() as CInt; len.serialize(cursor, version); value.serialize(cursor, version); } } } } impl FromCursor for Value { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> Result { // Per the protocol spec: a [value] is encoded as a [int] n followed by // n bytes when n is non-negative. Special negative encodings stand for // null (-1) and "not set" (-2). A length of zero is therefore valid and // represents an empty value (e.g. an empty BLOB or empty TEXT). let value_size = { let mut buff = [0; INT_LEN]; cursor.read_exact(&mut buff)?; CInt::from_be_bytes(buff) }; if value_size >= 0 { // covers both positive lengths and the empty-value (length 0) case Ok(Value::Some(cursor_next_value(cursor, value_size as usize)?)) } else if value_size == NULL_INT_VALUE { Ok(Value::Null) } else if value_size == NOT_SET_INT_VALUE { Ok(Value::NotSet) } else { // any other negative value is not part of the protocol Err(Error::General("Could not decode query values".into())) } } } // We are assuming here primitive value serialization will not change across protocol versions, // which gives us simpler user API. impl> From for Value { fn from(b: T) -> Value { Value::new(b.into()) } } impl> From> for Value { fn from(b: Option) -> Value { match b { Some(b) => Value::new(b.into()), None => Value::Null, } } } #[derive(Debug, Clone, Constructor)] pub struct Bytes(Vec); impl Bytes { /// Consumes `Bytes` and returns the inner `Vec` pub fn into_inner(self) -> Vec { self.0 } } impl From for Bytes { #[inline] fn from(value: String) -> Self { Bytes(value.into_bytes()) } } impl From<&str> for Bytes { #[inline] fn from(value: &str) -> Self { Bytes(value.as_bytes().to_vec()) } } impl From for Bytes { #[inline] fn from(value: i8) -> Self { Bytes(vec![value as u8]) } } impl From for Bytes { #[inline] fn from(value: i16) -> Self { Bytes(to_short(value)) } } impl From for Bytes { #[inline] fn from(value: i32) -> Self { Bytes(to_int(value)) } } impl From for Bytes { #[inline] fn from(value: i64) -> Self { Bytes(to_bigint(value)) } } impl From for Bytes { #[inline] fn from(value: u8) -> Self { Bytes(vec![value]) } } impl From for Bytes { #[inline] fn from(value: u16) -> Self { Bytes(to_u_short(value)) } } impl From for Bytes { #[inline] fn from(value: u32) -> Self { Bytes(to_u_int(value)) } } impl From for Bytes { #[inline] fn from(value: u64) -> Self { Bytes(to_u_big(value)) } } impl From for Bytes { #[inline] fn from(value: NonZeroI8) -> Self { value.get().into() } } impl From for Bytes { #[inline] fn from(value: NonZeroI16) -> Self { value.get().into() } } impl From for Bytes { #[inline] fn from(value: NonZeroI32) -> Self { value.get().into() } } impl From for Bytes { #[inline] fn from(value: NonZeroI64) -> Self { value.get().into() } } impl From for Bytes { #[inline] fn from(value: bool) -> Self { if value { Bytes(vec![1]) } else { Bytes(vec![0]) } } } impl From for Bytes { #[inline] fn from(value: Uuid) -> Self { Bytes(value.as_bytes().to_vec()) } } impl From for Bytes { #[inline] fn from(value: IpAddr) -> Self { match value { IpAddr::V4(ip) => Bytes(ip.octets().to_vec()), IpAddr::V6(ip) => Bytes(ip.octets().to_vec()), } } } impl From for Bytes { #[inline] fn from(value: f32) -> Self { Bytes(to_float(value)) } } impl From for Bytes { #[inline] fn from(value: f64) -> Self { Bytes(to_float_big(value)) } } impl From for Bytes { #[inline] fn from(value: PrimitiveDateTime) -> Self { let ts: i64 = value.assume_utc().unix_timestamp() * 1_000 + value.nanosecond() as i64 / 1_000_000; Bytes(to_bigint(ts)) } } impl From for Bytes { #[inline] fn from(value: Blob) -> Self { Bytes(value.into_vec()) } } impl From for Bytes { #[inline] fn from(value: Decimal) -> Self { Bytes(value.serialize_to_vec(Version::V4)) } } impl From for Bytes { #[inline] fn from(value: NaiveDateTime) -> Self { value.and_utc().timestamp_millis().into() } } impl From> for Bytes { #[inline] fn from(value: DateTime) -> Self { value.timestamp_millis().into() } } impl From for Bytes { #[inline] fn from(value: Duration) -> Self { Bytes(value.serialize_to_vec(Version::V5)) } } impl> From> for Bytes { fn from(vec: Vec) -> Bytes { let mut bytes = Vec::with_capacity(INT_LEN); let len = vec.len() as CInt; bytes.extend_from_slice(&len.to_be_bytes()); let mut cursor = Cursor::new(&mut bytes); cursor.set_position(INT_LEN as u64); for v in vec { let b: Bytes = v.into(); Value::new(b).serialize(&mut cursor, Version::V4); } Bytes(bytes) } } impl From for Bytes { fn from(value: BigInt) -> Self { Self(value.serialize_to_vec(Version::V4)) } } impl From> for Bytes where K: Into + Hash + Eq, V: Into, { fn from(map: HashMap) -> Bytes { let mut bytes = Vec::with_capacity(INT_LEN); let len = map.len() as CInt; bytes.extend_from_slice(&len.to_be_bytes()); let mut cursor = Cursor::new(&mut bytes); cursor.set_position(INT_LEN as u64); for (k, v) in map { let key_bytes: Bytes = k.into(); let val_bytes: Bytes = v.into(); Value::new(key_bytes).serialize(&mut cursor, Version::V4); Value::new(val_bytes).serialize(&mut cursor, Version::V4); } Bytes(bytes) } } impl From> for Bytes where K: Into + Hash + Eq, V: Into, { fn from(map: BTreeMap) -> Bytes { let mut bytes = Vec::with_capacity(INT_LEN); let len = map.len() as CInt; bytes.extend_from_slice(&len.to_be_bytes()); let mut cursor = Cursor::new(&mut bytes); cursor.set_position(INT_LEN as u64); for (k, v) in map { let key_bytes: Bytes = k.into(); let val_bytes: Bytes = v.into(); Value::new(key_bytes).serialize(&mut cursor, Version::V4); Value::new(val_bytes).serialize(&mut cursor, Version::V4); } Bytes(bytes) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_value_serialization() { assert_eq!( Value::Some(vec![1]).serialize_to_vec(Version::V4), vec![0, 0, 0, 1, 1] ); assert_eq!( Value::Some(vec![1, 2, 3]).serialize_to_vec(Version::V4), vec![0, 0, 0, 3, 1, 2, 3] ); assert_eq!( Value::Null.serialize_to_vec(Version::V4), vec![255, 255, 255, 255] ); assert_eq!( Value::NotSet.serialize_to_vec(Version::V4), vec![255, 255, 255, 254] ) } #[test] fn test_value_from_cursor_handles_all_lengths() { // length 0 — a zero-length value is a valid Cassandra value (an empty // string, an empty blob, etc.). It must round-trip as `Value::Some(vec![])`, // not be rejected as malformed. let bytes = vec![0, 0, 0, 0]; let mut cursor = Cursor::new(bytes.as_slice()); assert_eq!( Value::from_cursor(&mut cursor, Version::V4).unwrap(), Value::Some(vec![]) ); // positive length — value bytes follow the 4-byte length let bytes = vec![0, 0, 0, 3, 1, 2, 3]; let mut cursor = Cursor::new(bytes.as_slice()); assert_eq!( Value::from_cursor(&mut cursor, Version::V4).unwrap(), Value::Some(vec![1, 2, 3]) ); // -1 (0xFFFFFFFF) means null let bytes = vec![255, 255, 255, 255]; let mut cursor = Cursor::new(bytes.as_slice()); assert_eq!( Value::from_cursor(&mut cursor, Version::V4).unwrap(), Value::Null ); // -2 (0xFFFFFFFE) means "not set" (unbound bind variable) let bytes = vec![255, 255, 255, 254]; let mut cursor = Cursor::new(bytes.as_slice()); assert_eq!( Value::from_cursor(&mut cursor, Version::V4).unwrap(), Value::NotSet ); // anything else (e.g. -3) is malformed and must error let bytes = vec![255, 255, 255, 253]; let mut cursor = Cursor::new(bytes.as_slice()); assert!(Value::from_cursor(&mut cursor, Version::V4).is_err()); } #[test] fn test_new_value_all_types() { assert_eq!( Value::new("hello"), Value::Some(vec!(104, 101, 108, 108, 111)) ); assert_eq!( Value::new("hello".to_string()), Value::Some(vec!(104, 101, 108, 108, 111)) ); assert_eq!(Value::new(1_u8), Value::Some(vec!(1))); assert_eq!(Value::new(1_u16), Value::Some(vec!(0, 1))); assert_eq!(Value::new(1_u32), Value::Some(vec!(0, 0, 0, 1))); assert_eq!(Value::new(1_u64), Value::Some(vec!(0, 0, 0, 0, 0, 0, 0, 1))); assert_eq!(Value::new(1_i8), Value::Some(vec!(1))); assert_eq!(Value::new(1_i16), Value::Some(vec!(0, 1))); assert_eq!(Value::new(1_i32), Value::Some(vec!(0, 0, 0, 1))); assert_eq!(Value::new(1_i64), Value::Some(vec!(0, 0, 0, 0, 0, 0, 0, 1))); assert_eq!(Value::new(true), Value::Some(vec!(1))); assert_eq!( Value::new(Duration::new(100, 200, 300).unwrap()), Value::Some(vec!(200, 1, 144, 3, 216, 4)) ); } } ================================================ FILE: cassandra-protocol/src/types/vector.rs ================================================ use crate::error::{Error, Result}; use crate::frame::message_result::{ColType, ColTypeOption, ColTypeOptionValue}; use crate::frame::Version; use crate::types::data_serialization_types::*; use crate::types::{AsRust, AsRustType, CBytes}; use derive_more::Constructor; use itertools::Itertools; // TODO: consider using pointers to ColTypeOption and Vec instead of owning them. #[derive(Debug, Constructor)] pub struct Vector { /// column spec of the list, i.e. id should be List as it's a list and value should contain /// a type of list items. metadata: ColTypeOption, data: Vec, protocol_version: Version, } impl Vector { fn try_map(&self, f: F) -> Result> where F: FnMut(&CBytes) -> Result, { self.data.iter().map(f).try_collect() } } pub struct VectorInfo { pub internal_type: String, pub count: usize, } pub fn get_vector_type_info(option_value: &ColTypeOptionValue) -> Result { let input = match option_value { ColTypeOptionValue::CString(ref s) => s, _ => return Err(Error::General("Option value must be a string!".into())), }; let _custom_type = input.split('(').next().unwrap().rsplit('.').next().unwrap(); let vector_type = input .split('(') .nth(1) .and_then(|s| s.split(',').next()) .and_then(|s| s.rsplit('.').next()) .map(|s| s.trim()) .ok_or_else(|| Error::General("Cannot parse vector type!".into()))?; let count: usize = input .split('(') .nth(1) .and_then(|s| s.rsplit(',').next()) .and_then(|s| s.split(')').next()) .map(|s| s.trim().parse()) .transpose() .map_err(|_| Error::General("Cannot parse vector count!".to_string()))? .ok_or_else(|| Error::General("Cannot parse vector count!".into()))?; Ok(VectorInfo { internal_type: vector_type.to_string(), count, }) } impl AsRust for Vector {} vector_as_rust!(f32); vector_as_cassandra_type!(); ================================================ FILE: cassandra-protocol/src/types.rs ================================================ use self::cassandra_type::CassandraType; use crate::error::{column_is_empty_err, Error as CdrsError, Result as CDRSResult}; use crate::frame::traits::FromCursor; use crate::frame::{Serialize, Version}; use crate::types::data_serialization_types::*; use derive_more::Constructor; use std::convert::TryInto; use std::io::{self, Write}; use std::io::{Cursor, Read}; use std::net::{IpAddr, SocketAddr}; pub const SHORT_LEN: usize = 2; pub const INT_LEN: usize = 4; pub const LONG_LEN: usize = 8; pub const UUID_LEN: usize = 16; const NULL_INT_LEN: CInt = -1; const NULL_SHORT_LEN: CIntShort = -1; #[macro_use] pub mod blob; pub mod cassandra_type; pub mod data_serialization_types; pub mod decimal; pub mod duration; pub mod from_cdrs; pub mod list; pub mod map; pub mod rows; pub mod tuple; pub mod udt; pub mod value; pub mod vector; pub mod prelude { pub use crate::error::{Error, Result}; pub use crate::frame::{TryFromRow, TryFromUdt}; pub use crate::types::blob::Blob; pub use crate::types::decimal::Decimal; pub use crate::types::duration::Duration; pub use crate::types::list::List; pub use crate::types::map::Map; pub use crate::types::rows::Row; pub use crate::types::tuple::Tuple; pub use crate::types::udt::Udt; pub use crate::types::value::{Bytes, Value}; pub use crate::types::AsRustType; } pub trait AsCassandraType { fn as_cassandra_type(&self) -> CDRSResult>; } /// Should be used to represent a single column as a Rust value. pub trait AsRustType { fn as_rust_type(&self) -> CDRSResult>; fn as_r_type(&self) -> CDRSResult { self.as_rust_type() .and_then(|op| op.ok_or_else(|| CdrsError::from("Value is null or non-set"))) } } pub trait AsRust { fn as_rust(&self) -> CDRSResult> where Self: AsRustType, { self.as_rust_type() } fn as_r_rust(&self) -> CDRSResult where Self: AsRustType, { self.as_rust() .and_then(|op| op.ok_or_else(|| "Value is null or non-set".into())) } } /// Should be used to return a single column as a Rust value by its name. pub trait IntoRustByName { fn get_by_name(&self, name: &str) -> CDRSResult>; fn get_r_by_name(&self, name: &str) -> CDRSResult { self.get_by_name(name) .and_then(|op| op.ok_or_else(|| column_is_empty_err(name))) } } pub trait ByName { fn by_name(&self, name: &str) -> CDRSResult> where Self: IntoRustByName, { self.get_by_name(name) } fn r_by_name(&self, name: &str) -> CDRSResult where Self: IntoRustByName, { self.by_name(name) .and_then(|op| op.ok_or_else(|| column_is_empty_err(name))) } } /// Should be used to return a single column as a Rust value by its name. pub trait IntoRustByIndex { fn get_by_index(&self, index: usize) -> CDRSResult>; fn get_r_by_index(&self, index: usize) -> CDRSResult { self.get_by_index(index) .and_then(|op| op.ok_or_else(|| column_is_empty_err(index))) } } pub trait ByIndex { fn by_index(&self, index: usize) -> CDRSResult> where Self: IntoRustByIndex, { self.get_by_index(index) } fn r_by_index(&self, index: usize) -> CDRSResult where Self: IntoRustByIndex, { self.by_index(index) .and_then(|op| op.ok_or_else(|| column_is_empty_err(index))) } } #[inline] fn convert_to_array(bytes: &[u8]) -> Result<[u8; S], io::Error> { bytes .try_into() .map_err(|error| io::Error::new(io::ErrorKind::UnexpectedEof, error)) } #[inline] pub fn try_u64_from_bytes(bytes: &[u8]) -> Result { convert_to_array(bytes).map(u64::from_be_bytes) } #[inline] pub fn try_i64_from_bytes(bytes: &[u8]) -> Result { convert_to_array(bytes).map(i64::from_be_bytes) } #[inline] pub fn try_i32_from_bytes(bytes: &[u8]) -> Result { convert_to_array(bytes).map(i32::from_be_bytes) } #[inline] pub fn try_i16_from_bytes(bytes: &[u8]) -> Result { convert_to_array(bytes).map(i16::from_be_bytes) } #[inline] pub fn try_f32_from_bytes(bytes: &[u8]) -> Result { convert_to_array(bytes).map(f32::from_be_bytes) } #[inline] pub fn try_f64_from_bytes(bytes: &[u8]) -> Result { convert_to_array(bytes).map(f64::from_be_bytes) } #[inline] pub fn u16_from_bytes(bytes: [u8; 2]) -> u16 { u16::from_be_bytes(bytes) } #[inline] pub fn to_short(int: i16) -> Vec { int.to_be_bytes().into() } #[inline] pub fn to_int(int: i32) -> Vec { int.to_be_bytes().into() } #[inline] pub fn to_bigint(int: i64) -> Vec { int.to_be_bytes().into() } #[inline] pub fn to_u_short(int: u16) -> Vec { int.to_be_bytes().into() } #[inline] pub fn to_u_int(int: u32) -> Vec { int.to_be_bytes().into() } #[inline] pub fn to_u_big(int: u64) -> Vec { int.to_be_bytes().into() } #[inline] pub fn to_float(f: f32) -> Vec { f.to_be_bytes().into() } #[inline] pub fn to_float_big(f: f64) -> Vec { f.to_be_bytes().into() } pub fn serialize_str(cursor: &mut Cursor<&mut Vec>, value: &str, version: Version) { let len = value.len() as CIntShort; len.serialize(cursor, version); let _ = cursor.write(value.as_bytes()); } pub(crate) fn serialize_str_long(cursor: &mut Cursor<&mut Vec>, value: &str, version: Version) { let len = value.len() as CInt; len.serialize(cursor, version); let _ = cursor.write(value.as_bytes()); } pub(crate) fn from_cursor_str<'a>(cursor: &mut Cursor<&'a [u8]>) -> CDRSResult<&'a str> { let mut buff = [0; SHORT_LEN]; cursor.read_exact(&mut buff)?; let len = CIntShort::from_be_bytes(buff); let body_bytes = cursor_next_value_ref(cursor, len as usize)?; std::str::from_utf8(body_bytes).map_err(Into::into) } pub(crate) fn from_cursor_str_long<'a>(cursor: &mut Cursor<&'a [u8]>) -> CDRSResult<&'a str> { let mut buff = [0; INT_LEN]; cursor.read_exact(&mut buff)?; let len = CInt::from_be_bytes(buff); let body_bytes = cursor_next_value_ref(cursor, len as usize)?; std::str::from_utf8(body_bytes).map_err(Into::into) } pub(crate) fn serialize_str_list<'a>( cursor: &mut Cursor<&mut Vec>, list: impl ExactSizeIterator, version: Version, ) { let len = list.len() as CIntShort; len.serialize(cursor, version); for string in list { serialize_str(cursor, string, version); } } pub fn from_cursor_string_list(cursor: &mut Cursor<&[u8]>) -> CDRSResult> { let mut buff = [0; SHORT_LEN]; cursor.read_exact(&mut buff)?; let len = i16::from_be_bytes(buff); let mut list = Vec::with_capacity(len as usize); for _ in 0..len { list.push(from_cursor_str(cursor)?.to_string()); } Ok(list) } #[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)] /// The structure that represents a Cassandra byte type. pub struct CBytes { bytes: Option>, } impl CBytes { #[inline] pub fn new(bytes: Vec) -> CBytes { CBytes { bytes: Some(bytes) } } /// Creates Cassandra bytes that represent null value #[inline] pub fn new_null() -> CBytes { CBytes { bytes: None } } /// Converts `CBytes` into a plain array of bytes #[inline] pub fn into_bytes(self) -> Option> { self.bytes } #[inline] pub fn as_slice(&self) -> Option<&[u8]> { self.bytes.as_deref() } #[inline] pub fn is_null_or_empty(&self) -> bool { match &self.bytes { None => true, Some(bytes) => bytes.is_empty(), } } } impl FromCursor for CBytes { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> CDRSResult { let len = CInt::from_cursor(cursor, version)?; // null or not set value if len < 0 { return Ok(CBytes { bytes: None }); } cursor_next_value(cursor, len as usize).map(CBytes::new) } } impl Serialize for CBytes { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match &self.bytes { Some(bytes) => { let len = bytes.len() as CInt; len.serialize(cursor, version); bytes.serialize(cursor, version); } None => NULL_INT_LEN.serialize(cursor, version), } } } /// Cassandra short bytes #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default)] pub struct CBytesShort { bytes: Option>, } impl CBytesShort { #[inline] pub fn new(bytes: Vec) -> CBytesShort { CBytesShort { bytes: Some(bytes) } } /// Converts `CBytesShort` into a plain vector of bytes #[inline] pub fn into_bytes(self) -> Option> { self.bytes } #[inline] pub fn serialized_len(&self) -> usize { SHORT_LEN + if let Some(bytes) = &self.bytes { bytes.len() } else { 0 } } } impl FromCursor for CBytesShort { /// `from_cursor` gets the Cursor whose position is set such that it should be a start of bytes. /// It reads the required number of bytes and returns a CBytes fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> CDRSResult { let len = CIntShort::from_cursor(cursor, version)?; if len < 0 { return Ok(CBytesShort { bytes: None }); } cursor_next_value(cursor, len as usize).map(CBytesShort::new) } } impl Serialize for CBytesShort { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match &self.bytes { Some(bytes) => { let len = bytes.len() as CIntShort; len.serialize(cursor, version); bytes.serialize(cursor, version); } None => NULL_SHORT_LEN.serialize(cursor, version), } } } /// Cassandra int type. pub type CInt = i32; impl FromCursor for CInt { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> CDRSResult { let mut buff = [0; INT_LEN]; cursor.read_exact(&mut buff)?; Ok(CInt::from_be_bytes(buff)) } } /// Cassandra int short type. pub type CIntShort = i16; impl FromCursor for CIntShort { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> CDRSResult { let mut buff = [0; SHORT_LEN]; cursor.read_exact(&mut buff)?; Ok(CIntShort::from_be_bytes(buff)) } } /// Cassandra long type. pub type CLong = i64; impl FromCursor for CLong { fn from_cursor(cursor: &mut Cursor<&[u8]>, _version: Version) -> CDRSResult { let mut buff = [0; LONG_LEN]; cursor.read_exact(&mut buff)?; Ok(CLong::from_be_bytes(buff)) } } impl Serialize for SocketAddr { fn serialize(&self, cursor: &mut Cursor<&mut Vec>, version: Version) { match self.ip() { IpAddr::V4(v4) => { [4].serialize(cursor, version); v4.octets().serialize(cursor, version); } IpAddr::V6(v6) => { [16].serialize(cursor, version); v6.octets().serialize(cursor, version); } } to_int(self.port().into()).serialize(cursor, version); } } impl FromCursor for SocketAddr { fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> CDRSResult { let mut buff = [0]; cursor.read_exact(&mut buff)?; let n = buff[0]; let ip = decode_inet(cursor_next_value(cursor, n as usize)?.as_slice())?; let port = CInt::from_cursor(cursor, version)?; Ok(SocketAddr::new(ip, port as u16)) } } pub fn cursor_next_value(cursor: &mut Cursor<&[u8]>, len: usize) -> CDRSResult> { let mut buff = vec![0u8; len]; cursor.read_exact(&mut buff)?; Ok(buff) } pub fn cursor_next_value_ref<'a>( cursor: &mut Cursor<&'a [u8]>, len: usize, ) -> CDRSResult<&'a [u8]> { let start = cursor.position() as usize; let result = &cursor.get_ref()[start..start + len]; cursor.set_position(cursor.position() + len as u64); if result.len() != len { Err(CdrsError::General( "cursor_next_value_ref could not retrieve a full slice".into(), )) } else { Ok(result) } } #[cfg(test)] mod tests { use super::*; use crate::frame::traits::FromCursor; use num_bigint::BigInt; use std::io::Cursor; fn from_i_bytes(bytes: &[u8]) -> i64 { try_i64_from_bytes(bytes).unwrap() } fn try_u16_from_bytes(bytes: &[u8]) -> Result { Ok(u16::from_be_bytes(convert_to_array(bytes)?)) } fn to_varint(int: BigInt) -> Vec { int.to_signed_bytes_be() } #[test] fn test_from_cursor_str() { let a = &[0, 3, 102, 111, 111, 0]; let mut cursor: Cursor<&[u8]> = Cursor::new(a); let cstring = from_cursor_str(&mut cursor).unwrap(); assert_eq!(cstring, "foo"); } #[test] fn test_from_cursor_str_long() { let a = &[0, 0, 0, 3, 102, 111, 111, 0]; let mut cursor: Cursor<&[u8]> = Cursor::new(a); let cstring = from_cursor_str_long(&mut cursor).unwrap(); assert_eq!(cstring, "foo"); } #[test] fn test_serialize_str() { let input = "foo"; let mut buf = vec![]; serialize_str(&mut Cursor::new(&mut buf), input, Version::V4); assert_eq!(buf, &[0, 3, 102, 111, 111]); } #[test] fn test_serialize_str_long() { let input = "foo"; let mut buf = vec![]; serialize_str_long(&mut Cursor::new(&mut buf), input, Version::V4); assert_eq!(buf, &[0, 0, 0, 3, 102, 111, 111]); } #[test] fn test_cstringlist() { let a = &[0, 2, 0, 3, 102, 111, 111, 0, 3, 102, 111, 112]; let mut cursor: Cursor<&[u8]> = Cursor::new(a); let list = from_cursor_string_list(&mut cursor).unwrap(); assert_eq!(list, vec!("foo".to_string(), "fop".to_string())); } // CBytes #[test] fn test_cbytes_new() { let bytes_vec = vec![1, 2, 3]; let _ = CBytes::new(bytes_vec); } #[test] fn test_cbytes_into_bytes() { let cbytes = CBytes::new(vec![1, 2, 3]); assert_eq!(cbytes.into_bytes().unwrap(), &[1, 2, 3]); } #[test] fn test_cbytes_from_cursor() { let a = &[0, 0, 0, 3, 1, 2, 3]; let mut cursor: Cursor<&[u8]> = Cursor::new(a); let cbytes = CBytes::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(cbytes.into_bytes().unwrap(), vec![1, 2, 3]); } #[test] fn test_cbytes_serialize() { let bytes_vec = vec![1, 2, 3]; let cbytes = CBytes::new(bytes_vec); assert_eq!( cbytes.serialize_to_vec(Version::V4), vec![0, 0, 0, 3, 1, 2, 3] ); } // CBytesShort #[test] fn test_cbytesshort_new() { let bytes_vec = vec![1, 2, 3]; let _ = CBytesShort::new(bytes_vec); } #[test] fn test_cbytesshort_into_bytes() { let cbytes = CBytesShort::new(vec![1, 2, 3]); assert_eq!(cbytes.into_bytes().unwrap(), vec![1, 2, 3]); } #[test] fn test_cbytesshort_from_cursor() { let a = &[0, 3, 1, 2, 3]; let mut cursor: Cursor<&[u8]> = Cursor::new(a); let cbytes = CBytesShort::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(cbytes.into_bytes().unwrap(), vec![1, 2, 3]); } #[test] fn test_cbytesshort_serialize() { let bytes_vec: Vec = vec![1, 2, 3]; let cbytes = CBytesShort::new(bytes_vec); assert_eq!(cbytes.serialize_to_vec(Version::V4), vec![0, 3, 1, 2, 3]); } #[test] fn test_cint_from_cursor() { let a = &[0, 0, 0, 5]; let mut cursor: Cursor<&[u8]> = Cursor::new(a); let i = CInt::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(i, 5); } #[test] fn test_cintshort_from_cursor() { let a = &[0, 5]; let mut cursor: Cursor<&[u8]> = Cursor::new(a); let i = CIntShort::from_cursor(&mut cursor, Version::V4).unwrap(); assert_eq!(i, 5); } #[test] fn test_cursor_next_value() { let a = &[0, 1, 2, 3, 4]; let mut cursor: Cursor<&[u8]> = Cursor::new(a); let l = 3; let val = cursor_next_value(&mut cursor, l).unwrap(); assert_eq!(val, vec![0, 1, 2]); } #[test] fn test_try_u16_from_bytes() { let bytes: [u8; 2] = [0, 12]; // or .to_le() let val = try_u16_from_bytes(&bytes); assert_eq!(val.unwrap(), 12u16); } #[test] fn test_from_i_bytes() { let bytes: [u8; 8] = [0, 0, 0, 0, 0, 0, 0, 12]; // or .to_le() let val = from_i_bytes(&bytes); assert_eq!(val, 12i64); } #[test] fn test_to_varint() { assert_eq!(to_varint(0.into()), vec![0x00]); assert_eq!(to_varint(1.into()), vec![0x01]); assert_eq!(to_varint(127.into()), vec![0x7F]); assert_eq!(to_varint(128.into()), vec![0x00, 0x80]); assert_eq!(to_varint(129.into()), vec![0x00, 0x81]); assert_eq!(to_varint(BigInt::from(-1)), vec![0xFF]); assert_eq!(to_varint(BigInt::from(-128)), vec![0x80]); assert_eq!(to_varint(BigInt::from(-129)), vec![0xFF, 0x7F]); } } ================================================ FILE: cdrs-tokio/Cargo.toml ================================================ [package] name = "cdrs-tokio" version = "9.0.1" authors = ["Alex Pikalov ", "Kamil Rojewski "] edition = "2018" description = "Async Cassandra DB driver written in Rust" documentation = "https://docs.rs/cdrs-tokio" homepage = "https://github.com/krojew/cdrs-tokio" repository = "https://github.com/krojew/cdrs-tokio" readme = "../README.md" keywords = ["cassandra", "driver", "client", "cassandradb", "async"] license = "MIT/Apache-2.0" categories = ["asynchronous", "database"] rust-version = "1.80" [features] rust-tls = ["tokio-rustls", "webpki"] e2e-tests = [] derive = ["cdrs-tokio-helpers-derive"] http-proxy = ["async-http-proxy"] [dependencies] arc-swap.workspace = true atomic = "0.6.0" bytemuck = { version = "1.22.0", features = ["derive"] } cassandra-protocol = { path = "../cassandra-protocol", version = "4.0.0" } cdrs-tokio-helpers-derive = { path = "../cdrs-tokio-helpers-derive", version = "5.0.3", optional = true } derive_more.workspace = true derivative.workspace = true futures = { version = "0.3.28", default-features = false, features = ["alloc"] } fxhash = "0.2.1" itertools.workspace = true rand = "0.10.0" serde_json = "1.0.140" thiserror.workspace = true tokio = { version = "1.48.0", features = ["net", "io-util", "rt", "sync", "macros", "rt-multi-thread", "time"] } # note: default features for tokio-rustls include aws_lc_rs, which require clang on Windows => disable and let users # enable it explicitly tokio-rustls = { version = "0.26.0", optional = true, default-features = false, features = ["logging", "tls12"] } tracing = "0.1.41" uuid.workspace = true webpki = { version = "0.22.2", optional = true } [dependencies.async-http-proxy] version = "1.2.5" optional = true features = ["runtime-tokio", "basic-auth"] [dev-dependencies] float_eq = "1.0.1" maplit = "1.0.2" mockall = "0.14.0" regex = "1.11.1" uuid = { version = "1.19.0", features = ["v4"] } time = { version = "0.3.41", features = ["std", "macros"] } [[example]] name = "crud_operations" required-features = ["derive"] [[example]] name = "generic_connection" required-features = ["derive"] [[example]] name = "insert_collection" required-features = ["derive"] [[example]] name = "multiple_thread" required-features = ["derive"] [[example]] name = "paged_query" required-features = ["derive"] [[example]] name = "prepare_batch_execute" required-features = ["derive"] ================================================ FILE: cdrs-tokio/examples/README.md ================================================ # CDRS examples - [`crud_operations.rs`](./crud_operations.rs) demonstrates how to create keyspace, table and user defined type. As well basic CRUD (Create, Read, Update, Delete) operations are shown; - [`insert_collection.rs`](./insert_collection.rs) demonstrates how to insert items in lists, maps and sets; - [`multiple_thread.rs`](./multiple_thread.rs) shows how to use CDRS in multi thread applications; - [`paged_query.rs`](./paged_query.rs) uncovers query paging; - [`prepare_batch_execute.rs`](./prepare_batch_execute.rs) provides an example of query preparation and batching; - [`aws cassandra crud operations`](https://github.com/AERC18/cdrs-aws-cassandra) illustrates how to connect and do CRUD operations on Amazon Managed Apache Cassandra Service. ================================================ FILE: cdrs-tokio/examples/crud_operations.rs ================================================ #[macro_use] extern crate maplit; use cdrs_tokio::authenticators::StaticPasswordAuthenticatorProvider; use cdrs_tokio::cluster::session::{Session, SessionBuilder, TcpSessionBuilder}; use cdrs_tokio::cluster::{NodeTcpConfigBuilder, TcpConnectionManager}; use cdrs_tokio::frame::TryFromRow; use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; use cdrs_tokio::query::*; use cdrs_tokio::query_values; use cdrs_tokio::transport::TransportTcp; use cdrs_tokio::{IntoCdrsValue, TryFromRow, TryFromUdt}; use std::collections::HashMap; use std::sync::Arc; type CurrentSession = Session< TransportTcp, TcpConnectionManager, RoundRobinLoadBalancingStrategy, >; #[tokio::main] async fn main() { let user = "user"; let password = "password"; let auth = StaticPasswordAuthenticatorProvider::new(&user, &password); let config = NodeTcpConfigBuilder::new() .with_contact_point("localhost:9042".into()) .with_authenticator_provider(Arc::new(auth)) .build() .await .unwrap(); let mut session: CurrentSession = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), config) .build() .await .unwrap(); create_keyspace(&mut session).await; create_udt(&mut session).await; create_table(&mut session).await; insert_struct(&mut session).await; select_struct(&mut session).await; update_struct(&mut session).await; delete_struct(&mut session).await; } #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { key: i32, user: User, map: HashMap, list: Vec, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("key" => self.key, "user" => self.user, "map" => self.map, "list" => self.list) } } #[derive(Debug, Clone, PartialEq, IntoCdrsValue, TryFromUdt)] struct User { username: String, } async fn create_keyspace(session: &mut CurrentSession) { let create_ks: &'static str = "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; session .query(create_ks) .await .expect("Keyspace creation error"); } async fn create_udt(session: &mut CurrentSession) { let create_type_cql = "CREATE TYPE IF NOT EXISTS test_ks.user (username text)"; session .query(create_type_cql) .await .expect("Keyspace creation error"); } async fn create_table(session: &mut CurrentSession) { let create_table_cql = "CREATE TABLE IF NOT EXISTS test_ks.my_test_table (key int PRIMARY KEY, \ user frozen, map map>, list list>);"; session .query(create_table_cql) .await .expect("Table creation error"); } async fn insert_struct(session: &mut CurrentSession) { let row = RowStruct { key: 3i32, user: User { username: "John".to_string(), }, map: hashmap! { "John".to_string() => User { username: "John".to_string() } }, list: vec![User { username: "John".to_string(), }], }; let insert_struct_cql = "INSERT INTO test_ks.my_test_table \ (key, user, map, list) VALUES (?, ?, ?, ?)"; session .query_with_values(insert_struct_cql, row.into_query_values()) .await .expect("insert"); } async fn select_struct(session: &mut CurrentSession) { let select_struct_cql = "SELECT * FROM test_ks.my_test_table"; let rows = session .query(select_struct_cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); for row in rows { let my_row: RowStruct = RowStruct::try_from_row(row).expect("into RowStruct"); println!("struct got: {my_row:?}"); } } async fn update_struct(session: &mut CurrentSession) { let update_struct_cql = "UPDATE test_ks.my_test_table SET user = ? WHERE key = ?"; let upd_user = User { username: "Marry".to_string(), }; let user_key = 1i32; session .query_with_values(update_struct_cql, query_values!(upd_user, user_key)) .await .expect("update"); } async fn delete_struct(session: &mut CurrentSession) { let delete_struct_cql = "DELETE FROM test_ks.my_test_table WHERE key = ?"; let user_key = 1i32; session .query_with_values(delete_struct_cql, query_values!(user_key)) .await .expect("delete"); } ================================================ FILE: cdrs-tokio/examples/generic_connection.rs ================================================ use cdrs_tokio::cluster::connection_pool::ConnectionPoolConfig; use cdrs_tokio::cluster::session::{ NodeDistanceEvaluatorWrapper, ReconnectionPolicyWrapper, RetryPolicyWrapper, DEFAULT_TRANSPORT_BUFFER_SIZE, }; use cdrs_tokio::cluster::{ConnectionManager, KeyspaceHolder}; use cdrs_tokio::compression::Compression; use cdrs_tokio::frame::{Envelope, Version}; use cdrs_tokio::frame_encoding::ProtocolFrameEncodingFactory; use cdrs_tokio::future::BoxFuture; use cdrs_tokio::load_balancing::node_distance_evaluator::AllLocalNodeDistanceEvaluator; use cdrs_tokio::retry::ConstantReconnectionPolicy; use cdrs_tokio::IntoCdrsValue; use cdrs_tokio::{ authenticators::{SaslAuthenticatorProvider, StaticPasswordAuthenticatorProvider}, cluster::session::Session, cluster::{GenericClusterConfig, TcpConnectionManager}, error::Result, load_balancing::RoundRobinLoadBalancingStrategy, query::*, query_values, retry::DefaultRetryPolicy, transport::TransportTcp, types::prelude::*, TryFromRow, TryFromUdt, }; use futures::FutureExt; use maplit::hashmap; use std::{ collections::HashMap, net::IpAddr, net::{Ipv4Addr, SocketAddr}, sync::Arc, }; use tokio::sync::mpsc::Sender; type CurrentSession = Session< TransportTcp, VirtualConnectionManager, RoundRobinLoadBalancingStrategy, >; /// Implements a cluster configuration where the addresses to /// connect to are different from the ones configured by replacing /// the masked part of the address with a different subnet. /// /// This would allow running your connection through a proxy /// or mock server while also using a production configuration /// and having your load balancing configuration be aware of the /// 'real' addresses. /// /// This is just a simple use for the generic configuration. By /// replacing the transport itself you can do much more. struct VirtualClusterConfig { authenticator: Arc, mask: Ipv4Addr, actual: Ipv4Addr, version: Version, } fn rewrite(addr: SocketAddr, mask: &Ipv4Addr, actual: &Ipv4Addr) -> SocketAddr { match addr { SocketAddr::V4(addr) => { let virt = addr.ip().octets(); let mask = mask.octets(); let actual = actual.octets(); SocketAddr::new( IpAddr::V4(Ipv4Addr::new( (virt[0] & !mask[0]) | (actual[0] & mask[0]), (virt[1] & !mask[1]) | (actual[1] & mask[1]), (virt[2] & !mask[2]) | (actual[2] & mask[2]), (virt[3] & !mask[3]) | (actual[3] & mask[3]), )), addr.port(), ) } SocketAddr::V6(_) => { panic!("IpV6 is unsupported!"); } } } struct VirtualConnectionManager { inner: TcpConnectionManager, mask: Ipv4Addr, actual: Ipv4Addr, } impl ConnectionManager for VirtualConnectionManager { fn connection( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> BoxFuture<'_, Result> { self.inner.connection( event_handler, error_handler, rewrite(addr, &self.mask, &self.actual), ) } } impl VirtualConnectionManager { async fn new( config: &VirtualClusterConfig, keyspace_holder: Arc, ) -> Result { Ok(VirtualConnectionManager { inner: TcpConnectionManager::new( config.authenticator.clone(), keyspace_holder, Box::::default(), Compression::None, DEFAULT_TRANSPORT_BUFFER_SIZE, true, config.version, #[cfg(feature = "http-proxy")] None, ), mask: config.mask, actual: config.actual, }) } } impl GenericClusterConfig for VirtualClusterConfig { fn create_manager( &self, keyspace_holder: Arc, ) -> BoxFuture<'_, Result> { // create a connection manager that points at the rewritten address so that's where it connects, but // then return a manager with the 'virtual' address for internal purposes. VirtualConnectionManager::new(self, keyspace_holder).boxed() } fn event_channel_capacity(&self) -> usize { 32 } fn version(&self) -> Version { self.version } fn connection_pool_config(&self) -> ConnectionPoolConfig { Default::default() } } #[tokio::main] async fn main() { let user = "user"; let password = "password"; let authenticator = Arc::new(StaticPasswordAuthenticatorProvider::new(&user, &password)); let mask = Ipv4Addr::new(255, 255, 255, 0); let actual = Ipv4Addr::new(127, 0, 0, 0); let reconnection_policy = Arc::new(ConstantReconnectionPolicy::default()); let cluster_config = VirtualClusterConfig { authenticator, mask, actual, version: Version::V5, }; let nodes = [ SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9042), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9043), ]; let load_balancing = RoundRobinLoadBalancingStrategy::new(); let mut session = cdrs_tokio::cluster::connect_generic( &cluster_config, nodes, load_balancing, RetryPolicyWrapper(Box::::default()), ReconnectionPolicyWrapper(reconnection_policy), NodeDistanceEvaluatorWrapper(Box::::default()), None, ) .await .expect("session should be created"); create_keyspace(&mut session).await; create_udt(&mut session).await; create_table(&mut session).await; insert_struct(&mut session).await; select_struct(&mut session).await; update_struct(&mut session).await; delete_struct(&mut session).await; } #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { key: i32, user: User, map: HashMap, list: Vec, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("key" => self.key, "user" => self.user, "map" => self.map, "list" => self.list) } } #[derive(Debug, Clone, PartialEq, IntoCdrsValue, TryFromUdt)] struct User { username: String, } async fn create_keyspace(session: &mut CurrentSession) { let create_ks: &'static str = "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; session .query(create_ks) .await .expect("Keyspace creation error"); } async fn create_udt(session: &mut CurrentSession) { let create_type_cql = "CREATE TYPE IF NOT EXISTS test_ks.user (username text)"; session .query(create_type_cql) .await .expect("Keyspace creation error"); } async fn create_table(session: &mut CurrentSession) { let create_table_cql = "CREATE TABLE IF NOT EXISTS test_ks.my_test_table (key int PRIMARY KEY, \ user frozen, map map>, list list>);"; session .query(create_table_cql) .await .expect("Table creation error"); } //noinspection DuplicatedCode async fn insert_struct(session: &mut CurrentSession) { let row = RowStruct { key: 3i32, user: User { username: "John".to_string(), }, map: hashmap! { "John".to_string() => User { username: "John".to_string() } }, list: vec![User { username: "John".to_string(), }], }; let insert_struct_cql = "INSERT INTO test_ks.my_test_table \ (key, user, map, list) VALUES (?, ?, ?, ?)"; session .query_with_values(insert_struct_cql, row.into_query_values()) .await .expect("insert"); } //noinspection DuplicatedCode async fn select_struct(session: &mut CurrentSession) { let select_struct_cql = "SELECT * FROM test_ks.my_test_table"; let rows = session .query(select_struct_cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); for row in rows { let my_row: RowStruct = RowStruct::try_from_row(row).expect("into RowStruct"); println!("struct got: {my_row:?}"); } } //noinspection DuplicatedCode async fn update_struct(session: &mut CurrentSession) { let update_struct_cql = "UPDATE test_ks.my_test_table SET user = ? WHERE key = ?"; let upd_user = User { username: "Marry".to_string(), }; let user_key = 1i32; session .query_with_values(update_struct_cql, query_values!(upd_user, user_key)) .await .expect("update"); } async fn delete_struct(session: &mut CurrentSession) { let delete_struct_cql = "DELETE FROM test_ks.my_test_table WHERE key = ?"; let user_key = 1i32; session .query_with_values(delete_struct_cql, query_values!(user_key)) .await .expect("delete"); } ================================================ FILE: cdrs-tokio/examples/insert_collection.rs ================================================ #[macro_use] extern crate maplit; use cdrs_tokio::authenticators::StaticPasswordAuthenticatorProvider; use cdrs_tokio::cluster::session::{Session, SessionBuilder, TcpSessionBuilder}; use cdrs_tokio::cluster::{NodeTcpConfigBuilder, TcpConnectionManager}; use cdrs_tokio::frame::TryFromRow; use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; use cdrs_tokio::query::*; use cdrs_tokio::query_values; use cdrs_tokio::transport::TransportTcp; use cdrs_tokio::{IntoCdrsValue, TryFromRow, TryFromUdt}; use std::collections::HashMap; use std::sync::Arc; type CurrentSession = Session< TransportTcp, TcpConnectionManager, RoundRobinLoadBalancingStrategy, >; #[tokio::main] async fn main() { let user = "user"; let password = "password"; let auth = StaticPasswordAuthenticatorProvider::new(&user, &password); let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(auth)) .build() .await .unwrap(); let mut session: CurrentSession = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), cluster_config) .build() .await .unwrap(); create_keyspace(&mut session).await; create_udt(&mut session).await; create_table(&mut session).await; insert_struct(&mut session).await; append_list(&mut session).await; prepend_list(&mut session).await; append_set(&mut session).await; append_map(&mut session).await; select_struct(&mut session).await; delete_struct(&mut session).await; } #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { key: i32, map: HashMap, list: Vec, cset: Vec, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("key" => self.key, "map" => self.map, "list" => self.list, "cset" => self.cset) } } #[derive(Debug, Clone, PartialEq, IntoCdrsValue, TryFromUdt)] struct User { username: String, } async fn create_keyspace(session: &mut CurrentSession) { let create_ks: &'static str = "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; session .query(create_ks) .await .expect("Keyspace creation error"); } async fn create_udt(session: &mut CurrentSession) { let create_type_cql = "CREATE TYPE IF NOT EXISTS test_ks.user (username text)"; session .query(create_type_cql) .await .expect("Keyspace creation error"); } async fn create_table(session: &mut CurrentSession) { let create_table_cql = "CREATE TABLE IF NOT EXISTS test_ks.collection_table (key int PRIMARY KEY, \ user frozen, map map>, \ list list>, cset set>);"; session .query(create_table_cql) .await .expect("Table creation error"); } async fn append_list(session: &mut CurrentSession) { let key = 3i32; let extra_values = vec![ User { username: "William".to_string(), }, User { username: "Averel".to_string(), }, ]; let append_list_cql = "UPDATE test_ks.collection_table SET list = list + ? \ WHERE key = ?"; session .query_with_values(append_list_cql, query_values!(extra_values, key)) .await .expect("append list"); } async fn prepend_list(session: &mut CurrentSession) { let key = 3i32; let extra_values = vec![ User { username: "Joe".to_string(), }, User { username: "Jack".to_string(), }, ]; let prepend_list_cql = "UPDATE test_ks.collection_table SET list = ? + list \ WHERE key = ?"; session .query_with_values(prepend_list_cql, query_values!(extra_values, key)) .await .expect("prepend list"); } async fn append_set(session: &mut CurrentSession) { let key = 3i32; let extra_values = vec![ User { username: "William".to_string(), }, User { username: "Averel".to_string(), }, ]; let append_set_cql = "UPDATE test_ks.collection_table SET cset = cset + ? \ WHERE key = ?"; session .query_with_values(append_set_cql, query_values!(extra_values, key)) .await .expect("append set"); } async fn append_map(session: &mut CurrentSession) { let key = 3i32; let extra_values = hashmap![ "Joe".to_string() => User { username: "Joe".to_string() }, "Jack".to_string() => User { username: "Jack".to_string() }, ]; let append_map_cql = "UPDATE test_ks.collection_table SET map = map + ? \ WHERE key = ?"; session .query_with_values(append_map_cql, query_values!(extra_values, key)) .await .expect("append map"); } async fn insert_struct(session: &mut CurrentSession) { let row = RowStruct { key: 3i32, map: hashmap! { "John".to_string() => User { username: "John".to_string() } }, list: vec![User { username: "John".to_string(), }], cset: vec![User { username: "John".to_string(), }], }; let insert_struct_cql = "INSERT INTO test_ks.collection_table \ (key, map, list, cset) VALUES (?, ?, ?, ?)"; session .query_with_values(insert_struct_cql, row.into_query_values()) .await .expect("insert"); } async fn select_struct(session: &mut CurrentSession) { let select_struct_cql = "SELECT * FROM test_ks.collection_table"; let rows = session .query(select_struct_cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); for row in rows { let my_row: RowStruct = RowStruct::try_from_row(row).expect("into RowStruct"); println!("struct got: {my_row:#?}"); } } async fn delete_struct(session: &mut CurrentSession) { let delete_struct_cql = "DELETE FROM test_ks.collection_table WHERE key = ?"; let user_key = 3i32; session .query_with_values(delete_struct_cql, query_values!(user_key)) .await .expect("delete"); } ================================================ FILE: cdrs-tokio/examples/multiple_thread.rs ================================================ use cdrs_tokio::authenticators::NoneAuthenticatorProvider; use cdrs_tokio::cluster::session::{Session, SessionBuilder, TcpSessionBuilder}; use cdrs_tokio::cluster::{NodeTcpConfigBuilder, TcpConnectionManager}; use cdrs_tokio::frame::TryFromRow; use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; use cdrs_tokio::query::*; use cdrs_tokio::query_values; use cdrs_tokio::transport::TransportTcp; use cdrs_tokio::{IntoCdrsValue, TryFromRow, TryFromUdt}; use std::sync::Arc; type CurrentSession = Session< TransportTcp, TcpConnectionManager, RoundRobinLoadBalancingStrategy, >; #[tokio::main] async fn main() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let lb = RoundRobinLoadBalancingStrategy::new(); let session: Arc = Arc::new( TcpSessionBuilder::new(lb, cluster_config) .build() .await .unwrap(), ); create_keyspace(session.clone()).await; create_table(session.clone()).await; let futures: Vec> = (0..20) .map(|i| { let thread_session = session.clone(); tokio::spawn(insert_struct(thread_session, i)) }) .collect(); let _responses = futures::future::join_all(futures); select_struct(session).await; } #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { key: i32, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("key" => self.key) } } #[derive(Debug, Clone, PartialEq, IntoCdrsValue, TryFromUdt)] struct User { username: String, } async fn create_keyspace(session: Arc) { let create_ks: &'static str = "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; session .query(create_ks) .await .expect("Keyspace creation error"); } async fn create_table(session: Arc) { let create_table_cql = "CREATE TABLE IF NOT EXISTS test_ks.multi_thread_table (key int PRIMARY KEY);"; session .query(create_table_cql) .await .expect("Table creation error"); } async fn insert_struct(session: Arc, key: i32) { let row = RowStruct { key }; let insert_struct_cql = "INSERT INTO test_ks.multi_thread_table (key) VALUES (?)"; session .query_with_values(insert_struct_cql, row.into_query_values()) .await .expect("insert"); } async fn select_struct(session: Arc) { let select_struct_cql = "SELECT * FROM test_ks.multi_thread_table"; let rows = session .query(select_struct_cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); for row in rows { let my_row: RowStruct = RowStruct::try_from_row(row).expect("into RowStruct"); println!("struct got: {my_row:?}"); } } ================================================ FILE: cdrs-tokio/examples/paged_query.rs ================================================ use cdrs_tokio::authenticators::NoneAuthenticatorProvider; use cdrs_tokio::cluster::session::{Session, SessionBuilder, TcpSessionBuilder}; use cdrs_tokio::cluster::{NodeTcpConfigBuilder, PagerState, TcpConnectionManager}; use cdrs_tokio::frame::TryFromRow; use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; use cdrs_tokio::query::*; use cdrs_tokio::query_values; use cdrs_tokio::transport::TransportTcp; use cdrs_tokio::{IntoCdrsValue, TryFromRow}; use std::sync::Arc; type CurrentSession = Session< TransportTcp, TcpConnectionManager, RoundRobinLoadBalancingStrategy, >; #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { key: i32, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("key" => self.key) } } #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct AnotherTestTable { a: i32, b: i32, c: i32, d: i32, e: i32, } impl AnotherTestTable { fn into_query_values(self) -> QueryValues { query_values!("a" => self.a, "b" => self.b, "c" => self.c, "d" => self.d, "e" => self.e) } } #[tokio::main] async fn main() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let lb = RoundRobinLoadBalancingStrategy::new(); let session = TcpSessionBuilder::new(lb, cluster_config) .build() .await .unwrap(); create_keyspace(&session).await; create_udt(&session).await; create_table(&session).await; fill_table(&session).await; println!("Internal pager state\n"); paged_selection_query(&session).await; println!("\n\nExternal pager state for stateless executions\n"); paged_selection_query_with_state(&session, PagerState::new()).await; println!("\n\nPager with query values (list)\n"); paged_with_values_list(&session).await; println!("\n\nPager with query value (no list)\n"); paged_with_value(&session).await; println!("\n\nFinished paged query tests\n"); } async fn create_keyspace(session: &CurrentSession) { let create_ks: &'static str = "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; session .query(create_ks) .await .expect("Keyspace creation error"); } async fn create_udt(session: &CurrentSession) { let create_type_cql = "CREATE TYPE IF NOT EXISTS test_ks.user (username text)"; session .query(create_type_cql) .await .expect("Keyspace creation error"); } async fn create_table(session: &CurrentSession) { let create_table_cql = "CREATE TABLE IF NOT EXISTS test_ks.my_test_table (key int PRIMARY KEY, \ user test_ks.user, map map>, list list>);"; session .query(create_table_cql) .await .expect("Table creation error"); } async fn fill_table(session: &CurrentSession) { let insert_struct_cql = "INSERT INTO test_ks.my_test_table (key) VALUES (?)"; for k in 100..110 { let row = RowStruct { key: k }; session .query_with_values(insert_struct_cql, row.into_query_values()) .await .expect("insert"); } } async fn paged_selection_query(session: &CurrentSession) { let q = "SELECT * FROM test_ks.my_test_table;"; let mut pager = session.paged(2); let mut query_pager = pager.query(q); loop { let rows = query_pager.next().await.expect("pager next"); for row in rows { let my_row = RowStruct::try_from_row(row).expect("decode row"); println!("row - {my_row:?}"); } if !query_pager.has_more() { break; } } } async fn paged_with_value(session: &CurrentSession) { let create_table_cql = "CREATE TABLE IF NOT EXISTS test_ks.another_test_table (a int, b int, c int, d int, e int, primary key((a, b), c, d));"; session .query(create_table_cql) .await .expect("Table creation error"); for v in 1..=10 { session .query_with_values( "INSERT INTO test_ks.another_test_table (a, b, c, d, e) VALUES (?, ?, ?, ?, ?)", AnotherTestTable { a: 1, b: 1, c: 2, d: v, e: v, } .into_query_values(), ) .await .unwrap(); } let q = "SELECT * FROM test_ks.another_test_table where a = ? and b = 1 and c = ?"; let mut pager = session.paged(3); let mut query_pager = pager.query_with_params( q, QueryParamsBuilder::new() .with_values(query_values!(1, 2)) .build(), ); // Oddly enough, this returns false the first time... assert!(!query_pager.has_more()); let rows = query_pager.next().await.expect("pager next"); assert_eq!(3, rows.len()); let rows = query_pager.next().await.expect("pager next"); assert_eq!(3, rows.len()); let rows = query_pager.next().await.expect("pager next"); assert_eq!(3, rows.len()); let rows = query_pager.next().await.expect("pager next"); assert_eq!(1, rows.len()); assert!(!query_pager.has_more()); } async fn paged_with_values_list(session: &CurrentSession) { let q = "SELECT * FROM test_ks.my_test_table where key in ?"; let mut pager = session.paged(2); let mut query_pager = pager.query_with_params( q, QueryParamsBuilder::new() .with_values(query_values!(vec![100, 101, 102, 103, 104])) .build(), ); // Macro instead of a function or closure, since problem with lifetimes macro_rules! assert_amount_query_pager { ($row_amount: expr) => {{ let rows = query_pager.next().await.expect("pager next"); assert_eq!($row_amount, rows.len()); }}; } println!("Testing values 100 and 101"); assert_amount_query_pager!(2); assert!(query_pager.has_more()); assert!(!query_pager .pager_state() .cursor() .unwrap() .is_null_or_empty()); println!("Testing values 102 and 103"); assert_amount_query_pager!(2); assert!(query_pager.has_more()); assert!(!query_pager .pager_state() .cursor() .unwrap() .is_null_or_empty()); println!("Testing value 104"); assert_amount_query_pager!(1); // Now no more rows should be queried println!("Testing no more values are present"); assert!(!query_pager.has_more()); assert!(query_pager.pager_state().cursor().is_none()); } async fn paged_selection_query_with_state(session: &CurrentSession, state: PagerState) { let mut st = state; loop { let q = "SELECT * FROM test_ks.my_test_table;"; let mut pager = session.paged(2); let mut query_pager = pager.query_with_pager_state(q, st); let rows = query_pager.next().await.expect("pager next"); for row in rows { let my_row = RowStruct::try_from_row(row).expect("decode row"); println!("row - {my_row:?}"); } if !query_pager.has_more() { break; } st = query_pager.pager_state(); } } ================================================ FILE: cdrs-tokio/examples/prepare_batch_execute.rs ================================================ use cdrs_tokio::authenticators::NoneAuthenticatorProvider; use cdrs_tokio::cluster::session::{Session, SessionBuilder, TcpSessionBuilder}; use cdrs_tokio::cluster::{NodeTcpConfigBuilder, TcpConnectionManager}; use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; use cdrs_tokio::query::*; use cdrs_tokio::query_values; use cdrs_tokio::transport::TransportTcp; use cdrs_tokio::{IntoCdrsValue, TryFromRow}; use std::sync::Arc; type CurrentSession = Session< TransportTcp, TcpConnectionManager, RoundRobinLoadBalancingStrategy, >; #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { key: i32, } impl RowStruct { fn into_query_values(self) -> QueryValues { // **IMPORTANT NOTE:** query values should be WITHOUT NAMES // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L413 query_values!(self.key) } } #[tokio::main] async fn main() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let lb = RoundRobinLoadBalancingStrategy::new(); let mut session = TcpSessionBuilder::new(lb, cluster_config) .build() .await .unwrap(); create_keyspace(&mut session).await; create_table(&mut session).await; let insert_struct_cql = "INSERT INTO test_ks.my_test_table (key) VALUES (?)"; let prepared_query = session .prepare(insert_struct_cql) .await .expect("Prepare query error"); for k in 100..110 { let row = RowStruct { key: k }; insert_row(&mut session, row, &prepared_query).await; } batch_few_queries(&mut session, insert_struct_cql).await; } async fn create_keyspace(session: &mut CurrentSession) { let create_ks: &'static str = "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; session .query(create_ks) .await .expect("Keyspace creation error"); } async fn create_table(session: &mut CurrentSession) { let create_table_cql = "CREATE TABLE IF NOT EXISTS test_ks.my_test_table (key int PRIMARY KEY);"; session .query(create_table_cql) .await .expect("Table creation error"); } async fn insert_row(session: &mut CurrentSession, row: RowStruct, prepared_query: &PreparedQuery) { session .exec_with_values(prepared_query, row.into_query_values()) .await .expect("exec_with_values error"); } async fn batch_few_queries(session: &mut CurrentSession, query: &str) { let prepared_query = session.prepare(query).await.expect("Prepare query error"); let row_1 = RowStruct { key: 1001 }; let row_2 = RowStruct { key: 2001 }; let batch = BatchQueryBuilder::new() .add_query_prepared(&prepared_query, row_1.into_query_values()) .add_query(query, row_2.into_query_values()) .build() .expect("batch builder"); session.batch(batch).await.expect("batch query error"); } ================================================ FILE: cdrs-tokio/src/cluster/cluster_metadata_manager.rs ================================================ use arc_swap::ArcSwap; use cassandra_protocol::error::{Error, Result}; use cassandra_protocol::events::{SchemaChange, ServerEvent}; use cassandra_protocol::frame::events::{ SchemaChangeOptions, SchemaChangeType, StatusChange, StatusChangeType, TopologyChange, TopologyChangeType, }; use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType}; use cassandra_protocol::frame::message_query::BodyReqQuery; use cassandra_protocol::frame::{Envelope, Flags, Version}; use cassandra_protocol::query::{QueryParams, QueryParamsBuilder, QueryValues}; use cassandra_protocol::types::list::List; use cassandra_protocol::types::rows::Row; use cassandra_protocol::types::{AsRustType, ByName, IntoRustByName}; use fxhash::FxHashMap; use itertools::Itertools; use rand::{rng, RngExt}; use serde_json::{Map, Value as JsonValue}; use std::convert::TryInto; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::sync::broadcast::error::RecvError; use tokio::sync::broadcast::Receiver; use tracing::*; use crate::cluster::connection_pool::ConnectionPoolFactory; use crate::cluster::metadata_builder::{add_new_node, build_initial_metadata, refresh_metadata}; use crate::cluster::topology::{KeyspaceMetadata, Node, NodeState, ReplicationStrategy}; use crate::cluster::Murmur3Token; use crate::cluster::{ClusterMetadata, ConnectionManager}; use crate::cluster::{NodeInfo, SessionContext}; use crate::load_balancing::node_distance_evaluator::NodeDistanceEvaluator; use crate::transport::CdrsTransport; fn find_in_peers( peers: &[Row], broadcast_rpc_address: SocketAddr, control_addr: SocketAddr, ) -> Result> { peers .iter() .find_map(|peer| { broadcast_rpc_address_from_row(peer, control_addr) .filter(|peer_address| { *peer_address == broadcast_rpc_address && is_peer_row_valid(peer) }) .map(|peer_address| build_node_info(peer, peer_address)) }) .transpose() } async fn send_query( query: &str, transport: &T, version: Version, beta_protocol: bool, ) -> Result>> { let query_params = QueryParamsBuilder::new().build(); send_query_with_params(query, query_params, transport, version, beta_protocol).await } async fn send_query_with_values>( query: &str, values: V, transport: &T, version: Version, beta_protocol: bool, ) -> Result>> { let query_params = QueryParamsBuilder::new().with_values(values.into()).build(); send_query_with_params(query, query_params, transport, version, beta_protocol).await } async fn send_query_with_params( query: &str, query_params: QueryParams, transport: &T, version: Version, beta_protocol: bool, ) -> Result>> { let query = BodyReqQuery { query: query.to_string(), query_params, }; let flags = if beta_protocol { Flags::BETA } else { Flags::empty() }; let envelope = Envelope::new_query(query, flags, version); transport .write_envelope(&envelope, false) .await .and_then(|envelope| envelope.response_body()) .map(|body| body.into_rows()) } fn build_node_info(row: &Row, broadcast_rpc_address: SocketAddr) -> Result { row.get_r_by_name("host_id").and_then(move |host_id| { let broadcast_address: Option = row .get_by_name("broadcast_address") .or_else(|_| row.get_by_name("peer"))?; let broadcast_address = if let Some(broadcast_address) = broadcast_address { let port: Option = if row.contains_column("broadcast_port") { // system.local for Cassandra >= 4.0 row.get_by_name("broadcast_port")? } else if row.contains_column("peer_port") { // system.peers_v2 row.get_by_name("peer_port")? } else { None }; port.map(|port| SocketAddr::new(broadcast_address, port as u16)) } else { None }; let datacenter = row.get_r_by_name("data_center")?; let rack = row.get_r_by_name("rack")?; let tokens: List = row.get_r_by_name("tokens")?; let tokens: Vec = tokens.as_r_type()?; Ok(NodeInfo::new( host_id, broadcast_rpc_address, broadcast_address, datacenter, tokens .into_iter() .map(|token| { token.try_into().unwrap_or_else(|_| { warn!(%broadcast_rpc_address, "Unsupported token type - using a dummy value."); Murmur3Token::new(rng().random()) }) }) .collect(), rack, )) }) } fn build_node_broadcast_rpc_address( row: &Row, broadcast_rpc_address: Option, control_addr: SocketAddr, ) -> SocketAddr { if row.contains_column("peer") { // this can only happen when a misconfigured local node thinks it's also a peer broadcast_rpc_address.unwrap_or(control_addr) } else { // Don't rely on system.local.rpc_address for the control node, because it mistakenly // reports the normal RPC address instead of the broadcast one (CASSANDRA-11181). We // already know the endpoint anyway since we've just used it to query. control_addr } } fn broadcast_rpc_address_from_row(row: &Row, control_addr: SocketAddr) -> Option { // in system.peers or system.local let rpc_address: Result> = row.by_name("rpc_address").or_else(|_| { // in system.peers_v2 (Cassandra >= 4.0) row.by_name("native_address") }); let rpc_address = match rpc_address { Ok(Some(rpc_address)) => rpc_address, Ok(None) => return None, Err(error) => { // this could only happen if system tables are corrupted, but handle gracefully warn!(%error, "Error getting rpc address."); return None; } }; // system.local for Cassandra >= 4.0 let rpc_port: i32 = row .get_by_name("rpc_port") .or_else(|_| { // system.peers_v2 row.get_by_name("native_port") }) // use the default port if no port information was found in the row .map(|port| port.unwrap_or_else(|| control_addr.port() as i32)) .unwrap_or_else(|_| control_addr.port() as i32); let rpc_address = SocketAddr::new(rpc_address, rpc_port as u16); // if the peer is actually the control node, ignore that peer as it is likely a // misconfiguration problem if rpc_address == control_addr && row.contains_column("peer") { warn!( node = %rpc_address, control = %control_addr, "Control node has itself as a peer, thus will be ignored. This is likely due to a \ misconfiguration; please verify your rpc_address configuration in cassandra.yaml \ on all nodes in your cluster." ); None } else { Some(rpc_address) } } fn is_peer_row_valid(row: &Row) -> bool { let has_peers_rpc_address = !row.is_empty_by_name("rpc_address"); let has_peers_v_2_rpc_address = !row.is_empty_by_name("native_address") && !row.is_empty_by_name("native_port"); let has_rpc_address = has_peers_rpc_address || has_peers_v_2_rpc_address; has_rpc_address && !row.is_empty_by_name("host_id") && !row.is_empty_by_name("data_center") && !row.is_empty_by_name("rack") && !row.is_empty_by_name("tokens") && !row.is_empty_by_name("schema_version") } async fn fetch_control_connection_info( control_transport: &T, control_addr: &SocketAddr, version: Version, beta_protocol: bool, ) -> Result { send_query( "SELECT * FROM system.local", control_transport, version, beta_protocol, ) .await? .and_then(|mut rows| rows.pop()) .ok_or_else(|| format!("Node {control_addr} failed to return info about itself!").into()) } fn build_keyspace(row: &Row) -> Result<(String, KeyspaceMetadata)> { let keyspace_name = row.get_r_by_name("keyspace_name")?; let replication: String = row.get_r_by_name("replication")?; let replication: JsonValue = serde_json::from_str(&replication).map_err(|error| { Error::General(format!( "Error parsing replication for {keyspace_name}: {error}" )) })?; let replication_strategy = match replication { JsonValue::Object(properties) => build_replication_strategy(properties)?, _ => { return Err(Error::InvalidReplicationFormat { keyspace: keyspace_name, }) } }; Ok((keyspace_name, KeyspaceMetadata::new(replication_strategy))) } fn build_replication_strategy( mut properties: Map, ) -> Result { match properties.remove("class") { Some(JsonValue::String(class)) => Ok(match class.as_str() { "org.apache.cassandra.locator.SimpleStrategy" | "SimpleStrategy" => { ReplicationStrategy::SimpleStrategy { replication_factor: extract_replication_factor( properties.get("replication_factor"), )?, } } "org.apache.cassandra.locator.NetworkTopologyStrategy" | "NetworkTopologyStrategy" => { ReplicationStrategy::NetworkTopologyStrategy { datacenter_replication_factor: extract_datacenter_replication_factor( properties, )?, } } _ => ReplicationStrategy::Other, }), _ => Err("Missing replication strategy class!".into()), } } fn extract_datacenter_replication_factor( properties: Map, ) -> Result> { properties .into_iter() .map(|(key, replication_factor)| { extract_replication_factor(Some(&replication_factor)) .map(move |replication_factor| (key, replication_factor)) }) .try_collect() } fn extract_replication_factor(value: Option<&JsonValue>) -> Result { match value { Some(JsonValue::String(replication_factor)) => { let result = if let Some(slash) = replication_factor.find('/') { usize::from_str(&replication_factor[..slash]) } else { usize::from_str(replication_factor) }; result.map_err(|error| { format!("Failed to parse ('{replication_factor}'): {error}").into() }) } _ => Err("Missing replication factor!".into()), } } pub(crate) struct ClusterMetadataManager< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, > { metadata: ArcSwap>, contact_points: Vec>>, connection_pool_factory: Arc>, did_initial_refresh: AtomicBool, is_schema_v2: AtomicBool, session_context: Arc>, node_distance_evaluator: Box, version: Version, beta_protocol: bool, } impl + 'static> ClusterMetadataManager { pub(crate) fn new( contact_points: Vec>>, connection_pool_factory: Arc>, session_context: Arc>, node_distance_evaluator: Box, version: Version, beta_protocol: bool, ) -> Self { ClusterMetadataManager { metadata: ArcSwap::from_pointee(ClusterMetadata::default()), contact_points, connection_pool_factory, did_initial_refresh: AtomicBool::new(false), is_schema_v2: AtomicBool::new(true), session_context, node_distance_evaluator, version, beta_protocol, } } pub(crate) fn listen_to_events(self: &Arc, mut event_receiver: Receiver) { let cmm = Arc::downgrade(self); tokio::spawn(async move { loop { let event = event_receiver.recv().await; match event { Ok(event) => { if let Some(cmm) = cmm.upgrade() { cmm.process_event(event).await; } else { break; } } Err(RecvError::Lagged(n)) => { warn!("Skipped {} events.", n); } Err(RecvError::Closed) => break, } } }); } async fn process_event(&self, event: ServerEvent) { debug!(?event); match event { ServerEvent::TopologyChange(event) => self.process_topology_event(event).await, ServerEvent::StatusChange(event) => self.process_status_event(event).await, ServerEvent::SchemaChange(event) => self.process_schema_event(event).await, _ => warn!(?event, "Unrecognized event."), } } async fn process_schema_event(&self, event: SchemaChange) { if let SchemaChangeOptions::Keyspace(keyspace) = &event.options { match event.change_type { SchemaChangeType::Created | SchemaChangeType::Updated => { self.refresh_keyspace(keyspace).await } SchemaChangeType::Dropped => { self.remove_keyspace(keyspace); } _ => warn!(?event, "Unrecognized schema event."), } } } async fn process_topology_event(&self, event: TopologyChange) { match event.change_type { TopologyChangeType::NewNode => { // For NewNode we need an async metadata fetch (system.peers // lookup) before we can decide what to install, so we have to // start by snapshotting the current cluster state. The // subsequent add_new_node call performs its own atomic swap. let metadata = self.metadata.load().clone(); if metadata.has_node_by_rpc_address(event.addr) { debug!( broadcast_rpc_address = %event.addr, "Trying to add already existing node - ignoring." ); } else { self.add_new_node(event.addr, NodeState::Unknown, metadata) .await; } } TopologyChangeType::RemovedNode => { // RemovedNode is a pure transform - use rcu so a concurrent // metadata update (e.g. another event landing at the same // time) isn't lost between load() and store(). debug!(broadcast_rpc_address = %event.addr, "Removing node from cluster (if present)."); self.metadata .rcu(|metadata| Arc::new(metadata.clone_without_node(event.addr))); } _ => warn!(?event, "Unrecognized topology change type."), } } async fn process_status_event(&self, event: StatusChange) { match event.change_type { StatusChangeType::Up => { // We need an async fallback (add_new_node) when the node is // unknown to us, so the existing-node update is split out and // performed via rcu to avoid losing concurrent metadata // changes between load and store. let metadata_snapshot = self.metadata.load().clone(); if metadata_snapshot .find_node_by_rpc_address(event.addr) .is_some() { self.metadata.rcu(|metadata| { // Re-check inside the closure: a concurrent update may // have already moved the node to Up, in which case // leave the metadata untouched and return the same // Arc so rcu skips the swap. match metadata.find_node_by_rpc_address(event.addr) { Some(node) if node.state() != NodeState::Up => { debug!(?node, "Setting existing node state to up."); let new_node = node.clone_with_node_state(NodeState::Up); Arc::new(metadata.clone_with_node(new_node)) } _ => metadata.clone(), } }); } else { self.add_new_node(event.addr, NodeState::Up, metadata_snapshot) .await; } } StatusChangeType::Down => { // Capture the connection-state warning outside the rcu closure // because is_any_connection_up is async and rcu's closure is // synchronous (and may run multiple times under contention). if let Some(node) = self.metadata.load().find_node_by_rpc_address(event.addr) { let state = node.state(); if state != NodeState::Down && state != NodeState::ForcedDown && node.is_any_connection_up().await { warn!( ?node, "Marking node as down while there are established connections." ); } } else { debug!(broadcast_rpc_address = %event.addr, "Unknown node down."); return; } // Now atomically transition the node to Down via rcu. The // closure re-checks the state because an interleaved update // could have already moved it. self.metadata.rcu( |metadata| match metadata.find_node_by_rpc_address(event.addr) { Some(node) if node.state() != NodeState::Down && node.state() != NodeState::ForcedDown => { debug!(?node, "Setting existing node state to down."); let new_node = node.clone_with_node_state(NodeState::Down); Arc::new(metadata.clone_with_node(new_node)) } _ => metadata.clone(), }, ); } _ => warn!(?event, "Unrecognized status event."), } } fn remove_keyspace(&self, keyspace: &str) { // Use rcu so that a concurrent metadata update (e.g. another schema // event arriving in parallel) cannot be silently overwritten between // load() and store(). The closure may run more than once if the // ArcSwap loses a CAS race. self.metadata .rcu(|metadata| Arc::new(metadata.clone_without_keyspace(keyspace))); } async fn refresh_keyspace(&self, keyspace: &str) { if let Err(error) = self.try_refresh_keyspace(keyspace).await { error!(?error, %keyspace, "Error refreshing keyspace!"); } } async fn try_refresh_keyspace(&self, keyspace: &str) -> Result<()> { debug!(%keyspace, "Refreshing keyspace."); let control_transport = self.control_transport()?; send_query_with_values( "SELECT keyspace_name, toJson(replication) AS replication FROM system_schema.keyspaces WHERE keyspace_name = ?", QueryValues::SimpleValues(vec![keyspace.into()]), control_transport.as_ref(), self.version, self.beta_protocol, ) .await .map(|rows| { rows.and_then(|mut rows| rows.pop()) }) .and_then(|row| { match row { Some(row) => { let (keyspace_name, keyspace) = build_keyspace(&row)?; // Use rcu so that a concurrent metadata mutation (e.g. // another schema event landing in parallel, or the // status/topology event handlers) cannot be silently // overwritten between load() and store(). The closure // may run more than once if the ArcSwap loses a CAS // race, so we clone the freshly-built keyspace data // each iteration. self.metadata.rcu(|metadata| { Arc::new( metadata .clone_with_keyspace(keyspace_name.clone(), keyspace.clone()), ) }); } None => { warn!(%keyspace, "Keyspace to refresh disappeared."); self.remove_keyspace(keyspace); } } Ok(()) }) } async fn add_new_node( &self, broadcast_rpc_address: SocketAddr, state: NodeState, metadata: Arc>, ) { debug!(%broadcast_rpc_address, %state, "Adding new node to metadata."); let new_node_info = self.find_new_node_info(broadcast_rpc_address).await; match new_node_info { Ok(Some(new_node_info)) => { self.metadata.store(Arc::new(add_new_node( new_node_info, metadata.as_ref(), &self.connection_pool_factory, state, ))); } Ok(None) => { warn!(%broadcast_rpc_address, "Cannot find new node info. Ignoring new node."); } Err(error) => { error!(%error, %broadcast_rpc_address, "Error finding new node info!"); } } } async fn find_new_node_info( &self, broadcast_rpc_address: SocketAddr, ) -> Result> { debug!(%broadcast_rpc_address, "Fetching info about a new node."); let control_transport = self.control_transport()?; let control_addr = control_transport.address(); // in the awkward case we have the control connection node up, it won't be in peers if broadcast_rpc_address == control_addr { let local_info = fetch_control_connection_info( control_transport.as_ref(), &control_addr, self.version, self.beta_protocol, ) .await?; return build_node_info(&local_info, broadcast_rpc_address).map(Some); } send_query( &format!("SELECT * FROM {}", self.peer_table_name()), control_transport.as_ref(), self.version, self.beta_protocol, ) .await .map(|peers| { peers.and_then(|peers| { find_in_peers(&peers, broadcast_rpc_address, control_addr).transpose() }) })? .transpose() } #[inline] fn control_transport(&self) -> Result> { self.session_context .control_connection_transport .load() .clone() .ok_or_else(|| "Cannot fetch information without a control connection!".into()) } #[inline] fn peer_table_name(&self) -> &'static str { if self.is_schema_v2.load(Ordering::Relaxed) { "system.peers_v2" } else { "system.peers" } } #[inline] pub(crate) fn metadata(&self) -> Arc> { self.metadata.load().clone() } #[inline] pub(crate) fn find_node_by_rpc_address( &self, broadcast_rpc_address: SocketAddr, ) -> Option>> { self.metadata .load() .find_node_by_rpc_address(broadcast_rpc_address) } // Refreshes stored metadata. Note: it is expected to be called by the control connection. pub(crate) async fn refresh_metadata(&self, full_refresh: bool) -> Result<()> { let (node_infos, keyspaces) = tokio::try_join!(self.refresh_node_infos(), self.refresh_keyspaces())?; if self .did_initial_refresh .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { self.metadata.store(Arc::new(build_initial_metadata( node_infos, keyspaces, &self.contact_points, &self.connection_pool_factory, self.node_distance_evaluator.as_ref(), ))); } else { self.metadata.rcu(move |old_metadata| { if full_refresh { build_initial_metadata( node_infos.clone(), keyspaces.clone(), &self.contact_points, &self.connection_pool_factory, self.node_distance_evaluator.as_ref(), ) } else { refresh_metadata( &node_infos, old_metadata.as_ref(), &self.connection_pool_factory, self.node_distance_evaluator.as_ref(), ) } }); }; Ok(()) } async fn refresh_keyspaces(&self) -> Result> { let control_transport = self.control_transport()?; send_query( "SELECT keyspace_name, toJson(replication) AS replication FROM system_schema.keyspaces", control_transport.as_ref(), self.version, self.beta_protocol, ) .await .and_then(|rows| { rows.map(|rows| rows.iter().map(build_keyspace).try_collect()) .transpose() }) .map(|keyspaces| keyspaces.unwrap_or_default()) } async fn refresh_node_infos(&self) -> Result> { let control_transport = self.control_transport()?; let control_addr = control_transport.address(); let local = fetch_control_connection_info( control_transport.as_ref(), &control_addr, self.version, self.beta_protocol, ) .await?; if !is_peer_row_valid(&local) { return Err("Invalid local row info!".into()); } let local_broadcast_rpc_address = broadcast_rpc_address_from_row(&local, control_addr); let local_broadcast_rpc_address = build_node_broadcast_rpc_address(&local, local_broadcast_rpc_address, control_addr); let mut node_infos = vec![build_node_info(&local, local_broadcast_rpc_address)?]; let peers = self.query_peers(control_transport.as_ref()).await?; if let Some(peers) = peers { node_infos.reserve(peers.len()); node_infos = peers .iter() .filter_map(|row| { if !is_peer_row_valid(row) { return None; } broadcast_rpc_address_from_row(row, control_addr) .map(|broadcast_rpc_address| build_node_info(row, broadcast_rpc_address)) }) .fold_ok(node_infos, |mut node_infos, node_info| { node_infos.push(node_info); node_infos })?; } Ok(node_infos) } async fn query_peers(&self, transport: &T) -> Result>> { if !self.is_schema_v2.load(Ordering::Relaxed) { // we've already checked for v2 before, so proceed with legacy peers return self.query_legacy_peers(transport).await; } let peers_v2_result = send_query( "SELECT * FROM system.peers_v2", transport, self.version, self.beta_protocol, ) .await; match peers_v2_result { Ok(result) => Ok(result), // peers_v2 does not exist Err(Error::Server { body: ErrorBody { ty: ErrorType::Invalid, .. }, .. }) => { self.is_schema_v2.store(false, Ordering::Relaxed); self.query_legacy_peers(transport).await } Err(Error::Server { body: ErrorBody { ty: ErrorType::Server, ref message, }, .. }) if message.contains("Unknown keyspace/cf pair (system.peers_v2)") => { self.is_schema_v2.store(false, Ordering::Relaxed); self.query_legacy_peers(transport).await } Err(error) => Err(error), } } #[inline] async fn query_legacy_peers(&self, transport: &T) -> Result>> { send_query( "SELECT * FROM system.peers", transport, self.version, self.beta_protocol, ) .await } } ================================================ FILE: cdrs-tokio/src/cluster/config_proxy.rs ================================================ #[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] pub(crate) struct HttpBasicAuth { pub(crate) username: String, pub(crate) password: String, } #[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] pub struct HttpProxyConfig { pub(crate) address: String, pub(crate) basic_auth: Option, } #[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] pub struct HttpProxyConfigBuilder { address: String, basic_auth: Option, } impl HttpProxyConfigBuilder { /// Creates a new proxy configuration builder with given proxy address. pub fn new(address: String) -> Self { Self { address, basic_auth: None, } } /// Adds HTTP basic Auth. pub fn with_basic_auth(mut self, username: String, password: String) -> Self { self.basic_auth = Some(HttpBasicAuth { password, username }); self } /// Build the resulting configuration. pub fn build(self) -> HttpProxyConfig { HttpProxyConfig { basic_auth: self.basic_auth, address: self.address, } } } ================================================ FILE: cdrs-tokio/src/cluster/config_rustls.rs ================================================ use cassandra_protocol::authenticators::{NoneAuthenticatorProvider, SaslAuthenticatorProvider}; use cassandra_protocol::error::Result; use cassandra_protocol::frame::Version; use derivative::Derivative; use std::net::SocketAddr; use std::sync::Arc; use tokio_rustls::rustls::{pki_types::ServerName, ClientConfig}; #[cfg(feature = "http-proxy")] use crate::cluster::HttpProxyConfig; use crate::cluster::NodeAddress; /// Single node TLS connection config. See [NodeRustlsConfigBuilder]. #[derive(Derivative, Clone)] #[derivative(Debug)] pub struct NodeRustlsConfig { pub(crate) contact_points: Vec, pub(crate) dns_name: ServerName<'static>, #[derivative(Debug = "ignore")] pub(crate) authenticator_provider: Arc, pub(crate) config: Arc, pub(crate) version: Version, pub(crate) beta_protocol: bool, #[cfg(feature = "http-proxy")] pub(crate) http_proxy: Option, } /// Builder structure that helps to configure TLS connection for node. #[derive(Derivative, Clone)] #[derivative(Debug)] pub struct NodeRustlsConfigBuilder { addrs: Vec, dns_name: ServerName<'static>, #[derivative(Debug = "ignore")] authenticator_provider: Arc, config: Arc, version: Version, beta_protocol: bool, #[cfg(feature = "http-proxy")] http_proxy: Option, } impl NodeRustlsConfigBuilder { pub fn new(dns_name: ServerName<'static>, config: Arc) -> Self { NodeRustlsConfigBuilder { addrs: vec![], dns_name, authenticator_provider: Arc::new(NoneAuthenticatorProvider), config, version: Version::V4, beta_protocol: false, #[cfg(feature = "http-proxy")] http_proxy: None, } } /// Sets new authenticator. #[must_use] pub fn with_authenticator_provider( mut self, authenticator_provider: Arc, ) -> Self { self.authenticator_provider = authenticator_provider; self } /// Adds initial node address (a contact point). Contact points are considered local to the /// driver until a topology refresh occurs. #[must_use] pub fn with_contact_point(mut self, addr: NodeAddress) -> Self { self.addrs.push(addr); self } /// Adds initial node addresses #[must_use] pub fn with_contact_points(mut self, addr: Vec) -> Self { self.addrs.extend(addr); self } /// Set cassandra protocol version #[must_use] pub fn with_version(mut self, version: Version) -> Self { self.version = version; self } /// Sets beta protocol usage flag #[must_use] pub fn with_beta_protocol(mut self, beta_protocol: bool) -> Self { self.beta_protocol = beta_protocol; self } /// Adds HTTP proxy configuration #[cfg(feature = "http-proxy")] #[must_use] pub fn with_http_proxy(mut self, config: HttpProxyConfig) -> Self { self.http_proxy = Some(config); self } /// Finalizes building process pub async fn build(self) -> Result { // replace with map() when async lambdas become available let mut contact_points = Vec::with_capacity(self.addrs.len()); for contact_point in self.addrs { contact_points.append(&mut contact_point.resolve_address().await?); } Ok(NodeRustlsConfig { contact_points, dns_name: self.dns_name, authenticator_provider: self.authenticator_provider, config: self.config, version: self.version, beta_protocol: self.beta_protocol, #[cfg(feature = "http-proxy")] http_proxy: self.http_proxy, }) } } ================================================ FILE: cdrs-tokio/src/cluster/config_tcp.rs ================================================ use cassandra_protocol::authenticators::{NoneAuthenticatorProvider, SaslAuthenticatorProvider}; use cassandra_protocol::error::Result; use cassandra_protocol::frame::Version; use derivative::Derivative; use std::net::SocketAddr; use std::sync::Arc; #[cfg(feature = "http-proxy")] use crate::cluster::HttpProxyConfig; use crate::cluster::NodeAddress; /// Single node TCP connection config. See [NodeTcpConfigBuilder]. #[derive(Derivative, Clone)] #[derivative(Debug)] pub struct NodeTcpConfig { pub(crate) contact_points: Vec, #[derivative(Debug = "ignore")] pub(crate) authenticator_provider: Arc, pub(crate) version: Version, pub(crate) beta_protocol: bool, #[cfg(feature = "http-proxy")] pub(crate) http_proxy: Option, } /// Builder structure that helps to configure TCP connection for node. #[derive(Derivative, Clone)] #[derivative(Debug)] pub struct NodeTcpConfigBuilder { addrs: Vec, #[derivative(Debug = "ignore")] authenticator_provider: Arc, version: Version, beta_protocol: bool, #[cfg(feature = "http-proxy")] http_proxy: Option, } impl Default for NodeTcpConfigBuilder { fn default() -> Self { NodeTcpConfigBuilder { addrs: vec![], authenticator_provider: Arc::new(NoneAuthenticatorProvider), version: Version::V4, beta_protocol: false, #[cfg(feature = "http-proxy")] http_proxy: None, } } } impl NodeTcpConfigBuilder { pub fn new() -> NodeTcpConfigBuilder { Default::default() } /// Sets new authenticator. #[must_use] pub fn with_authenticator_provider( mut self, authenticator_provider: Arc, ) -> Self { self.authenticator_provider = authenticator_provider; self } /// Adds initial node address (a contact point). Contact points are considered local to the /// driver until a topology refresh occurs. #[must_use] pub fn with_contact_point(mut self, addr: NodeAddress) -> Self { self.addrs.push(addr); self } /// Adds initial node addresses #[must_use] pub fn with_contact_points(mut self, addr: Vec) -> Self { self.addrs.extend(addr); self } /// Set cassandra protocol version #[must_use] pub fn with_version(mut self, version: Version) -> Self { self.version = version; self } /// Sets beta protocol usage flag #[must_use] pub fn with_beta_protocol(mut self, beta_protocol: bool) -> Self { self.beta_protocol = beta_protocol; self } /// Adds HTTP proxy configuration #[cfg(feature = "http-proxy")] #[must_use] pub fn with_http_proxy(mut self, config: HttpProxyConfig) -> Self { self.http_proxy = Some(config); self } /// Finalizes building process pub async fn build(self) -> Result { // replace with map() when async lambdas become available let mut contact_points = Vec::with_capacity(self.addrs.len()); for contact_point in self.addrs { contact_points.append(&mut contact_point.resolve_address().await?); } Ok(NodeTcpConfig { contact_points, authenticator_provider: self.authenticator_provider, version: self.version, beta_protocol: self.beta_protocol, #[cfg(feature = "http-proxy")] http_proxy: self.http_proxy, }) } } ================================================ FILE: cdrs-tokio/src/cluster/connection_manager.rs ================================================ use std::io; use std::net::SocketAddr; use tokio::sync::mpsc::Sender; #[cfg(test)] use mockall::*; use crate::cluster::KeyspaceHolder; use crate::future::BoxFuture; use crate::transport::CdrsTransport; use cassandra_protocol::authenticators::SaslAuthenticatorProvider; use cassandra_protocol::compression::Compression; use cassandra_protocol::error::{Error, Result}; use cassandra_protocol::frame::message_response::ResponseBody; use cassandra_protocol::frame::{Envelope, Opcode, Version}; use cassandra_protocol::query::utils::quote; /// Manages establishing connections to nodes. pub trait ConnectionManager: Send + Sync { /// Tries to establish a new, ready-to-use connection with optional server event and error /// handlers. fn connection( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> BoxFuture<'_, Result>; } #[cfg(test)] mock! { pub ConnectionManager { } #[allow(dead_code)] impl ConnectionManager for ConnectionManager { fn connection<'a>( &'a self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> BoxFuture<'a, Result>; } } /// Establishes Cassandra connection with given authentication, last used keyspace and compression. pub async fn startup< T: CdrsTransport + 'static, A: SaslAuthenticatorProvider + Send + Sync + ?Sized + 'static, >( transport: &T, authenticator_provider: &A, keyspace_holder: &KeyspaceHolder, compression: Compression, version: Version, ) -> Result<()> { let startup_envelope = Envelope::new_req_startup(compression.as_str().map(String::from), version); let start_response = match transport.write_envelope(&startup_envelope, true).await { Ok(response) => Ok(response), Err(Error::Server { body, .. }) if body.is_bad_protocol() => { Err(Error::InvalidProtocol(transport.address())) } Err(error) => Err(error), }?; if start_response.opcode == Opcode::Ready { return set_keyspace(transport, keyspace_holder, version).await; } if start_response.opcode == Opcode::Authenticate { let body = start_response.response_body()?; let authenticator = body.authenticator() .ok_or_else(|| Error::General("Cassandra server did communicate that it needed authentication but the auth schema was missing in the body response".into()))?; // This creates a new scope; avoiding a clone // and we check whether // 1. any authenticators has been passed in by client and if not send error back // 2. authenticator is provided by the client and `auth_scheme` presented by // the server and client are same if not send error back // 3. if it falls through it means the preliminary conditions are true authenticator_provider .name() .ok_or_else(|| Error::General("No authenticator was provided".to_string())) .and_then(|auth| { if authenticator != auth { let io_err = io::Error::new( io::ErrorKind::NotFound, format!( "Unsupported type of authenticator. {authenticator:?} got, but {auth} is supported." ), ); return Err(Error::Io(io_err)); } Ok(()) })?; let authenticator = authenticator_provider.create_authenticator(); let response = authenticator.initial_response(); let mut envelope = transport .write_envelope(&Envelope::new_req_auth_response(response, version), false) .await?; loop { match envelope.response_body()? { ResponseBody::AuthChallenge(challenge) => { let response = authenticator.evaluate_challenge(challenge.data)?; envelope = transport .write_envelope(&Envelope::new_req_auth_response(response, version), false) .await?; } ResponseBody::AuthSuccess(success) => { authenticator.handle_success(success.data)?; break; } _ => return Err(Error::UnexpectedAuthResponse(envelope.opcode)), } } return set_keyspace(transport, keyspace_holder, version).await; } Err(Error::UnexpectedStartupResponse(start_response.opcode)) } async fn set_keyspace( transport: &T, keyspace_holder: &KeyspaceHolder, version: Version, ) -> Result<()> { if let Some(current_keyspace) = keyspace_holder.current_keyspace() { let use_envelope = Envelope::new_req_query( format!("USE {}", quote(current_keyspace.as_ref())), Default::default(), None, false, None, None, None, None, None, None, Default::default(), version, ); transport .write_envelope(&use_envelope, false) .await .map(|_| ()) } else { Ok(()) } } ================================================ FILE: cdrs-tokio/src/cluster/connection_pool.rs ================================================ use atomic::Atomic; use bytemuck::NoUninit; use cassandra_protocol::frame::{Envelope, Version}; use cassandra_protocol::query::utils::quote; use derive_more::Display; use futures::future::join_all; use itertools::Itertools; use std::marker::PhantomData; use std::net::SocketAddr; use std::ops::{Deref, DerefMut}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Weak}; use std::time::Duration; use tokio::sync::watch::Receiver; use tokio::sync::{mpsc, RwLock}; use tokio::time::{interval_at, sleep, Instant, Interval, MissedTickBehavior}; use tracing::*; use crate::cluster::topology::{Node, NodeDistance, NodeState}; use crate::cluster::ConnectionManager; use crate::error::{Error, Result as CdrsResult}; use crate::retry::{ReconnectionPolicy, ReconnectionSchedule}; use crate::transport::CdrsTransport; #[derive(Copy, Clone, PartialEq, Eq, Display, NoUninit)] #[repr(u8)] enum ReconnectionState { NotRunning, InProgress, Disabled, } async fn new_connection>( connection_manager: &CM, broadcast_rpc_address: SocketAddr, timeout: Option, error_handler: mpsc::Sender, ) -> CdrsResult { if let Some(timeout) = timeout { tokio::time::timeout( timeout, connection_manager.connection(None, Some(error_handler), broadcast_rpc_address), ) .await .map_err(|_| { Error::Timeout(format!( "Timeout waiting for connection to: {broadcast_rpc_address}" )) }) .and_then(|result| result) } else { connection_manager .connection(None, Some(error_handler), broadcast_rpc_address) .await } } /// Configuration for node connection pools. By default, the pool size depends on the number of /// cpu for local nodes and a fixed value for remote, and there is no timeout. If the distance to a /// given node is unknown, it is treated as remote. See [ConnectionPoolConfigBuilder]. #[derive(Clone, Copy, Debug)] pub struct ConnectionPoolConfig { local_size: usize, remote_size: usize, connect_timeout: Option, heartbeat_interval: Duration, } impl Default for ConnectionPoolConfig { fn default() -> Self { ConnectionPoolConfig { local_size: 1, remote_size: 1, connect_timeout: None, heartbeat_interval: Duration::from_secs(30), } } } /// A builder for [ConnectionPoolConfig]. #[derive(Default, Clone, Debug)] pub struct ConnectionPoolConfigBuilder { config: ConnectionPoolConfig, } impl ConnectionPoolConfigBuilder { pub fn new() -> Self { Default::default() } /// Sets local node pool size. #[must_use] pub fn with_local_size(mut self, local_size: usize) -> Self { self.config.local_size = local_size; self } /// Sets remote node pool size. #[must_use] pub fn with_remote_size(mut self, remote_size: usize) -> Self { self.config.remote_size = remote_size; self } /// Sets new connection timeout. #[must_use] pub fn with_connect_timeout(mut self, connect_timeout: Option) -> Self { self.config.connect_timeout = connect_timeout; self } /// Sets new heartbeat interval. #[must_use] pub fn with_heartbeat_interval(mut self, heartbeat_interval: Duration) -> Self { self.config.heartbeat_interval = heartbeat_interval; self } /// Build the resulting config. #[must_use] pub fn build(self) -> ConnectionPoolConfig { self.config } } pub(crate) struct ConnectionPoolFactory< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, > { config: ConnectionPoolConfig, version: Version, connection_manager: Arc, keyspace_receiver: Receiver>, reconnection_policy: Arc, _transport: PhantomData, } impl + 'static> ConnectionPoolFactory { pub(crate) fn new( config: ConnectionPoolConfig, version: Version, connection_manager: CM, keyspace_receiver: Receiver>, reconnection_policy: Arc, ) -> Self { ConnectionPoolFactory { config, version, connection_manager: Arc::new(connection_manager), keyspace_receiver, reconnection_policy, _transport: Default::default(), } } #[inline] pub(crate) fn connection_manager(&self) -> &CM { self.connection_manager.as_ref() } pub(crate) async fn create( &self, node_distance: NodeDistance, broadcast_rpc_address: SocketAddr, node: Weak>, ) -> CdrsResult>> { let (error_sender, error_receiver) = mpsc::channel(if node_distance == NodeDistance::Local { self.config.local_size } else { self.config.remote_size }); // Clone the keyspace receiver BEFORE opening any connections. The // receiver remembers the watch's current value at clone time as // "seen". If we cloned it after pool initialisation, an update that // landed in between would be visible at clone time and changed() // would never fire for it - leaving the freshly-opened connections // stuck on the previous keyspace. let mut keyspace_receiver = self.keyspace_receiver.clone(); let pool = Arc::new( ConnectionPool::new( &self.connection_manager, broadcast_rpc_address, node_distance, self.config, error_sender, ) .await?, ); let weak_pool = Arc::downgrade(&pool); Self::monitor_connections( error_receiver, weak_pool.clone(), node.clone(), self.reconnection_policy.clone(), ); Self::start_heartbeat( weak_pool.clone(), node, self.config.heartbeat_interval, self.version, ); let weak_pool_for_keyspace = weak_pool.clone(); let version = self.version; tokio::spawn(async move { while let Ok(()) = keyspace_receiver.changed().await { let keyspace = keyspace_receiver.borrow().clone(); // Try to upgrade the weak reference to a strong Arc let pool = match weak_pool_for_keyspace.upgrade() { Some(pool) => pool, None => { debug!("Pool dropped, exiting keyspace watcher task."); break; } }; if let Some(keyspace) = keyspace { let use_envelope = Arc::new(Envelope::new_req_query( format!("USE {}", quote(&keyspace)), Default::default(), None, false, None, None, None, None, None, None, Default::default(), version, )); let pool_guard = pool.pool.read().await; join_all(pool_guard.iter() .filter(|connection| !connection.is_broken()) .map(|connection| { let use_envelope = use_envelope.clone(); async move { if let Err(error) = connection.write_envelope(use_envelope.as_ref(), false).await { error!(%error, ?broadcast_rpc_address, "Error settings keyspace for connection!"); } } })).await; } } }); Ok(pool) } fn start_heartbeat( pool: Weak>, node: Weak>, heartbeat_interval: Duration, version: Version, ) { let mut interval = create_heartbeat_interval(Instant::now(), heartbeat_interval); tokio::spawn(async move { loop { interval.tick().await; if let Some(node) = node.upgrade() { let broadcast_rpc_address = node.broadcast_address(); let state = node.state(); if state == NodeState::ForcedDown { debug!( ?broadcast_rpc_address, "Stopping heartbeat due to node being forced down." ); break; } if state == NodeState::Up { if let Some(pool) = pool.upgrade() { let envelope = Envelope::new_req_options(version); let pool = pool.pool.read().await; for connection in pool.deref() { if let Err(error) = connection.write_envelope(&envelope, false).await { warn!(?broadcast_rpc_address, %error, "Error waiting for heartbeat response - the connection will probably go down."); } } } else { debug!( ?broadcast_rpc_address, "Stopping heartbeat due to pool being gone." ); break; } } } else { break; } } debug!("Stopped heartbeat."); }); } fn monitor_connections( mut receiver: mpsc::Receiver, pool: Weak>, node: Weak>, reconnection_policy: Arc, ) { tokio::spawn(async move { let reconnection_state = Arc::new(Atomic::new(ReconnectionState::NotRunning)); while receiver.recv().await.is_some() { if let Some(node) = node.upgrade() { let broadcast_rpc_address = node.broadcast_address(); if node.state() == NodeState::ForcedDown { debug!( ?broadcast_rpc_address, "Not starting reconnection for a forced down node." ); break; } { // check if the node is down (no active connections) if let Some(pool) = pool.upgrade() { if Self::are_all_connections_down(pool.deref()).await { debug!( ?broadcast_rpc_address, "All connections broken - marking node as down." ); node.mark_down(); } } else { // the pool is gone - we're shutting down break; } } // when one connection goes down, all of them will most likely go down, so we need // to protect against many reconnection attempts let state = reconnection_state.load(Ordering::Relaxed); if state != ReconnectionState::NotRunning { if state == ReconnectionState::Disabled { break; } continue; } reconnection_state.store(ReconnectionState::InProgress, Ordering::Relaxed); warn!( ?broadcast_rpc_address, "Connection down. Starting reconnection." ); let reconnection_schedule = reconnection_policy.new_node_schedule(); let reconnecting = reconnection_state.clone(); let pool = pool.clone(); let node = Arc::downgrade(&node); tokio::spawn(async move { let new_state = Self::run_reconnection_loop(reconnection_schedule, pool.clone()).await; reconnecting.store(new_state, Ordering::Relaxed); debug!(?broadcast_rpc_address, %new_state, "Reconnection loop stopped."); if new_state == ReconnectionState::Disabled { if let Some(node) = node.upgrade() { warn!( ?broadcast_rpc_address, "Forcing node down, since no connection can be established." ); node.force_down(); } } else if new_state == ReconnectionState::NotRunning { if let Some(node) = node.upgrade() { debug!(?broadcast_rpc_address, "All connections reestablished."); node.mark_up(); } else { debug!( ?broadcast_rpc_address, "Node is discarded during reconnection." ); } } else if let Some(pool) = pool.upgrade() { if pool.is_any_connection_up().await { if let Some(node) = node.upgrade() { debug!( ?broadcast_rpc_address, "Marking node as up - some connections are established." ); node.mark_up(); } } } else if let Some(node) = node.upgrade() { debug!( ?broadcast_rpc_address, "Pool gone while in reconnection loop." ); node.force_down(); } }); } else { warn!("Node not found when trying to reconnect!"); break; }; } debug!("Pool monitoring stopped."); }); } async fn are_all_connections_down(pool: &ConnectionPool) -> bool { let connections = pool.pool.read().await; for connection in connections.deref() { if !connection.is_broken() { return false; } } true } async fn run_reconnection_loop( mut reconnection_schedule: Box, pool: Weak>, ) -> ReconnectionState { while let Some(delay) = reconnection_schedule.next_delay() { sleep(delay).await; let pool = match pool.upgrade() { None => return ReconnectionState::Disabled, // the pool might be gone Some(pool) => pool, }; match pool.reconnect_broken().await { Ok(all_reconnected) if all_reconnected => return ReconnectionState::NotRunning, Err(Error::InvalidProtocol(_)) => return ReconnectionState::Disabled, _ => {} } } // the policy doesn't want to reconnect to this node ReconnectionState::Disabled } } pub(crate) struct ConnectionPool> { connection_manager: Weak, broadcast_rpc_address: SocketAddr, config: ConnectionPoolConfig, pool: RwLock>>, desired_size: usize, current_index: AtomicUsize, error_sender: mpsc::Sender, } impl> ConnectionPool { async fn new( connection_manager: &Arc, broadcast_rpc_address: SocketAddr, node_distance: NodeDistance, config: ConnectionPoolConfig, error_sender: mpsc::Sender, ) -> CdrsResult { let desired_size = if node_distance == NodeDistance::Local { config.local_size } else { config.remote_size }; // initialize the pool let pool: Vec<_> = join_all((0..desired_size).map(|_| { new_connection( connection_manager.as_ref(), broadcast_rpc_address, config.connect_timeout, error_sender.clone(), ) })) .await .into_iter() .filter_map(|connection| match connection { Ok(connection) => Some(Ok(connection)), // propagate unrecoverable error Err(Error::InvalidProtocol(addr)) => Some(Err(Error::InvalidProtocol(addr))), // skip invalid connections which can be established later Err(_) => None, }) .map_ok(Arc::new) .try_collect()?; if pool.len() != desired_size { // some connections have failed, but can be brought back up, so trigger reconnection match error_sender.try_send(Error::General( "Not all pool connections could be established!".to_string(), )) { Ok(_) => debug!("Error handler notified!"), Err(e) => warn!("Error handler failed to notify: {e}"), } } Ok(ConnectionPool { connection_manager: Arc::downgrade(connection_manager), broadcast_rpc_address, config, pool: RwLock::new(pool), desired_size, current_index: AtomicUsize::new(0), error_sender, }) } pub(crate) async fn connection(&self) -> CdrsResult> { fn create_no_connections_error(broadcast_rpc_address: SocketAddr) -> Error { warn!(%broadcast_rpc_address, "All connections down to node."); Error::General(format!( "No active connections to: {}", broadcast_rpc_address )) } let pool = self.pool.read().await; let pool_len = pool.len(); if pool_len == 0 { return Err(create_no_connections_error(self.broadcast_rpc_address)); } let mut index = self.current_index.fetch_add(1, Ordering::Relaxed) % pool_len; let first_index = index; loop { let connection = &pool[index]; if !connection.is_broken() { return Ok(connection.clone()); } index = (index + 1) % pool_len; if index == first_index { // we've checked the whole pool and everything's down return Err(create_no_connections_error(self.broadcast_rpc_address)); } } } pub(crate) async fn is_any_connection_up(&self) -> bool { let connections = self.pool.read().await; for connection in connections.deref() { if !connection.is_broken() { return true; } } false } async fn reconnect_broken(&self) -> CdrsResult { if let Some(connection_manager) = self.connection_manager.upgrade() { let mut pool = self.pool.write().await; // 1. try to reconnect broken for connection in pool.deref_mut() { if connection.is_broken() { *connection = Arc::new( new_connection( connection_manager.as_ref(), self.broadcast_rpc_address, self.config.connect_timeout, self.error_sender.clone(), ) .await?, ); } } // 2. try to fill missing for _ in pool.len()..self.desired_size { pool.push(Arc::new( new_connection( connection_manager.as_ref(), self.broadcast_rpc_address, self.config.connect_timeout, self.error_sender.clone(), ) .await?, )); } // at this point either all connections are up, or some might have died in the meantime, // which will trigger a new reconnection Ok(true) } else { // connection manager is gone - we're probably dropping the session Ok(false) } } } /// Builds the [`Interval`] used by the heartbeat loop. The first tick fires /// `period` after `now`, then every `period`. We explicitly set the missed /// tick behavior to `Skip` so that if a heartbeat round takes longer than /// the configured period (e.g. a slow node, a stalled connection), the /// runtime does not pile up a burst of catch-up ticks the moment we return /// to `tick().await`. Without this, tokio's default behavior would keep /// firing immediately until it had "caught up", which on a healthy cluster /// just means a thundering herd of OPTIONS messages. fn create_heartbeat_interval(now: Instant, period: Duration) -> Interval { let mut interval = interval_at(now + period, period); interval.set_missed_tick_behavior(MissedTickBehavior::Skip); interval } #[cfg(test)] mod heartbeat_interval_tests { use super::*; #[tokio::test] async fn create_heartbeat_interval_skips_missed_ticks() { let interval = create_heartbeat_interval(Instant::now(), Duration::from_secs(30)); assert_eq!(interval.missed_tick_behavior(), MissedTickBehavior::Skip); } } ================================================ FILE: cdrs-tokio/src/cluster/control_connection.rs ================================================ use derive_more::Constructor; use std::sync::Arc; use std::time::Duration; use tokio::sync::broadcast::Sender; use tokio::sync::mpsc::{channel, Receiver}; use tokio::time::sleep; use tracing::*; use crate::cluster::topology::Node; use crate::cluster::{ClusterMetadataManager, ConnectionManager, SessionContext}; use crate::load_balancing::LoadBalancingStrategy; use crate::retry::{ReconnectionPolicy, ReconnectionSchedule}; use crate::transport::CdrsTransport; use cassandra_protocol::events::{ServerEvent, SimpleServerEvent}; use cassandra_protocol::frame::{Envelope, Version}; const DEFAULT_RECONNECT_DELAY: Duration = Duration::from_secs(10); const EVENT_CHANNEL_CAPACITY: usize = 32; #[derive(Constructor)] pub(crate) struct ControlConnection< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, LB: LoadBalancingStrategy + Send + Sync, > { load_balancing: Arc, contact_points: Vec>>, reconnection_policy: Arc, cluster_metadata_manager: Arc>, event_sender: Sender, session_context: Arc>, version: Version, } impl< T: CdrsTransport, CM: ConnectionManager, LB: LoadBalancingStrategy + Send + Sync, > ControlConnection { pub(crate) async fn run(self, init_complete_sender: tokio::sync::oneshot::Sender<()>) { let (event_envelope_sender, event_envelope_receiver) = channel(EVENT_CHANNEL_CAPACITY); let (error_sender, mut error_receiver) = channel(1); Self::process_events(event_envelope_receiver, self.event_sender.clone()); let mut init_complete_sender = Some(init_complete_sender); 'listen: loop { let current_connection = self .session_context .control_connection_transport .load() .clone(); if let Some(current_connection) = current_connection { let register_envelope = Envelope::new_req_register( vec![ SimpleServerEvent::SchemaChange, SimpleServerEvent::StatusChange, SimpleServerEvent::TopologyChange, ], self.version, ); // in case of error, simply reconnect let result = current_connection .write_envelope(®ister_envelope, false) .await; if let Some(sender) = init_complete_sender.take() { sender.send(()).ok(); } match result { Ok(_) => { let error = error_receiver.recv().await; match error { Some(error) => { // show info and try to reconnect warn!(%error, "Error in control connection! Trying to reconnect."); } None => { // shouldn't happen, since the connection is shared, but bail out // anyway break; } } } Err(error) => { error!(%error, "Error subscribing to events! Trying to reconnect."); } } self.session_context .control_connection_transport .store(None); } else { debug!("Establishing new control connection..."); let mut schedule = self.reconnection_policy.new_node_schedule(); loop { let mut full_refresh = false; let mut plan = self .load_balancing .query_plan(None, self.cluster_metadata_manager.metadata().as_ref()); if plan.nodes.is_empty() { warn!("No nodes found for control connection!"); Self::wait_for_reconnection(&mut schedule).await; // when the whole cluster goes down, there's nothing to update LB state, so // we're left with contact points plan.nodes.clone_from(&self.contact_points); // it means that we need to build metadata from scratch once we are // reconnected full_refresh = true; } for node in &plan.nodes { if let Ok(connection) = node .new_connection( Some(event_envelope_sender.clone()), Some(error_sender.clone()), ) .await { debug!("Established new control connection."); self.session_context .control_connection_transport .store(Some(Arc::new(connection))); if let Err(error) = self .cluster_metadata_manager .refresh_metadata(full_refresh) .await { error!(%error, "Error refreshing nodes! Trying to refresh control connection."); continue; } continue 'listen; } } // all nodes failed Self::wait_for_reconnection(&mut schedule).await; } } } } async fn wait_for_reconnection(schedule: &mut Box) { // as long as the session is alive, try establishing the control connection let delay = schedule.next_delay().unwrap_or(DEFAULT_RECONNECT_DELAY); sleep(delay).await; } fn process_events( mut event_envelope_receiver: Receiver, event_sender: Sender, ) { tokio::spawn(async move { while let Some(envelope) = event_envelope_receiver.recv().await { if let Ok(body) = envelope.response_body() { if let Some(event) = body.into_server_event() { let _ = event_sender.send(event.event); } } } }); } } ================================================ FILE: cdrs-tokio/src/cluster/keyspace_holder.rs ================================================ use std::sync::Arc; use arc_swap::ArcSwapOption; use tokio::sync::watch::Sender; /// Holds currently set global keyspace. #[derive(Debug)] pub struct KeyspaceHolder { current_keyspace: ArcSwapOption, keyspace_sender: Sender>, } impl KeyspaceHolder { pub fn new(keyspace_sender: Sender>) -> Self { KeyspaceHolder { current_keyspace: Default::default(), keyspace_sender, } } #[inline] pub fn current_keyspace(&self) -> Option> { self.current_keyspace.load().clone() } #[inline] pub fn update_current_keyspace(&self, keyspace: String) { let old_keyspace = self.current_keyspace.swap(Some(Arc::new(keyspace.clone()))); match &old_keyspace { None => { self.send_notification(keyspace); } Some(old_keyspace) if **old_keyspace != keyspace => { self.send_notification(keyspace); } _ => {} } } #[inline] pub fn update_current_keyspace_without_notification(&self, keyspace: String) { self.current_keyspace.store(Some(Arc::new(keyspace))); } #[inline] fn send_notification(&self, keyspace: String) { let _ = self.keyspace_sender.send(Some(keyspace)); } } ================================================ FILE: cdrs-tokio/src/cluster/metadata_builder.rs ================================================ use fxhash::{FxHashMap, FxHashSet}; use std::collections::hash_map::Entry; use std::sync::Arc; use tracing::*; use crate::cluster::connection_pool::ConnectionPoolFactory; use crate::cluster::topology::{KeyspaceMetadata, Node, NodeState}; use crate::cluster::{ClusterMetadata, ConnectionManager, NodeInfo}; use crate::load_balancing::node_distance_evaluator::NodeDistanceEvaluator; use crate::transport::CdrsTransport; pub(crate) fn build_initial_metadata>( node_infos: Vec, keyspaces: FxHashMap, contact_points: &[Arc>], connection_pool_factory: &Arc>, node_distance_evaluator: &(dyn NodeDistanceEvaluator + Send + Sync), ) -> ClusterMetadata { let mut nodes = FxHashMap::with_capacity_and_hasher(node_infos.len(), Default::default()); for node_info in node_infos { if let Entry::Vacant(entry) = nodes.entry(node_info.host_id) { let contact_point = contact_points.iter().find(|contact_point| { contact_point.broadcast_rpc_address() == node_info.broadcast_rpc_address }); let node = if let Some(contact_point) = contact_point { debug!(?node_info, "Copying contact point."); Arc::new(contact_point.clone_as_contact_point(node_info)) } else { debug!(?node_info, "Adding new node."); Arc::new(Node::new_with_state( connection_pool_factory.clone(), node_info.broadcast_rpc_address, node_info.broadcast_address, Some(node_info.host_id), node_distance_evaluator.compute_distance(&node_info), NodeState::Up, node_info.tokens.clone(), node_info.rack, node_info.datacenter, )) }; entry.insert(node); } else { warn!( host_id = %node_info.host_id, "Found duplicate peer entries - keeping only the first one." ); } } ClusterMetadata::new(nodes, keyspaces) } pub(crate) fn refresh_metadata>( node_infos: &[NodeInfo], old_metadata: &ClusterMetadata, connection_pool_factory: &Arc>, node_distance_evaluator: &dyn NodeDistanceEvaluator, ) -> ClusterMetadata { let old_nodes = old_metadata.nodes(); let mut seen_hosts = FxHashSet::default(); let mut added_or_updated = FxHashMap::default(); for node_info in node_infos { if seen_hosts.contains(&node_info.host_id) { warn!( host_id = %node_info.host_id, "Found duplicate peer entries - keeping only the first one." ); } else { seen_hosts.insert(node_info.host_id); let old_node = old_nodes.get(&node_info.host_id); if let Some(old_node) = old_node { debug!(?node_info, "Updating old node."); added_or_updated.insert( node_info.host_id, Arc::new(old_node.clone_with_node_info(node_info.clone())), ); } else { debug!(?node_info, "Adding new node."); let node = Arc::new(Node::new( connection_pool_factory.clone(), node_info.broadcast_rpc_address, node_info.broadcast_address, Some(node_info.host_id), node_distance_evaluator.compute_distance(node_info), node_info.tokens.clone(), node_info.rack.clone(), node_info.datacenter.clone(), )); added_or_updated.insert(node_info.host_id, node); } } } ClusterMetadata::new(added_or_updated, old_metadata.keyspaces().clone()) } pub(crate) fn add_new_node>( node_info: NodeInfo, old_metadata: &ClusterMetadata, connection_pool_factory: &Arc>, state: NodeState, ) -> ClusterMetadata { let old_node = old_metadata.find_node_by_host_id(&node_info.host_id); if let Some(old_node) = old_node { // If a node is restarted after changing its broadcast RPC address, Cassandra considers that // an addition, even though the host_id hasn't changed :( if old_node.broadcast_rpc_address() == node_info.broadcast_rpc_address { debug!(?old_node, "Ignoring adding an existing node."); return old_metadata.clone_with_node(old_node.clone_with_node_state(state)); } debug!(?old_node, "Updating old node with new info."); return old_metadata .clone_with_node(old_node.clone_with_node_info_and_state(node_info, state)); } old_metadata.clone_with_node(Node::with_state( connection_pool_factory.clone(), node_info.broadcast_rpc_address, node_info.broadcast_address, Some(node_info.host_id), state, node_info.tokens, node_info.rack, node_info.datacenter, )) } //noinspection DuplicatedCode #[cfg(test)] mod tests { use cassandra_protocol::frame::Version; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; use tokio::sync::watch; use uuid::Uuid; use crate::cluster::connection_manager::MockConnectionManager; use crate::cluster::connection_pool::ConnectionPoolFactory; use crate::cluster::metadata_builder::{ add_new_node, build_initial_metadata, refresh_metadata, }; use crate::cluster::topology::NodeMap; use crate::cluster::topology::{Node, NodeDistance, NodeState}; use crate::cluster::{ClusterMetadata, NodeInfo}; use crate::load_balancing::node_distance_evaluator::MockNodeDistanceEvaluator; use crate::retry::MockReconnectionPolicy; use crate::transport::MockCdrsTransport; fn create_connection_pool_factory( ) -> Arc>> { let (_, keyspace_receiver) = watch::channel(None); let connection_manager = MockConnectionManager::::new(); let reconnection_policy = MockReconnectionPolicy::new(); let connection_pool_factory = ConnectionPoolFactory::new( Default::default(), Version::V4, connection_manager, keyspace_receiver, Arc::new(reconnection_policy), ); Arc::new(connection_pool_factory) } #[test] fn should_create_initial_metadata_from_all_new_nodes() { let node_infos = vec![NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, "".into(), Default::default(), "".into(), )]; let connection_pool_factory = create_connection_pool_factory(); let mut node_distance_evaluator = MockNodeDistanceEvaluator::new(); node_distance_evaluator .expect_compute_distance() .return_const(None); let metadata = build_initial_metadata( node_infos.clone(), Default::default(), &[], &connection_pool_factory, &node_distance_evaluator, ); let nodes = metadata.nodes(); assert_eq!(nodes.len(), 1); assert_eq!( nodes .get(&node_infos[0].host_id) .unwrap() .broadcast_rpc_address(), node_infos[0].broadcast_rpc_address ); } #[test] fn should_copy_old_node() { let connection_pool_factory = create_connection_pool_factory(); let node_distance_evaluator = MockNodeDistanceEvaluator::new(); let node_infos = vec![NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, "".into(), Default::default(), "".into(), )]; let contact_points = [Arc::new( Node::new( connection_pool_factory.clone(), node_infos[0].broadcast_rpc_address, node_infos[0].broadcast_address, Some(node_infos[0].host_id), None, Default::default(), "".into(), "".into(), ) .clone_with_node_state(NodeState::Up), )]; let metadata = build_initial_metadata( node_infos.clone(), Default::default(), &contact_points, &connection_pool_factory, &node_distance_evaluator, ); let nodes = metadata.nodes(); assert_eq!(nodes.len(), 1); assert_eq!( nodes.get(&node_infos[0].host_id).unwrap().state(), NodeState::Up ); } #[test] fn should_replace_old_metadata_nodes_with_new() { let connection_pool_factory = create_connection_pool_factory(); let mut node_distance_evaluator = MockNodeDistanceEvaluator::new(); node_distance_evaluator .expect_compute_distance() .return_const(None); let node_infos = [NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, "".into(), Default::default(), "".into(), )]; let old_host_id = Uuid::new_v4(); let mut old_nodes = NodeMap::default(); old_nodes.insert( old_host_id, Arc::new(Node::new( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080), None, Some(old_host_id), None, Default::default(), "".into(), "".into(), )), ); let old_metadata = ClusterMetadata::new(old_nodes, Default::default()); let metadata = refresh_metadata( &node_infos, &old_metadata, &connection_pool_factory, &node_distance_evaluator, ); let nodes = metadata.nodes(); assert_eq!(nodes.len(), 1); assert_eq!( nodes .get(&node_infos[0].host_id) .unwrap() .broadcast_rpc_address(), node_infos[0].broadcast_rpc_address ); } #[test] fn should_update_old_metadata_nodes_with_new_info() { let connection_pool_factory = create_connection_pool_factory(); let node_distance_evaluator = MockNodeDistanceEvaluator::new(); let node_infos = [NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, "".into(), Default::default(), "".into(), )]; let mut old_nodes = NodeMap::default(); old_nodes.insert( node_infos[0].host_id, Arc::new(Node::new( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080), None, Some(node_infos[0].host_id), None, Default::default(), "".into(), "".into(), )), ); let old_metadata = ClusterMetadata::new(old_nodes, Default::default()); let metadata = refresh_metadata( &node_infos, &old_metadata, &connection_pool_factory, &node_distance_evaluator, ); let nodes = metadata.nodes(); assert_eq!(nodes.len(), 1); assert_eq!( nodes .get(&node_infos[0].host_id) .unwrap() .broadcast_rpc_address(), node_infos[0].broadcast_rpc_address ); } #[test] fn should_not_add_already_existing_node() { let connection_pool_factory = create_connection_pool_factory(); let node_info = NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, "".into(), Default::default(), "".into(), ); let old_node = Node::with_distance( connection_pool_factory.clone(), node_info.broadcast_rpc_address, None, Some(node_info.host_id), NodeDistance::Local, ); assert_eq!(old_node.state(), NodeState::Unknown); let mut old_nodes = NodeMap::default(); old_nodes.insert(node_info.host_id, Arc::new(old_node)); let old_metadata = ClusterMetadata::new(old_nodes, Default::default()); let metadata = add_new_node( node_info.clone(), &old_metadata, &connection_pool_factory, NodeState::Up, ); let nodes = metadata.nodes(); assert_eq!(nodes.len(), 1); assert_eq!( nodes .get(&node_info.host_id) .unwrap() .broadcast_rpc_address(), node_info.broadcast_rpc_address ); assert_eq!( nodes.get(&node_info.host_id).unwrap().state(), NodeState::Up ); assert_eq!( nodes.get(&node_info.host_id).unwrap().distance().unwrap(), NodeDistance::Local ); } #[test] fn should_update_existing_node() { let connection_pool_factory = create_connection_pool_factory(); let node_info = NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, "".into(), Default::default(), "".into(), ); let old_node = Node::with_distance( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080), None, Some(node_info.host_id), NodeDistance::Local, ); assert_eq!(old_node.state(), NodeState::Unknown); let mut old_nodes = NodeMap::default(); old_nodes.insert(node_info.host_id, Arc::new(old_node)); let old_metadata = ClusterMetadata::new(old_nodes, Default::default()); let metadata = add_new_node( node_info.clone(), &old_metadata, &connection_pool_factory, NodeState::Up, ); let nodes = metadata.nodes(); assert_eq!(nodes.len(), 1); assert_eq!( nodes .get(&node_info.host_id) .unwrap() .broadcast_rpc_address(), node_info.broadcast_rpc_address ); assert_eq!( nodes.get(&node_info.host_id).unwrap().state(), NodeState::Up ); assert!(nodes.get(&node_info.host_id).unwrap().distance().is_none()); } #[test] fn should_add_new_node() { let connection_pool_factory = create_connection_pool_factory(); let node_info = NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, "".into(), Default::default(), "".into(), ); let old_metadata = ClusterMetadata::new(Default::default(), Default::default()); let metadata = add_new_node( node_info.clone(), &old_metadata, &connection_pool_factory, NodeState::Up, ); let nodes = metadata.nodes(); assert_eq!(nodes.len(), 1); assert_eq!( nodes .get(&node_info.host_id) .unwrap() .broadcast_rpc_address(), node_info.broadcast_rpc_address ); assert_eq!( nodes.get(&node_info.host_id).unwrap().state(), NodeState::Up ); assert!(nodes.get(&node_info.host_id).unwrap().distance().is_none()); } } ================================================ FILE: cdrs-tokio/src/cluster/node_address.rs ================================================ use derive_more::Display; use std::net::SocketAddr; use tokio::net::lookup_host; use cassandra_protocol::error::Result; /// Representation of a node address. Can be a direct socket address or a hostname. In the latter /// case, the host can be resolved to multiple addresses, which could result in multiple node /// configurations. #[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display, Debug)] pub enum NodeAddress { Direct(SocketAddr), Hostname(String), } impl From for NodeAddress { fn from(addr: SocketAddr) -> Self { NodeAddress::Direct(addr) } } impl From for NodeAddress { fn from(value: String) -> Self { NodeAddress::Hostname(value) } } impl From<&String> for NodeAddress { fn from(value: &String) -> Self { NodeAddress::Hostname(value.clone()) } } impl From<&str> for NodeAddress { fn from(value: &str) -> Self { NodeAddress::Hostname(value.to_string()) } } impl NodeAddress { /// Resolves this address to socket addresses. pub async fn resolve_address(&self) -> Result> { match self { NodeAddress::Direct(addr) => Ok(vec![*addr]), NodeAddress::Hostname(hostname) => lookup_host(hostname) .await .map(|addrs| addrs.collect()) .map_err(Into::into), } } } ================================================ FILE: cdrs-tokio/src/cluster/node_info.rs ================================================ use derivative::Derivative; use derive_more::Constructor; use std::net::SocketAddr; use uuid::Uuid; use crate::cluster::Murmur3Token; /// Information about a node. #[derive(Constructor, Clone, Derivative)] #[derivative(Debug)] pub struct NodeInfo { pub host_id: Uuid, pub broadcast_rpc_address: SocketAddr, pub broadcast_address: Option, pub datacenter: String, #[derivative(Debug = "ignore")] pub tokens: Vec, pub rack: String, } ================================================ FILE: cdrs-tokio/src/cluster/pager.rs ================================================ use cassandra_protocol::consistency::Consistency; use cassandra_protocol::error; use cassandra_protocol::frame::message_result::RowsMetadataFlags; use cassandra_protocol::query::{PreparedQuery, QueryParams, QueryParamsBuilder, QueryValues}; use cassandra_protocol::types::rows::Row; use cassandra_protocol::types::CBytes; use crate::cluster::session::Session; use crate::cluster::ConnectionManager; use crate::load_balancing::LoadBalancingStrategy; use crate::statement::StatementParamsBuilder; use crate::transport::CdrsTransport; pub struct SessionPager< 'a, T: CdrsTransport + 'static, CM: ConnectionManager + 'static, LB: LoadBalancingStrategy + Send + Sync, > { page_size: i32, session: &'a Session, } impl< 'a, T: CdrsTransport + 'static, CM: ConnectionManager, LB: LoadBalancingStrategy + Send + Sync, > SessionPager<'a, T, CM, LB> { pub fn new(session: &'a Session, page_size: i32) -> SessionPager<'a, T, CM, LB> { SessionPager { session, page_size } } pub fn query_with_pager_state( &'a mut self, query: Q, state: PagerState, ) -> QueryPager<'a, Q, SessionPager<'a, T, CM, LB>> where Q: ToString, { self.query_with_pager_state_params(query, state, Default::default()) } pub fn query_with_pager_state_params( &'a mut self, query: Q, state: PagerState, qp: QueryParams, ) -> QueryPager<'a, Q, SessionPager<'a, T, CM, LB>> where Q: ToString, { QueryPager { pager: self, pager_state: state, query, qv: qp.values, consistency: qp.consistency, } } pub fn query(&'a mut self, query: Q) -> QueryPager<'a, Q, SessionPager<'a, T, CM, LB>> where Q: ToString, { self.query_with_params( query, QueryParamsBuilder::new() .with_consistency(Consistency::One) .build(), ) } pub fn query_with_params( &'a mut self, query: Q, qp: QueryParams, ) -> QueryPager<'a, Q, SessionPager<'a, T, CM, LB>> where Q: ToString, { self.query_with_pager_state_params(query, PagerState::new(), qp) } pub fn exec_with_pager_state( &'a mut self, query: &'a PreparedQuery, state: PagerState, ) -> ExecPager<'a, SessionPager<'a, T, CM, LB>> { ExecPager { pager: self, pager_state: state, query, } } pub fn exec( &'a mut self, query: &'a PreparedQuery, ) -> ExecPager<'a, SessionPager<'a, T, CM, LB>> { self.exec_with_pager_state(query, PagerState::new()) } } pub struct QueryPager<'a, Q: ToString, P: 'a> { pager: &'a mut P, pager_state: PagerState, query: Q, qv: Option, consistency: Consistency, } impl< 'a, Q: ToString, T: CdrsTransport + 'static, CM: ConnectionManager + Send + Sync + 'static, LB: LoadBalancingStrategy + Send + Sync + 'static, > QueryPager<'a, Q, SessionPager<'a, T, CM, LB>> { pub fn into_pager_state(self) -> PagerState { self.pager_state } pub async fn next(&mut self) -> error::Result> { let mut params = StatementParamsBuilder::new() .with_consistency(self.consistency) .with_page_size(self.pager.page_size); if let Some(qv) = &self.qv { params = params.with_values(qv.clone()); } if let Some(cursor) = &self.pager_state.cursor { params = params.with_paging_state(cursor.clone()); } let query = self.query.to_string(); let body = self .pager .session .query_with_params(query, params.build()) .await .and_then(|envelope| envelope.response_body())?; let metadata = body .as_rows_metadata() .ok_or("Pager query should yield a vector of rows")?; self.pager_state.has_more_pages = Some(metadata.flags.contains(RowsMetadataFlags::HAS_MORE_PAGES)); self.pager_state.cursor.clone_from(&metadata.paging_state); body.into_rows() .ok_or_else(|| "Pager query should yield a vector of rows".into()) } pub fn has_more(&self) -> bool { self.pager_state.has_more_pages.unwrap_or(false) } /// This method returns a copy of pager state so /// the state may be used later for continuing paging. pub fn pager_state(&self) -> PagerState { self.pager_state.clone() } } pub struct ExecPager<'a, P: 'a> { pager: &'a mut P, pager_state: PagerState, query: &'a PreparedQuery, } impl< 'a, T: CdrsTransport + 'static, CM: ConnectionManager + Send + Sync + 'static, LB: LoadBalancingStrategy + Send + Sync + 'static, > ExecPager<'a, SessionPager<'a, T, CM, LB>> { pub fn into_pager_state(self) -> PagerState { self.pager_state } pub async fn next(&mut self) -> error::Result> { let mut params = StatementParamsBuilder::new().with_page_size(self.pager.page_size); if let Some(cursor) = &self.pager_state.cursor { params = params.with_paging_state(cursor.clone()); } let body = self .pager .session .exec_with_params(self.query, ¶ms.build()) .await .and_then(|envelope| envelope.response_body())?; let metadata = body .as_rows_metadata() .ok_or("Pager query should yield a vector of rows")?; self.pager_state.has_more_pages = Some(metadata.flags.contains(RowsMetadataFlags::HAS_MORE_PAGES)); self.pager_state.cursor.clone_from(&metadata.paging_state); body.into_rows() .ok_or_else(|| "Pager query should yield a vector of rows".into()) } #[inline] pub fn has_more(&self) -> bool { self.pager_state.has_more_pages.unwrap_or(false) } /// This method returns a copy of pager state so /// the state may be used later for continuing paging. #[inline] pub fn pager_state(&self) -> PagerState { self.pager_state.clone() } } #[derive(Clone, PartialEq, Eq, Debug, Default)] pub struct PagerState { cursor: Option, has_more_pages: Option, } impl PagerState { pub fn new() -> Self { Default::default() } pub fn new_with_cursor(cursor: CBytes) -> Self { PagerState { cursor: Some(cursor), has_more_pages: None, } } pub fn new_with_cursor_and_more_flag(cursor: CBytes, has_more: bool) -> Self { PagerState { cursor: Some(cursor), has_more_pages: Some(has_more), } } #[inline] pub fn has_more(&self) -> bool { self.has_more_pages.unwrap_or(false) } #[inline] pub fn cursor(&self) -> Option { self.cursor.clone() } #[inline] pub fn into_cursor(self) -> Option { self.cursor } } ================================================ FILE: cdrs-tokio/src/cluster/rustls_connection_manager.rs ================================================ use crate::cluster::connection_manager::{startup, ConnectionManager}; #[cfg(feature = "http-proxy")] use crate::cluster::HttpProxyConfig; use crate::cluster::KeyspaceHolder; use crate::frame_encoding::FrameEncodingFactory; use crate::future::BoxFuture; use crate::transport::TransportRustls; #[cfg(feature = "http-proxy")] use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth}; use cassandra_protocol::authenticators::SaslAuthenticatorProvider; use cassandra_protocol::compression::Compression; use cassandra_protocol::error::{Error, Result}; use cassandra_protocol::frame::{Envelope, Version}; use futures::FutureExt; use std::io; #[cfg(feature = "http-proxy")] use std::io::ErrorKind; use std::net::SocketAddr; use std::ops::Deref; use std::sync::Arc; #[cfg(feature = "http-proxy")] use tokio::net::TcpStream; use tokio::sync::mpsc::Sender; use tokio_rustls::rustls::{pki_types::ServerName, ClientConfig}; pub struct RustlsConnectionManager { dns_name: ServerName<'static>, authenticator_provider: Arc, config: Arc, keyspace_holder: Arc, frame_encoder_factory: Box, compression: Compression, buffer_size: usize, tcp_nodelay: bool, version: Version, #[cfg(feature = "http-proxy")] http_proxy: Option, } impl ConnectionManager for RustlsConnectionManager { //noinspection DuplicatedCode fn connection( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> BoxFuture<'_, Result> { self.establish_connection(event_handler, error_handler, addr) .boxed() } } impl RustlsConnectionManager { #[allow(clippy::too_many_arguments)] pub fn new( dns_name: ServerName<'static>, authenticator_provider: Arc, config: Arc, keyspace_holder: Arc, frame_encoder_factory: Box, compression: Compression, buffer_size: usize, tcp_nodelay: bool, version: Version, #[cfg(feature = "http-proxy")] http_proxy: Option, ) -> Self { RustlsConnectionManager { dns_name, authenticator_provider, config, keyspace_holder, frame_encoder_factory, compression, buffer_size, tcp_nodelay, version, #[cfg(feature = "http-proxy")] http_proxy, } } //noinspection DuplicatedCode #[cfg(feature = "http-proxy")] async fn create_transport( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> io::Result { if let Some(http_proxy) = &self.http_proxy { let mut stream = TcpStream::connect(&http_proxy.address).await?; if let Some(auth) = &http_proxy.basic_auth { http_connect_tokio_with_basic_auth( &mut stream, &addr.ip().to_string(), addr.port(), &auth.username, &auth.password, ) .await .map_err(|error| io::Error::new(ErrorKind::Other, error.to_string()))?; } else { http_connect_tokio(&mut stream, &addr.ip().to_string(), addr.port()) .await .map_err(|error| io::Error::new(ErrorKind::Other, error.to_string()))?; } stream.set_nodelay(self.tcp_nodelay)?; TransportRustls::with_stream( stream, addr, self.dns_name.clone(), self.config.clone(), self.keyspace_holder.clone(), event_handler, error_handler, self.compression, self.frame_encoder_factory .create_encoder(self.version, self.compression), self.frame_encoder_factory .create_decoder(self.version, self.compression), self.buffer_size, ) .await } else { TransportRustls::new( addr, self.dns_name.clone(), self.config.clone(), self.keyspace_holder.clone(), event_handler, error_handler, self.compression, self.frame_encoder_factory .create_encoder(self.version, self.compression), self.frame_encoder_factory .create_decoder(self.version, self.compression), self.buffer_size, self.tcp_nodelay, ) .await } } //noinspection DuplicatedCode #[cfg(not(feature = "http-proxy"))] async fn create_transport( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> io::Result { TransportRustls::new( addr, self.dns_name.clone(), self.config.clone(), self.keyspace_holder.clone(), event_handler, error_handler, self.compression, self.frame_encoder_factory .create_encoder(self.version, self.compression), self.frame_encoder_factory .create_decoder(self.version, self.compression), self.buffer_size, self.tcp_nodelay, ) .await } async fn establish_connection( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> Result { let transport = self .create_transport(event_handler, error_handler, addr) .await?; startup( &transport, self.authenticator_provider.deref(), self.keyspace_holder.deref(), self.compression, self.version, ) .await?; Ok(transport) } } ================================================ FILE: cdrs-tokio/src/cluster/send_envelope.rs ================================================ use cassandra_protocol::error; use cassandra_protocol::frame::Envelope; use std::sync::Arc; use crate::cluster::topology::Node; use crate::cluster::ConnectionManager; use crate::retry::{QueryInfo, RetryDecision, RetrySession}; use crate::transport::CdrsTransport; /// Mid-level interface for sending envelopes to the cluster. Uses a query plan to route the envelope /// to the appropriate node, and retry policy for error handling. Returns `None` if no nodes were /// present in the query plan. pub async fn send_envelope + 'static>( query_plan: impl Iterator>>, envelope: &Envelope, is_idempotent: bool, mut retry_session: Box, ) -> Option> { let mut result = None; 'next_node: for node in query_plan { loop { let transport = node.persistent_connection().await; match transport { Ok(transport) => match transport.write_envelope(envelope, false).await { Ok(envelope) => return Some(Ok(envelope)), Err(error) => { let query_info = QueryInfo { error: &error, is_idempotent, }; match retry_session.decide(query_info) { RetryDecision::RetrySameNode => continue, RetryDecision::RetryNextNode => continue 'next_node, RetryDecision::DontRetry => return Some(Err(error)), } } }, // save the error, but keep trying, since another node might be up Err(error) => { result = Some(Err(error)); continue 'next_node; } } } } result } ================================================ FILE: cdrs-tokio/src/cluster/session.rs ================================================ use arc_swap::ArcSwapOption; use cassandra_protocol::compression::Compression; use cassandra_protocol::consistency::Consistency; use cassandra_protocol::error; use cassandra_protocol::events::ServerEvent; use cassandra_protocol::frame::message_error::{ErrorType, UnpreparedError}; use cassandra_protocol::frame::message_query::BodyReqQuery; use cassandra_protocol::frame::message_response::ResponseBody; use cassandra_protocol::frame::message_result::{BodyResResultPrepared, TableSpec}; use cassandra_protocol::frame::{Envelope, Flags, Serialize, Version}; use cassandra_protocol::query::{PreparedQuery, QueryBatch, QueryValues}; use cassandra_protocol::types::value::Value; use cassandra_protocol::types::{CBytesShort, CIntShort, SHORT_LEN}; use derivative::Derivative; use futures::stream::FuturesUnordered; use futures::{FutureExt, StreamExt}; use itertools::Itertools; use std::io::{Cursor, Write}; use std::marker::PhantomData; use std::net::SocketAddr; use std::sync::{Arc, LazyLock, Mutex}; use thiserror::Error; use tokio::sync::broadcast::{channel, Receiver, Sender}; use tokio::sync::watch; use tokio::task::JoinHandle; use tokio::time::sleep; use tokio::{pin, select}; use tracing::*; use crate::cluster::connection_manager::ConnectionManager; use crate::cluster::connection_pool::{ConnectionPoolConfig, ConnectionPoolFactory}; use crate::cluster::control_connection::ControlConnection; #[cfg(feature = "rust-tls")] use crate::cluster::rustls_connection_manager::RustlsConnectionManager; use crate::cluster::send_envelope::send_envelope; use crate::cluster::tcp_connection_manager::TcpConnectionManager; use crate::cluster::topology::{Node, NodeDistance, NodeState}; use crate::cluster::Murmur3Token; #[cfg(feature = "rust-tls")] use crate::cluster::NodeRustlsConfig; use crate::cluster::{ClusterMetadata, ClusterMetadataManager, SessionContext}; use crate::cluster::{GenericClusterConfig, KeyspaceHolder}; use crate::cluster::{NodeTcpConfig, SessionPager}; use crate::frame_encoding::{FrameEncodingFactory, ProtocolFrameEncodingFactory}; use crate::future::BoxFuture; use crate::load_balancing::node_distance_evaluator::AllLocalNodeDistanceEvaluator; use crate::load_balancing::node_distance_evaluator::NodeDistanceEvaluator; use crate::load_balancing::{ InitializingWrapperLoadBalancingStrategy, LoadBalancingStrategy, QueryPlan, Request, }; use crate::retry::{ DefaultRetryPolicy, ExponentialReconnectionPolicy, ReconnectionPolicy, RetryPolicy, }; use crate::speculative_execution::{Context, SpeculativeExecutionPolicy}; use crate::statement::{StatementParams, StatementParamsBuilder}; #[cfg(feature = "rust-tls")] use crate::transport::TransportRustls; use crate::transport::{CdrsTransport, TransportTcp}; pub const DEFAULT_TRANSPORT_BUFFER_SIZE: usize = 1024; const DEFAULT_EVENT_CHANNEL_CAPACITY: usize = 128; /// How many times execution will reprepare-and-retry an Unprepared statement /// before giving up. Without an upper bound a misbehaving cluster (e.g. /// repeatedly evicting prepared statements) could keep us looping /// indefinitely. Five attempts is enough to ride out transient schema or /// node restarts while still surfacing a real failure to the caller. const MAX_REPREPARE_ATTEMPTS: usize = 5; static DEFAULT_STATEMENT_PARAMETERS: LazyLock = LazyLock::new(Default::default); #[inline] fn convert_to_prepared(body: ResponseBody) -> error::Result { body.into_prepared() .ok_or_else(|| "Cannot convert envelope into prepare response!".into()) } #[inline] fn prepare_flags(with_tracing: bool, with_warnings: bool, beta_protocol: bool) -> Flags { let mut flags = Flags::empty(); if with_tracing { flags.insert(Flags::TRACING); } if with_warnings { flags.insert(Flags::WARNING); } if beta_protocol { flags.insert(Flags::BETA); } flags } fn create_keyspace_holder() -> (Arc, watch::Receiver>) { let (keyspace_sender, keyspace_receiver) = watch::channel(None); ( Arc::new(KeyspaceHolder::new(keyspace_sender)), keyspace_receiver, ) } fn verify_compression_configuration( version: Version, compression: Compression, ) -> Result<(), SessionBuildError> { if version < Version::V5 || compression != Compression::Snappy { Ok(()) } else { Err(SessionBuildError::CompressionTypeNotSupported) } } // https://github.com/apache/cassandra/blob/3a950b45c321e051a9744721408760c568c05617/src/java/org/apache/cassandra/db/marshal/CompositeType.java#L39 fn serialize_routing_value(cursor: &mut Cursor<&mut Vec>, value: &Vec, version: Version) { // Reserve 2 bytes for the value length, write the value, then come back // and patch the length in. The placeholder write advances the cursor by // exactly SHORT_LEN bytes so the rewind is always valid; the value write // advances by `value.len()`, giving us the size to back-fill. let temp_size: CIntShort = 0; temp_size.serialize(cursor, version); let before_value_pos = cursor.position(); value.serialize(cursor, version); let after_value_pos = cursor.position(); cursor.set_position(before_value_pos - SHORT_LEN as u64); let value_size: CIntShort = (after_value_pos - before_value_pos) as CIntShort; value_size.serialize(cursor, version); cursor.set_position(after_value_pos); let _ = cursor.write(&[0]); } fn serialize_routing_key_with_indexes( values: &[Value], pk_indexes: &[i16], version: Version, ) -> Option> { match pk_indexes.len() { 0 => None, 1 => values .get(pk_indexes[0] as usize) .and_then(|value| match value { Value::Some(value) => Some(value.serialize_to_vec(version)), _ => None, }), _ => { let mut buf = vec![]; if pk_indexes .iter() .map(|index| values.get(*index as usize)) .fold_options(Cursor::new(&mut buf), |mut cursor, value| { if let Value::Some(value) = value { serialize_routing_value(&mut cursor, value, version) } cursor }) .is_some() { Some(buf) } else { None } } } } fn serialize_routing_key(values: &[Value], version: Version) -> Vec { match values.len() { 0 => vec![], 1 => match &values[0] { Value::Some(value) => value.serialize_to_vec(version), _ => vec![], }, _ => { let mut buf = vec![]; let mut cursor = Cursor::new(&mut buf); for value in values { if let Value::Some(value) = value { serialize_routing_value(&mut cursor, value, version); } } buf } } } /// CDRS session that holds a pool of connections to nodes and provides an interface for /// interacting with the cluster. #[derive(Derivative)] #[derivative(Debug)] pub struct Session< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, LB: LoadBalancingStrategy + Send + Sync, > { #[derivative(Debug = "ignore")] load_balancing: Arc>, keyspace_holder: Arc, #[derivative(Debug = "ignore")] retry_policy: Box, #[derivative(Debug = "ignore")] speculative_execution_policy: Option>, control_connection_handle: JoinHandle<()>, event_sender: Sender, #[derivative(Debug = "ignore")] cluster_metadata_manager: Arc>, #[derivative(Debug = "ignore")] _transport: PhantomData, #[derivative(Debug = "ignore")] _connection_manager: PhantomData, version: Version, } impl< T: CdrsTransport + 'static, CM: ConnectionManager, LB: LoadBalancingStrategy + Send + Sync, > Drop for Session { fn drop(&mut self) { self.control_connection_handle.abort(); } } impl< T: CdrsTransport + 'static, CM: ConnectionManager + Send + Sync + 'static, LB: LoadBalancingStrategy + Send + Sync + 'static, > Session { /// Returns new `SessionPager` that can be used for performing paged queries. pub fn paged(&self, page_size: i32) -> SessionPager<'_, T, CM, LB> { SessionPager::new(self, page_size) } /// Executes given prepared query with query parameters. pub fn exec_with_params<'a, 'b: 'a>( &'a self, prepared: &'b PreparedQuery, parameters: &'b StatementParams, ) -> BoxFuture<'a, error::Result> { async move { let consistency = parameters.query_params.consistency; let flags = prepare_flags( parameters.tracing, parameters.warnings, parameters.beta_protocol, ); let result_metadata_id = prepared .result_metadata_id .load() .as_ref() .map(|metadata| (**metadata).clone()); let envelope = Envelope::new_req_execute( &prepared.id, result_metadata_id.as_ref(), ¶meters.query_params, flags, self.version, ); let keyspace = prepared .keyspace .as_deref() .or(parameters.keyspace.as_deref()); let routing_key = parameters .query_params .values .as_ref() .and_then(|values| match values { QueryValues::SimpleValues(values) => serialize_routing_key_with_indexes( values, &prepared.pk_indexes, self.version, ) .or_else(|| { parameters .routing_key .as_ref() .map(|values| serialize_routing_key(values, self.version)) }), QueryValues::NamedValues(_) => None, }); // Bounded retry loop. Previously this was an unbounded recursive // self-call: if the cluster kept returning Unprepared even after // a successful reprepare (rare but possible during schema/cluster // instability), the client would recurse forever. Capping at // MAX_REPREPARE_ATTEMPTS gives the cluster a few chances to settle // and then surfaces the original error to the caller. let mut attempts_remaining = MAX_REPREPARE_ATTEMPTS; let result = loop { let result = self .send_envelope( // Borrowed: the bounded reprepare loop may run this // multiple times; cloning the encoded body each // iteration would be wasteful for any non-trivial // EXECUTE payload. &envelope, parameters.is_idempotent, keyspace, parameters.token, routing_key.as_deref(), Some(consistency), parameters.speculative_execution_policy.as_ref(), parameters.retry_policy.as_ref(), ) .await; // Try to identify an Unprepared error and the address of the // node that reported it. If we have budget left, reprepare on // that node and retry. Otherwise fall through with the result // we have - good or bad. if let Err(error::Error::Server { body: error, addr }) = &result { if let ErrorType::Unprepared(_) = error.ty { if attempts_remaining > 0 { attempts_remaining -= 1; if self .reprepare( &prepared.id, prepared.query.clone(), keyspace.map(|keyspace| keyspace.to_string()), parameters, *addr, ) .await .is_ok() { continue; } } } } break result; }; let response = result .as_ref() .map_err(|error| error.clone()) .and_then(|result| result.response_body()); let new_metadata_id = response.as_ref().map(|result| { result .as_rows_metadata() .and_then(|metadata| metadata.new_metadata_id.as_ref()) }); if let Ok(Some(new_metadata_id)) = new_metadata_id { prepared .result_metadata_id .swap(Some(Arc::new(new_metadata_id.clone()))); } result } .boxed() } /// Executes given prepared query with query values. pub async fn exec_with_values>( &self, prepared: &PreparedQuery, values: V, ) -> error::Result { self.exec_with_params( prepared, &StatementParamsBuilder::new() .with_values(values.into()) .build(), ) .await } /// Executes the given prepared query. #[inline] pub async fn exec(&self, prepared: &PreparedQuery) -> error::Result { self.exec_with_params(prepared, &DEFAULT_STATEMENT_PARAMETERS) .await } /// Prepares a query for execution. Along with the query itself, the /// method takes `with_tracing` and `with_warnings` flags to get /// tracing information and warnings. Returns the raw prepared /// query result. pub async fn prepare_raw_tw( &self, query: Q, keyspace: Option, with_tracing: bool, with_warnings: bool, beta_protocol: bool, ) -> error::Result { self.prepare_raw_tw_with_query_plan( query, keyspace, with_tracing, with_warnings, beta_protocol, None, ) .await } /// Prepares a query for execution. Along with the query itself, the /// method takes `with_tracing` and `with_warnings` flags to get /// tracing information and warnings. Returns the raw prepared /// query result. Optional query plan can be provided to customize /// query preparation. pub async fn prepare_raw_tw_with_query_plan( &self, query: Q, keyspace: Option, with_tracing: bool, with_warnings: bool, beta_protocol: bool, query_plan: Option>, ) -> error::Result { let flags = prepare_flags(with_tracing, with_warnings, beta_protocol); let envelope = Envelope::new_req_prepare(query.to_string(), keyspace, flags, self.version); let response = match query_plan { None => { self.send_envelope(&envelope, true, None, None, None, None, None, None) .await } Some(query_plan) => send_envelope( query_plan.nodes.into_iter(), &envelope, true, self.retry_policy.as_ref().new_session(), ) .await .unwrap_or_else(|| Err("No response for prepare!".into())), }; response .and_then(|response| response.response_body()) .and_then(convert_to_prepared) } /// Prepares a query without additional tracing information and warnings. /// Returns the raw prepared query result. #[inline] pub async fn prepare_raw(&self, query: Q) -> error::Result { self.prepare_raw_tw(query, None, false, false, false).await } /// Prepares a query for execution. Along with the query itself, /// the method takes `with_tracing` and `with_warnings` flags /// to get tracing information and warnings. Returns the prepared /// query. pub async fn prepare_tw( &self, query: Q, keyspace: Option, with_tracing: bool, with_warnings: bool, beta_protocol: bool, ) -> error::Result { let s = query.to_string(); self.prepare_raw_tw(query, keyspace, with_tracing, with_warnings, beta_protocol) .await .map(|result| PreparedQuery { id: result.id, query: s, keyspace: result .metadata .global_table_spec .map(|TableSpec { ks_name, .. }| ks_name), pk_indexes: result.metadata.pk_indexes, result_metadata_id: ArcSwapOption::new(result.result_metadata_id.map(Arc::new)), }) } /// Prepares a query without additional tracing information and warnings. /// Returns the prepared query. #[inline] pub async fn prepare(&self, query: Q) -> error::Result { self.prepare_tw(query, None, false, false, false).await } /// Executes batch query. #[inline] pub async fn batch(&self, batch: QueryBatch) -> error::Result { self.batch_with_params(batch, &DEFAULT_STATEMENT_PARAMETERS) .await } /// Executes a batch query with parameters. pub fn batch_with_params<'a, 'b: 'a>( &'a self, batch: QueryBatch, parameters: &'b StatementParams, ) -> BoxFuture<'a, error::Result> { async move { let flags = prepare_flags( parameters.tracing, parameters.warnings, parameters.beta_protocol, ); let consistency = batch.request.consistency; let envelope = Envelope::new_req_batch(batch.request.clone(), flags, self.version); // Bounded retry loop, same rationale as exec_with_params: if the // cluster keeps reporting Unprepared we want to give up cleanly // rather than recurse without bound. let mut attempts_remaining = MAX_REPREPARE_ATTEMPTS; loop { let result = self .send_envelope( // See exec_with_params - same retry-loop hot path. &envelope, parameters.is_idempotent, parameters.keyspace.as_deref(), None, None, Some(consistency), parameters.speculative_execution_policy.as_ref(), parameters.retry_policy.as_ref(), ) .await; if let Err(error::Error::Server { body: error, addr }) = &result { if let ErrorType::Unprepared(UnpreparedError { id }) = &error.ty { if attempts_remaining == 0 { // out of retries - return the most recent error return result; } let query = match batch.prepared_queries.get(id) { None => { warn!( ?id, "Cannot find prepared query for unprepared statement in a batch!" ); return result; } Some(query) => query, }; attempts_remaining -= 1; let prepare_result = self .reprepare( id, query.query.clone(), query.keyspace.clone(), parameters, *addr, ) .await; if prepare_result.is_ok() { // try the batch again with the freshly prepared statement continue; } } } return result; } } .boxed() } async fn reprepare( &self, id: &CBytesShort, query: String, keyspace: Option, parameters: &StatementParams, node_broadcast_rpc_address: SocketAddr, ) -> error::Result<()> { debug!("Re-preparing statement."); let flags = prepare_flags( parameters.tracing, parameters.warnings, parameters.beta_protocol, ); // We need to send the prepare statement to the failing node. let node = self .cluster_metadata_manager .find_node_by_rpc_address(node_broadcast_rpc_address) .ok_or_else(|| { error::Error::from(format!( "Cannot find node {node_broadcast_rpc_address} for statement re-preparation!" )) })?; let prepare_envelope = Envelope::new_req_prepare(query, keyspace, flags, self.version); let retry_policy = self.effective_retry_policy(parameters.retry_policy.as_ref()); let prepare_result = send_envelope( [node].iter().cloned(), &prepare_envelope, true, retry_policy.new_session(), ) .await .unwrap_or_else(|| Err("No response for re-prepare statement!".into())) .and_then(|response| response.response_body()) .and_then(convert_to_prepared)?; // re-prepare the statement and check the resulting id - it should remain the // same as the old one, except when schema changed in the meantime, in which // case, the client should have the knowledge how to handle it // see: https://issues.apache.org/jira/browse/CASSANDRA-10786 if id != &prepare_result.id { return Err("Re-preparing an unprepared statement resulted in a different id - probably schema changed on the server.".into()); } Ok(()) } /// Executes a query. #[inline] pub async fn query(&self, query: Q) -> error::Result { self.query_with_params(query, DEFAULT_STATEMENT_PARAMETERS.clone()) .await } /// Executes a query with bounded values (either with or without names). #[inline] pub async fn query_with_values>( &self, query: Q, values: V, ) -> error::Result { self.query_with_params( query, StatementParamsBuilder::new() .with_values(values.into()) .build(), ) .await } /// Executes a query with query parameters. pub async fn query_with_params( &self, query: Q, parameters: StatementParams, ) -> error::Result { let is_idempotent = parameters.is_idempotent; let consistency = parameters.query_params.consistency; let keyspace = parameters.keyspace; let token = parameters.token; let routing_key = parameters .routing_key .as_ref() .map(|values| serialize_routing_key(values, self.version)); let query = BodyReqQuery { query: query.to_string(), query_params: parameters.query_params, }; let flags = prepare_flags( parameters.tracing, parameters.warnings, parameters.beta_protocol, ); let envelope = Envelope::new_query(query, flags, self.version); self.send_envelope( &envelope, is_idempotent, keyspace.as_deref(), token, routing_key.as_deref(), Some(consistency), parameters.speculative_execution_policy.as_ref(), parameters.retry_policy.as_ref(), ) .await } /// Returns currently set global keyspace. #[inline] pub fn current_keyspace(&self) -> Option> { self.keyspace_holder.current_keyspace() } /// Returns current cluster metadata. #[inline] pub fn cluster_metadata(&self) -> Arc> { self.cluster_metadata_manager.metadata() } /// Returns query plan for given request. If no request is given, return a generic plan for /// establishing connection(s) to node(s). #[inline] pub fn query_plan(&self, request: Option) -> QueryPlan { self.load_balancing .query_plan(request, self.cluster_metadata().as_ref()) } /// Creates a new server event receiver. You can use multiple receivers at the same time. #[inline] pub fn create_event_receiver(&self) -> Receiver { self.event_sender.subscribe() } /// Returns current retry policy. #[inline] pub fn retry_policy(&self) -> &dyn RetryPolicy { self.retry_policy.as_ref() } // Take envelope by reference: send_envelope dispatches it (potentially // multiple times across speculative execution and per-node retry) but // never needs ownership. Keeping it borrowed lets retry loops in callers // (e.g. the bounded reprepare loop) avoid an unnecessary `Vec` clone // of the encoded body on every iteration. #[allow(clippy::too_many_arguments)] async fn send_envelope( &self, envelope: &Envelope, is_idempotent: bool, keyspace: Option<&str>, token: Option, routing_key: Option<&[u8]>, consistency: Option, speculative_execution_policy: Option<&Arc>, retry_policy: Option<&Arc>, ) -> error::Result { let current_keyspace = self.current_keyspace(); let request = Request::new( keyspace.or_else(|| current_keyspace.as_ref().map(|keyspace| &***keyspace)), token, routing_key, consistency, ); let query_plan = self.query_plan(Some(request)); struct SharedQueryPlan< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, I: Iterator>>, > { current_node: Mutex, } impl< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, I: Iterator>>, > SharedQueryPlan { fn new(current_node: I) -> Self { SharedQueryPlan { current_node: Mutex::new(current_node), } } } impl< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, I: Iterator>>, > Iterator for &SharedQueryPlan { type Item = Arc>; fn next(&mut self) -> Option { self.current_node.lock().unwrap().next() } } let speculative_execution_policy = speculative_execution_policy .map(|speculative_execution_policy| speculative_execution_policy.as_ref()) .or(self.speculative_execution_policy.as_deref()); let retry_policy = self.effective_retry_policy(retry_policy); match speculative_execution_policy { Some(speculative_execution_policy) if is_idempotent => { let shared_query_plan = SharedQueryPlan::new(query_plan.nodes.into_iter()); let mut context = Context::new(1); let mut async_tasks = FuturesUnordered::new(); async_tasks.push(send_envelope( &shared_query_plan, envelope, is_idempotent, retry_policy.new_session(), )); let sleep_fut = sleep( speculative_execution_policy .execution_interval(&context) .unwrap_or_default(), ) .fuse(); pin!(sleep_fut); let mut last_error = None; loop { select! { _ = &mut sleep_fut => { if let Some(interval) = speculative_execution_policy.execution_interval(&context) { context.running_executions += 1; async_tasks.push(send_envelope( &shared_query_plan, envelope, is_idempotent, retry_policy.new_session(), )); sleep_fut.set(sleep(interval).fuse()); } } result = async_tasks.select_next_some() => { match result { Some(result) => { match result { Err(error::Error::Io(_)) | Err(error::Error::Timeout(_)) => { last_error = Some(result); }, _ => return result, } } None => { if async_tasks.is_empty() { // at this point, we exhausted all available nodes and // there's no request in flight, which can potentially // reach a node return last_error.unwrap_or_else(|| Err("No nodes available in query plan!".into())); } } } } } } } _ => send_envelope( query_plan.nodes.into_iter(), envelope, is_idempotent, retry_policy.new_session(), ) .await .unwrap_or_else(|| Err("No nodes available in query plan!".into())), } } #[inline] fn effective_retry_policy<'a, 'b: 'a>( &'a self, retry_policy: Option<&'b Arc>, ) -> &'a (dyn RetryPolicy + Send + Sync) { retry_policy .map(|retry_policy| retry_policy.as_ref()) .unwrap_or_else(|| self.retry_policy.as_ref()) } #[allow(clippy::too_many_arguments)] async fn new( load_balancing: LB, keyspace_holder: Arc, keyspace_receiver: watch::Receiver>, retry_policy: Box, reconnection_policy: Arc, node_distance_evaluator: Box, speculative_execution_policy: Option>, contact_points: Vec, connection_manager: CM, event_channel_capacity: usize, version: Version, connection_pool_config: ConnectionPoolConfig, beta_protocol: bool, ) -> Result { let connection_pool_factory = Arc::new(ConnectionPoolFactory::new( connection_pool_config, version, connection_manager, keyspace_receiver, reconnection_policy.clone(), )); let contact_points = contact_points .into_iter() .map(|contact_point| { Arc::new(Node::new_with_state( connection_pool_factory.clone(), contact_point, None, None, // assume contact points are local until refresh Some(NodeDistance::Local), NodeState::Up, Default::default(), // as with distance, rack/dc is unknown until refresh "".into(), "".into(), )) }) .collect_vec(); let load_balancing = Arc::new(InitializingWrapperLoadBalancingStrategy::new( load_balancing, contact_points.clone(), )); let (event_sender, event_receiver) = channel(event_channel_capacity); let session_context = Arc::new(SessionContext::default()); let cluster_metadata_manager = Arc::new(ClusterMetadataManager::new( contact_points.clone(), connection_pool_factory, session_context.clone(), node_distance_evaluator, version, beta_protocol, )); cluster_metadata_manager.listen_to_events(event_receiver); let control_connection = ControlConnection::new( load_balancing.clone(), contact_points, reconnection_policy.clone(), cluster_metadata_manager.clone(), event_sender.clone(), session_context, version, ); let (init_complete_sender, init_complete_receiver) = tokio::sync::oneshot::channel(); // Wrap the JoinHandle in a guard so that if any error path below // returns early, the spawned control-connection task is aborted // rather than left running for the lifetime of the runtime. tokio's // JoinHandle does not abort on drop. Once we successfully reach the // Ok(Session { ... }) below, into_inner() releases the guard and // ownership passes to the Session struct (whose own Drop impl // aborts the handle on session shutdown). let control_connection_handle = AbortOnDropHandle::new(tokio::spawn(control_connection.run(init_complete_sender))); if init_complete_receiver.await.is_err() { // guard drops here -> task aborted, no leak return Err(SessionBuildError::SessionInitFailed); } Ok(Session { load_balancing, keyspace_holder, retry_policy, speculative_execution_policy, control_connection_handle: control_connection_handle.into_inner(), event_sender, cluster_metadata_manager, _transport: Default::default(), _connection_manager: Default::default(), version, }) } } /// Workaround for #[repr(transparent)] pub struct RetryPolicyWrapper(pub Box); /// Workaround for #[repr(transparent)] pub struct ReconnectionPolicyWrapper(pub Arc); /// Workaround for #[repr(transparent)] pub struct NodeDistanceEvaluatorWrapper(pub Box); /// Workaround for #[repr(transparent)] pub struct SpeculativeExecutionPolicyWrapper(pub Box); /// This function uses a user-supplied connection configuration to initialize all the /// connections in the session. It can be used to supply your own transport and load /// balancing mechanisms to support unusual node discovery mechanisms or configuration needs. /// /// The config object supplied differs from the [`NodeTcpConfig`] and [`NodeRustlsConfig`] /// objects in that it is not expected to include an address. Instead, the same configuration /// will be applied to all connections across the cluster. pub async fn connect_generic( config: &C, initial_nodes: A, load_balancing: LB, retry_policy: RetryPolicyWrapper, reconnection_policy: ReconnectionPolicyWrapper, node_distance_evaluator: NodeDistanceEvaluatorWrapper, speculative_execution_policy: Option, ) -> error::Result> where A: IntoIterator, T: CdrsTransport + 'static, CM: ConnectionManager + Send + Sync + 'static, C: GenericClusterConfig, LB: LoadBalancingStrategy + Sized + Send + Sync + 'static, { let (keyspace_holder, keyspace_receiver) = create_keyspace_holder(); let connection_manager = config.create_manager(keyspace_holder.clone()).await?; Session::new( load_balancing, keyspace_holder, keyspace_receiver, retry_policy.0, reconnection_policy.0, node_distance_evaluator.0, speculative_execution_policy.map(|policy| policy.0), initial_nodes.into_iter().collect(), connection_manager, config.event_channel_capacity(), config.version(), config.connection_pool_config(), config.beta_protocol(), ) .await .map_err(|e| error::Error::General(e.to_string())) } struct SessionConfig< T: CdrsTransport, CM: ConnectionManager, LB: LoadBalancingStrategy + Send + Sync, > { compression: Compression, transport_buffer_size: usize, tcp_nodelay: bool, load_balancing: LB, retry_policy: Box, reconnection_policy: Arc, node_distance_evaluator: Box, speculative_execution_policy: Option>, event_channel_capacity: usize, connection_pool_config: ConnectionPoolConfig, keyspace: Option, _connection_manager: PhantomData, _transport: PhantomData, } impl< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, LB: LoadBalancingStrategy + Send + Sync + 'static, > SessionConfig { fn new(load_balancing: LB) -> Self { SessionConfig { compression: Compression::None, transport_buffer_size: DEFAULT_TRANSPORT_BUFFER_SIZE, tcp_nodelay: true, load_balancing, retry_policy: Box::::default(), reconnection_policy: Arc::new(ExponentialReconnectionPolicy::default()), node_distance_evaluator: Box::::default(), speculative_execution_policy: None, event_channel_capacity: DEFAULT_EVENT_CHANNEL_CAPACITY, connection_pool_config: Default::default(), keyspace: None, _connection_manager: Default::default(), _transport: Default::default(), } } async fn into_session( self, keyspace_holder: Arc, keyspace_receiver: watch::Receiver>, contact_points: Vec, connection_manager: CM, version: Version, beta_protocol: bool, ) -> Result, SessionBuildError> { if let Some(keyspace) = self.keyspace { keyspace_holder.update_current_keyspace_without_notification(keyspace); } Session::new( self.load_balancing, keyspace_holder, keyspace_receiver, self.retry_policy, self.reconnection_policy, self.node_distance_evaluator, self.speculative_execution_policy, contact_points, connection_manager, self.event_channel_capacity, version, self.connection_pool_config, beta_protocol, ) .await } } /// `Session` build error. #[derive(Error, Debug, Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] pub enum SessionBuildError { #[error("Given compression type is not supported for selected protocol!")] CompressionTypeNotSupported, #[error("Session control connection died before completing initialization")] SessionInitFailed, } /// Builder for easy `Session` creation. Requires static `LoadBalancingStrategy`, but otherwise, other /// configuration parameters can be dynamically set. Use concrete implementers to create specific /// sessions. pub trait SessionBuilder< T: CdrsTransport + 'static, CM: ConnectionManager, LB: LoadBalancingStrategy + Send + Sync + 'static, > { /// Sets new compression. #[must_use] fn with_compression(self, compression: Compression) -> Self; /// Set new retry policy. #[must_use] fn with_retry_policy(self, retry_policy: Box) -> Self; /// Set new reconnection policy. #[must_use] fn with_reconnection_policy( self, reconnection_policy: Arc, ) -> Self; /// Sets custom frame encoder factory. #[must_use] fn with_frame_encoder_factory( self, frame_encoder_factory: Box, ) -> Self; /// Sets new node distance evaluator. Computing node distance is fundamental to proper /// topology-aware load balancing - see [`NodeDistanceEvaluator`]. #[must_use] fn with_node_distance_evaluator( self, node_distance_evaluator: Box, ) -> Self; /// Sets new speculative execution policy. #[must_use] fn with_speculative_execution_policy( self, speculative_execution_policy: Box, ) -> Self; /// Sets new transport buffer size. High values are recommended with large numbers of in flight /// queries. #[must_use] fn with_transport_buffer_size(self, transport_buffer_size: usize) -> Self; /// Sets NODELAY for given session connections. #[must_use] fn with_tcp_nodelay(self, tcp_nodelay: bool) -> Self; /// Sets event channel capacity. If the driver receives more server events than the capacity, /// some events might get dropped. This can result in the driver operating in a sub-optimal way. #[must_use] fn with_event_channel_capacity(self, event_channel_capacity: usize) -> Self; /// Sets node connection pool configuration for given session. #[must_use] fn with_connection_pool_config(self, connection_pool_config: ConnectionPoolConfig) -> Self; /// Sets the keyspace to use. If not using a keyspace explicitly in queries, one should be set /// either by calling this function or by a `USE` statement. Due to the asynchronous nature of /// the driver and the usage of connection pools, the effect of switching current keyspace via /// `USE` might not propagate immediately to all active connections, resulting in queries /// using a wrong keyspace. If one is known upfront, it's safer to set it while building /// the [`Session`]. #[must_use] fn with_keyspace(self, keyspace: String) -> Self; /// Sets the beta protocol flag. Server will respond with ERROR if the protocol version is /// marked as beta on server and the client does not provide this flag. #[must_use] fn with_beta_protocol(self, beta_protocol: bool) -> Self; /// Builds the resulting session. fn build(self) -> BoxFuture<'static, Result, SessionBuildError>>; } /// Builder for non-TLS sessions. pub struct TcpSessionBuilder< LB: LoadBalancingStrategy + Send + Sync, > { config: SessionConfig, node_config: NodeTcpConfig, frame_encoder_factory: Box, } impl + Send + Sync + 'static> TcpSessionBuilder { //noinspection DuplicatedCode /// Creates a new builder with default session configuration. pub fn new(load_balancing: LB, node_config: NodeTcpConfig) -> Self { TcpSessionBuilder { config: SessionConfig::new(load_balancing), node_config, frame_encoder_factory: Box::::default(), } } } impl + Send + Sync + 'static> SessionBuilder for TcpSessionBuilder { fn with_compression(mut self, compression: Compression) -> Self { self.config.compression = compression; self } fn with_retry_policy(mut self, retry_policy: Box) -> Self { self.config.retry_policy = retry_policy; self } fn with_reconnection_policy( mut self, reconnection_policy: Arc, ) -> Self { self.config.reconnection_policy = reconnection_policy; self } fn with_frame_encoder_factory( mut self, frame_encoder_factory: Box, ) -> Self { self.frame_encoder_factory = frame_encoder_factory; self } fn with_node_distance_evaluator( mut self, node_distance_evaluator: Box, ) -> Self { self.config.node_distance_evaluator = node_distance_evaluator; self } fn with_speculative_execution_policy( mut self, speculative_execution_policy: Box, ) -> Self { self.config.speculative_execution_policy = Some(speculative_execution_policy); self } fn with_transport_buffer_size(mut self, transport_buffer_size: usize) -> Self { self.config.transport_buffer_size = transport_buffer_size; self } fn with_tcp_nodelay(mut self, tcp_nodelay: bool) -> Self { self.config.tcp_nodelay = tcp_nodelay; self } fn with_event_channel_capacity(mut self, event_channel_capacity: usize) -> Self { self.config.event_channel_capacity = event_channel_capacity; self } fn with_connection_pool_config(mut self, connection_pool_config: ConnectionPoolConfig) -> Self { self.config.connection_pool_config = connection_pool_config; self } fn with_keyspace(mut self, keyspace: String) -> Self { self.config.keyspace = Some(keyspace); self } fn with_beta_protocol(mut self, beta_protocol: bool) -> Self { self.node_config.beta_protocol = beta_protocol; self } fn build( self, ) -> BoxFuture< 'static, Result, SessionBuildError>, > { async move { match verify_compression_configuration( self.node_config.version, self.config.compression, ) { Ok(()) => { let (keyspace_holder, keyspace_receiver) = create_keyspace_holder(); let connection_manager = TcpConnectionManager::new( self.node_config.authenticator_provider, keyspace_holder.clone(), self.frame_encoder_factory, self.config.compression, self.config.transport_buffer_size, self.config.tcp_nodelay, self.node_config.version, #[cfg(feature = "http-proxy")] self.node_config.http_proxy, ); self.config .into_session( keyspace_holder, keyspace_receiver, self.node_config.contact_points, connection_manager, self.node_config.version, self.node_config.beta_protocol, ) .await } Err(err) => Err(err), } } .boxed() } } #[cfg(feature = "rust-tls")] /// Builder for TLS sessions. pub struct RustlsSessionBuilder< LB: LoadBalancingStrategy + Send + Sync + 'static, > { config: SessionConfig, node_config: NodeRustlsConfig, frame_encoder_factory: Box, } #[cfg(feature = "rust-tls")] impl + Send + Sync> RustlsSessionBuilder { //noinspection DuplicatedCode /// Creates a new builder with default session configuration. pub fn new(load_balancing: LB, node_config: NodeRustlsConfig) -> Self { RustlsSessionBuilder { config: SessionConfig::new(load_balancing), node_config, frame_encoder_factory: Box::::default(), } } } #[cfg(feature = "rust-tls")] impl< LB: LoadBalancingStrategy + Send + Sync + 'static, > SessionBuilder for RustlsSessionBuilder { fn with_compression(mut self, compression: Compression) -> Self { self.config.compression = compression; self } fn with_retry_policy(mut self, retry_policy: Box) -> Self { self.config.retry_policy = retry_policy; self } fn with_reconnection_policy( mut self, reconnection_policy: Arc, ) -> Self { self.config.reconnection_policy = reconnection_policy; self } fn with_frame_encoder_factory( mut self, frame_encoder_factory: Box, ) -> Self { self.frame_encoder_factory = frame_encoder_factory; self } fn with_node_distance_evaluator( mut self, node_distance_evaluator: Box, ) -> Self { self.config.node_distance_evaluator = node_distance_evaluator; self } fn with_speculative_execution_policy( mut self, speculative_execution_policy: Box, ) -> Self { self.config.speculative_execution_policy = Some(speculative_execution_policy); self } fn with_transport_buffer_size(mut self, transport_buffer_size: usize) -> Self { self.config.transport_buffer_size = transport_buffer_size; self } fn with_tcp_nodelay(mut self, tcp_nodelay: bool) -> Self { self.config.tcp_nodelay = tcp_nodelay; self } fn with_event_channel_capacity(mut self, event_channel_capacity: usize) -> Self { self.config.event_channel_capacity = event_channel_capacity; self } fn with_connection_pool_config(mut self, connection_pool_config: ConnectionPoolConfig) -> Self { self.config.connection_pool_config = connection_pool_config; self } fn with_keyspace(mut self, keyspace: String) -> Self { self.config.keyspace = Some(keyspace); self } fn with_beta_protocol(mut self, beta_protocol: bool) -> Self { self.node_config.beta_protocol = beta_protocol; self } fn build( self, ) -> BoxFuture< 'static, Result, SessionBuildError>, > { async move { match verify_compression_configuration( self.node_config.version, self.config.compression, ) { Ok(()) => { let (keyspace_holder, keyspace_receiver) = create_keyspace_holder(); let connection_manager = RustlsConnectionManager::new( self.node_config.dns_name, self.node_config.authenticator_provider, self.node_config.config, keyspace_holder.clone(), self.frame_encoder_factory, self.config.compression, self.config.transport_buffer_size, self.config.tcp_nodelay, self.node_config.version, #[cfg(feature = "http-proxy")] self.node_config.http_proxy, ); self.config .into_session( keyspace_holder, keyspace_receiver, self.node_config.contact_points, connection_manager, self.node_config.version, self.node_config.beta_protocol, ) .await } Err(err) => Err(err), } } .boxed() } } /// RAII guard that aborts a spawned task if the guard is dropped without /// being explicitly released via [`Self::into_inner`]. /// /// Used during session construction so that if Session::new returns Err /// before reaching the final `Ok(Session { ... })` (which moves the /// JoinHandle into the struct), the spawned background task is cleaned up /// instead of being leaked. tokio's JoinHandle does not abort on drop by /// default, so the dropped handle would otherwise let the task keep /// running for the lifetime of the runtime. struct AbortOnDropHandle(Option>); impl AbortOnDropHandle { fn new(handle: JoinHandle<()>) -> Self { Self(Some(handle)) } /// Releases ownership of the inner JoinHandle. The Drop impl becomes a /// no-op for this guard, so the caller is now responsible for the task's /// lifecycle. fn into_inner(mut self) -> JoinHandle<()> { self.0 .take() .expect("AbortOnDropHandle inner cannot be None") } } impl Drop for AbortOnDropHandle { fn drop(&mut self) { if let Some(handle) = self.0.take() { handle.abort(); } } } #[cfg(test)] mod tests { use crate::cluster::session::{prepare_flags, AbortOnDropHandle}; use cassandra_protocol::frame::Flags; use tokio::task::JoinHandle; #[test] fn prepare_flags_test() { assert!(prepare_flags(true, false, false).contains(Flags::TRACING)); assert!(prepare_flags(false, true, false).contains(Flags::WARNING)); assert!(prepare_flags(false, false, true).contains(Flags::BETA)); let all = prepare_flags(true, true, true); assert!(all.contains(Flags::TRACING)); assert!(all.contains(Flags::WARNING)); assert!(all.contains(Flags::BETA)); } // The drop guard wraps a JoinHandle so that if Session::new returns // early (e.g. init never completes), the spawned control connection // task is aborted instead of left running indefinitely. #[tokio::test] async fn abort_on_drop_handle_aborts_when_dropped() { // Spawn a task that will run forever unless aborted. Hold an // AbortHandle on the side so we can observe whether the inner // JoinHandle was actually aborted after the guard drops. let task: JoinHandle<()> = tokio::spawn(async { std::future::pending::<()>().await; }); let abort_observer = task.abort_handle(); // Wrap the JoinHandle in a guard, then immediately drop the guard // without releasing it. The task should be aborted. { let _guard = AbortOnDropHandle::new(task); } // Give the runtime a chance to process the abort. tokio::task::yield_now().await; tokio::time::sleep(std::time::Duration::from_millis(10)).await; assert!( abort_observer.is_finished(), "task must be aborted after AbortOnDropHandle drops" ); } // The "happy path" - if the caller releases the guard via into_inner, // the task must NOT be aborted; it has been transferred to the caller's // ownership (typically stored in the constructed Session). #[tokio::test] async fn abort_on_drop_handle_does_not_abort_when_released() { let task: JoinHandle<()> = tokio::spawn(async { std::future::pending::<()>().await; }); let abort_observer = task.abort_handle(); let guard = AbortOnDropHandle::new(task); // Release the inner handle - guard's Drop must become a no-op now. let released = guard.into_inner(); // Task should still be running. tokio::task::yield_now().await; assert!( !abort_observer.is_finished(), "task must keep running after into_inner" ); // Cleanup: now we abort it explicitly. released.abort(); } } ================================================ FILE: cdrs-tokio/src/cluster/session_context.rs ================================================ use crate::transport::CdrsTransport; use arc_swap::ArcSwapOption; pub struct SessionContext { pub control_connection_transport: ArcSwapOption, } impl Default for SessionContext { fn default() -> Self { SessionContext { control_connection_transport: ArcSwapOption::empty(), } } } ================================================ FILE: cdrs-tokio/src/cluster/tcp_connection_manager.rs ================================================ use crate::cluster::connection_manager::{startup, ConnectionManager}; #[cfg(feature = "http-proxy")] use crate::cluster::HttpProxyConfig; use crate::cluster::KeyspaceHolder; use crate::frame_encoding::FrameEncodingFactory; use crate::future::BoxFuture; use crate::transport::TransportTcp; #[cfg(feature = "http-proxy")] use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth}; use cassandra_protocol::authenticators::SaslAuthenticatorProvider; use cassandra_protocol::compression::Compression; use cassandra_protocol::error::{Error, Result}; use cassandra_protocol::frame::{Envelope, Version}; use futures::FutureExt; use std::io; #[cfg(feature = "http-proxy")] use std::io::ErrorKind; use std::net::SocketAddr; use std::ops::Deref; use std::sync::Arc; #[cfg(feature = "http-proxy")] use tokio::net::TcpStream; use tokio::sync::mpsc::Sender; pub struct TcpConnectionManager { authenticator_provider: Arc, keyspace_holder: Arc, frame_encoder_factory: Box, compression: Compression, buffer_size: usize, tcp_nodelay: bool, version: Version, #[cfg(feature = "http-proxy")] http_proxy: Option, } impl ConnectionManager for TcpConnectionManager { //noinspection DuplicatedCode fn connection( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> BoxFuture<'_, Result> { self.establish_connection(event_handler, error_handler, addr) .boxed() } } impl TcpConnectionManager { #[allow(clippy::too_many_arguments)] pub fn new( authenticator_provider: Arc, keyspace_holder: Arc, frame_encoder_factory: Box, compression: Compression, buffer_size: usize, tcp_nodelay: bool, version: Version, #[cfg(feature = "http-proxy")] http_proxy: Option, ) -> Self { Self { authenticator_provider, keyspace_holder, frame_encoder_factory, compression, buffer_size, tcp_nodelay, version, #[cfg(feature = "http-proxy")] http_proxy, } } //noinspection DuplicatedCode #[cfg(feature = "http-proxy")] async fn create_transport( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> io::Result { if let Some(http_proxy) = &self.http_proxy { let mut stream = TcpStream::connect(&http_proxy.address).await?; if let Some(auth) = &http_proxy.basic_auth { http_connect_tokio_with_basic_auth( &mut stream, &addr.ip().to_string(), addr.port(), &auth.username, &auth.password, ) .await .map_err(|error| io::Error::new(ErrorKind::Other, error.to_string()))?; } else { http_connect_tokio(&mut stream, &addr.ip().to_string(), addr.port()) .await .map_err(|error| io::Error::new(ErrorKind::Other, error.to_string()))?; } stream.set_nodelay(self.tcp_nodelay)?; TransportTcp::with_stream( stream, addr, self.keyspace_holder.clone(), event_handler, error_handler, self.compression, self.frame_encoder_factory .create_encoder(self.version, self.compression), self.frame_encoder_factory .create_decoder(self.version, self.compression), self.buffer_size, ) } else { TransportTcp::new( addr, self.keyspace_holder.clone(), event_handler, error_handler, self.compression, self.frame_encoder_factory .create_encoder(self.version, self.compression), self.frame_encoder_factory .create_decoder(self.version, self.compression), self.buffer_size, self.tcp_nodelay, ) .await } } //noinspection DuplicatedCode #[cfg(not(feature = "http-proxy"))] async fn create_transport( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> io::Result { TransportTcp::new( addr, self.keyspace_holder.clone(), event_handler, error_handler, self.compression, self.frame_encoder_factory .create_encoder(self.version, self.compression), self.frame_encoder_factory .create_decoder(self.version, self.compression), self.buffer_size, self.tcp_nodelay, ) .await } async fn establish_connection( &self, event_handler: Option>, error_handler: Option>, addr: SocketAddr, ) -> Result { let transport = self .create_transport(event_handler, error_handler, addr) .await?; startup( &transport, self.authenticator_provider.deref(), self.keyspace_holder.deref(), self.compression, self.version, ) .await?; Ok(transport) } } ================================================ FILE: cdrs-tokio/src/cluster/token_map.rs ================================================ use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; use std::net::SocketAddr; use std::sync::Arc; use crate::cluster::topology::{Node, NodeMap}; use crate::cluster::ConnectionManager; use crate::cluster::Murmur3Token; use crate::transport::CdrsTransport; /// Map of tokens to nodes. pub struct TokenMap + 'static> { token_ring: BTreeMap>>, } impl> Clone for TokenMap { fn clone(&self) -> Self { TokenMap { token_ring: self.token_ring.clone(), } } } impl> Debug for TokenMap { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("TokenMap") .field("token_ring", &self.token_ring) .finish() } } impl> Default for TokenMap { fn default() -> Self { TokenMap { token_ring: Default::default(), } } } impl> TokenMap { pub fn new(nodes: &NodeMap) -> Self { TokenMap { token_ring: nodes .iter() .flat_map(|(_, node)| { node.tokens() .iter() .map(move |token| (*token, node.clone())) }) .collect(), } } /// Returns up to `replica_count` distinct nodes starting at the given token /// and walking the token ring in the direction of replicas. /// /// With vnodes, a single physical node owns many tokens, so a naive walk /// can return the same node many times in a row. Cassandra's replication /// algorithm picks DISTINCT endpoints, so we dedup on each node's /// broadcast RPC address as we walk and stop once we've collected enough. pub fn nodes_for_token_capped( &self, token: Murmur3Token, replica_count: usize, ) -> impl Iterator>> + '_ { // Iterate ring positions starting from `token` and wrap around. We // can't use `Iterator::take(replica_count)` directly because // `replica_count` counts unique nodes, not ring positions. // // The set tracks broadcast RPC addresses we've already yielded; this // matches the identity used elsewhere in the load balancer // (`unique_by(|node| node.broadcast_rpc_address())`). let mut seen = std::collections::HashSet::with_capacity(replica_count); self.token_ring .range(token..) .chain(self.token_ring.iter()) .filter_map(move |(_, node)| { if seen.len() >= replica_count { // already collected the requested number of replicas return None; } if seen.insert(node.broadcast_rpc_address()) { Some(node.clone()) } else { None } }) } /// Returns all distinct nodes starting at the given token and walking the /// token ring in the direction of replicas. /// /// As with `nodes_for_token_capped`, the dedup is on broadcast RPC /// address so a node owning many vnode tokens still appears only once. pub fn nodes_for_token( &self, token: Murmur3Token, ) -> impl Iterator>> + '_ { let mut seen = std::collections::HashSet::new(); self.token_ring .range(token..) .chain(self.token_ring.iter()) .filter_map(move |(_, node)| { if seen.insert(node.broadcast_rpc_address()) { Some(node.clone()) } else { None } }) } /// Creates a new map with a new node inserted. #[must_use] pub fn clone_with_node(&self, node: Arc>) -> Self { let mut map = self.clone(); for token in node.tokens() { map.token_ring.insert(*token, node.clone()); } map } /// Creates a new map with a node removed. #[must_use] pub fn clone_without_node(&self, broadcast_rpc_address: SocketAddr) -> Self { let token_ring = self .token_ring .iter() .filter_map(|(token, node)| { if node.broadcast_rpc_address() == broadcast_rpc_address { None } else { Some((*token, node.clone())) } }) .collect(); TokenMap { token_ring } } } #[cfg(test)] mod tests { use cassandra_protocol::frame::Version; use itertools::Itertools; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::{Arc, LazyLock}; use tokio::sync::watch; use uuid::Uuid; use crate::cluster::connection_manager::MockConnectionManager; use crate::cluster::connection_pool::ConnectionPoolFactory; use crate::cluster::topology::{Node, NodeMap}; use crate::cluster::Murmur3Token; use crate::cluster::TokenMap; use crate::retry::MockReconnectionPolicy; use crate::transport::MockCdrsTransport; static HOST_ID_1: LazyLock = LazyLock::new(Uuid::new_v4); static HOST_ID_2: LazyLock = LazyLock::new(Uuid::new_v4); static HOST_ID_3: LazyLock = LazyLock::new(Uuid::new_v4); fn prepare_nodes() -> NodeMap> { let (_, keyspace_receiver) = watch::channel(None); let connection_manager = MockConnectionManager::::new(); let reconnection_policy = MockReconnectionPolicy::new(); let connection_pool_factory = Arc::new(ConnectionPoolFactory::new( Default::default(), Version::V4, connection_manager, keyspace_receiver, Arc::new(reconnection_policy), )); // each node gets a distinct broadcast RPC address so that dedup by // endpoint inside the token map can distinguish them let mut nodes = NodeMap::default(); nodes.insert( *HOST_ID_1, Arc::new(Node::new( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, Some(*HOST_ID_1), None, vec![ Murmur3Token::new(-2), Murmur3Token::new(-1), Murmur3Token::new(0), ], "".into(), "".into(), )), ); nodes.insert( *HOST_ID_2, Arc::new(Node::new( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080), None, Some(*HOST_ID_2), None, vec![Murmur3Token::new(20)], "".into(), "".into(), )), ); nodes.insert( *HOST_ID_3, Arc::new(Node::new( connection_pool_factory, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3)), 8080), None, Some(*HOST_ID_3), None, vec![ Murmur3Token::new(2), Murmur3Token::new(1), Murmur3Token::new(10), ], "".into(), "".into(), )), ); nodes } fn verify_tokens(host_ids: &[Uuid], token: Murmur3Token) { let token_map = TokenMap::new(&prepare_nodes()); let nodes = token_map .nodes_for_token_capped(token, host_ids.len()) .collect_vec(); assert_eq!(nodes.len(), host_ids.len()); for (index, node) in nodes.iter().enumerate() { assert_eq!(node.host_id().unwrap(), host_ids[index]); } } #[test] fn should_return_replicas_in_order() { // ring walk from token 0 visits HOST_ID_1's primary token (0), then // HOST_ID_3's tokens (1, 2, 10), then HOST_ID_2 (20). Replicas must be // distinct so we expect each node exactly once. verify_tokens(&[*HOST_ID_1, *HOST_ID_3, *HOST_ID_2], Murmur3Token::new(0)); } #[test] fn should_return_replicas_in_order_for_non_primary_token() { verify_tokens(&[*HOST_ID_3, *HOST_ID_2], Murmur3Token::new(3)); } // Cassandra's replication algorithm picks replicas as DISTINCT endpoints // walking the ring. With vnodes, a single physical node owns many tokens // and the same node can appear multiple consecutive ring positions, so // simply taking the first N tokens would return the same physical replica // multiple times - which is wrong: a "replica" is a copy on a different // node, not on the same one. #[test] fn should_return_distinct_nodes_as_replicas() { let token_map = TokenMap::new(&prepare_nodes()); // there are only 3 distinct nodes in the ring - asking for 5 replicas // must yield those 3 nodes once each, not the same node padded out. let nodes = token_map .nodes_for_token_capped(Murmur3Token::new(0), 5) .collect_vec(); let host_ids: Vec<_> = nodes.iter().map(|n| n.host_id().unwrap()).collect(); assert_eq!(host_ids, vec![*HOST_ID_1, *HOST_ID_3, *HOST_ID_2]); } #[test] fn should_cap_replica_count_when_smaller_than_distinct_nodes() { let token_map = TokenMap::new(&prepare_nodes()); // when fewer replicas than distinct nodes are requested, the iterator // should stop at exactly that count (no need to enumerate the rest). let nodes = token_map .nodes_for_token_capped(Murmur3Token::new(0), 2) .collect_vec(); let host_ids: Vec<_> = nodes.iter().map(|n| n.host_id().unwrap()).collect(); assert_eq!(host_ids, vec![*HOST_ID_1, *HOST_ID_3]); } #[test] fn should_return_replicas_in_a_ring() { // starting at token 20 we hit HOST_ID_2 first, then wrap around to the // smaller tokens which belong to HOST_ID_1, then to HOST_ID_3. verify_tokens(&[*HOST_ID_2, *HOST_ID_1, *HOST_ID_3], Murmur3Token::new(20)); } } ================================================ FILE: cdrs-tokio/src/cluster/topology/cluster_metadata.rs ================================================ use fxhash::FxHashMap; use itertools::Itertools; use std::net::SocketAddr; use std::sync::Arc; use uuid::Uuid; use crate::cluster::topology::keyspace_metadata::KeyspaceMetadata; use crate::cluster::topology::node::Node; use crate::cluster::topology::{DatacenterMetadata, NodeMap}; use crate::cluster::{ConnectionManager, TokenMap}; use crate::transport::CdrsTransport; fn build_datacenter_info>( nodes: &NodeMap, ) -> FxHashMap { let grouped_by_dc = nodes .values() .sorted_unstable_by_key(|node| node.datacenter()) .chunk_by(|node| node.datacenter()); (&grouped_by_dc) .into_iter() .map(|(dc, nodes)| { ( dc.into(), DatacenterMetadata::new(nodes.unique_by(|node| node.rack()).count()), ) }) .collect() } /// Immutable metadata of the Cassandra cluster that this driver instance is connected to. #[derive(Debug, Clone)] pub struct ClusterMetadata + 'static> { nodes: NodeMap, token_map: TokenMap, keyspaces: FxHashMap, datacenters: FxHashMap, } impl> ClusterMetadata { pub fn new(nodes: NodeMap, keyspaces: FxHashMap) -> Self { let token_map = TokenMap::new(&nodes); let datacenters = build_datacenter_info(&nodes); ClusterMetadata { nodes, token_map, keyspaces, datacenters, } } /// Returns current token map. #[inline] pub fn token_map(&self) -> &TokenMap { &self.token_map } /// Creates a new metadata with a keyspace replaced/added. #[must_use] pub fn clone_with_keyspace(&self, keyspace_name: String, keyspace: KeyspaceMetadata) -> Self { let mut keyspaces = self.keyspaces.clone(); keyspaces.insert(keyspace_name, keyspace); ClusterMetadata { nodes: self.nodes.clone(), token_map: self.token_map.clone(), keyspaces, datacenters: self.datacenters.clone(), } } /// Creates a new metadata with a keyspace removed. #[must_use] pub fn clone_without_keyspace(&self, keyspace: &str) -> Self { let mut keyspaces = self.keyspaces.clone(); keyspaces.remove(keyspace); ClusterMetadata { nodes: self.nodes.clone(), token_map: self.token_map.clone(), keyspaces, datacenters: self.datacenters.clone(), } } /// Creates a new metadata with a node replaced/added. The node must have a host id. #[must_use] pub fn clone_with_node(&self, node: Node) -> Self { let node = Arc::new(node); let token_map = self.token_map.clone_with_node(node.clone()); let mut nodes = self.nodes.clone(); nodes.insert( node.host_id().expect("Adding a node without host id!"), node, ); let datacenters = build_datacenter_info(&nodes); ClusterMetadata { nodes, token_map, keyspaces: self.keyspaces.clone(), datacenters, } } /// Creates a new metadata with a node removed. #[must_use] pub fn clone_without_node(&self, broadcast_rpc_address: SocketAddr) -> Self { let nodes = self .nodes .iter() .filter_map(|(host_id, node)| { if node.broadcast_rpc_address() != broadcast_rpc_address { Some((*host_id, node.clone())) } else { None } }) .collect(); Self::new(nodes, self.keyspaces.clone()) } /// Returns all known nodes. #[inline] pub fn nodes(&self) -> &NodeMap { &self.nodes } /// Returns known keyspaces. #[inline] pub fn keyspaces(&self) -> &FxHashMap { &self.keyspaces } /// Returns known keyspace, if present. #[inline] pub fn keyspace(&self, keyspace: &str) -> Option<&KeyspaceMetadata> { self.keyspaces.get(keyspace) } /// Returns known datacenters. #[inline] pub fn datacenters(&self) -> &FxHashMap { &self.datacenters } /// Returns known datacenter, if present. #[inline] pub fn datacenter(&self, name: &str) -> Option<&DatacenterMetadata> { self.datacenters.get(name) } /// Returns node that are not ignored for load balancing. #[inline] pub fn unignored_nodes(&self) -> Vec>> { self.nodes .iter() .filter_map(|(_, node)| { if node.is_ignored() { None } else { Some(node.clone()) } }) .collect() } /// Returns node that are not ignored for load balancing and are local. #[inline] pub fn unignored_local_nodes(&self) -> Vec>> { self.nodes .iter() .filter_map(|(_, node)| { if node.is_ignored() || !node.is_local() { None } else { Some(node.clone()) } }) .collect() } /// Returns node that are not ignored for load balancing and are remote. #[inline] pub fn unignored_remote_nodes_capped(&self, max_count: usize) -> Vec>> { self.nodes .iter() .filter_map(|(_, node)| { if node.is_ignored() || !node.is_remote() { None } else { Some(node.clone()) } }) .take(max_count) .collect() } /// Checks if any nodes are known. #[inline] pub fn has_nodes(&self) -> bool { !self.nodes.is_empty() } /// Checks if a node with a given address is present. #[inline] pub fn has_node_by_rpc_address(&self, broadcast_rpc_address: SocketAddr) -> bool { self.nodes .iter() .any(|(_, node)| node.broadcast_rpc_address() == broadcast_rpc_address) } /// Finds a node by its address. #[inline] pub fn find_node_by_rpc_address( &self, broadcast_rpc_address: SocketAddr, ) -> Option>> { self.nodes .iter() .find(|(_, node)| node.broadcast_rpc_address() == broadcast_rpc_address) .map(|(_, node)| node.clone()) } /// Finds a node by its host id. #[inline] pub fn find_node_by_host_id(&self, host_id: &Uuid) -> Option>> { self.nodes.get(host_id).cloned() } } impl> Default for ClusterMetadata { fn default() -> Self { ClusterMetadata { nodes: Default::default(), token_map: Default::default(), keyspaces: Default::default(), datacenters: Default::default(), } } } //noinspection DuplicatedCode #[cfg(test)] mod tests { use cassandra_protocol::frame::Version; use fxhash::FxHashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; use tokio::sync::watch; use uuid::Uuid; use crate::cluster::connection_manager::MockConnectionManager; use crate::cluster::connection_pool::ConnectionPoolFactory; use crate::cluster::topology::cluster_metadata::build_datacenter_info; use crate::cluster::topology::Node; use crate::retry::MockReconnectionPolicy; use crate::transport::MockCdrsTransport; #[test] fn should_build_datacenter_info() { let (_, keyspace_receiver) = watch::channel(None); let connection_manager = MockConnectionManager::::new(); let reconnection_policy = MockReconnectionPolicy::new(); let connection_pool_factory = Arc::new(ConnectionPoolFactory::new( Default::default(), Version::V4, connection_manager, keyspace_receiver, Arc::new(reconnection_policy), )); let mut nodes = FxHashMap::default(); nodes.insert( Uuid::new_v4(), Arc::new(Node::new( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080), None, None, None, Default::default(), "r1".into(), "dc1".into(), )), ); nodes.insert( Uuid::new_v4(), Arc::new(Node::new( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080), None, None, None, Default::default(), "r1".into(), "dc1".into(), )), ); nodes.insert( Uuid::new_v4(), Arc::new(Node::new( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080), None, None, None, Default::default(), "r2".into(), "dc1".into(), )), ); nodes.insert( Uuid::new_v4(), Arc::new(Node::new( connection_pool_factory, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080), None, None, None, Default::default(), "r1".into(), "dc2".into(), )), ); let dc_info = build_datacenter_info(&nodes); assert_eq!(dc_info.get("dc1").unwrap().rack_count, 2); assert_eq!(dc_info.get("dc2").unwrap().rack_count, 1); } } ================================================ FILE: cdrs-tokio/src/cluster/topology/datacenter_metadata.rs ================================================ use derive_more::Constructor; /// Information about a datacenter. #[derive(Clone, Debug, Constructor)] pub struct DatacenterMetadata { pub rack_count: usize, } ================================================ FILE: cdrs-tokio/src/cluster/topology/keyspace_metadata.rs ================================================ use derive_more::Constructor; use crate::cluster::topology::ReplicationStrategy; /// Keyspace metadata. #[derive(Clone, Debug, Constructor)] pub struct KeyspaceMetadata { pub replication_strategy: ReplicationStrategy, } ================================================ FILE: cdrs-tokio/src/cluster/topology/node.rs ================================================ use atomic::Atomic; use cassandra_protocol::error::{Error, Result}; use cassandra_protocol::frame::Envelope; use std::fmt::{Debug, Formatter}; use std::net::SocketAddr; use std::sync::atomic::Ordering; use std::sync::Arc; use tokio::sync::mpsc::Sender; use tokio::sync::OnceCell; use tracing::*; use uuid::Uuid; use crate::cluster::connection_pool::{ConnectionPool, ConnectionPoolFactory}; use crate::cluster::topology::{NodeDistance, NodeState}; use crate::cluster::Murmur3Token; use crate::cluster::{ConnectionManager, NodeInfo}; use crate::transport::CdrsTransport; /// Metadata about a Cassandra node in the cluster, along with a connection. pub struct Node + 'static> { connection_pool_factory: Arc>, connection_pool: OnceCell>>, broadcast_rpc_address: SocketAddr, broadcast_address: Option, distance: Option, state: Atomic, host_id: Option, tokens: Vec, rack: String, datacenter: String, } impl> Debug for Node { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("Node") .field("broadcast_rpc_address", &self.broadcast_rpc_address) .field("broadcast_address", &self.broadcast_address) .field("distance", &self.distance) .field("state", &self.state) .field("host_id", &self.host_id) .field("tokens", &self.tokens) .field("rack", &self.rack) .field("datacenter", &self.datacenter) .finish() } } impl> Node { #[allow(clippy::too_many_arguments)] pub(crate) fn new( connection_pool_factory: Arc>, broadcast_rpc_address: SocketAddr, broadcast_address: Option, host_id: Option, distance: Option, tokens: Vec, rack: String, datacenter: String, ) -> Self { Self { connection_pool_factory, connection_pool: Default::default(), broadcast_rpc_address, broadcast_address, distance, state: Atomic::new(NodeState::Unknown), host_id, tokens, rack, datacenter, } } #[allow(clippy::too_many_arguments)] pub(crate) fn new_with_state( connection_pool_factory: Arc>, broadcast_rpc_address: SocketAddr, broadcast_address: Option, host_id: Option, distance: Option, state: NodeState, tokens: Vec, rack: String, datacenter: String, ) -> Self { Self { connection_pool_factory, connection_pool: Default::default(), broadcast_rpc_address, broadcast_address, distance, state: Atomic::new(state), host_id, tokens, rack, datacenter, } } #[allow(clippy::too_many_arguments)] pub(crate) fn with_state( connection_pool_factory: Arc>, broadcast_rpc_address: SocketAddr, broadcast_address: Option, host_id: Option, state: NodeState, tokens: Vec, rack: String, datacenter: String, ) -> Self { Self { connection_pool_factory, connection_pool: Default::default(), broadcast_rpc_address, broadcast_address, distance: None, state: Atomic::new(state), host_id, tokens, rack, datacenter, } } #[cfg(test)] pub(crate) fn with_distance( connection_pool_factory: Arc>, broadcast_rpc_address: SocketAddr, broadcast_address: Option, host_id: Option, distance: NodeDistance, ) -> Self { Self { connection_pool_factory, connection_pool: Default::default(), broadcast_rpc_address, broadcast_address, distance: Some(distance), state: Atomic::new(NodeState::Unknown), host_id, tokens: Default::default(), rack: Default::default(), datacenter: Default::default(), } } #[inline] pub fn state(&self) -> NodeState { self.state.load(Ordering::Relaxed) } /// The host ID that is assigned to this node by Cassandra. This value can be used to uniquely /// identify a node even when the underling IP address changes. #[inline] pub fn host_id(&self) -> Option { self.host_id } /// The node's broadcast RPC address. That is, the address that the node expects clients to /// connect to. #[inline] pub fn broadcast_rpc_address(&self) -> SocketAddr { self.broadcast_rpc_address } /// The node's broadcast address. That is, the address that other nodes use to communicate with /// that node. #[inline] pub fn broadcast_address(&self) -> Option { self.broadcast_address } /// Returns tokens associated with the node. #[inline] pub fn tokens(&self) -> &[Murmur3Token] { &self.tokens } /// Returns the dc the node is in. #[inline] pub fn datacenter(&self) -> &str { &self.datacenter } /// Returns the rack the node is in. #[inline] pub fn rack(&self) -> &str { &self.rack } /// Returns a connection to given node. #[inline] pub async fn persistent_connection(self: &Arc) -> Result> { let pool = self .connection_pool .get_or_try_init(|| { debug!(?self.host_id, "Creating connection pool"); self.connection_pool_factory.create( self.distance.unwrap_or(NodeDistance::Remote), self.broadcast_rpc_address, Arc::downgrade(self), ) }) .await; let pool = match pool { Ok(pool) => pool, Err(Error::InvalidProtocol(addr)) => { // we can't connect to this node even if it's up self.force_down(); return Err(Error::InvalidProtocol(addr)); } Err(error) => return Err(error), }; pool.connection().await } /// Checks if any connection is still available. pub async fn is_any_connection_up(&self) -> bool { if let Some(pool) = self.connection_pool.get() { pool.is_any_connection_up().await } else { false } } /// Creates a new connection to the node with optional event and error handlers. pub async fn new_connection( &self, event_handler: Option>, error_handler: Option>, ) -> Result { debug!("Establishing new connection to node..."); self.connection_pool_factory .connection_manager() .connection(event_handler, error_handler, self.broadcast_rpc_address) .await } /// Returns node distance in relation to the driver, if available. #[inline] pub fn distance(&self) -> Option { self.distance } /// Checks if the node is local in relation to the driver. #[inline] pub fn is_local(&self) -> bool { self.distance == Some(NodeDistance::Local) } /// Checks if the node is remote in relation to the driver. #[inline] pub fn is_remote(&self) -> bool { self.distance == Some(NodeDistance::Remote) } /// Should this node be ignored from establishing connections. #[inline] pub fn is_ignored(&self) -> bool { self.distance.is_none() || self.state.load(Ordering::Relaxed) != NodeState::Up } pub(crate) fn force_down(&self) { self.state.store(NodeState::ForcedDown, Ordering::Relaxed); } pub(crate) fn mark_down(&self) { self.state.store(NodeState::Down, Ordering::Relaxed); } pub(crate) fn mark_up(&self) { self.state.store(NodeState::Up, Ordering::Relaxed); } #[inline] pub(crate) fn clone_with_node_info(&self, node_info: NodeInfo) -> Self { let address_changed = self .broadcast_address .map(|address| address != node_info.broadcast_rpc_address) // if we don't know the previous address, we'll trust whoever inserted the node to // know its state .unwrap_or(false); let mut new_node_state = self.state.load(Ordering::Relaxed); if address_changed { new_node_state = NodeState::Unknown; } // If we recreate the node with the status Down, it will be removed from the load-balancing strategy. // The only way to promote the node back to the Up state is to receive an error in // connection_pool.rs::monitor_connections and schedule the reconnection. This method // is triggered only when we call persistent_connection on the node, which means that the node // must be a part of the load-balancing strategy at least once. // We do not care about it for Unknown and ForcedDown, because these will be taken care of // on topology events in control_connection.rs::process_events or in load balancers if new_node_state == NodeState::Down { debug!( ?node_info.broadcast_rpc_address, "Cloned the node with Down state", ); new_node_state = NodeState::Up; } Self { connection_pool_factory: self.connection_pool_factory.clone(), connection_pool: Default::default(), broadcast_rpc_address: node_info.broadcast_rpc_address, broadcast_address: node_info.broadcast_address, // since address could change, we can't be sure of distance or state distance: if address_changed { None } else { self.distance }, state: Atomic::new(new_node_state), host_id: Some(node_info.host_id), tokens: node_info.tokens, rack: node_info.rack, datacenter: node_info.datacenter, } } #[inline] pub(crate) fn clone_as_contact_point(&self, node_info: NodeInfo) -> Self { // control points might have valid state already, so no need to reset Self { connection_pool_factory: self.connection_pool_factory.clone(), connection_pool: self.connection_pool.clone(), broadcast_rpc_address: self.broadcast_rpc_address, broadcast_address: node_info.broadcast_address, distance: self.distance, state: Atomic::new(self.state.load(Ordering::Relaxed)), host_id: Some(node_info.host_id), tokens: node_info.tokens, rack: node_info.rack, datacenter: node_info.datacenter, } } #[inline] pub(crate) fn clone_with_node_info_and_state( &self, node_info: NodeInfo, state: NodeState, ) -> Self { Self { connection_pool_factory: self.connection_pool_factory.clone(), connection_pool: Default::default(), broadcast_rpc_address: node_info.broadcast_rpc_address, broadcast_address: node_info.broadcast_address, // since address could change, we can't be sure of distance distance: None, state: Atomic::new(state), host_id: Some(node_info.host_id), tokens: node_info.tokens, rack: node_info.rack, datacenter: node_info.datacenter, } } #[inline] pub(crate) fn clone_with_node_state(&self, state: NodeState) -> Self { Self { connection_pool_factory: self.connection_pool_factory.clone(), connection_pool: Default::default(), broadcast_rpc_address: self.broadcast_rpc_address, broadcast_address: self.broadcast_address, distance: self.distance, state: Atomic::new(state), host_id: self.host_id, tokens: self.tokens.clone(), rack: self.rack.clone(), datacenter: self.datacenter.clone(), } } } ================================================ FILE: cdrs-tokio/src/cluster/topology/node_distance.rs ================================================ use derive_more::Display; /// Determines how the driver will manage connections to a Cassandra node. #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Display)] pub enum NodeDistance { /// An "active" distance that, indicates that the driver should maintain connections to the /// node; it also marks it as "preferred", meaning that the node may have priority for /// some tasks (for example, being chosen as the control connection host). Local, /// An "active" distance that, indicates that the driver should maintain connections to the /// node; it also marks it as "less preferred", meaning that other nodes may have a higher /// priority for some tasks (for example, being chosen as the control connection host). Remote, } ================================================ FILE: cdrs-tokio/src/cluster/topology/node_state.rs ================================================ use bytemuck::NoUninit; use derive_more::Display; /// The state of a node, as viewed from the driver. #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Display, NoUninit)] #[repr(u8)] pub enum NodeState { /// The driver has never tried to connect to the node, nor received any topology events about it. /// /// This happens when nodes are first added to the cluster, and will persist if your /// [`LoadBalancingStrategy`](crate::load_balancing::LoadBalancingStrategy) decides to ignore /// them. Since the driver does not connect to them, the only way it can assess their states is /// from topology events. Unknown, /// A node is considered up in either of the following situations: 1) the driver has at least /// one active connection to the node, or 2) the driver is not actively trying to connect to the /// node (because it's ignored by the /// [`LoadBalancingStrategy`](crate::load_balancing::LoadBalancingStrategy)), but it has /// received a topology event indicating that the node is up. Up, /// A node is considered down in either of the following situations: 1) the driver has lost all /// connections to the node (and is currently trying to reconnect), or 2) the driver is not /// actively trying to connect to the node (because it's ignored by the /// [`LoadBalancingStrategy`](crate::load_balancing::LoadBalancingStrategy), but it has received /// a topology event indicating that the node is down. Down, /// The node was forced down externally, the driver will never try to reconnect to it. It can /// happen when an unrecoverable error happened when connecting to the node (e.g. invalid /// protocol version) or when something decides the node should not be contacted (e.g. /// [`LoadBalancingStrategy`](crate::load_balancing::LoadBalancingStrategy)). ForcedDown, } ================================================ FILE: cdrs-tokio/src/cluster/topology/replication_strategy.rs ================================================ use fxhash::FxHashMap; /// A replication strategy determines the nodes where replicas are placed. #[derive(Debug, Clone)] pub enum ReplicationStrategy { SimpleStrategy { replication_factor: usize, }, NetworkTopologyStrategy { datacenter_replication_factor: FxHashMap, }, Other, } ================================================ FILE: cdrs-tokio/src/cluster/topology.rs ================================================ use std::sync::Arc; use fxhash::FxHashMap; use uuid::Uuid; pub mod cluster_metadata; mod datacenter_metadata; mod keyspace_metadata; mod node; mod node_distance; mod node_state; mod replication_strategy; pub use self::datacenter_metadata::DatacenterMetadata; pub use self::keyspace_metadata::KeyspaceMetadata; pub use self::node::Node; pub use self::node_distance::NodeDistance; pub use self::node_state::NodeState; pub use self::replication_strategy::ReplicationStrategy; /// Map from host id to a node. pub type NodeMap = FxHashMap>>; ================================================ FILE: cdrs-tokio/src/cluster.rs ================================================ pub(crate) use self::cluster_metadata_manager::ClusterMetadataManager; #[cfg(feature = "http-proxy")] pub use self::config_proxy::{HttpProxyConfig, HttpProxyConfigBuilder}; #[cfg(feature = "rust-tls")] pub use self::config_rustls::{NodeRustlsConfig, NodeRustlsConfigBuilder}; pub use self::config_tcp::{NodeTcpConfig, NodeTcpConfigBuilder}; pub use self::connection_manager::{startup, ConnectionManager}; pub use self::keyspace_holder::KeyspaceHolder; pub use self::node_address::NodeAddress; pub use self::node_info::NodeInfo; pub use self::pager::{ExecPager, PagerState, QueryPager, SessionPager}; #[cfg(feature = "rust-tls")] pub use self::rustls_connection_manager::RustlsConnectionManager; pub use self::session::connect_generic; pub(crate) use self::session_context::SessionContext; pub use self::tcp_connection_manager::TcpConnectionManager; pub use self::token_map::TokenMap; pub use self::topology::cluster_metadata::ClusterMetadata; use crate::cluster::connection_pool::ConnectionPoolConfig; use crate::future::BoxFuture; use crate::transport::CdrsTransport; use cassandra_protocol::error; use cassandra_protocol::frame::Version; pub use cassandra_protocol::token::Murmur3Token; use std::sync::Arc; mod cluster_metadata_manager; #[cfg(feature = "http-proxy")] mod config_proxy; #[cfg(feature = "rust-tls")] mod config_rustls; mod config_tcp; #[cfg(not(test))] mod connection_manager; #[cfg(test)] pub mod connection_manager; pub mod connection_pool; mod control_connection; mod keyspace_holder; mod metadata_builder; mod node_address; mod node_info; mod pager; #[cfg(feature = "rust-tls")] mod rustls_connection_manager; pub mod send_envelope; pub mod session; mod session_context; mod tcp_connection_manager; mod token_map; pub mod topology; /// Generic connection configuration trait that can be used to create user-supplied /// connection objects that can be used with the `session::connect()` function. pub trait GenericClusterConfig>: Send + Sync { fn create_manager( &self, keyspace_holder: Arc, ) -> BoxFuture<'_, error::Result>; /// Returns desired event channel capacity. Take a look at /// [`Session`](session::Session) builders for more info. fn event_channel_capacity(&self) -> usize; /// Cassandra protocol version to use. fn version(&self) -> Version; /// Connection pool configuration. fn connection_pool_config(&self) -> ConnectionPoolConfig; /// Enable beta protocol support. fn beta_protocol(&self) -> bool { false } } ================================================ FILE: cdrs-tokio/src/envelope_parser.rs ================================================ use std::convert::TryFrom; use std::io; use std::io::Cursor; use std::net::SocketAddr; use tokio::io::AsyncReadExt; use cassandra_protocol::compression::Compression; use cassandra_protocol::error; use cassandra_protocol::frame::message_response::ResponseBody; use cassandra_protocol::frame::{ Direction, Envelope, Flags, Opcode, Version, LENGTH_LEN, STREAM_LEN, }; use cassandra_protocol::types::data_serialization_types::decode_timeuuid; use cassandra_protocol::types::{ from_cursor_string_list, try_i16_from_bytes, try_i32_from_bytes, UUID_LEN, }; // Cassandra's documented hard cap for an envelope body. Pre-V5 (unframed) // connections can in principle send anything up to 256 MiB. Using this as the // upper bound prevents a hostile or malfunctioning server from making us // allocate gigabytes from a single 4-byte length field. const MAX_ENVELOPE_BODY_SIZE: usize = 256 * 1024 * 1024; async fn parse_raw_envelope( cursor: &mut T, compressor: Compression, ) -> error::Result { let mut version_bytes = [0; Version::BYTE_LENGTH]; let mut flag_bytes = [0; Flags::BYTE_LENGTH]; let mut opcode_bytes = [0; Opcode::BYTE_LENGTH]; let mut stream_bytes = [0; STREAM_LEN]; let mut length_bytes = [0; LENGTH_LEN]; // NOTE: order of reads matters cursor.read_exact(&mut version_bytes).await?; cursor.read_exact(&mut flag_bytes).await?; cursor.read_exact(&mut stream_bytes).await?; cursor.read_exact(&mut opcode_bytes).await?; cursor.read_exact(&mut length_bytes).await?; let version = Version::try_from(version_bytes[0])?; let direction = Direction::from(version_bytes[0]); let flags = Flags::from_bits_truncate(flag_bytes[0]); let stream_id = try_i16_from_bytes(&stream_bytes)?; let opcode = Opcode::try_from(opcode_bytes[0])?; // The wire format encodes the body length as a signed 32-bit int. Without // validation a negative value would wrap around to a multi-gigabyte usize, // and even a legitimate large positive value would happily allocate up to // 2 GiB before reading any body bytes. Reject both before allocating. let length_signed = try_i32_from_bytes(&length_bytes)?; if length_signed < 0 { return Err(error::Error::Io(io::Error::new( io::ErrorKind::InvalidData, format!("negative envelope body length {length_signed}"), ))); } let length = length_signed as usize; if length > MAX_ENVELOPE_BODY_SIZE { return Err(error::Error::Io(io::Error::new( io::ErrorKind::InvalidData, format!("envelope body length {length} exceeds maximum {MAX_ENVELOPE_BODY_SIZE}"), ))); } let mut body_bytes = vec![0; length]; cursor.read_exact(&mut body_bytes).await?; let full_body = if flags.contains(Flags::COMPRESSION) { compressor.decode(body_bytes)? } else { Compression::None.decode(body_bytes)? }; let body_len = full_body.len(); // Use cursor to get tracing id, warnings and actual body let mut body_cursor = Cursor::new(full_body.as_slice()); // The TRACING flag has different semantics in each direction: in a request // it is the client asking the server to enable tracing; in a response it // signals that the body starts with a 16-byte tracing UUID. Reading that // UUID on the wrong direction (which the previous code did) would silently // consume the first 16 bytes of a request body. Match the canonical // Envelope::from_buffer parser by gating on direction too. let tracing_id = if flags.contains(Flags::TRACING) && direction == Direction::Response { let mut tracing_bytes = [0; UUID_LEN]; std::io::Read::read_exact(&mut body_cursor, &mut tracing_bytes)?; decode_timeuuid(&tracing_bytes).ok() } else { None }; let warnings = if flags.contains(Flags::WARNING) { from_cursor_string_list(&mut body_cursor)? } else { vec![] }; let mut body = Vec::with_capacity(body_len - body_cursor.position() as usize); std::io::Read::read_to_end(&mut body_cursor, &mut body)?; let envelope = Envelope { version, direction, flags, opcode, stream_id, body, tracing_id, warnings, }; Ok(envelope) } pub async fn parse_envelope( cursor: &mut T, compressor: Compression, addr: SocketAddr, ) -> error::Result { let envelope = parse_raw_envelope(cursor, compressor).await?; convert_envelope_into_result(envelope, addr) } pub(crate) fn convert_envelope_into_result( envelope: Envelope, addr: SocketAddr, ) -> error::Result { match envelope.opcode { Opcode::Error => envelope.response_body().and_then(|err| match err { ResponseBody::Error(err) => Err(error::Error::Server { body: err, addr }), // ResponseBody::try_from is expected to always return Error for // Opcode::Error, but if a future protocol change ever broke that // invariant we don't want to crash the reader task with a panic // - surface it as a normal error instead. other => Err(error::Error::General(format!( "Expected ResponseBody::Error for Opcode::Error envelope, got {other:?}" ))), }), _ => Ok(envelope), } } #[cfg(test)] mod tests { use super::*; use cassandra_protocol::frame::Version; // Build a minimal valid envelope header byte sequence with the supplied // 4-byte length field. The header layout is: // 1 byte version | 1 byte flags | 2 bytes stream | 1 byte opcode | 4 bytes length fn header_with_length(length_bytes: [u8; 4]) -> Vec { let mut buf = vec![ u8::from(Version::V4), // version 0, // flags 0, 0, // stream id u8::from(Opcode::Ready), // opcode (any valid one will do) ]; buf.extend_from_slice(&length_bytes); buf } #[tokio::test] async fn parse_envelope_rejects_negative_body_length() { // a length field of 0xFFFFFFFF reads as -1 in i32. Without validation // this would be cast to usize::MAX-ish and immediately OOM the process. let mut payload = header_with_length([0xff, 0xff, 0xff, 0xff]); let mut cursor = std::io::Cursor::new(&mut payload); let mut bytes = vec![]; tokio::io::AsyncReadExt::read_to_end(&mut cursor, &mut bytes) .await .unwrap(); let mut reader = bytes.as_slice(); assert!(parse_raw_envelope(&mut reader, Compression::None) .await .is_err()); } #[tokio::test] async fn parse_envelope_rejects_oversized_body_length() { // i32::MAX (~2 GiB) is well past any plausible Cassandra envelope. // We expect a clean error rather than allocation of a giant buffer // up-front before even reading the body. let payload = header_with_length(i32::MAX.to_be_bytes()); let mut reader = payload.as_slice(); assert!(parse_raw_envelope(&mut reader, Compression::None) .await .is_err()); } // Per the Cassandra protocol spec, the TRACING flag in a REQUEST asks the // server to enable tracing and carries no payload metadata. The // tracing_id UUID only appears in RESPONSES. parse_raw_envelope used to // attempt to read 16 bytes of tracing UUID whenever the TRACING flag was // set regardless of direction, which is inconsistent with the canonical // Envelope::from_buffer parser and would corrupt the body on a request. #[tokio::test] async fn parse_envelope_does_not_read_tracing_id_for_request_direction() { // Build a request envelope (direction bit clear in byte 0) with // TRACING set and exactly 16 bytes of body. After parsing, the body // must come back intact - no bytes consumed as a tracing UUID. let body: Vec = (0..16u8).collect(); let mut wire = vec![ // version 4, direction = Request (0x80 bit clear) u8::from(Version::V4), // TRACING flag (0x02) Flags::TRACING.bits(), // stream id 0, 0, // opcode - any valid request opcode (Query) u8::from(Opcode::Query), ]; wire.extend_from_slice(&(body.len() as i32).to_be_bytes()); wire.extend_from_slice(&body); let mut reader = wire.as_slice(); let envelope = parse_raw_envelope(&mut reader, Compression::None) .await .expect("a request envelope with TRACING flag must still parse"); assert_eq!(envelope.direction, Direction::Request); assert!( envelope.tracing_id.is_none(), "request envelopes must not carry a tracing UUID" ); assert_eq!( envelope.body, body, "request body should be preserved verbatim, got {:?}", envelope.body ); } } ================================================ FILE: cdrs-tokio/src/frame_encoding.rs ================================================ use cassandra_protocol::compression::Compression; use cassandra_protocol::frame::frame_decoder::{ FrameDecoder, LegacyFrameDecoder, Lz4FrameDecoder, UncompressedFrameDecoder, }; use cassandra_protocol::frame::frame_encoder::{ FrameEncoder, LegacyFrameEncoder, Lz4FrameEncoder, UncompressedFrameEncoder, }; use cassandra_protocol::frame::Version; /// A factory for frame encoder/decoder. pub trait FrameEncodingFactory { /// Creates a new frame encoder based on given protocol settings. fn create_encoder( &self, version: Version, compression: Compression, ) -> Box; /// Creates a new frame decoder based on given protocol settings. fn create_decoder( &self, version: Version, compression: Compression, ) -> Box; } /// Frame encoding factor based on protocol settings. #[derive(Copy, Clone, Debug, Default, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct ProtocolFrameEncodingFactory; impl FrameEncodingFactory for ProtocolFrameEncodingFactory { fn create_encoder( &self, version: Version, compression: Compression, ) -> Box { if version >= Version::V5 { match compression { Compression::Lz4 => Box::::default(), // >= v5 supports only lz4 => fall back to uncompressed _ => Box::::default(), } } else { Box::::default() } } fn create_decoder( &self, version: Version, compression: Compression, ) -> Box { if version >= Version::V5 { match compression { Compression::Lz4 => Box::::default(), // >= v5 supports only lz4 => fall back to uncompressed _ => Box::::default(), } } else { Box::::default() } } } ================================================ FILE: cdrs-tokio/src/future.rs ================================================ /// An owned dynamically typed `Future` for use in cases where you can't /// statically type your result or need to add some indirection. pub type BoxFuture<'a, T> = futures::future::BoxFuture<'a, T>; ================================================ FILE: cdrs-tokio/src/lib.rs ================================================ //! **cdrs** is a native Cassandra DB client written in Rust. //! //! ## Getting started //! //! This example configures a cluster consisting of a single node, and uses round-robin load balancing. //! //! ```no_run //! use cdrs_tokio::cluster::session::{TcpSessionBuilder, SessionBuilder}; //! use cdrs_tokio::cluster::NodeTcpConfigBuilder; //! use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; //! use std::sync::Arc; //! //! #[tokio::main] //! async fn main() { //! let cluster_config = NodeTcpConfigBuilder::new() //! .with_contact_point("127.0.0.1:9042".into()) //! .build() //! .await //! .unwrap(); //! let session = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), cluster_config) //! .build() //! .await //! .unwrap(); //! //! let create_ks = "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ //! 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"; //! session //! .query(create_ks) //! .await //! .expect("Keyspace create error"); //! } //! ``` //! //! ## Nodes and load balancing //! //! In order to maximize efficiency, the driver needs to be appropriately configured for given use //! case. Please look at available [load balancers](crate::load_balancing) and //! [node distance evaluators](crate::load_balancing::node_distance_evaluator) to pick the optimal //! solution when building the [`Session`](crate::cluster::session::Session). Topology-aware load //! balancing is preferred when dealing with multi-node clusters, otherwise simpler strategies might //! prove more efficient. #[macro_use] mod macros; pub mod cluster; pub mod envelope_parser; pub mod load_balancing; pub mod frame_encoding; pub mod future; pub mod retry; pub mod speculative_execution; pub mod statement; pub mod transport; pub use cassandra_protocol::authenticators; pub use cassandra_protocol::compression; pub use cassandra_protocol::consistency; pub use cassandra_protocol::error; pub use cassandra_protocol::frame; pub use cassandra_protocol::query; pub use cassandra_protocol::types; pub type Error = error::Error; pub type Result = error::Result; #[cfg(feature = "derive")] pub use cdrs_tokio_helpers_derive::{DbMirror, IntoCdrsValue, TryFromRow, TryFromUdt}; ================================================ FILE: cdrs-tokio/src/load_balancing/initializing_wrapper.rs ================================================ use std::sync::Arc; use crate::cluster::topology::Node; use crate::cluster::{ClusterMetadata, ConnectionManager}; use crate::load_balancing::{LoadBalancingStrategy, QueryPlan, Request}; use crate::transport::CdrsTransport; // Wrapper strategy which returns contact points until cluster metadata gets populated. pub struct InitializingWrapperLoadBalancingStrategy< T: CdrsTransport + 'static, CM: ConnectionManager + 'static, LB: LoadBalancingStrategy, > { inner: LB, contact_points_query_plan: QueryPlan, } impl, LB: LoadBalancingStrategy> LoadBalancingStrategy for InitializingWrapperLoadBalancingStrategy { fn query_plan( &self, request: Option, cluster: &ClusterMetadata, ) -> QueryPlan { if cluster.has_nodes() { self.inner.query_plan(request, cluster) } else { self.contact_points_query_plan.clone() } } } impl, LB: LoadBalancingStrategy> InitializingWrapperLoadBalancingStrategy { pub fn new(inner: LB, contact_points: Vec>>) -> Self { InitializingWrapperLoadBalancingStrategy { inner, contact_points_query_plan: QueryPlan::new(contact_points), } } } ================================================ FILE: cdrs-tokio/src/load_balancing/node_distance_evaluator.rs ================================================ #[cfg(test)] use mockall::*; use crate::cluster::topology::NodeDistance; use crate::cluster::NodeInfo; /// A node distance evaluator evaluates given node distance in relation to the driver. #[cfg_attr(test, automock)] pub trait NodeDistanceEvaluator { /// Tries to compute a distance to a given node. Can return `None` if the distance cannot be /// determined. In such case, the nodes without a distance are expected to be ignored by load /// balancers. fn compute_distance(&self, node: &NodeInfo) -> Option; } /// A simple evaluator which treats all nodes as local. #[derive(Default, Debug)] pub struct AllLocalNodeDistanceEvaluator; impl NodeDistanceEvaluator for AllLocalNodeDistanceEvaluator { fn compute_distance(&self, _node: &NodeInfo) -> Option { Some(NodeDistance::Local) } } /// An evaluator which is aware of node location in relation to local DC. Built-in /// [`TopologyAwareLoadBalancingStrategy`](crate::load_balancing::TopologyAwareLoadBalancingStrategy) /// can use this information to properly identify which nodes to use in query plans. #[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] pub struct TopologyAwareNodeDistanceEvaluator { local_dc: String, } impl NodeDistanceEvaluator for TopologyAwareNodeDistanceEvaluator { fn compute_distance(&self, node: &NodeInfo) -> Option { Some(if node.datacenter == self.local_dc { NodeDistance::Local } else { NodeDistance::Remote }) } } impl TopologyAwareNodeDistanceEvaluator { /// Local DC name represents the datacenter local to where the driver is running. pub fn new(local_dc: String) -> Self { TopologyAwareNodeDistanceEvaluator { local_dc } } } //noinspection DuplicatedCode #[cfg(test)] mod tests { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use uuid::Uuid; use crate::cluster::topology::NodeDistance; use crate::cluster::NodeInfo; use crate::load_balancing::node_distance_evaluator::NodeDistanceEvaluator; use crate::load_balancing::node_distance_evaluator::TopologyAwareNodeDistanceEvaluator; #[test] fn should_return_topology_aware_distance() { let local_dc = "test"; let evaluator = TopologyAwareNodeDistanceEvaluator::new(local_dc.into()); assert_eq!( evaluator .compute_distance(&NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, "".into(), Default::default(), "".into(), )) .unwrap(), NodeDistance::Remote ); assert_eq!( evaluator .compute_distance(&NodeInfo::new( Uuid::new_v4(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), None, local_dc.into(), Default::default(), "".into(), )) .unwrap(), NodeDistance::Local ); } } ================================================ FILE: cdrs-tokio/src/load_balancing/random.rs ================================================ use derivative::Derivative; use std::marker::PhantomData; use rand::prelude::*; use rand::rng; use crate::cluster::{ClusterMetadata, ConnectionManager}; use crate::load_balancing::{LoadBalancingStrategy, QueryPlan, Request}; use crate::transport::CdrsTransport; /// Pure random load balancing. #[derive(Default, Derivative)] #[derivative(Debug)] pub struct RandomLoadBalancingStrategy> { #[derivative(Debug = "ignore")] _transport: PhantomData, #[derivative(Debug = "ignore")] _connection_manager: PhantomData, } impl> RandomLoadBalancingStrategy { pub fn new() -> Self { RandomLoadBalancingStrategy { _transport: Default::default(), _connection_manager: Default::default(), } } } impl> LoadBalancingStrategy for RandomLoadBalancingStrategy { fn query_plan( &self, _request: Option, cluster: &ClusterMetadata, ) -> QueryPlan { let mut result = cluster.unignored_nodes(); result.shuffle(&mut rng()); QueryPlan::new(result) } } ================================================ FILE: cdrs-tokio/src/load_balancing/request.rs ================================================ use cassandra_protocol::consistency::Consistency; use derive_more::Constructor; use crate::cluster::Murmur3Token; /// A request executed by a `Session`. #[derive(Constructor, Clone, Debug)] pub struct Request<'a> { pub keyspace: Option<&'a str>, pub token: Option, pub routing_key: Option<&'a [u8]>, pub consistency: Option, } ================================================ FILE: cdrs-tokio/src/load_balancing/round_robin.rs ================================================ use derivative::Derivative; use std::marker::PhantomData; use std::sync::atomic::{AtomicUsize, Ordering}; use crate::cluster::{ClusterMetadata, ConnectionManager}; use crate::load_balancing::{LoadBalancingStrategy, QueryPlan, Request}; use crate::transport::CdrsTransport; /// Round-robin load balancing. #[derive(Derivative, Default)] #[derivative(Debug)] pub struct RoundRobinLoadBalancingStrategy> { prev_idx: AtomicUsize, #[derivative(Debug = "ignore")] _transport: PhantomData, #[derivative(Debug = "ignore")] _connection_manager: PhantomData, } impl> RoundRobinLoadBalancingStrategy { pub fn new() -> Self { RoundRobinLoadBalancingStrategy { prev_idx: AtomicUsize::new(0), _transport: Default::default(), _connection_manager: Default::default(), } } } impl> LoadBalancingStrategy for RoundRobinLoadBalancingStrategy { fn query_plan( &self, _request: Option, cluster: &ClusterMetadata, ) -> QueryPlan { let mut nodes = cluster.unignored_nodes(); if nodes.is_empty() { return QueryPlan::new(nodes); } let cur_idx = self.prev_idx.fetch_add(1, Ordering::SeqCst) % nodes.len(); nodes.rotate_left(cur_idx); QueryPlan::new(nodes) } } ================================================ FILE: cdrs-tokio/src/load_balancing/topology_aware.rs ================================================ use crate::cluster::topology::{KeyspaceMetadata, Node, NodeDistance, ReplicationStrategy}; use crate::cluster::Murmur3Token; use crate::cluster::{ClusterMetadata, ConnectionManager}; use crate::load_balancing::{LoadBalancingStrategy, QueryPlan, Request}; use crate::transport::CdrsTransport; use cassandra_protocol::consistency::Consistency; use derivative::Derivative; use fxhash::{FxHashMap, FxHashSet}; use itertools::Itertools; use rand::prelude::*; use rand::rng; use std::cmp::Ordering as CmpOrdering; use std::marker::PhantomData; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; /// Topology-aware load balancing strategy. Depends on up-to-date topology information, which is /// constantly monitored in the background by a control connection. For best results, a /// topology-aware [`NodeDistanceEvaluator`](crate::load_balancing::node_distance_evaluator::NodeDistanceEvaluator) (e.g. /// [`TopologyAwareNodeDistanceEvaluator`](crate::load_balancing::node_distance_evaluator::TopologyAwareNodeDistanceEvaluator)) /// should also be used. /// /// This implementation prioritizes replica nodes over non-replica ones; if more than one replica /// is available, the replicas will be shuffled. Non-replica nodes will be included in a round-robin /// fashion. If local nodes are present in the cluster, only those will be used, unless a remote /// failover dc is allowed. /// /// Note: if a referenced keyspace doesn't use `NetworkTopologyStrategy`, replica nodes will be /// chosen ignoring distance information. #[derive(Derivative)] #[derivative(Debug)] pub struct TopologyAwareLoadBalancingStrategy> { max_nodes_per_remote_dc: Option, allow_dc_failover_for_local_cl: bool, prev_idx: AtomicUsize, #[derivative(Debug = "ignore")] _transport: PhantomData, #[derivative(Debug = "ignore")] _connection_manager: PhantomData, } impl> LoadBalancingStrategy for TopologyAwareLoadBalancingStrategy { fn query_plan( &self, request: Option, cluster: &ClusterMetadata, ) -> QueryPlan { if let Some(request) = request { self.replicas_for_request(request, cluster) } else { QueryPlan::new(self.round_robin_unignored_local_nodes(cluster)) } } } impl> TopologyAwareLoadBalancingStrategy { /// Creates new strategy. The parameters determine if remote non-replica nodes should be /// considered in addition to replica ones, if `NetworkTopologyStrategy` is used for queried /// keyspace. `allow_dc_failover_for_local_cl` determines if remote non-replicas should be /// used to local-only consistency. pub fn new( max_nodes_per_remote_dc: Option, allow_dc_failover_for_local_cl: bool, ) -> Self { TopologyAwareLoadBalancingStrategy { max_nodes_per_remote_dc, allow_dc_failover_for_local_cl, prev_idx: AtomicUsize::new(0), _transport: Default::default(), _connection_manager: Default::default(), } } fn replicas_for_request( &self, request: Request, cluster: &ClusterMetadata, ) -> QueryPlan { let token = request .token .or_else(|| request.routing_key.map(Murmur3Token::generate)); if let Some(token) = token { self.replicas_for_token(token, request.keyspace, request.consistency, cluster) } else { QueryPlan::new(self.round_robin_unignored_local_nodes(cluster)) } } fn replicas_for_token( &self, token: Murmur3Token, keyspace: Option<&str>, consistency: Option, cluster: &ClusterMetadata, ) -> QueryPlan { keyspace .and_then(|keyspace| cluster.keyspace(keyspace)) .map(|keyspace| self.replicas_for_keyspace(token, keyspace, consistency, cluster)) .unwrap_or_else(|| QueryPlan::new(self.round_robin_unignored_local_nodes(cluster))) } fn replicas_for_keyspace( &self, token: Murmur3Token, keyspace: &KeyspaceMetadata, consistency: Option, cluster: &ClusterMetadata, ) -> QueryPlan { match &keyspace.replication_strategy { ReplicationStrategy::SimpleStrategy { replication_factor } => { self.simple_strategy_replicas(token, *replication_factor, cluster) } ReplicationStrategy::NetworkTopologyStrategy { datacenter_replication_factor, } => self.network_topology_strategy_replicas( token, datacenter_replication_factor.clone(), consistency, cluster, ), ReplicationStrategy::Other => self.simple_strategy_replicas(token, 1, cluster), } } fn network_topology_strategy_replicas( &self, token: Murmur3Token, mut datacenter_replication_factor: FxHashMap, consistency: Option, cluster: &ClusterMetadata, ) -> QueryPlan { // similar to Datastax BasicLoadBalancingPolicy: // 1. fetch all replicas // 2. extract unignored local // 3. shuffle replicas // 4. append round-robin unignored local non-replicas // 5. optionally, add shuffled remote unignored non-replicas // Filter ignored nodes BEFORE the per-DC/per-rack counting loop. The // previous code stripped them only at the end via `retain`, so an // ignored node could legitimately consume one of the per-DC quota // slots and then disappear, leaving us with fewer real replicas than // the configured replication factor. let replicas = cluster .token_map() .nodes_for_token(token) .filter(|node| !node.is_ignored()) .collect_vec(); let desired_replica_count = datacenter_replication_factor.values().sum(); let mut same_rack_replicas: FxHashMap = datacenter_replication_factor .iter() .map(|(dc, replication_factor)| { let rack_count = cluster.datacenter(dc).map(|dc| dc.rack_count).unwrap_or(0); (dc.into(), replication_factor.saturating_sub(rack_count)) }) .collect(); let mut result = Vec::with_capacity(desired_replica_count); let mut used_dc_racks: FxHashSet<(&str, &str)> = Default::default(); for replica in &replicas { if let Some(datacenter_replication_factor) = datacenter_replication_factor.get_mut(replica.datacenter()) { if *datacenter_replication_factor == 0 { // found enough nodes in this datacenter continue; } let current_node_dc = replica.datacenter(); let current_node_rack = replica.rack(); if used_dc_racks.contains(&(current_node_dc, current_node_rack)) { // check if we need to put nodes from the same rack multiple times to meet // the replication factor if let Some(same_rack_replicas) = same_rack_replicas.get_mut(current_node_dc) { if *same_rack_replicas > 0 { *same_rack_replicas -= 1; *datacenter_replication_factor -= 1; result.push(replica.clone()); } } } else { *datacenter_replication_factor -= 1; used_dc_racks.insert((current_node_dc, current_node_rack)); result.push(replica.clone()); } if result.len() == desired_replica_count { break; } } } // the result now contains only unignored mixed local/remote nodes - // put local in front result.sort_unstable_by(|a, b| { let a_distance = a.distance(); let b_distance = b.distance(); if a_distance == b_distance { return CmpOrdering::Equal; } if a_distance == Some(NodeDistance::Local) { return CmpOrdering::Less; } if b_distance == Some(NodeDistance::Local) { return CmpOrdering::Greater; } CmpOrdering::Equal }); let mut rng = rng(); // find now many local nodes we have let local_count = result.iter().position(|node| node.is_remote()).unwrap_or(0); if local_count > 0 { result[..local_count].shuffle(&mut rng); } // add unignored non-replicas let unignored_nodes = self.round_robin_unignored_local_nodes(cluster); let replicas = result.into_iter().chain(unignored_nodes); // now the result contains (in order): local replicas, remote replicas, local non-replicas if let Some(max_nodes_per_remote_dc) = self.max_nodes_per_remote_dc { if let Some(consistency) = consistency { if !self.allow_dc_failover_for_local_cl && consistency.is_dc_local() { return QueryPlan::new(replicas.collect()); } } let mut remote_nodes = cluster.unignored_remote_nodes_capped(max_nodes_per_remote_dc); remote_nodes.shuffle(&mut rng); QueryPlan::new( replicas .chain(remote_nodes) .unique_by(|node| node.broadcast_rpc_address()) .collect(), ) } else { QueryPlan::new( replicas .unique_by(|node| node.broadcast_rpc_address()) .collect(), ) } } fn simple_strategy_replicas( &self, token: Murmur3Token, replica_count: usize, cluster: &ClusterMetadata, ) -> QueryPlan { // Walk the ring filtering ignored nodes BEFORE taking replica_count. // Filtering after a cap (the previous behaviour) silently shrinks the // replica set whenever the natural ring positions for `token` happen // to land on ignored nodes, leaving the query plan to fall back to // unrelated round-robin nodes for the leading slots. let mut replicas = cluster .token_map() .nodes_for_token(token) .filter(|node| !node.is_ignored()) .take(replica_count) .collect_vec(); replicas.shuffle(&mut rng()); let unignored_nodes = self.round_robin_unignored_nodes(cluster); QueryPlan::new( replicas .into_iter() .chain(unignored_nodes) .unique_by(|node| node.broadcast_rpc_address()) .collect(), ) } fn round_robin_unignored_nodes( &self, cluster: &ClusterMetadata, ) -> Vec>> { let mut nodes = cluster.unignored_nodes(); if nodes.is_empty() { return nodes; } let cur_idx = self.prev_idx.fetch_add(1, Ordering::SeqCst) % nodes.len(); nodes.rotate_left(cur_idx); nodes } fn round_robin_unignored_local_nodes( &self, cluster: &ClusterMetadata, ) -> Vec>> { let mut nodes = cluster.unignored_local_nodes(); if nodes.is_empty() { return nodes; } let cur_idx = self.prev_idx.fetch_add(1, Ordering::SeqCst) % nodes.len(); nodes.rotate_left(cur_idx); nodes } } //noinspection DuplicatedCode #[cfg(test)] mod tests { use cassandra_protocol::frame::Version; use fxhash::FxHashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::{Arc, LazyLock}; use tokio::sync::watch; use uuid::Uuid; use crate::cluster::connection_manager::MockConnectionManager; use crate::cluster::connection_pool::ConnectionPoolFactory; use crate::cluster::topology::{ KeyspaceMetadata, Node, NodeDistance, NodeState, ReplicationStrategy, }; use crate::cluster::ClusterMetadata; use crate::cluster::Murmur3Token; use crate::load_balancing::{ LoadBalancingStrategy, Request, TopologyAwareLoadBalancingStrategy, }; use crate::retry::MockReconnectionPolicy; use crate::transport::MockCdrsTransport; static HOST_ID_1: LazyLock = LazyLock::new(Uuid::new_v4); static HOST_ID_2: LazyLock = LazyLock::new(Uuid::new_v4); static HOST_ID_3: LazyLock = LazyLock::new(Uuid::new_v4); static HOST_ID_4: LazyLock = LazyLock::new(Uuid::new_v4); static HOST_ID_5: LazyLock = LazyLock::new(Uuid::new_v4); fn create_cluster( ) -> ClusterMetadata> { let (_, keyspace_receiver) = watch::channel(None); let connection_manager = MockConnectionManager::::new(); let reconnection_policy = MockReconnectionPolicy::new(); let connection_pool_factory = Arc::new(ConnectionPoolFactory::new( Default::default(), Version::V4, connection_manager, keyspace_receiver, Arc::new(reconnection_policy), )); let mut nodes = FxHashMap::default(); nodes.insert( *HOST_ID_1, Arc::new(Node::new_with_state( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 1), None, Some(*HOST_ID_1), Some(NodeDistance::Local), NodeState::Up, vec![Murmur3Token::new(1), Murmur3Token::new(2)], "r1".into(), "dc1".into(), )), ); nodes.insert( *HOST_ID_2, Arc::new(Node::new_with_state( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 2), None, Some(*HOST_ID_2), Some(NodeDistance::Local), NodeState::Up, vec![Murmur3Token::new(3), Murmur3Token::new(4)], "r1".into(), "dc1".into(), )), ); nodes.insert( *HOST_ID_3, Arc::new(Node::new_with_state( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 3), None, Some(*HOST_ID_3), Some(NodeDistance::Local), NodeState::Up, vec![Murmur3Token::new(7)], "r2".into(), "dc1".into(), )), ); nodes.insert( Uuid::new_v4(), Arc::new(Node::new_with_state( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 4), None, None, None, NodeState::Up, vec![Murmur3Token::new(8)], "r2".into(), "dc1".into(), )), ); nodes.insert( *HOST_ID_4, Arc::new(Node::new_with_state( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 5), None, Some(*HOST_ID_4), Some(NodeDistance::Remote), NodeState::Up, vec![Murmur3Token::new(5), Murmur3Token::new(6)], "r1".into(), "dc2".into(), )), ); nodes.insert( Uuid::new_v4(), Arc::new(Node::new_with_state( connection_pool_factory.clone(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 6), None, None, None, NodeState::Up, vec![Murmur3Token::new(9)], "r1".into(), "dc2".into(), )), ); nodes.insert( *HOST_ID_5, Arc::new(Node::new_with_state( connection_pool_factory, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 7), None, Some(*HOST_ID_5), Some(NodeDistance::Remote), NodeState::Up, vec![Murmur3Token::new(0)], "r2".into(), "dc2".into(), )), ); let mut datacenter_replication_factor_2 = FxHashMap::default(); datacenter_replication_factor_2.insert("dc1".into(), 3); datacenter_replication_factor_2.insert("dc2".into(), 1); let mut datacenter_replication_factor_4 = FxHashMap::default(); datacenter_replication_factor_4.insert("dc1".into(), 2); datacenter_replication_factor_4.insert("dc2".into(), 1); let mut keyspaces = FxHashMap::default(); keyspaces.insert( "k1".into(), KeyspaceMetadata::new(ReplicationStrategy::SimpleStrategy { replication_factor: 2, }), ); keyspaces.insert( "k2".into(), KeyspaceMetadata::new(ReplicationStrategy::NetworkTopologyStrategy { datacenter_replication_factor: datacenter_replication_factor_2, }), ); keyspaces.insert( "k3".into(), KeyspaceMetadata::new(ReplicationStrategy::Other), ); keyspaces.insert( "k4".into(), KeyspaceMetadata::new(ReplicationStrategy::NetworkTopologyStrategy { datacenter_replication_factor: datacenter_replication_factor_4, }), ); ClusterMetadata::new(nodes, keyspaces) } #[test] fn should_return_nodes_without_request() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan(None, &cluster); assert_eq!(query_plan.nodes.len(), 3); for node in &query_plan.nodes { assert!(node.is_local()); } } #[test] fn should_return_local_nodes_without_request() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan(None, &cluster); assert_eq!(query_plan.nodes.len(), 3); for node in &query_plan.nodes { assert!(node.is_local()); } } #[test] fn should_return_local_nodes_without_token() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan(Some(Request::new(None, None, None, None)), &cluster); assert_eq!(query_plan.nodes.len(), 3); for node in &query_plan.nodes { assert!(node.is_local()); } } #[test] fn should_return_local_nodes_without_keyspace() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan( Some(Request::new(None, Some(Murmur3Token::new(4)), None, None)), &cluster, ); assert_eq!(query_plan.nodes.len(), 3); for node in &query_plan.nodes { assert!(node.is_local()); } } #[test] fn should_return_all_nodes_with_unknown_strategy() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan( Some(Request::new( Some("k3"), Some(Murmur3Token::new(4)), None, None, )), &cluster, ); assert_eq!(query_plan.nodes.len(), 5); // 1 replica + 4 unignored assert_eq!(query_plan.nodes[0].host_id().unwrap(), *HOST_ID_2); } #[test] fn should_return_replica_nodes_with_simple_strategy() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan( Some(Request::new( Some("k1"), Some(Murmur3Token::new(4)), None, None, )), &cluster, ); assert_eq!(query_plan.nodes.len(), 5); // 2 replicas + 3 unignored assert!( query_plan.nodes[0].host_id().unwrap() == *HOST_ID_2 || query_plan.nodes[0].host_id().unwrap() == *HOST_ID_4 ); assert!( query_plan.nodes[1].host_id().unwrap() == *HOST_ID_2 || query_plan.nodes[1].host_id().unwrap() == *HOST_ID_4 ); assert!(query_plan.nodes.iter().all(|node| !node.is_ignored())); } // Same bug shape that network_topology_strategy_replicas had, but in // simple_strategy_replicas: when the natural ring walk starts on ignored // nodes, capping the iteration at replica_count BEFORE the unignored // filter strips the ignored ones away leaves us with fewer real replicas // than the configured replication factor. The fix must take replicas // among unignored nodes, so the query plan still leads with real // candidates. // // The test cluster has two ignored nodes at tokens 8 and 9. Walking from // token 8 with RF=2 (keyspace k1) hits those ignored nodes first; the // pre-fix code would drop them and the query plan would start with // round-robin unignored fallbacks instead of the actual ring replicas // (HOST_ID_5 at token 0 and HOST_ID_1 at token 1). #[test] fn simple_strategy_uses_unignored_nodes_when_ring_starts_on_ignored() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan( Some(Request::new( Some("k1"), // SimpleStrategy with replication_factor: 2 Some(Murmur3Token::new(8)), None, None, )), &cluster, ); // Every node returned must be unignored assert!(query_plan.nodes.iter().all(|node| !node.is_ignored())); // The first two slots are the natural replicas - HOST_ID_5 (the // unignored node at the next ring position past 8) and HOST_ID_1 // (the unignored node after that). Neither should be displaced by // a round-robin fallback just because two ignored nodes happened to // sit in the way. let leading_two: Vec<_> = query_plan .nodes .iter() .take(2) .map(|node| node.host_id().unwrap()) .collect(); assert!( leading_two.contains(&*HOST_ID_5), "expected HOST_ID_5 in the leading replicas, got {:?}", leading_two ); assert!( leading_two.contains(&*HOST_ID_1), "expected HOST_ID_1 in the leading replicas, got {:?}", leading_two ); } #[test] fn should_return_topology_aware_nodes_with_network_topology_strategy_with_repeated_racks() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan( Some(Request::new( Some("k2"), Some(Murmur3Token::new(2)), None, None, )), &cluster, ); assert_eq!(query_plan.nodes.len(), 4); assert!( query_plan.nodes[0].host_id().unwrap() == *HOST_ID_1 || query_plan.nodes[0].host_id().unwrap() == *HOST_ID_2 || query_plan.nodes[0].host_id().unwrap() == *HOST_ID_3 ); assert!( query_plan.nodes[1].host_id().unwrap() == *HOST_ID_1 || query_plan.nodes[1].host_id().unwrap() == *HOST_ID_2 || query_plan.nodes[1].host_id().unwrap() == *HOST_ID_3 ); assert!( query_plan.nodes[2].host_id().unwrap() == *HOST_ID_1 || query_plan.nodes[2].host_id().unwrap() == *HOST_ID_2 || query_plan.nodes[2].host_id().unwrap() == *HOST_ID_3 ); assert_eq!(query_plan.nodes[3].host_id().unwrap(), *HOST_ID_4); } #[test] fn should_return_topology_aware_nodes_with_network_topology_strategy() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(None, false); let query_plan = lb.query_plan( Some(Request::new( Some("k4"), Some(Murmur3Token::new(2)), None, None, )), &cluster, ); assert_eq!(query_plan.nodes.len(), 4); assert!( query_plan.nodes[0].host_id().unwrap() == *HOST_ID_1 || query_plan.nodes[0].host_id().unwrap() == *HOST_ID_3 ); assert!( query_plan.nodes[1].host_id().unwrap() == *HOST_ID_1 || query_plan.nodes[1].host_id().unwrap() == *HOST_ID_3 ); assert_eq!(query_plan.nodes[2].host_id().unwrap(), *HOST_ID_4); assert_eq!(query_plan.nodes[3].host_id().unwrap(), *HOST_ID_2); } #[test] fn should_return_topology_aware_nodes_with_network_topology_strategy_with_remote() { let cluster = create_cluster(); let lb = TopologyAwareLoadBalancingStrategy::new(Some(5), false); let query_plan = lb.query_plan( Some(Request::new( Some("k4"), Some(Murmur3Token::new(2)), None, None, )), &cluster, ); assert_eq!(query_plan.nodes.len(), 5); assert!( query_plan.nodes[0].host_id().unwrap() == *HOST_ID_1 || query_plan.nodes[0].host_id().unwrap() == *HOST_ID_3 ); assert!( query_plan.nodes[1].host_id().unwrap() == *HOST_ID_1 || query_plan.nodes[1].host_id().unwrap() == *HOST_ID_3 ); assert_eq!(query_plan.nodes[2].host_id().unwrap(), *HOST_ID_4); assert_eq!(query_plan.nodes[3].host_id().unwrap(), *HOST_ID_2); assert_eq!(query_plan.nodes[4].host_id().unwrap(), *HOST_ID_5); } } ================================================ FILE: cdrs-tokio/src/load_balancing.rs ================================================ mod initializing_wrapper; pub mod node_distance_evaluator; mod random; mod request; mod round_robin; mod topology_aware; pub(crate) use self::initializing_wrapper::InitializingWrapperLoadBalancingStrategy; pub use self::random::RandomLoadBalancingStrategy; pub use self::request::Request; pub use self::round_robin::RoundRobinLoadBalancingStrategy; pub use self::topology_aware::TopologyAwareLoadBalancingStrategy; use crate::cluster::topology::Node; use crate::cluster::{ClusterMetadata, ConnectionManager}; use crate::transport::CdrsTransport; use derive_more::Constructor; use std::sync::Arc; #[derive(Debug, Constructor, Default)] pub struct QueryPlan + 'static> { pub nodes: Vec>>, } impl> Clone for QueryPlan { fn clone(&self) -> Self { QueryPlan::new(self.nodes.clone()) } } /// Load balancing strategy, usually used for managing target node connections. pub trait LoadBalancingStrategy> { /// Returns query plan for given request. If no request is given, return a generic plan for /// establishing connection(s) to node(s). fn query_plan( &self, request: Option, cluster: &ClusterMetadata, ) -> QueryPlan; } ================================================ FILE: cdrs-tokio/src/macros.rs ================================================ #[macro_export] /// Transforms arguments to values consumed by queries. macro_rules! query_values { ($($value:expr),*) => { { use cdrs_tokio::types::value::Value; use cdrs_tokio::query::QueryValues; let mut values: Vec = Vec::new(); $( values.push($value.into()); )* QueryValues::SimpleValues(values) } }; ($($name:expr => $value:expr),*) => { { use cdrs_tokio::types::value::Value; use cdrs_tokio::query::QueryValues; use std::collections::HashMap; let mut values: HashMap = HashMap::new(); $( values.insert($name.to_string(), $value.into()); )* QueryValues::NamedValues(values) } }; } ================================================ FILE: cdrs-tokio/src/retry/reconnection_policy.rs ================================================ use derive_more::Constructor; #[cfg(test)] use mockall::automock; use rand::{rng, RngExt}; use std::time::Duration; const DEFAULT_BASE_DELAY: Duration = Duration::from_secs(1); const DEFAULT_MAX_DELAY: Duration = Duration::from_secs(60); /// Determines the time for the next reconnection attempt when trying to reconnect to a node. pub trait ReconnectionSchedule { /// Returns next reconnect delay or `None` if not attempt should be made. fn next_delay(&mut self) -> Option; } /// Creates reconnection schedules when trying to re-establish connections. #[cfg_attr(test, automock)] pub trait ReconnectionPolicy { /// Creates new schedule when a connection needs to be re-established. fn new_node_schedule(&self) -> Box; } /// Schedules reconnection at constant interval. #[derive(Copy, Clone, Constructor, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct ConstantReconnectionPolicy { base_delay: Duration, } impl Default for ConstantReconnectionPolicy { fn default() -> Self { ConstantReconnectionPolicy::new(DEFAULT_BASE_DELAY) } } impl ReconnectionPolicy for ConstantReconnectionPolicy { fn new_node_schedule(&self) -> Box { Box::new(ConstantReconnectionSchedule::new(self.base_delay)) } } #[derive(Constructor)] struct ConstantReconnectionSchedule { base_delay: Duration, } impl ReconnectionSchedule for ConstantReconnectionSchedule { fn next_delay(&mut self) -> Option { Some(self.base_delay) } } /// Never schedules reconnections. #[derive(Default, Copy, Clone, Debug, PartialEq, Ord, PartialOrd, Eq, Hash)] pub struct NeverReconnectionPolicy; impl ReconnectionPolicy for NeverReconnectionPolicy { fn new_node_schedule(&self) -> Box { Box::new(NeverReconnectionSchedule) } } struct NeverReconnectionSchedule; impl ReconnectionSchedule for NeverReconnectionSchedule { fn next_delay(&mut self) -> Option { None } } /// A reconnection policy that waits exponentially longer between each reconnection attempt /// (but keeps a constant delay once a maximum delay is reached). The delay will increase /// exponentially, with an added jitter. /// /// Note: by design this policy retries forever. The delay grows by a factor of two on each /// attempt up to `max_delay`, then stays at `max_delay` indefinitely. The `max_attempts` /// field controls only when the delay saturates - it is the number of doublings before /// `base_delay` reaches `max_delay`, NOT a hard cap on the number of reconnection attempts. /// Use [`NeverReconnectionPolicy`] if you want to avoid reconnections entirely. #[derive(Copy, Clone, Constructor, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct ExponentialReconnectionPolicy { base_delay: Duration, max_delay: Duration, /// Attempts after which the delay has saturated at [`Self::max_delay`]. Subsequent /// reconnection attempts continue to be scheduled at `max_delay` indefinitely. max_attempts: usize, } impl ReconnectionPolicy for ExponentialReconnectionPolicy { fn new_node_schedule(&self) -> Box { Box::new(ExponentialReconnectionSchedule::new( self.base_delay, self.max_delay, self.max_attempts, )) } } impl Default for ExponentialReconnectionPolicy { fn default() -> Self { let base_delay = DEFAULT_BASE_DELAY.as_millis() as i64; let ceil = u32::from((base_delay & (base_delay - 1)) != 0); ExponentialReconnectionPolicy::new( DEFAULT_BASE_DELAY, DEFAULT_MAX_DELAY, (64 - (i64::MAX / base_delay).leading_zeros() - ceil) as usize, ) } } struct ExponentialReconnectionSchedule { base_delay: Duration, max_delay: Duration, max_attempts: usize, attempt: usize, } impl ReconnectionSchedule for ExponentialReconnectionSchedule { fn next_delay(&mut self) -> Option { // Once we've reached the saturation point, keep returning the max // delay forever. Use `>=` rather than `==` so that nothing - including // a future caller mutating fields directly - can accidentally skip // this branch and re-enter the doubling logic on a saturated counter. if self.attempt >= self.max_attempts { return Some(self.max_delay); } self.attempt += 1; let delay = self .base_delay .saturating_mul(1u32.checked_shl(self.attempt as u32).unwrap_or(u32::MAX)) .min(self.max_delay); // Apply +/-15% jitter so a flock of clients reconnecting at the same // time don't all retry in lockstep. let jitter = rng().random_range(85..116); Some( (delay / 100) .saturating_mul(jitter) .clamp(self.base_delay, self.max_delay), ) } } impl ExponentialReconnectionSchedule { pub fn new(base_delay: Duration, max_delay: Duration, max_attempts: usize) -> Self { ExponentialReconnectionSchedule { base_delay, max_delay, max_attempts, attempt: 0, } } } #[cfg(test)] mod tests { use crate::retry::reconnection_policy::ExponentialReconnectionSchedule; use crate::retry::ReconnectionSchedule; #[test] fn should_reach_max_exponential_delay_without_panic() { let mut schedule = ExponentialReconnectionSchedule { base_delay: Default::default(), max_delay: Default::default(), max_attempts: usize::MAX, attempt: usize::MAX - 1, }; schedule.next_delay(); } // Once the schedule's attempt counter has reached or surpassed // max_attempts, every subsequent call must return Some(max_delay) - // never None and never something larger than max_delay. This is the // documented saturation behaviour of [`ExponentialReconnectionPolicy`]. #[test] fn saturated_schedule_keeps_returning_max_delay() { use std::time::Duration; let max_delay = Duration::from_secs(60); let mut schedule = ExponentialReconnectionSchedule { base_delay: Duration::from_secs(1), max_delay, // arrange the attempt counter slightly past max_attempts so the // saturation branch is the one we exercise. Without `>=` (only // `==`) the schedule would skip this branch and try to double // the delay again - which we explicitly do not want once we're // already at the cap. max_attempts: 5, attempt: 6, }; for _ in 0..3 { assert_eq!(schedule.next_delay(), Some(max_delay)); } } } ================================================ FILE: cdrs-tokio/src/retry/retry_policy.rs ================================================ use derive_more::Display; use cassandra_protocol::error::Error; use cassandra_protocol::frame::message_error::{ ErrorBody, ErrorType, ReadTimeoutError, WriteTimeoutError, WriteType, }; #[derive(Debug, PartialEq, Eq, Ord, PartialOrd, Hash, Copy, Clone, Display)] pub enum RetryDecision { RetrySameNode, RetryNextNode, DontRetry, } /// Information about a failed query. pub struct QueryInfo<'a> { pub error: &'a Error, pub is_idempotent: bool, } /// Query-specific information about current state of retrying. pub trait RetrySession { /// Decide what to do with the failing query. fn decide(&mut self, query_info: QueryInfo) -> RetryDecision; } /// Retry policy determines what to do in case of communication error. pub trait RetryPolicy { /// Called for each new query, starts a session of deciding about retries. fn new_session(&self) -> Box; } /// Forwards all errors directly to the user, never retries #[derive(Default, Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct FallthroughRetryPolicy; impl RetryPolicy for FallthroughRetryPolicy { fn new_session(&self) -> Box { Box::::default() } } #[derive(Default, Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct FallthroughRetrySession; impl RetrySession for FallthroughRetrySession { fn decide(&mut self, _query_info: QueryInfo) -> RetryDecision { RetryDecision::DontRetry } } /// Default retry policy - retries when there is a high chance that a retry might help. /// Behaviour based on [DataStax Java Driver](https://docs.datastax.com/en/developer/java-driver/4.10/manual/core/retries/) #[derive(Default, Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct DefaultRetryPolicy; impl RetryPolicy for DefaultRetryPolicy { fn new_session(&self) -> Box { Box::::default() } } #[derive(Default, Debug, Clone, Copy)] pub struct DefaultRetrySession { was_unavailable_retry: bool, was_read_timeout_retry: bool, was_write_timeout_retry: bool, } impl RetrySession for DefaultRetrySession { fn decide(&mut self, query_info: QueryInfo) -> RetryDecision { match query_info.error { Error::Io(_) | Error::General(_) | Error::Server { body: ErrorBody { ty: ErrorType::Overloaded, .. }, .. } | Error::Server { body: ErrorBody { ty: ErrorType::Server, .. }, .. } | Error::Server { body: ErrorBody { ty: ErrorType::Truncate, .. }, .. } => { if query_info.is_idempotent { RetryDecision::RetryNextNode } else { RetryDecision::DontRetry } } Error::Server { body: ErrorBody { ty: ErrorType::Unavailable(_), .. }, .. } => { if !self.was_unavailable_retry { self.was_unavailable_retry = true; RetryDecision::RetryNextNode } else { RetryDecision::DontRetry } } Error::Server { body: ErrorBody { ty: ErrorType::ReadTimeout(error @ ReadTimeoutError { .. }), .. }, .. } => { if !self.was_read_timeout_retry && error.received >= error.block_for && error.replica_has_responded() { self.was_read_timeout_retry = true; RetryDecision::RetrySameNode } else { RetryDecision::DontRetry } } Error::Server { body: ErrorBody { ty: ErrorType::WriteTimeout(error @ WriteTimeoutError { .. }), .. }, .. } => { if !self.was_write_timeout_retry && query_info.is_idempotent && error.write_type == WriteType::BatchLog { self.was_write_timeout_retry = true; RetryDecision::RetrySameNode } else { RetryDecision::DontRetry } } Error::Server { body: ErrorBody { ty: ErrorType::IsBootstrapping, .. }, .. } => RetryDecision::RetryNextNode, _ => RetryDecision::DontRetry, } } } ================================================ FILE: cdrs-tokio/src/retry.rs ================================================ mod reconnection_policy; mod retry_policy; pub use reconnection_policy::*; pub use retry_policy::*; ================================================ FILE: cdrs-tokio/src/speculative_execution.rs ================================================ //! Pre-emptively query another node if the current one takes too long to respond. //! //! Sometimes a Cassandra node might be experiencing difficulties (ex: long GC pause) and take //! longer than usual to reply. Queries sent to that node will experience bad latency. //! //! One thing we can do to improve that is pre-emptively start a second execution of the query //! against another node, before the first node has replied or errored out. If that second node //! replies faster, we can send the response back to the client. We also cancel the first execution. //! //! Turning on speculative executions doesn't change the driver's retry behavior. Each parallel //! execution will trigger retries independently. use derive_more::Constructor; use std::time::Duration; /// Current speculative execution context. #[derive(Constructor, Debug)] pub struct Context { pub running_executions: usize, } /// The policy that decides if the driver will send speculative queries to the next nodes when the /// current node takes too long to respond. If a query is not idempotent, the driver will never /// schedule speculative executions for it, because there is no way to guarantee that only one node /// will apply the mutation. pub trait SpeculativeExecutionPolicy { /// Returns the time until a speculative request is sent to the next node. `None` means there /// should not be another execution. fn execution_interval(&self, context: &Context) -> Option; } /// A policy that schedules a configurable number of speculative executions, separated by a fixed /// delay. #[derive(Debug, Clone, Copy, Constructor)] pub struct ConstantSpeculativeExecutionPolicy { max_executions: usize, delay: Duration, } impl SpeculativeExecutionPolicy for ConstantSpeculativeExecutionPolicy { fn execution_interval(&self, context: &Context) -> Option { if context.running_executions < self.max_executions { Some(self.delay) } else { None } } } ================================================ FILE: cdrs-tokio/src/statement/statement_params.rs ================================================ use cassandra_protocol::query::QueryParams; use cassandra_protocol::types::value::Value; use derivative::Derivative; use std::sync::Arc; use crate::cluster::Murmur3Token; use crate::retry::RetryPolicy; use crate::speculative_execution::SpeculativeExecutionPolicy; /// Parameters of Query for query operation. #[derive(Default, Clone, Derivative)] #[derivative(Debug)] pub struct StatementParams { /// Protocol-level parameters. pub query_params: QueryParams, /// Is the query idempotent. pub is_idempotent: bool, /// Query keyspace. If not using a global one, setting it explicitly might help the load /// balancer use more appropriate nodes. Note: prepared statements with keyspace information /// take precedence over this field. pub keyspace: Option, /// The token to use for token-aware routing. A load balancer may use this information to /// determine which nodes to contact. Takes precedence over `routing_key`. pub token: Option, /// The partition key to use for token-aware routing. A load balancer may use this information /// to determine which nodes to contact. Alternative to `token`. Note: prepared statements /// with bound primary key values take precedence over this field. pub routing_key: Option>, /// Should tracing be enabled. pub tracing: bool, /// Should warnings be enabled. pub warnings: bool, /// Custom statement speculative execution policy. #[derivative(Debug = "ignore")] pub speculative_execution_policy: Option>, /// Custom statement retry policy. #[derivative(Debug = "ignore")] pub retry_policy: Option>, /// Enable beta protocol features. Server will respond with ERROR if protocol version is marked /// as beta on server and client does not provide this flag. pub beta_protocol: bool, } ================================================ FILE: cdrs-tokio/src/statement/statement_params_builder.rs ================================================ use cassandra_protocol::consistency::Consistency; use cassandra_protocol::query::{QueryFlags, QueryParams, QueryValues}; use cassandra_protocol::types::value::Value; use cassandra_protocol::types::{CBytes, CInt, CLong}; use derivative::Derivative; use std::sync::Arc; use crate::cluster::Murmur3Token; use crate::retry::RetryPolicy; use crate::speculative_execution::SpeculativeExecutionPolicy; use crate::statement::StatementParams; #[derive(Default, Derivative)] #[derivative(Debug)] pub struct StatementParamsBuilder { consistency: Consistency, flags: Option, values: Option, with_names: bool, page_size: Option, paging_state: Option, serial_consistency: Option, timestamp: Option, is_idempotent: bool, keyspace: Option, now_in_seconds: Option, token: Option, routing_key: Option>, tracing: bool, warnings: bool, #[derivative(Debug = "ignore")] speculative_execution_policy: Option>, #[derivative(Debug = "ignore")] retry_policy: Option>, beta_protocol: bool, } impl StatementParamsBuilder { pub fn new() -> StatementParamsBuilder { Default::default() } /// Sets new statement consistency #[must_use] pub fn with_consistency(mut self, consistency: Consistency) -> Self { self.consistency = consistency; self } // Sets new flags. #[must_use] pub fn with_flags(mut self, flags: QueryFlags) -> Self { self.flags = Some(flags); self } /// Sets new statement values. #[must_use] pub fn with_values(mut self, values: QueryValues) -> Self { self.with_names = values.has_names(); self.values = Some(values); self.flags = self.flags.or_else(|| { let mut flags = QueryFlags::VALUE; if self.with_names { flags.insert(QueryFlags::WITH_NAMES_FOR_VALUES); } Some(flags) }); self } /// Sets the "with names for values" flag. #[must_use] pub fn with_names(mut self, with_names: bool) -> Self { self.with_names = with_names; self } /// Sets new statement consistency. #[must_use] pub fn with_page_size(mut self, size: i32) -> Self { self.page_size = Some(size); self.flags = self.flags.or(Some(QueryFlags::PAGE_SIZE)); self } /// Sets new paging state. #[must_use] pub fn with_paging_state(mut self, state: CBytes) -> Self { self.paging_state = Some(state); self.flags = self.flags.or(Some(QueryFlags::WITH_PAGING_STATE)); self } /// Sets new serial consistency. #[must_use] pub fn with_serial_consistency(mut self, serial_consistency: Consistency) -> Self { self.serial_consistency = Some(serial_consistency); self } /// Sets new timestamp. #[must_use] pub fn with_timestamp(mut self, timestamp: i64) -> Self { self.timestamp = Some(timestamp); self } /// Sets new keyspace. #[must_use] pub fn with_keyspace(mut self, keyspace: String) -> Self { self.keyspace = Some(keyspace); self } /// Sets new token for routing. #[must_use] pub fn with_token(mut self, token: Murmur3Token) -> Self { self.token = Some(token); self } /// Sets new explicit routing key. #[must_use] pub fn with_routing_key(mut self, routing_key: Vec) -> Self { self.routing_key = Some(routing_key); self } /// Marks the statement as idempotent or not #[must_use] pub fn idempotent(mut self, value: bool) -> Self { self.is_idempotent = value; self } /// Sets custom statement speculative execution policy. #[must_use] pub fn with_speculative_execution_policy( mut self, speculative_execution_policy: Arc, ) -> Self { self.speculative_execution_policy = Some(speculative_execution_policy); self } /// Sets custom statement retry policy. #[must_use] pub fn with_retry_policy(mut self, retry_policy: Arc) -> Self { self.retry_policy = Some(retry_policy); self } /// Sets beta protocol usage flag #[must_use] pub fn with_beta_protocol(mut self, beta_protocol: bool) -> Self { self.beta_protocol = beta_protocol; self } /// Sets "now" in seconds. #[must_use] pub fn with_now_in_seconds(mut self, now_in_seconds: CInt) -> Self { self.now_in_seconds = Some(now_in_seconds); self } #[must_use] pub fn build(self) -> StatementParams { StatementParams { query_params: QueryParams { consistency: self.consistency, values: self.values, with_names: self.with_names, page_size: self.page_size, paging_state: self.paging_state, serial_consistency: self.serial_consistency, timestamp: self.timestamp, keyspace: self.keyspace.clone(), now_in_seconds: self.now_in_seconds, }, is_idempotent: self.is_idempotent, keyspace: self.keyspace, token: self.token, routing_key: self.routing_key, tracing: self.tracing, warnings: self.warnings, speculative_execution_policy: self.speculative_execution_policy, retry_policy: self.retry_policy, beta_protocol: self.beta_protocol, } } } ================================================ FILE: cdrs-tokio/src/statement.rs ================================================ mod statement_params; mod statement_params_builder; pub use statement_params::*; pub use statement_params_builder::*; ================================================ FILE: cdrs-tokio/src/transport.rs ================================================ //!This module contains a declaration of `CdrsTransport` trait which should be implemented //!for particular transport to be able using it as transport of CDRS client. //! //!Currently, CDRS provides concrete transport which implements `CdrsTransport` trait. //!There are: //! //! * [`TransportTcp`] is default TCP transport which is usually used to establish //!connection and exchange frames. //! //! * [`TransportRustls`] is a transport which is used to establish SSL encrypted connection //!with Apache Cassandra server. **Note:** this option is available if and only if CDRS is imported //!with the `rust-tls` feature. use cassandra_protocol::compression::Compression; use cassandra_protocol::frame::frame_decoder::FrameDecoder; use cassandra_protocol::frame::frame_encoder::FrameEncoder; use cassandra_protocol::frame::message_result::ResultKind; use cassandra_protocol::frame::{Envelope, StreamId, MAX_FRAME_SIZE}; use cassandra_protocol::frame::{FromBytes, Opcode, EVENT_STREAM_ID}; use cassandra_protocol::types::INT_LEN; use derive_more::Constructor; use futures::FutureExt; use fxhash::FxHashMap; use itertools::Itertools; use std::io; use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, AtomicI16, Ordering}; use std::sync::{Arc, Mutex}; use tokio::io::{ split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf, }; use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; #[cfg(feature = "rust-tls")] use tokio_rustls::rustls::{pki_types::ServerName, ClientConfig}; #[cfg(feature = "rust-tls")] use tokio_rustls::TlsConnector as RustlsConnector; use tracing::*; #[cfg(test)] use mockall::*; use crate::cluster::KeyspaceHolder; use crate::envelope_parser::{convert_envelope_into_result, parse_envelope}; use crate::future::BoxFuture; use crate::Error; use crate::Result; const INITIAL_STREAM_ID: i16 = 1; /// General CDRS transport trait. pub trait CdrsTransport: Send + Sync { /// Schedules data envelope for writing and waits for a response. Handshake envelopes need to /// be marked as such, since their wire representation is different. fn write_envelope<'a>( &'a self, envelope: &'a Envelope, handshake: bool, ) -> BoxFuture<'a, Result>; /// Checks if the connection is broken (e.g. after read or write errors). fn is_broken(&self) -> bool; /// Returns associated node address. fn address(&self) -> SocketAddr; } #[cfg(test)] mock! { pub CdrsTransport { } impl CdrsTransport for CdrsTransport { fn write_envelope( &self, envelope: &Envelope, handshake: bool, ) -> BoxFuture<'static, Result>; fn is_broken(&self) -> bool; fn address(&self) -> SocketAddr; } } /// Default Tcp transport. #[derive(Debug)] pub struct TransportTcp { inner: AsyncTransport, } impl TransportTcp { #[allow(clippy::too_many_arguments)] pub async fn new( addr: SocketAddr, keyspace_holder: Arc, event_handler: Option>, error_handler: Option>, compression: Compression, frame_encoder: Box, frame_decoder: Box, buffer_size: usize, tcp_nodelay: bool, ) -> io::Result { TcpStream::connect(addr).await.and_then(move |socket| { socket.set_nodelay(tcp_nodelay)?; Self::with_stream( socket, addr, keyspace_holder, event_handler, error_handler, compression, frame_encoder, frame_decoder, buffer_size, ) }) } #[allow(clippy::too_many_arguments)] pub fn with_stream( stream: T, addr: SocketAddr, keyspace_holder: Arc, event_handler: Option>, error_handler: Option>, compression: Compression, frame_encoder: Box, frame_decoder: Box, buffer_size: usize, ) -> io::Result { let (read_half, write_half) = split(stream); Ok(TransportTcp { inner: AsyncTransport::new( addr, compression, frame_encoder, frame_decoder, buffer_size, read_half, write_half, event_handler, error_handler, keyspace_holder, ), }) } } impl CdrsTransport for TransportTcp { //noinspection DuplicatedCode #[inline] fn write_envelope<'a>( &'a self, envelope: &'a Envelope, handshake: bool, ) -> BoxFuture<'a, Result> { self.inner.write_envelope(envelope, handshake).boxed() } #[inline] fn is_broken(&self) -> bool { self.inner.is_broken() } #[inline] fn address(&self) -> SocketAddr { self.inner.addr() } } #[cfg(feature = "rust-tls")] #[derive(Debug)] pub struct TransportRustls { inner: AsyncTransport, } #[cfg(feature = "rust-tls")] impl TransportRustls { #[allow(clippy::too_many_arguments)] pub async fn new( addr: SocketAddr, dns_name: ServerName<'static>, config: Arc, keyspace_holder: Arc, event_handler: Option>, error_handler: Option>, compression: Compression, frame_encoder: Box, frame_decoder: Box, buffer_size: usize, tcp_nodelay: bool, ) -> io::Result { let stream = TcpStream::connect(addr).await?; stream.set_nodelay(tcp_nodelay)?; Self::with_stream( stream, addr, dns_name, config, keyspace_holder, event_handler, error_handler, compression, frame_encoder, frame_decoder, buffer_size, ) .await } #[allow(clippy::too_many_arguments)] pub async fn with_stream( stream: T, addr: SocketAddr, dns_name: ServerName<'static>, config: Arc, keyspace_holder: Arc, event_handler: Option>, error_handler: Option>, compression: Compression, frame_encoder: Box, frame_decoder: Box, buffer_size: usize, ) -> io::Result { let connector = RustlsConnector::from(config.clone()); let stream = connector.connect(dns_name, stream).await?; let (read_half, write_half) = split(stream); Ok(Self { inner: AsyncTransport::new( addr, compression, frame_encoder, frame_decoder, buffer_size, read_half, write_half, event_handler, error_handler, keyspace_holder, ), }) } } #[cfg(feature = "rust-tls")] impl CdrsTransport for TransportRustls { //noinspection DuplicatedCode #[inline] fn write_envelope<'a>( &'a self, envelope: &'a Envelope, handshake: bool, ) -> BoxFuture<'a, Result> { self.inner.write_envelope(envelope, handshake).boxed() } #[inline] fn is_broken(&self) -> bool { self.inner.is_broken() } #[inline] fn address(&self) -> SocketAddr { self.inner.addr() } } #[derive(Debug)] struct AsyncTransport { addr: SocketAddr, compression: Compression, write_sender: mpsc::Sender, is_broken: Arc, processing_handle: JoinHandle<()>, } impl Drop for AsyncTransport { fn drop(&mut self) { self.processing_handle.abort(); } } impl AsyncTransport { #[allow(clippy::too_many_arguments)] fn new( addr: SocketAddr, compression: Compression, frame_encoder: Box, frame_decoder: Box, buffer_size: usize, read_half: ReadHalf, write_half: WriteHalf, event_handler: Option>, error_handler: Option>, keyspace_holder: Arc, ) -> Self { let (write_sender, write_receiver) = mpsc::channel(buffer_size); let is_broken = Arc::new(AtomicBool::new(false)); let processing_handle = tokio::spawn(Self::start_processing( write_receiver, event_handler, error_handler, read_half, write_half, keyspace_holder, is_broken.clone(), compression, addr, frame_encoder, frame_decoder, )); AsyncTransport { addr, compression, write_sender, is_broken, processing_handle, } } #[inline] fn is_broken(&self) -> bool { self.is_broken.load(Ordering::Relaxed) } #[inline] fn addr(&self) -> SocketAddr { self.addr } async fn write_envelope(&self, envelope: &Envelope, handshake: bool) -> Result { let (sender, receiver) = oneshot::channel(); // leave stream id empty for now and generate it later // handshake messages are never compressed let data = if handshake { envelope.encode_with(Compression::None)? } else { envelope.encode_with(self.compression)? }; self.write_sender .send(Request::new(data, sender, handshake)) .await .map_err(|_| Error::General("Connection closed when writing data!".into()))?; receiver .await .map_err(|_| Error::General("Connection closed while waiting for response!".into()))? } #[allow(clippy::too_many_arguments)] async fn start_processing( write_receiver: mpsc::Receiver, event_handler: Option>, error_handler: Option>, read_half: ReadHalf, write_half: WriteHalf, keyspace_holder: Arc, is_broken: Arc, compression: Compression, addr: SocketAddr, frame_encoder: Box, frame_decoder: Box, ) { let response_handler_map = ResponseHandlerMap::new(); let writer = Self::start_writing( write_receiver, BufWriter::new(write_half), &response_handler_map, frame_encoder, ); let reader = Self::start_reading_handshake_frames( BufReader::with_capacity(MAX_FRAME_SIZE, read_half), event_handler, compression, addr, keyspace_holder, &response_handler_map, frame_decoder, ); let result = tokio::try_join!(writer, reader); if let Err(error) = result { error!(%error, "Transport error!"); is_broken.store(true, Ordering::Relaxed); response_handler_map.signal_general_error(&error); if let Some(error_handler) = error_handler { match error_handler.try_send(error) { Ok(_) => debug!("Error handler notified!"), Err(e) => warn!("Error handler failed to notify: {e}"), } } } } async fn start_reading_handshake_frames( mut read_half: impl AsyncRead + Unpin, event_handler: Option>, compression: Compression, addr: SocketAddr, keyspace_holder: Arc, response_handler_map: &ResponseHandlerMap, frame_decoder: Box, ) -> Result<()> { // before Authenticate or Ready, envelopes are unframed loop { let result = parse_envelope(&mut read_half, compression, addr).await; match result { Ok(envelope) => { if envelope.stream_id >= 0 { let opcode = envelope.opcode; response_handler_map.send_response(envelope.stream_id, Ok(envelope))?; if opcode == Opcode::Authenticate || opcode == Opcode::Ready { // all frames should now be encoded return Self::start_reading_normal_frames( read_half, event_handler, compression, addr, keyspace_holder, response_handler_map, frame_decoder, ) .await; } } else if envelope.stream_id == EVENT_STREAM_ID { // server event if let Some(event_handler) = &event_handler { let _ = event_handler.send(envelope).await; } } } Err(error) => return Err(error), } } } async fn start_reading_normal_frames( mut read_half: impl AsyncRead + Unpin, event_handler: Option>, compression: Compression, addr: SocketAddr, keyspace_holder: Arc, response_handler_map: &ResponseHandlerMap, mut frame_decoder: Box, ) -> Result<()> { let mut buffer = Vec::with_capacity(MAX_FRAME_SIZE); loop { let num_read = read_half.read_buf(&mut buffer).await?; if num_read == 0 { break Err(Error::Io(io::Error::new( io::ErrorKind::UnexpectedEof, "EOF", ))); } let envelopes = frame_decoder.consume(&mut buffer, compression)?; for envelope in envelopes { if envelope.stream_id >= 0 { // in case we get a SetKeyspace result, we need to store current keyspace // checks are done manually for speed if envelope.opcode == Opcode::Result { let result_kind = ResultKind::from_bytes(&envelope.body[..INT_LEN])?; if result_kind == ResultKind::SetKeyspace { let response_body = envelope.response_body()?; let set_keyspace = response_body.into_set_keyspace().ok_or_else(|| { Error::General( "SetKeyspace not found with SetKeyspace opcode!".into(), ) })?; keyspace_holder.update_current_keyspace(set_keyspace.body); } } // normal response to query response_handler_map.send_response( envelope.stream_id, convert_envelope_into_result(envelope, addr), )?; } else if envelope.stream_id == EVENT_STREAM_ID { // server event if let Some(event_handler) = &event_handler { let _ = event_handler.send(envelope).await; } } } } } async fn start_writing( mut write_receiver: mpsc::Receiver, mut write_half: impl AsyncWrite + Unpin, response_handler_map: &ResponseHandlerMap, mut frame_encoder: Box, ) -> Result<()> { let mut frame_stream_ids = Vec::with_capacity(1); while let Some(mut request) = write_receiver.recv().await { frame_stream_ids.clear(); loop { let stream_id = response_handler_map.next_stream_id(); frame_stream_ids.push(stream_id); request.set_stream_id(stream_id); response_handler_map.add_handler(stream_id, request.handler); if request.handshake { // handshake messages are not framed, so let's write them directly if let Err(error) = write_half.write_all(&request.data).await { response_handler_map.send_response(stream_id, Err(error.into()))?; return Err(Error::General("Write channel failure!".into())); } } else { // post-handshake messages can be aggregated in frames by the encoder loop { if frame_encoder.can_fit(request.data.len()) { frame_encoder.add_envelope(request.data); break; } // flush the previous frame or create a non-self-contained one if frame_encoder.has_envelopes() { // we have some envelopes => flush current frame Self::write_self_contained_frame( &mut write_half, response_handler_map, &mut frame_stream_ids, frame_encoder.as_mut(), ) .await?; } else { // non-self-contained let data_len = request.data.len(); let mut data_start = 0; while data_start < data_len { let (data_start_offset, frame) = frame_encoder .finalize_non_self_contained(&request.data[data_start..]); data_start += data_start_offset; Self::write_frame( &mut write_half, response_handler_map, &mut frame_stream_ids, frame, ) .await?; frame_encoder.reset(); } break; } } } request = match write_receiver.try_recv() { Ok(request) => request, Err(_) => { if frame_encoder.has_envelopes() { Self::write_self_contained_frame( &mut write_half, response_handler_map, &mut frame_stream_ids, frame_encoder.as_mut(), ) .await?; } if let Err(error) = write_half.flush().await { Self::notify_error_handlers( response_handler_map, &mut frame_stream_ids, error.into(), )?; return Err(Error::General("Write channel failure!".into())); } break; } } } } Ok(()) } async fn write_self_contained_frame( write_half: &mut (impl AsyncWrite + Unpin), response_handler_map: &ResponseHandlerMap, frame_stream_ids: &mut Vec, frame_encoder: &mut (dyn FrameEncoder + Send + Sync), ) -> Result<()> { Self::write_frame( write_half, response_handler_map, frame_stream_ids, frame_encoder.finalize_self_contained(), ) .await?; frame_encoder.reset(); // Drop the stream ids that just left over the wire. Their handlers // are now waiting for actual server responses (or the connection // tearing down), and must NOT be notified of write failures from a // *subsequent* frame in the same batch. Without this clear we'd // accumulate stream ids across every flushed frame and a single late // write error would falsely fail every previously-sent envelope. frame_stream_ids.clear(); Ok(()) } async fn write_frame( write_half: &mut (impl AsyncWrite + Unpin), response_handler_map: &ResponseHandlerMap, frame_stream_ids: &mut Vec, frame: &[u8], ) -> Result<()> { // If the underlying socket write fails, fan the error out to every // handler whose envelope is in this frame, and then propagate the // failure so the writer task can shut down cleanly. The previous // implementation returned Ok(()) here, swallowing the original write // error, so the writer would happily keep looping on a dead socket. if let Err(error) = write_half.write_all(frame).await { let propagated: Error = error.into(); Self::notify_error_handlers( response_handler_map, frame_stream_ids, propagated.clone(), )?; return Err(propagated); } Ok(()) } fn notify_error_handlers( response_handler_map: &ResponseHandlerMap, frame_stream_ids: &mut Vec, error: Error, ) -> Result<()> { frame_stream_ids .drain(..) .map(|stream_id| response_handler_map.send_response(stream_id, Err(error.clone()))) .try_collect() } } type ResponseHandler = oneshot::Sender>; struct ResponseHandlerMap { stream_handlers: Mutex>, available_stream_id: AtomicI16, } impl ResponseHandlerMap { #[inline] pub fn new() -> Self { ResponseHandlerMap { stream_handlers: Default::default(), available_stream_id: AtomicI16::new(INITIAL_STREAM_ID), } } #[inline] pub fn add_handler(&self, stream_id: StreamId, handler: ResponseHandler) { self.stream_handlers .lock() .unwrap() .insert(stream_id, handler); } pub fn send_response(&self, stream_id: StreamId, response: Result) -> Result<()> { match self.stream_handlers.lock().unwrap().remove(&stream_id) { Some(handler) => { let _ = handler.send(response); Ok(()) } // unmatched stream - probably a bug somewhere None => Err(Error::General(format!("Unmatched stream id: {stream_id}"))), } } pub fn signal_general_error(&self, error: &Error) { for (_, handler) in self.stream_handlers.lock().unwrap().drain() { let _ = handler.send(Err(error.clone())); } } pub fn next_stream_id(&self) -> StreamId { // We allocate stream ids in [INITIAL_STREAM_ID, i16::MAX] inclusive, // wrapping back to INITIAL_STREAM_ID once the maximum has been used. // // The previous implementation called `fetch_add` and then tried to // compare-and-swap the *pre-increment* value back to INITIAL when it // saw a negative result. By construction the post-increment value is // already on the counter at that point, so the CAS expected the wrong // value and almost never succeeded under contention - the counter // would walk through ~32K negative values before yielding usable ids // again, and could yield stream id 0 (which we never want). // // Using a CAS loop that observes the current value and atomically // computes both the value to return AND the next counter state keeps // every returned id strictly inside the allowed range and skips no // ids on wrap. // // Memory ordering: the caller publishes the freshly-allocated // stream id by writing it into the request body and then registering // a handler in the Mutex-protected stream_handlers map. The Mutex // already provides the necessary happens-before for the reader, so // Relaxed would be sufficient today; we use Release on the success // leg as defence in depth so any future change to a lock-free // handler map keeps the publish ordering intact. loop { let current = self.available_stream_id.load(Ordering::Relaxed); let (return_value, new_value) = if current < INITIAL_STREAM_ID { // defensive: counter somehow ended up below INITIAL (e.g. a // future caller stored a bad value). Recover by snapping back // to the start of the range. (INITIAL_STREAM_ID, INITIAL_STREAM_ID + 1) } else if current == i16::MAX { // last id in the range - return it and wrap the counter // straight back to INITIAL_STREAM_ID for the next caller. (i16::MAX, INITIAL_STREAM_ID) } else { (current, current + 1) }; if self .available_stream_id .compare_exchange_weak(current, new_value, Ordering::Release, Ordering::Relaxed) .is_ok() { return return_value; } // CAS lost the race; another thread updated the counter. Reload // and try again - we'll get the next id in sequence. } } } #[derive(Constructor)] struct Request { data: Vec, handler: ResponseHandler, handshake: bool, } impl Request { #[inline] fn set_stream_id(&mut self, stream_d: StreamId) { self.data[2..4].copy_from_slice(&stream_d.to_be_bytes()); } } #[cfg(test)] mod write_buffer_tests { use super::*; use cassandra_protocol::frame::frame_encoder::UncompressedFrameEncoder; use tokio::io::sink; // Verifies that once a frame has been successfully written, the stream // ids that were aggregated into that frame are dropped from the // bookkeeping vector. Otherwise a later write failure would notify those // handlers - whose envelopes have already been sent and may yet receive // a real server response - with a spurious "Write channel failure". #[tokio::test] async fn write_self_contained_frame_clears_stream_ids_on_success() { let map = ResponseHandlerMap::new(); let mut sink = sink(); let mut encoder: Box = Box::new(UncompressedFrameEncoder::default()); // Pretend two envelopes' worth of work has been added to the encoder. // We don't actually need to register handlers on the map for this // test - we're checking that the frame_stream_ids buffer is emptied // once the frame goes out, regardless of map state. encoder.add_envelope(vec![0; 16]); let mut frame_stream_ids: Vec = vec![1, 2]; AsyncTransport::write_self_contained_frame( &mut sink, &map, &mut frame_stream_ids, encoder.as_mut(), ) .await .expect("sink writes always succeed"); // After a successful frame write the vector must be empty so the // *next* frame's failure can't accidentally notify handlers from // already-sent frames. assert!( frame_stream_ids.is_empty(), "frame_stream_ids should be empty after successful write, was {:?}", frame_stream_ids ); } } #[cfg(test)] mod stream_id_tests { use super::*; #[test] fn next_stream_id_starts_at_initial_value() { let map = ResponseHandlerMap::new(); assert_eq!(map.next_stream_id(), INITIAL_STREAM_ID); assert_eq!(map.next_stream_id(), INITIAL_STREAM_ID + 1); } #[test] fn next_stream_id_wraps_to_initial_after_overflow() { let map = ResponseHandlerMap::new(); // arrange the counter so the very next fetch_add overflows i16 map.available_stream_id.store(i16::MAX, Ordering::Relaxed); // the call right at the boundary still returns the last positive id assert_eq!(map.next_stream_id(), i16::MAX); // after the wrap, all returned ids must remain valid (positive) and // the sequence must restart at INITIAL_STREAM_ID. Without the fix the // function would either burn ~32K negative ids before yielding 0, or // return 0 itself, neither of which is correct for our protocol. assert_eq!(map.next_stream_id(), INITIAL_STREAM_ID); assert_eq!(map.next_stream_id(), INITIAL_STREAM_ID + 1); assert_eq!(map.next_stream_id(), INITIAL_STREAM_ID + 2); } } ================================================ FILE: cdrs-tokio/tests/collection_types.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use cassandra_protocol::frame::Version; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::blob::Blob; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::list::List; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::map::Map; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::AsRust; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::ByName; #[cfg(feature = "e2e-tests")] use common::*; #[cfg(feature = "e2e-tests")] use maplit::hashmap; #[cfg(feature = "e2e-tests")] use std::collections::HashMap; #[cfg(feature = "e2e-tests")] use std::str::FromStr; #[cfg(feature = "e2e-tests")] use uuid::Uuid; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn list_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_lists \ (my_text_list frozen> PRIMARY KEY, \ my_nested_list list>>)"; let session = setup(cql, Version::V4).await.expect("setup"); list_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn list_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_lists \ (my_text_list frozen> PRIMARY KEY, \ my_nested_list list>>)"; let session = setup(cql, Version::V5).await.expect("setup"); list_test(session).await; } #[cfg(feature = "e2e-tests")] async fn list_test(session: CurrentSession) { let my_text_list = vec!["text1", "text2", "text3"]; let my_nested_list: Vec> = vec![vec![1, 2, 3], vec![999, 888, 777, 666, 555], vec![-1, -2]]; let values = query_values!(my_text_list.clone(), my_nested_list.clone()); let cql = "INSERT INTO cdrs_test.test_lists \ (my_text_list, my_nested_list) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert lists error"); let cql = "SELECT * FROM cdrs_test.test_lists"; let rows = session .query(cql) .await .expect("query lists error") .response_body() .expect("get body with lists error") .into_rows() .expect("converting body with lists into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_text_list_row: Vec = row .r_by_name::("my_text_list") .expect("my_text_list") .as_r_rust() .expect("my_text_list as rust"); let my_nested_list_outer_row: Vec = row .r_by_name::("my_nested_list") .expect("my_nested_list") .as_r_rust() .expect("my_nested_list (outer) as rust"); let mut my_nested_list_row = Vec::with_capacity(my_nested_list_outer_row.len()); for my_nested_list_inner_row in my_nested_list_outer_row { let my_nested_list_inner_row: Vec = my_nested_list_inner_row .as_r_rust() .expect("my_nested_list (inner) as rust"); my_nested_list_row.push(my_nested_list_inner_row); } assert_eq!(my_text_list_row, vec!["text1", "text2", "text3"]); assert_eq!(my_nested_list_row, my_nested_list); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn list_advanced_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_lists_v4 \ (my_text_list frozen> PRIMARY KEY, \ my_nested_list list>>)"; let session = setup(cql, Version::V4).await.expect("setup"); list_advanced_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn list_advanced_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_lists_v4 \ (my_text_list frozen> PRIMARY KEY, \ my_nested_list list>>)"; let session = setup(cql, Version::V5).await.expect("setup"); list_advanced_test(session).await; } #[cfg(feature = "e2e-tests")] async fn list_advanced_test(session: CurrentSession) { let my_text_list = vec![ "text1".to_string(), "text2".to_string(), "text3".to_string(), ]; let my_nested_list: Vec> = vec![vec![1, 2, 3], vec![999, 888, 777, 666, 555], vec![-1, -2]]; let values = query_values!(my_text_list.clone(), my_nested_list.clone()); let cql = "INSERT INTO cdrs_test.test_lists_v4 \ (my_text_list, my_nested_list) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_lists_v4"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_text_list_row: Vec = row .r_by_name::("my_text_list") .expect("my_text_list") .as_r_rust() .expect("my_text_list as rust"); let my_nested_list_outer_row: Vec = row .r_by_name::("my_nested_list") .expect("my_nested_list") .as_r_rust() .expect("my_nested_list (outer) as rust"); let mut my_nested_list_row = Vec::with_capacity(my_nested_list_outer_row.len()); for my_nested_list_inner_row in my_nested_list_outer_row { let my_nested_list_inner_row: Vec = my_nested_list_inner_row .as_r_rust() .expect("my_nested_list (inner) as rust"); my_nested_list_row.push(my_nested_list_inner_row); } assert_eq!(my_text_list_row, my_text_list); assert_eq!(my_nested_list_row, my_nested_list); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn set_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_sets \ (my_text_set frozen> PRIMARY KEY, \ my_nested_set set>>)"; let session = setup(cql, Version::V4).await.expect("setup"); set_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn set_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_sets \ (my_text_set frozen> PRIMARY KEY, \ my_nested_set set>>)"; let session = setup(cql, Version::V5).await.expect("setup"); set_test(session).await; } #[cfg(feature = "e2e-tests")] async fn set_test(session: CurrentSession) { let my_text_set = vec![ "text1".to_string(), "text2".to_string(), "text3".to_string(), ]; let my_nested_set: Vec> = vec![vec![-2, -1], vec![1, 2, 3], vec![555, 666, 777, 888, 999]]; let values = query_values!(my_text_set.clone(), my_nested_set.clone()); let cql = "INSERT INTO cdrs_test.test_sets \ (my_text_set, my_nested_set) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_sets"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_text_set_row: Vec = row .r_by_name::("my_text_set") .expect("my_text_set") .as_r_rust() .expect("my_text_set as rust"); let my_nested_set_outer_row: Vec = row .r_by_name::("my_nested_set") .expect("my_nested_set") .as_r_rust() .expect("my_nested_set (outer) as rust"); let mut my_nested_set_row = Vec::with_capacity(my_nested_set_outer_row.len()); for my_nested_set_inner_row in my_nested_set_outer_row { let my_nested_set_inner_row: Vec = my_nested_set_inner_row .as_r_rust() .expect("my_nested_set (inner) as rust"); my_nested_set_row.push(my_nested_set_inner_row); } assert_eq!(my_text_set_row, my_text_set); assert_eq!(my_nested_set_row, my_nested_set); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn set_advanced_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_sets_v4 \ (my_text_set frozen> PRIMARY KEY, \ my_nested_set set>>)"; let session = setup(cql, Version::V4).await.expect("setup"); set_advanced_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn set_advanced_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_sets_v4 \ (my_text_set frozen> PRIMARY KEY, \ my_nested_set set>>)"; let session = setup(cql, Version::V5).await.expect("setup"); set_advanced_test(session).await; } #[cfg(feature = "e2e-tests")] async fn set_advanced_test(session: CurrentSession) { let my_text_set = vec![ "text1".to_string(), "text2".to_string(), "text3".to_string(), ]; let my_nested_set: Vec> = vec![vec![-2, -1], vec![1, 2, 3], vec![555, 666, 777, 888, 999]]; let values = query_values!(my_text_set.clone(), my_nested_set.clone()); let cql = "INSERT INTO cdrs_test.test_sets_v4 \ (my_text_set, my_nested_set) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_sets_v4"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_text_set_row: Vec = row .r_by_name::("my_text_set") .expect("my_text_set") .as_r_rust() .expect("my_text_set as rust"); let my_nested_set_outer_row: Vec = row .r_by_name::("my_nested_set") .expect("my_nested_set") .as_r_rust() .expect("my_nested_set (outer) as rust"); let mut my_nested_set_row = Vec::with_capacity(my_nested_set_outer_row.len()); for my_nested_set_inner_row in my_nested_set_outer_row { let my_nested_set_inner_row: Vec = my_nested_set_inner_row .as_r_rust() .expect("my_nested_set (inner) as rust"); my_nested_set_row.push(my_nested_set_inner_row); } assert_eq!(my_text_set_row, my_text_set); assert_eq!(my_nested_set_row, my_nested_set); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn map_without_blob_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_maps_without_blob \ (my_key int PRIMARY KEY, \ my_text_map map, \ my_nested_map map>>)"; let session = setup(cql, Version::V4).await.expect("setup"); map_without_blob_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn map_without_blob_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_maps_without_blob \ (my_key int PRIMARY KEY, \ my_text_map map, \ my_nested_map map>>)"; let session = setup(cql, Version::V5).await.expect("setup"); map_without_blob_test(session).await; } #[cfg(feature = "e2e-tests")] async fn map_without_blob_test(session: CurrentSession) { let my_text_map = hashmap! { "key1".to_string() => "value1".to_string(), "key2".to_string() => "value2".to_string(), "key3".to_string() => "value3".to_string(), }; let my_nested_map: HashMap> = hashmap! { Uuid::from_str("bb16106a-10bc-4a07-baa3-126ffe208c43").unwrap() => hashmap!{ 1 => 1, 2 => 2, }, Uuid::from_str("687d7677-dbf0-4d25-8cf3-e5d9185bba0b").unwrap() => hashmap!{ 1 => 1, }, Uuid::from_str("c4dc6e8b-758a-4af4-ab00-ec356fb688d9").unwrap() => hashmap!{ 1 => 1, 2 => 2, 3 => 3, }, }; let values = query_values!(0i32, my_text_map.clone(), my_nested_map.clone()); let cql = "INSERT INTO cdrs_test.test_maps_without_blob \ (my_key, my_text_map, my_nested_map) VALUES (?, ?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_maps_without_blob"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_text_map_row: HashMap = row .r_by_name::("my_text_map") .expect("my_text_map") .as_r_rust() .expect("my_text_map as rust"); let my_nested_map_outer_row: HashMap = row .r_by_name::("my_nested_map") .expect("my_nested_map") .as_r_rust() .expect("my_nested_map (outer) as rust"); let mut my_nested_map_row = HashMap::with_capacity(my_nested_map_outer_row.len()); for (index, my_nested_map_inner_row) in my_nested_map_outer_row { let my_nested_map_inner_row: HashMap = my_nested_map_inner_row .as_r_rust() .expect("my_nested_map (inner) as rust"); my_nested_map_row.insert(index, my_nested_map_inner_row); } assert_eq!(my_text_map_row, my_text_map); assert_eq!(my_nested_map_row, my_nested_map); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn map_without_blob_advanced_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_maps_without_blob_v4 \ (my_text_map frozen> PRIMARY KEY, \ my_nested_map map>>)"; let session = setup(cql, Version::V4).await.expect("setup"); map_without_blob_advanced_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn map_without_blob_advanced_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_maps_without_blob_v4 \ (my_text_map frozen> PRIMARY KEY, \ my_nested_map map>>)"; let session = setup(cql, Version::V5).await.expect("setup"); map_without_blob_advanced_test(session).await; } #[cfg(feature = "e2e-tests")] async fn map_without_blob_advanced_test(session: CurrentSession) { let my_text_map = hashmap! { "key1".to_string() => "value1".to_string(), "key2".to_string() => "value2".to_string(), "key3".to_string() => "value3".to_string(), }; let my_nested_map: HashMap> = hashmap! { Uuid::from_str("bb16106a-10bc-4a07-baa3-126ffe208c43").unwrap() => hashmap!{ 1 => 1, 2 => 2, }, Uuid::from_str("687d7677-dbf0-4d25-8cf3-e5d9185bba0b").unwrap() => hashmap!{ 1 => 1, }, Uuid::from_str("c4dc6e8b-758a-4af4-ab00-ec356fb688d9").unwrap() => hashmap!{ 1 => 1, 2 => 2, 3 => 3, }, }; let values = query_values!(my_text_map.clone(), my_nested_map.clone()); let cql = "INSERT INTO cdrs_test.test_maps_without_blob_v4 \ (my_text_map, my_nested_map) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_maps_without_blob_v4"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_text_map_row: HashMap = row .r_by_name::("my_text_map") .expect("my_text_map") .as_r_rust() .expect("my_text_map as rust"); let my_nested_map_outer_row: HashMap = row .r_by_name::("my_nested_map") .expect("my_nested_map") .as_r_rust() .expect("my_nested_map (outer) as rust"); let mut my_nested_map_row = HashMap::with_capacity(my_nested_map_outer_row.len()); for (index, my_nested_map_inner_row) in my_nested_map_outer_row { let my_nested_map_inner_row: HashMap = my_nested_map_inner_row .as_r_rust() .expect("my_nested_map (inner) as rust"); my_nested_map_row.insert(index, my_nested_map_inner_row); } assert_eq!(my_text_map_row, my_text_map); assert_eq!(my_nested_map_row, my_nested_map); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn map_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_maps \ (my_text_map frozen> PRIMARY KEY, \ my_nested_map map>>)"; let session = setup(cql, Version::V4).await.expect("setup"); map_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn map_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_maps \ (my_text_map frozen> PRIMARY KEY, \ my_nested_map map>>)"; let session = setup(cql, Version::V5).await.expect("setup"); map_test(session).await; } #[cfg(feature = "e2e-tests")] async fn map_test(session: CurrentSession) { let my_text_map = hashmap! { "key1".to_string() => "value1".to_string(), "key2".to_string() => "value2".to_string(), "key3".to_string() => "value3".to_string(), }; let my_nested_map: HashMap> = hashmap! { Uuid::from_str("bb16106a-10bc-4a07-baa3-126ffe208c43").unwrap() => hashmap!{ 1 => vec![52, 121, 209, 200, 81, 118, 181, 17].into(), 2 => vec![226, 90, 51, 10, 26, 87, 141, 61].into(), }, Uuid::from_str("687d7677-dbf0-4d25-8cf3-e5d9185bba0b").unwrap() => hashmap!{ 1 => vec![224, 155, 148, 6, 217, 96, 120, 38].into(), }, Uuid::from_str("c4dc6e8b-758a-4af4-ab00-ec356fb688d9").unwrap() => hashmap!{ 1 => vec![164, 238, 196, 10, 149, 169, 145, 239].into(), 2 => vec![250, 87, 119, 134, 105, 236, 240, 64].into(), 3 => vec![72, 81, 26, 173, 107, 96, 38, 91].into(), }, }; let values = query_values!(my_text_map.clone(), my_nested_map.clone()); let cql = "INSERT INTO cdrs_test.test_maps \ (my_text_map, my_nested_map) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_maps"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_text_map_row: HashMap = row .r_by_name::("my_text_map") .expect("my_text_map") .as_r_rust() .expect("my_text_map as rust"); let my_nested_map_outer_row: HashMap = row .r_by_name::("my_nested_map") .expect("my_nested_map") .as_r_rust() .expect("my_nested_map (outer) as rust"); let mut my_nested_map_row = HashMap::with_capacity(my_nested_map_outer_row.len()); for (index, my_nested_map_inner_row) in my_nested_map_outer_row { let my_nested_map_inner_row: HashMap = my_nested_map_inner_row .as_r_rust() .expect("my_nested_map (inner) as rust"); my_nested_map_row.insert(index, my_nested_map_inner_row); } assert_eq!(my_text_map_row, my_text_map); assert_eq!(my_nested_map_row, my_nested_map); } } ================================================ FILE: cdrs-tokio/tests/common.rs ================================================ #[cfg(feature = "e2e-tests")] use std::sync::Arc; #[cfg(feature = "e2e-tests")] use cassandra_protocol::frame::Version; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::Session; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::{SessionBuilder, TcpSessionBuilder}; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::NodeTcpConfigBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::TcpConnectionManager; #[cfg(feature = "e2e-tests")] use cdrs_tokio::error::Result; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::retry::NeverReconnectionPolicy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::transport::TransportTcp; #[cfg(feature = "e2e-tests")] use regex::Regex; #[cfg(feature = "e2e-tests")] pub const ADDR: &str = "127.0.0.1:9042"; #[cfg(feature = "e2e-tests")] pub type CurrentSession = Session< TransportTcp, TcpConnectionManager, RoundRobinLoadBalancingStrategy, >; #[cfg(feature = "e2e-tests")] #[allow(dead_code)] pub async fn setup(create_table_cql: &'static str, version: Version) -> Result { setup_multiple(&[create_table_cql], version).await } #[cfg(feature = "e2e-tests")] pub async fn setup_multiple( create_cqls: &[&'static str], version: Version, ) -> Result { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point(ADDR.into()) .with_version(version) .build() .await .unwrap(); let session = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .build() .await .unwrap(); let re_table_name = Regex::new(r"CREATE TABLE IF NOT EXISTS (\w+\.\w+)").unwrap(); let create_keyspace_query = "CREATE KEYSPACE IF NOT EXISTS cdrs_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 1} \ AND durable_writes = false"; session.query(create_keyspace_query).await?; for create_cql in create_cqls.iter() { let table_name = re_table_name .captures(create_cql) .map(|cap| cap.get(1).unwrap().as_str()); // Re-using tables is a lot faster than creating/dropping them for every test. // But if table definitions change while editing tests // the old tables need to be dropped. For example by uncommenting the following lines. // if let Some(table_name) = table_name { // let cql = format!("DROP TABLE IF EXISTS {}", table_name); // let query = QueryBuilder::new(cql).finalize(); // session.query(query, true, true)?; // } session.query(create_cql.to_owned()).await?; if let Some(table_name) = table_name { let cql = format!("TRUNCATE TABLE {table_name}"); session.query(cql).await?; } } Ok(session) } ================================================ FILE: cdrs-tokio/tests/compression.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use cassandra_protocol::compression::Compression; #[cfg(feature = "e2e-tests")] use cassandra_protocol::frame::Version; #[cfg(feature = "e2e-tests")] use cassandra_protocol::types::blob::Blob; #[cfg(feature = "e2e-tests")] use cassandra_protocol::types::ByIndex; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::{SessionBuilder, TcpSessionBuilder}; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::NodeTcpConfigBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::retry::NeverReconnectionPolicy; #[cfg(feature = "e2e-tests")] use common::*; #[cfg(feature = "e2e-tests")] use rand::prelude::*; #[cfg(feature = "e2e-tests")] use rand::rng; #[cfg(feature = "e2e-tests")] use std::sync::Arc; #[cfg(feature = "e2e-tests")] async fn encode_decode_test(version: Version) { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point(ADDR.into()) .with_version(version) .build() .await .unwrap(); let session = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .with_compression(Compression::Lz4) .build() .await .unwrap(); session .query( "CREATE KEYSPACE IF NOT EXISTS cdrs_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 1} \ AND durable_writes = false", ) .await .unwrap(); session .query( "CREATE TABLE IF NOT EXISTS cdrs_test.test_compression (pk int PRIMARY KEY, data blob)", ) .await .unwrap(); let mut rng = rng(); let mut data = vec![0u8; 5 * 1024 * 1024]; for elem in &mut data { *elem = rng.random(); } let blob = Blob::new(data); session .query_with_values( "INSERT INTO cdrs_test.test_compression (pk, data) VALUES (1, ?)", query_values!(blob.clone()), ) .await .unwrap(); let stored_data: Blob = session .query("SELECT data FROM cdrs_test.test_compression WHERE pk = 1") .await .unwrap() .response_body() .unwrap() .into_rows() .unwrap() .first() .unwrap() .r_by_index::(0) .unwrap(); assert_eq!(stored_data, blob); } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn encode_decode_test_v4() { encode_decode_test(Version::V4).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn encode_decode_test_v5() { encode_decode_test(Version::V5).await; } ================================================ FILE: cdrs-tokio/tests/derive_traits.rs ================================================ #![cfg(feature = "derive")] mod common; #[cfg(feature = "e2e-tests")] use cassandra_protocol::frame::Version; #[cfg(feature = "e2e-tests")] use cdrs_tokio::consistency::Consistency; #[cfg(feature = "e2e-tests")] use cdrs_tokio::frame::TryFromRow; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query::QueryValues; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::statement::StatementParamsBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::blob::Blob; #[cfg(feature = "e2e-tests")] use cdrs_tokio::IntoCdrsValue; #[cfg(feature = "e2e-tests")] use cdrs_tokio::{TryFromRow, TryFromUdt}; #[cfg(feature = "e2e-tests")] use common::*; #[cfg(feature = "e2e-tests")] use maplit::hashmap; #[cfg(feature = "e2e-tests")] use std::collections::HashMap; #[cfg(feature = "e2e-tests")] use std::str::FromStr; #[cfg(feature = "e2e-tests")] use time::PrimitiveDateTime; #[cfg(feature = "e2e-tests")] use uuid::Uuid; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn simple_udt_v4() { let create_type_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.derive_udt (my_text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_derived_udt \ (my_key int PRIMARY KEY, my_udt derive_udt, my_uuid uuid, my_blob blob)"; let session = setup_multiple(&[create_type_cql, create_table_cql], Version::V4) .await .expect("setup"); simple_udt_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn simple_udt_v5() { let create_type_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.derive_udt (my_text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_derived_udt \ (my_key int PRIMARY KEY, my_udt derive_udt, my_uuid uuid, my_blob blob)"; let session = setup_multiple(&[create_type_cql, create_table_cql], Version::V5) .await .expect("setup"); simple_udt_test(session).await; } #[cfg(feature = "e2e-tests")] async fn simple_udt_test(session: CurrentSession) { #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { my_key: i32, my_udt: MyUdt, my_uuid: Uuid, my_blob: Blob, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("my_key" => self.my_key, "my_udt" => self.my_udt, "my_uuid" => self.my_uuid, "my_blob" => self.my_blob) } } #[derive(Debug, Clone, PartialEq, IntoCdrsValue, TryFromUdt)] struct MyUdt { pub my_text: String, } let row_struct = RowStruct { my_key: 1i32, my_udt: MyUdt { my_text: "my_text".to_string(), }, my_uuid: Uuid::from_str("bb16106a-10bc-4a07-baa3-126ffe208c43").unwrap(), my_blob: Blob::new(vec![]), }; let cql = "INSERT INTO cdrs_test.test_derived_udt \ (my_key, my_udt, my_uuid, my_blob) VALUES (?, ?, ?, ?)"; session .query_with_values(cql, row_struct.clone().into_query_values()) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_derived_udt"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_udt_row: RowStruct = RowStruct::try_from_row(row).expect("into RowStruct"); assert_eq!(my_udt_row, row_struct); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn nested_udt_v4() { let create_type1_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.nested_inner_udt (my_text text)"; let create_type2_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.nested_outer_udt \ (my_inner_udt frozen)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_nested_udt \ (my_key int PRIMARY KEY, my_outer_udt nested_outer_udt)"; let session = setup_multiple( &[create_type1_cql, create_type2_cql, create_table_cql], Version::V4, ) .await .expect("setup"); nested_udt_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn nested_udt_v5() { let create_type1_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.nested_inner_udt (my_text text)"; let create_type2_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.nested_outer_udt \ (my_inner_udt frozen)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_nested_udt \ (my_key int PRIMARY KEY, my_outer_udt nested_outer_udt)"; let session = setup_multiple( &[create_type1_cql, create_type2_cql, create_table_cql], Version::V5, ) .await .expect("setup"); nested_udt_test(session).await; } #[cfg(feature = "e2e-tests")] async fn nested_udt_test(session: CurrentSession) { #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { my_key: i32, my_outer_udt: MyOuterUdt, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("my_key" => self.my_key, "my_outer_udt" => self.my_outer_udt) } } #[derive(Clone, Debug, IntoCdrsValue, TryFromUdt, PartialEq)] struct MyInnerUdt { pub my_text: String, } #[derive(Clone, Debug, IntoCdrsValue, TryFromUdt, PartialEq)] struct MyOuterUdt { pub my_inner_udt: MyInnerUdt, } let row_struct = RowStruct { my_key: 0, my_outer_udt: MyOuterUdt { my_inner_udt: MyInnerUdt { my_text: "my_text".to_string(), }, }, }; let cql = "INSERT INTO cdrs_test.test_nested_udt \ (my_key, my_outer_udt) VALUES (?, ?)"; session .query_with_values(cql, row_struct.clone().into_query_values()) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_nested_udt"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_row_struct: RowStruct = RowStruct::try_from_row(row).expect("into RowStruct"); assert_eq!(my_row_struct, row_struct); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn alter_udt_add_v4() { let drop_table_cql = "DROP TABLE IF EXISTS cdrs_test.test_alter_udt_add"; let drop_type_cql = "DROP TYPE IF EXISTS cdrs_test.alter_udt_add_udt"; let create_type_cql = "CREATE TYPE cdrs_test.alter_udt_add_udt (my_text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_alter_udt_add \ (my_key int PRIMARY KEY, my_map frozen>)"; let session = setup_multiple( &[ drop_table_cql, drop_type_cql, create_type_cql, create_table_cql, ], Version::V4, ) .await .expect("setup"); alter_udt_add_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn alter_udt_add_v5() { let drop_table_cql = "DROP TABLE IF EXISTS cdrs_test.test_alter_udt_add"; let drop_type_cql = "DROP TYPE IF EXISTS cdrs_test.alter_udt_add_udt"; let create_type_cql = "CREATE TYPE cdrs_test.alter_udt_add_udt (my_text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_alter_udt_add \ (my_key int PRIMARY KEY, my_map frozen>)"; let session = setup_multiple( &[ drop_table_cql, drop_type_cql, create_type_cql, create_table_cql, ], Version::V5, ) .await .expect("setup"); alter_udt_add_test(session).await; } #[cfg(feature = "e2e-tests")] async fn alter_udt_add_test(session: CurrentSession) { #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { my_key: i32, my_map: HashMap, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("my_key" => self.my_key, "my_map" => self.my_map) } } #[derive(Clone, Debug, IntoCdrsValue, TryFromUdt, PartialEq)] struct MyUdtA { pub my_text: String, } #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStructB { my_key: i32, my_map: HashMap, } #[derive(Clone, Debug, IntoCdrsValue, TryFromUdt, PartialEq)] struct MyUdtB { pub my_text: String, pub my_timestamp: Option, } let row_struct = RowStruct { my_key: 0, my_map: hashmap! { "1".to_string() => MyUdtA {my_text: "my_text".to_string()} }, }; let cql = "INSERT INTO cdrs_test.test_alter_udt_add \ (my_key, my_map) VALUES (?, ?)"; session .query_with_values(cql, row_struct.clone().into_query_values()) .await .expect("insert"); let cql = "ALTER TYPE cdrs_test.alter_udt_add_udt ADD my_timestamp timestamp"; session.query(cql).await.expect("alter type"); let expected_nested_udt = MyUdtB { my_text: "my_text".to_string(), my_timestamp: None, }; let cql = "SELECT * FROM cdrs_test.test_alter_udt_add"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let altered_row: RowStructB = RowStructB::try_from_row(row).expect("into RowStructB"); assert_eq!( altered_row.my_map, hashmap! { "1".to_string() => expected_nested_udt.clone() } ); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn update_list_with_udt_v4() { let drop_table_cql = "DROP TABLE IF EXISTS cdrs_test.update_list_with_udt"; let drop_type_cql = "DROP TYPE IF EXISTS cdrs_test.update_list_with_udt"; let create_type_cql = "CREATE TYPE cdrs_test.update_list_with_udt (id uuid, text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.update_list_with_udt \ (id uuid PRIMARY KEY, udts_set set>)"; let session = setup_multiple( &[ drop_table_cql, drop_type_cql, create_type_cql, create_table_cql, ], Version::V4, ) .await .expect("setup"); update_list_with_udt_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn update_list_with_udt_v5() { let drop_table_cql = "DROP TABLE IF EXISTS cdrs_test.update_list_with_udt"; let drop_type_cql = "DROP TYPE IF EXISTS cdrs_test.update_list_with_udt"; let create_type_cql = "CREATE TYPE cdrs_test.update_list_with_udt (id uuid, text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.update_list_with_udt \ (id uuid PRIMARY KEY, udts_set set>)"; let session = setup_multiple( &[ drop_table_cql, drop_type_cql, create_type_cql, create_table_cql, ], Version::V5, ) .await .expect("setup"); update_list_with_udt_test(session).await; } #[cfg(feature = "e2e-tests")] async fn update_list_with_udt_test(session: CurrentSession) { #[derive(Clone, Debug, IntoCdrsValue, TryFromRow, PartialEq)] struct RowStruct { id: Uuid, udts_set: Vec, } impl RowStruct { fn into_query_values(self) -> QueryValues { query_values!("id" => self.id, "udts_set" => self.udts_set) } } #[derive(Clone, Debug, IntoCdrsValue, TryFromUdt, PartialEq)] struct MyUdt { pub id: Uuid, pub text: String, } let row_struct = RowStruct { id: Uuid::parse_str("5bd8877a-e2b2-4d6f-aafd-c3f72a6964cf").expect("row id"), udts_set: vec![MyUdt { id: Uuid::parse_str("08f49fa5-934b-4aff-8a87-f3a3287296ba").expect("udt id"), text: "text".into(), }], }; let cql = "INSERT INTO cdrs_test.update_list_with_udt \ (id, udts_set) VALUES (?, ?)"; session .query_with_values(cql, row_struct.clone().into_query_values()) .await .expect("insert"); let query = session .prepare("UPDATE cdrs_test.update_list_with_udt SET udts_set = udts_set + ? WHERE id = ?") .await .expect("prepare query"); let params = StatementParamsBuilder::new() .with_consistency(Consistency::Quorum) .with_values(query_values!( vec![MyUdt { id: Uuid::parse_str("68f49fa5-934b-4aff-8a87-f3a32872a6ba").expect("udt id"), text: "abc".into(), }], Uuid::parse_str("5bd8877a-e2b2-4d6f-aafd-c3f72a6964cf").unwrap() )); session .exec_with_params(&query, ¶ms.build()) .await .expect("update set"); let expected_row_struct = RowStruct { id: Uuid::parse_str("5bd8877a-e2b2-4d6f-aafd-c3f72a6964cf").expect("row id"), udts_set: vec![ MyUdt { id: Uuid::parse_str("08f49fa5-934b-4aff-8a87-f3a3287296ba").expect("udt id"), text: "text".into(), }, MyUdt { id: Uuid::parse_str("68f49fa5-934b-4aff-8a87-f3a32872a6ba").expect("udt id"), text: "abc".into(), }, ], }; let cql = "SELECT * FROM cdrs_test.update_list_with_udt"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let altered_row: RowStruct = RowStruct::try_from_row(row).expect("into RowStruct"); assert_eq!(altered_row, expected_row_struct); } } ================================================ FILE: cdrs-tokio/tests/keyspace.rs ================================================ #[cfg(feature = "e2e-tests")] use std::collections::HashMap; #[cfg(feature = "e2e-tests")] use std::sync::Arc; #[cfg(feature = "e2e-tests")] use cdrs_tokio::authenticators::NoneAuthenticatorProvider; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::SessionBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::TcpSessionBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::NodeTcpConfigBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::retry::NeverReconnectionPolicy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::map::Map; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::{AsRust, ByName, IntoRustByName}; #[cfg(feature = "e2e-tests")] #[tokio::test] async fn create_keyspace() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let lb = RoundRobinLoadBalancingStrategy::new(); let session = TcpSessionBuilder::new(lb, cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .build() .await .unwrap(); let drop_query = "DROP KEYSPACE IF EXISTS create_ks_test"; let keyspace_dropped = session.query(drop_query).await.is_ok(); assert!(keyspace_dropped, "Should drop new keyspace without errors"); let create_query = "CREATE KEYSPACE IF NOT EXISTS create_ks_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 1} \ AND durable_writes = false"; let keyspace_created = session.query(create_query).await.is_ok(); assert!( keyspace_created, "Should create new keyspace without errors" ); let select_query = "SELECT * FROM system_schema.keyspaces WHERE keyspace_name = 'create_ks_test'"; let keyspace_selected = session .query(select_query) .await .expect("select keyspace query") .response_body() .expect("get select keyspace query body") .into_rows() .expect("convert keyspaces results into rows"); assert_eq!(keyspace_selected.len(), 1); let keyspace = &keyspace_selected[0]; let keyspace_name: String = keyspace .get_r_by_name("keyspace_name") .expect("keyspace name into rust error"); assert_eq!( keyspace_name, "create_ks_test".to_string(), "wrong keyspace name" ); let durable_writes: bool = keyspace .get_r_by_name("durable_writes") .expect("durable writes into rust error"); assert!(!durable_writes, "wrong durable writes"); let mut expected_strategy_options: HashMap = HashMap::new(); expected_strategy_options.insert("replication_factor".to_string(), "1".to_string()); expected_strategy_options.insert( "class".to_string(), "org.apache.cassandra.locator.SimpleStrategy".to_string(), ); let strategy_options: HashMap = keyspace .r_by_name::("replication") .expect("strategy options into rust error") .as_r_rust() .expect("uuid_key_map"); assert_eq!( expected_strategy_options, strategy_options, "wrong strategy options" ); } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn alter_keyspace() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let lb = RoundRobinLoadBalancingStrategy::new(); let session = TcpSessionBuilder::new(lb, cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .build() .await .unwrap(); let drop_query = "DROP KEYSPACE IF EXISTS alter_ks_test"; let keyspace_dropped = session.query(drop_query).await.is_ok(); assert!(keyspace_dropped, "Should drop new keyspace without errors"); let create_query = "CREATE KEYSPACE IF NOT EXISTS alter_ks_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 1} \ AND durable_writes = false"; let keyspace_created = session.query(create_query).await.is_ok(); assert!( keyspace_created, "Should create new keyspace without errors" ); let alter_query = "ALTER KEYSPACE alter_ks_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 3} \ AND durable_writes = false"; assert!( session.query(alter_query).await.is_ok(), "alter should be without errors" ); let select_query = "SELECT * FROM system_schema.keyspaces WHERE keyspace_name = 'alter_ks_test'"; let keyspace_selected = session .query(select_query) .await .expect("select keyspace query") .response_body() .expect("get select keyspace query body") .into_rows() .expect("convert keyspaces results into rows"); assert_eq!(keyspace_selected.len(), 1); let keyspace = &keyspace_selected[0]; let strategy_options: HashMap = keyspace .r_by_name::("replication") .expect("strategy options into rust error") .as_r_rust() .expect("uuid_key_map"); assert_eq!( strategy_options .get("replication_factor") .expect("replication_factor unwrap"), &"3".to_string() ); } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn use_keyspace() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let lb = RoundRobinLoadBalancingStrategy::new(); let session = TcpSessionBuilder::new(lb, cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .build() .await .unwrap(); let create_query = "CREATE KEYSPACE IF NOT EXISTS use_ks_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 1} \ AND durable_writes = false"; let keyspace_created = session.query(create_query).await.is_ok(); assert!( keyspace_created, "Should create new keyspace without errors" ); let use_query = "USE use_ks_test"; let keyspace_used = session .query(use_query) .await .expect("should use selected") .response_body() .expect("should get body") .into_set_keyspace() .expect("set keyspace") .body; assert_eq!(keyspace_used.as_str(), "use_ks_test", "wrong keyspace used"); } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn drop_keyspace() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let lb = RoundRobinLoadBalancingStrategy::new(); let session = TcpSessionBuilder::new(lb, cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .build() .await .unwrap(); let create_query = "CREATE KEYSPACE IF NOT EXISTS drop_ks_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 1} \ AND durable_writes = false"; let keyspace_created = session.query(create_query).await.is_ok(); assert!( keyspace_created, "Should create new keyspace without errors" ); let drop_query = "DROP KEYSPACE drop_ks_test"; let keyspace_dropped = session.query(drop_query).await.is_ok(); assert!(keyspace_dropped, "Should drop new keyspace without errors"); } ================================================ FILE: cdrs-tokio/tests/multi_node_speculative_execution.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::{SessionBuilder, TcpSessionBuilder}; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::NodeTcpConfigBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; #[cfg(feature = "e2e-tests")] use common::*; #[cfg(feature = "e2e-tests")] use std::sync::Arc; #[cfg(feature = "e2e-tests")] use std::time::Duration; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::retry::NeverReconnectionPolicy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::speculative_execution::ConstantSpeculativeExecutionPolicy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::IntoRustByName; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn multi_node_speculative_execution() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point(ADDR.into()) .with_contact_point(ADDR.into()) .with_contact_point(ADDR.into()) .build() .await .unwrap(); let session = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .with_speculative_execution_policy(Box::new(ConstantSpeculativeExecutionPolicy::new( 5, Duration::from_secs(0), ))) .build() .await .unwrap(); let create_keyspace_query = "CREATE KEYSPACE IF NOT EXISTS cdrs_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 1} \ AND durable_writes = false"; session .query(create_keyspace_query) .await .expect("create keyspace error"); let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.single_node_speculative_execution \ (id text PRIMARY KEY)"; session.query(cql).await.expect("create table error"); let query_insert = "INSERT INTO cdrs_test.single_node_speculative_execution \ (id) VALUES (?)"; let items = vec!["1".to_string(), "2".to_string(), "3".to_string()]; for item in items { let values = query_values!(item); session .query_with_values(query_insert, values) .await .expect("insert item error"); } let cql = "SELECT * FROM cdrs_test.single_node_speculative_execution WHERE id IN ?"; let criteria = vec!["1".to_string(), "3".to_string()]; let rows = session .query_with_values(cql, query_values!(criteria.clone())) .await .expect("select values query error") .response_body() .expect("get body error") .into_rows() .expect("converting into rows error"); assert_eq!(rows.len(), criteria.len()); let found_all_matching_criteria = criteria.iter().all(|criteria_item: &String| { rows.iter().any(|row| { let id: String = row.get_r_by_name("id").expect("id"); criteria_item.clone() == id }) }); assert!( found_all_matching_criteria, "should find at least one element for each criteria" ); } ================================================ FILE: cdrs-tokio/tests/multithread.rs ================================================ #[cfg(feature = "e2e-tests")] use cdrs_tokio::authenticators::NoneAuthenticatorProvider; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::{SessionBuilder, TcpSessionBuilder}; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::NodeTcpConfigBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::retry::NeverReconnectionPolicy; #[cfg(feature = "e2e-tests")] use std::sync::Arc; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn multithread() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let no_compression = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .build() .await .unwrap(); no_compression.query("CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };").await.expect("Could not create ks"); no_compression .query("use test_ks;") .await .expect("Keyspace create error"); no_compression.query("create table if not exists user (user_id int primary key) WITH compaction = { 'class' : 'LeveledCompactionStrategy' };").await.expect("Could not create table"); let arc = Arc::new(no_compression); let mut handles = vec![]; for _ in 0..100 { let c = Arc::clone(&arc); handles.push(tokio::spawn( async move { c.query("select * from user").await }, )); } for task in handles { let result = task.await.unwrap(); match result { Ok(_) => { println!("Query went OK"); } Err(e) => { panic!("Query error: {:#?}", e); } } } } ================================================ FILE: cdrs-tokio/tests/native_types.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use cassandra_protocol::frame::Version; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::blob::Blob; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::decimal::Decimal; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::map::Map; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::value::Bytes; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::{AsRust, ByName, IntoRustByName}; #[cfg(feature = "e2e-tests")] use common::*; #[cfg(feature = "e2e-tests")] use float_eq::*; #[cfg(feature = "e2e-tests")] use std::collections::HashMap; #[cfg(feature = "e2e-tests")] use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; #[cfg(feature = "e2e-tests")] use std::str::FromStr; #[cfg(feature = "e2e-tests")] use time::PrimitiveDateTime; #[cfg(feature = "e2e-tests")] use uuid::Uuid; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn string_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_string \ (my_ascii ascii PRIMARY KEY, my_text text, my_varchar varchar)"; let session = setup(cql, Version::V4).await.expect("setup"); string_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn string_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_string \ (my_ascii ascii PRIMARY KEY, my_text text, my_varchar varchar)"; let session = setup(cql, Version::V5).await.expect("setup"); string_test(session).await; } #[cfg(feature = "e2e-tests")] async fn string_test(session: CurrentSession) { let my_ascii = "my_ascii"; let my_text = "my_text"; let my_varchar = "my_varchar"; let values = query_values!(my_ascii, my_text, my_varchar); let query = "INSERT INTO cdrs_test.test_string \ (my_ascii, my_text, my_varchar) VALUES (?, ?, ?)"; session .query_with_values(query, values) .await .expect("insert stings error"); let cql = "SELECT * FROM cdrs_test.test_string"; let rows = session .query(cql) .await .expect("select strings query error") .response_body() .expect("get body error") .into_rows() .expect("converting into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_ascii_row: String = row.get_r_by_name("my_ascii").expect("my_ascii"); let my_text_row: String = row.get_r_by_name("my_text").expect("my_text"); let my_varchar_row: String = row.get_r_by_name("my_varchar").expect("my_varchar"); assert_eq!(my_ascii_row, my_ascii); assert_eq!(my_text_row, my_text); assert_eq!(my_varchar_row, my_varchar); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn counter_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_counter \ (my_bigint bigint PRIMARY KEY, my_counter counter)"; let session = setup(cql, Version::V4).await.expect("setup"); counter_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn counter_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_counter \ (my_bigint bigint PRIMARY KEY, my_counter counter)"; let session = setup(cql, Version::V5).await.expect("setup"); counter_test(session).await; } #[cfg(feature = "e2e-tests")] async fn counter_test(session: CurrentSession) { let my_bigint: i64 = 10_000_000_000_000_000; let my_counter: i64 = 100_000_000; let values = query_values!(my_counter, my_bigint); let query = "UPDATE cdrs_test.test_counter SET my_counter = my_counter + ? \ WHERE my_bigint = ?"; session .query_with_values(query, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_counter"; let rows = session .query(cql) .await .expect("select counter query error") .response_body() .expect("get counter body error") .into_rows() .expect("converting coutner body into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_bigint_row: i64 = row.get_r_by_name("my_bigint").expect("my_bigint"); let my_counter_row: i64 = row.get_r_by_name("my_counter").expect("my_counter"); assert_eq!(my_bigint_row, my_bigint); assert_eq!(my_counter_row, my_counter); } } // TODO varint #[tokio::test] #[cfg(feature = "e2e-tests")] async fn integer_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_integer \ (my_bigint bigint PRIMARY KEY, my_int int, my_boolean boolean)"; let session = setup(cql, Version::V4).await.expect("setup"); integer_test(session).await; } // TODO varint #[tokio::test] #[cfg(feature = "e2e-tests")] async fn integer_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_integer \ (my_bigint bigint PRIMARY KEY, my_int int, my_boolean boolean)"; let session = setup(cql, Version::V5).await.expect("setup"); integer_test(session).await; } #[cfg(feature = "e2e-tests")] async fn integer_test(session: CurrentSession) { let my_bigint: i64 = 10_000_000_000_000_000; let my_int: i32 = 100_000_000; let my_boolean: bool = true; let values = query_values!(my_bigint, my_int, my_boolean); let query = "INSERT INTO cdrs_test.test_integer \ (my_bigint, my_int, my_boolean) VALUES (?, ?, ?)"; session .query_with_values(query, values) .await .expect("insert integers error"); let cql = "SELECT * FROM cdrs_test.test_integer"; let rows = session .query(cql) .await .expect("select integers query error") .response_body() .expect("get body with integers error") .into_rows() .expect("converting body with integers into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_bigint_row: i64 = row.get_r_by_name("my_bigint").expect("my_bigint"); let my_int_row: i32 = row.get_r_by_name("my_int").expect("my_int"); let my_boolean_row: bool = row.get_r_by_name("my_boolean").expect("my_boolean"); assert_eq!(my_bigint_row, my_bigint); assert_eq!(my_int_row, my_int); assert_eq!(my_boolean_row, my_boolean); } } // TODO counter, varint #[tokio::test] #[cfg(feature = "e2e-tests")] async fn integer_advanced_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_integer_v4 \ (my_bigint bigint PRIMARY KEY, my_int int, my_smallint smallint, \ my_tinyint tinyint, my_boolean boolean)"; let session = setup(cql, Version::V4).await.expect("setup"); integer_advanced_test(session).await; } // TODO counter, varint #[tokio::test] #[cfg(feature = "e2e-tests")] async fn integer_advanced_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_integer_v4 \ (my_bigint bigint PRIMARY KEY, my_int int, my_smallint smallint, \ my_tinyint tinyint, my_boolean boolean)"; let session = setup(cql, Version::V5).await.expect("setup"); integer_advanced_test(session).await; } #[cfg(feature = "e2e-tests")] async fn integer_advanced_test(session: CurrentSession) { let my_bigint: i64 = 10_000_000_000_000_000; let my_int: i32 = 100_000_000; let my_smallint: i16 = 10_000; let my_tinyint: i8 = 100; let my_boolean: bool = true; let values = query_values!(my_bigint, my_int, my_smallint, my_tinyint, my_boolean); let query = "INSERT INTO cdrs_test.test_integer_v4 \ (my_bigint, my_int, my_smallint, my_tinyint, my_boolean) VALUES (?, ?, ?, ?, ?)"; session .query_with_values(query, values) .await .expect("insert integers error"); let cql = "SELECT * FROM cdrs_test.test_integer_v4"; let rows = session .query(cql) .await .expect("query integers error") .response_body() .expect("get body with integers error") .into_rows() .expect("converting body with integers into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_bigint_row: i64 = row.get_r_by_name("my_bigint").expect("my_bigint"); let my_int_row: i32 = row.get_r_by_name("my_int").expect("my_int"); let my_smallint_row: i16 = row.get_r_by_name("my_smallint").expect("my_smallint"); let my_tinyint_row: i8 = row.get_r_by_name("my_tinyint").expect("my_tinyint"); let my_boolean_row: bool = row.get_r_by_name("my_boolean").expect("my_boolean"); assert_eq!(my_bigint_row, my_bigint); assert_eq!(my_int_row, my_int); assert_eq!(my_smallint_row, my_smallint); assert_eq!(my_tinyint_row, my_tinyint); assert_eq!(my_boolean_row, my_boolean); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn float_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_float \ (my_float float PRIMARY KEY, my_double double, my_decimal_a decimal, my_decimal_b decimal)"; let session = setup(cql, Version::V4).await.expect("setup"); float_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn float_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_float \ (my_float float PRIMARY KEY, my_double double, my_decimal_a decimal, my_decimal_b decimal)"; let session = setup(cql, Version::V5).await.expect("setup"); float_test(session).await; } #[cfg(feature = "e2e-tests")] async fn float_test(session: CurrentSession) { let my_float: f32 = 123.456; let my_double: f64 = 987.654; let my_decimal_b = i64::MAX; let values = query_values!( my_float, my_double, Decimal::new(12001.into(), 2), Decimal::from(my_decimal_b) ); let query = "INSERT INTO cdrs_test.test_float (my_float, my_double, my_decimal_a, my_decimal_b) VALUES (?, ?, ?, ?)"; session .query_with_values(query, values) .await .expect("insert floats error"); let cql = "SELECT * FROM cdrs_test.test_float"; let rows = session .query(cql) .await .expect("query floats error") .response_body() .expect("get body with floats error") .into_rows() .expect("converting body with floats into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_float_row: f32 = row.get_r_by_name("my_float").expect("my_float"); let my_double_row: f64 = row.get_r_by_name("my_double").expect("my_double"); let my_decimal_row_a: Decimal = row.get_r_by_name("my_decimal_a").expect("my_decimal_a"); let my_decimal_row_b: Decimal = row.get_r_by_name("my_decimal_b").expect("my_decimal_b"); assert_float_eq!(my_float_row, my_float, abs <= f32::EPSILON); assert_float_eq!(my_double_row, my_double, abs <= f64::EPSILON); assert_eq!(my_decimal_row_a, Decimal::new(12001.into(), 2)); assert_eq!(my_decimal_row_b, Decimal::from(my_decimal_b)); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn blob_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_blob \ (my_blob blob PRIMARY KEY, my_mapblob map)"; let session = setup(cql, Version::V4).await.expect("setup"); blob_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn blob_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_blob \ (my_blob blob PRIMARY KEY, my_mapblob map)"; let session = setup(cql, Version::V5).await.expect("setup"); blob_test(session).await; } #[cfg(feature = "e2e-tests")] async fn blob_test(session: CurrentSession) { let my_blob: Blob = vec![0, 1, 2, 4, 8, 16, 32, 64, 128, 255].into(); let my_map: HashMap = [ ("a".to_owned(), b"aaaaa".to_vec().into()), ("b".to_owned(), b"bbbbb".to_vec().into()), ("c".to_owned(), b"ccccc".to_vec().into()), ("d".to_owned(), b"ddddd".to_vec().into()), ] .iter() .cloned() .collect(); let val_map: HashMap = my_map .clone() .into_iter() .map(|(k, v)| (k, Bytes::new(v.into_vec()))) .collect(); let values = query_values!(my_blob.clone(), val_map); let query = "INSERT INTO cdrs_test.test_blob (my_blob, my_mapblob) VALUES (?,?)"; session .query_with_values(query, values) .await .expect("insert blob error"); let cql = "SELECT * FROM cdrs_test.test_blob"; let rows = session .query(cql) .await .expect("query blobs error") .response_body() .expect("get body with blobs error") .into_rows() .expect("converting body with blobs into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_blob_row: Blob = row.get_r_by_name("my_blob").expect("my_blob"); assert_eq!(my_blob_row, my_blob); let my_map_row: HashMap = row .r_by_name::("my_mapblob") .expect("my_mapblob by name") .as_r_rust() .expect("my_mapblob as r rust"); assert_eq!(my_map_row, my_map); } } // TODO timeuuid #[tokio::test] #[cfg(feature = "e2e-tests")] async fn uuid_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_uuid \ (my_uuid uuid PRIMARY KEY)"; let session = setup(cql, Version::V4).await.expect("setup"); uuid_test(session).await; } // TODO timeuuid #[tokio::test] #[cfg(feature = "e2e-tests")] async fn uuid_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_uuid \ (my_uuid uuid PRIMARY KEY)"; let session = setup(cql, Version::V5).await.expect("setup"); uuid_test(session).await; } #[cfg(feature = "e2e-tests")] async fn uuid_test(session: CurrentSession) { let my_uuid = Uuid::from_str("bb16106a-10bc-4a07-baa3-126ffe208c43").unwrap(); let values = query_values!(my_uuid); let query = "INSERT INTO cdrs_test.test_uuid (my_uuid) VALUES (?)"; session .query_with_values(query, values) .await .expect("insert UUID error"); let cql = "SELECT * FROM cdrs_test.test_uuid"; let rows = session .query(cql) .await .expect("query UUID error") .response_body() .expect("get body with UUID error") .into_rows() .expect("conversion body with UUID into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_uuid_row: Uuid = row.get_r_by_name("my_uuid").expect("my_uuid"); assert_eq!(my_uuid_row, my_uuid); } } // TODO date, time, duration #[tokio::test] #[cfg(feature = "e2e-tests")] async fn time_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_time \ (my_timestamp timestamp PRIMARY KEY)"; let session = setup(cql, Version::V4).await.expect("setup"); time_test(session).await; } // TODO date, time, duration #[tokio::test] #[cfg(feature = "e2e-tests")] async fn time_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_time \ (my_timestamp timestamp PRIMARY KEY)"; let session = setup(cql, Version::V5).await.expect("setup"); time_test(session).await; } #[cfg(feature = "e2e-tests")] async fn time_test(session: CurrentSession) { let my_timestamp: PrimitiveDateTime = time::macros::datetime!(2019-01-01 0:00); let values = query_values!(my_timestamp); let query = "INSERT INTO cdrs_test.test_time (my_timestamp) VALUES (?)"; session .query_with_values(query, values) .await .expect("insert timestamp error"); let cql = "SELECT * FROM cdrs_test.test_time"; let rows = session .query(cql) .await .expect("query with time error") .response_body() .expect("get body with time error") .into_rows() .expect("converting body with time into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_timestamp_row: time::PrimitiveDateTime = row.get_r_by_name("my_timestamp").expect("my_timestamp"); assert_eq!(my_timestamp_row.second(), my_timestamp.second()); assert_eq!( my_timestamp_row.nanosecond() / 1_000_000, my_timestamp.nanosecond() / 1_000_000 ); // C* `timestamp` has millisecond precision } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn inet_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_inet \ (my_inet_v4 inet PRIMARY KEY, my_inet_v6 inet)"; let session = setup(cql, Version::V4).await.expect("setup"); inet_test(session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn inet_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_inet \ (my_inet_v4 inet PRIMARY KEY, my_inet_v6 inet)"; let session = setup(cql, Version::V5).await.expect("setup"); inet_test(session).await; } #[cfg(feature = "e2e-tests")] async fn inet_test(session: CurrentSession) { let my_inet_v4 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); let my_inet_v6 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); let values = query_values!(my_inet_v4, my_inet_v6); let query = "INSERT INTO cdrs_test.test_inet (my_inet_v4, my_inet_v6) VALUES (?, ?)"; session .query_with_values(query, values) .await .expect("insert inet error"); let query = "SELECT * FROM cdrs_test.test_inet"; let rows = session .query(query) .await .expect("query inet error") .response_body() .expect("get body with inet error") .into_rows() .expect("converting body with inet into rows error"); assert_eq!(rows.len(), 1); for row in rows { let my_inet_v4_row: IpAddr = row.get_r_by_name("my_inet_v4").expect("my_inet_v4"); let my_inet_v6_row: IpAddr = row.get_r_by_name("my_inet_v6").expect("my_inet_v6"); assert_eq!(my_inet_v4_row, my_inet_v4); assert_eq!(my_inet_v6_row, my_inet_v6); } } ================================================ FILE: cdrs-tokio/tests/paged_query.rs ================================================ #[cfg(feature = "e2e-tests")] use cdrs_tokio::authenticators::NoneAuthenticatorProvider; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::{SessionBuilder, TcpSessionBuilder}; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::NodeTcpConfigBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::retry::NeverReconnectionPolicy; #[cfg(feature = "e2e-tests")] use std::sync::Arc; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn paged_query() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let lb = RoundRobinLoadBalancingStrategy::new(); let session = TcpSessionBuilder::new(lb, cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .build() .await .unwrap(); session .query( "CREATE KEYSPACE IF NOT EXISTS test_ks WITH REPLICATION = { \ 'class' : 'SimpleStrategy', 'replication_factor' : 1 };", ) .await .expect("Keyspace creation error"); session .query("use test_ks") .await .expect("Using keyspace went wrong"); session.query("create table if not exists user (user_id int primary key) WITH compaction = { 'class' : 'LeveledCompactionStrategy' };").await.expect("Could not create table"); for i in 0..=9 { session .query(format!("insert into user(user_id) values ({i})")) .await .expect("Could not create table"); } let mut pager = session.paged(3); let mut query_pager = pager.query("SELECT * FROM user"); // This returns always false the first time assert!(!query_pager.has_more()); let rows = query_pager.next().await.expect("pager next"); assert_eq!(3, rows.len()); assert!(query_pager.has_more()); let rows = query_pager.next().await.expect("pager next"); assert_eq!(3, rows.len()); assert!(query_pager.has_more()); let rows = query_pager.next().await.expect("pager next"); assert_eq!(3, rows.len()); assert!(query_pager.has_more()); let rows = query_pager.next().await.expect("pager next"); assert_eq!(1, rows.len()); assert!(!query_pager.has_more()); } ================================================ FILE: cdrs-tokio/tests/query_values.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use cassandra_protocol::frame::Version; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::IntoRustByName; #[cfg(feature = "e2e-tests")] use common::*; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn query_values_in_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_query_values_in \ (id text PRIMARY KEY)"; let session = setup(cql, Version::V4).await.expect("setup"); query_values_in_test(cql, session).await; } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn query_values_in_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_query_values_in \ (id text PRIMARY KEY)"; let session = setup(cql, Version::V5).await.expect("setup"); query_values_in_test(cql, session).await; } #[cfg(feature = "e2e-tests")] async fn query_values_in_test(cql: &str, session: CurrentSession) { session.query(cql).await.expect("create table error"); let query_insert = "INSERT INTO cdrs_test.test_query_values_in \ (id) VALUES (?)"; let items = vec!["1".to_string(), "2".to_string(), "3".to_string()]; for item in items { let values = query_values!(item); session .query_with_values(query_insert, values) .await .expect("insert item error"); } let cql = "SELECT * FROM cdrs_test.test_query_values_in WHERE id IN ?"; let criteria = vec!["1".to_string(), "3".to_string()]; let rows = session .query_with_values(cql, query_values!(criteria.clone())) .await .expect("select values query error") .response_body() .expect("get body error") .into_rows() .expect("converting into rows error"); assert_eq!(rows.len(), criteria.len()); let found_all_matching_criteria = criteria.iter().all(|criteria_item: &String| { rows.iter().any(|row| { let id: String = row.get_r_by_name("id").expect("id"); criteria_item.clone() == id }) }); assert!( found_all_matching_criteria, "should find at least one element for each criteria" ); } ================================================ FILE: cdrs-tokio/tests/single_node_speculative_execution.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::{SessionBuilder, TcpSessionBuilder}; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::NodeTcpConfigBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::RoundRobinLoadBalancingStrategy; #[cfg(feature = "e2e-tests")] use common::*; #[cfg(feature = "e2e-tests")] use std::sync::Arc; #[cfg(feature = "e2e-tests")] use std::time::Duration; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::retry::NeverReconnectionPolicy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::speculative_execution::ConstantSpeculativeExecutionPolicy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::IntoRustByName; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn single_node_speculative_execution() { let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point(ADDR.into()) .build() .await .unwrap(); let session = TcpSessionBuilder::new(RoundRobinLoadBalancingStrategy::new(), cluster_config) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .with_speculative_execution_policy(Box::new(ConstantSpeculativeExecutionPolicy::new( 5, Duration::from_secs(1), ))) .build() .await .unwrap(); let create_keyspace_query = "CREATE KEYSPACE IF NOT EXISTS cdrs_test WITH \ replication = {'class': 'SimpleStrategy', 'replication_factor': 1} \ AND durable_writes = false"; session .query(create_keyspace_query) .await .expect("create keyspace error"); let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.single_node_speculative_execution \ (id text PRIMARY KEY)"; session.query(cql).await.expect("create table error"); let query_insert = "INSERT INTO cdrs_test.single_node_speculative_execution \ (id) VALUES (?)"; let items = vec!["1".to_string(), "2".to_string(), "3".to_string()]; for item in items { let values = query_values!(item); session .query_with_values(query_insert, values) .await .expect("insert item error"); } let cql = "SELECT * FROM cdrs_test.single_node_speculative_execution WHERE id IN ?"; let criteria = vec!["1".to_string(), "3".to_string()]; let rows = session .query_with_values(cql, query_values!(criteria.clone())) .await .expect("select values query error") .response_body() .expect("get body error") .into_rows() .expect("converting into rows error"); assert_eq!(rows.len(), criteria.len()); let found_all_matching_criteria = criteria.iter().all(|criteria_item: &String| { rows.iter().any(|row| { let id: String = row.get_r_by_name("id").expect("id"); criteria_item.clone() == id }) }); assert!( found_all_matching_criteria, "should find at least one element for each criteria" ); } ================================================ FILE: cdrs-tokio/tests/topology_aware.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use cdrs_tokio::authenticators::NoneAuthenticatorProvider; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::SessionBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::session::TcpSessionBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::cluster::NodeTcpConfigBuilder; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::node_distance_evaluator::TopologyAwareNodeDistanceEvaluator; #[cfg(feature = "e2e-tests")] use cdrs_tokio::load_balancing::TopologyAwareLoadBalancingStrategy; #[cfg(feature = "e2e-tests")] use std::sync::Arc; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::retry::NeverReconnectionPolicy; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::IntoRustByName; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn query_topology_aware() { // this test is essentially the same as query values, but checks if topology aware load // balancing works let cluster_config = NodeTcpConfigBuilder::new() .with_contact_point("127.0.0.1:9042".into()) .with_authenticator_provider(Arc::new(NoneAuthenticatorProvider)) .build() .await .unwrap(); let session = TcpSessionBuilder::new( TopologyAwareLoadBalancingStrategy::new(None, false), cluster_config, ) .with_reconnection_policy(Arc::new(NeverReconnectionPolicy)) .with_node_distance_evaluator(Box::new(TopologyAwareNodeDistanceEvaluator::new( "datacenter1".into(), ))) .build() .await .unwrap(); let create_keyspace_query = "CREATE KEYSPACE IF NOT EXISTS cdrs_test WITH \ replication = {'class': 'NetworkTopologyStrategy', 'datacenter1': 1} \ AND durable_writes = false"; session.query(create_keyspace_query).await.unwrap(); let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_query_values_in \ (id text PRIMARY KEY)"; session.query(cql).await.unwrap(); let query_insert = "INSERT INTO cdrs_test.test_query_values_in \ (id) VALUES (?)"; let items = vec!["1".to_string(), "2".to_string(), "3".to_string()]; for item in items { let values = query_values!(item); session .query_with_values(query_insert, values) .await .expect("insert item error"); } let cql = "SELECT * FROM cdrs_test.test_query_values_in WHERE id IN ?"; let criteria = vec!["1".to_string(), "3".to_string()]; let rows = session .query_with_values(cql, query_values!(criteria.clone())) .await .expect("select values query error") .response_body() .expect("get body error") .into_rows() .expect("converting into rows error"); assert_eq!(rows.len(), criteria.len()); let found_all_matching_criteria = criteria.iter().all(|criteria_item: &String| { rows.iter().any(|row| { let id: String = row.get_r_by_name("id").expect("id"); criteria_item.clone() == id }) }); assert!( found_all_matching_criteria, "should find at least one element for each criteria" ); } ================================================ FILE: cdrs-tokio/tests/tuple_types.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use common::*; #[cfg(feature = "e2e-tests")] use cassandra_protocol::frame::Version; #[cfg(feature = "e2e-tests")] use cdrs_tokio::error::Result; #[cfg(feature = "e2e-tests")] use cdrs_tokio::frame::Serialize; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::blob::Blob; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::tuple::Tuple; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::value::{Bytes, Value}; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::{IntoRustByIndex, IntoRustByName}; #[cfg(feature = "e2e-tests")] use std::io::Cursor; #[cfg(feature = "e2e-tests")] use std::str::FromStr; #[cfg(feature = "e2e-tests")] use time::{ macros::{date, time}, PrimitiveDateTime, }; #[cfg(feature = "e2e-tests")] use uuid::Uuid; #[tokio::test] #[cfg(feature = "e2e-tests")] async fn simple_tuple_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.simple_tuple \ (my_tuple tuple PRIMARY KEY)"; let session = setup(cql, Version::V4).await.expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyTuple { pub my_text: String, pub my_int: i32, } impl MyTuple { pub fn try_from(tuple: Tuple) -> Result { let my_text: String = tuple.get_r_by_index(0)?; let my_int: i32 = tuple.get_r_by_index(1)?; Ok(MyTuple { my_text, my_int }) } } impl From for Bytes { fn from(value: MyTuple) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); let val_bytes: Bytes = value.my_int.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); Bytes::new(bytes) } } let my_tuple = MyTuple { my_text: "my_text".to_string(), my_int: 0, }; let values = query_values!(my_tuple.clone()); let cql = "INSERT INTO cdrs_test.simple_tuple \ (my_tuple) VALUES (?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.simple_tuple"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_tuple_row: Tuple = row.get_r_by_name("my_tuple").expect("my_tuple"); let my_tuple_row = MyTuple::try_from(my_tuple_row).expect("my_tuple as rust"); assert_eq!(my_tuple_row, my_tuple); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn nested_tuples_v4() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_nested_tuples \ (my_key int PRIMARY KEY, \ my_outer_tuple tuple>)"; let session = setup(cql, Version::V4).await.expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyInnerTuple { pub my_text: String, pub my_int: i32, pub my_timestamp: PrimitiveDateTime, } impl MyInnerTuple { pub fn try_from(tuple: Tuple) -> Result { let my_text: String = tuple.get_r_by_index(0)?; let my_int: i32 = tuple.get_r_by_index(1)?; let my_timestamp: PrimitiveDateTime = tuple.get_r_by_index(2)?; Ok(MyInnerTuple { my_text, my_int, my_timestamp, }) } } impl From for Bytes { fn from(value: MyInnerTuple) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); let val_bytes: Bytes = value.my_int.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); let val_bytes: Bytes = value.my_timestamp.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); Bytes::new(bytes) } } #[derive(Debug, Clone, PartialEq)] struct MyOuterTuple { pub my_uuid: Uuid, pub my_blob: Vec, pub my_inner_tuple: MyInnerTuple, } impl MyOuterTuple { pub fn try_from(tuple: Tuple) -> Result { let my_uuid: Uuid = tuple.get_r_by_index(0)?; let my_blob: Blob = tuple.get_r_by_index(1)?; let my_inner_tuple: Tuple = tuple.get_r_by_index(2)?; let my_inner_tuple = MyInnerTuple::try_from(my_inner_tuple).expect("from tuple"); Ok(MyOuterTuple { my_uuid, my_blob: my_blob.into_vec(), my_inner_tuple, }) } } impl From for Bytes { fn from(value: MyOuterTuple) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_uuid.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); let val_bytes: Bytes = Bytes::new(value.my_blob); Value::new(val_bytes).serialize(&mut cursor, Version::V4); let val_bytes: Bytes = value.my_inner_tuple.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); Bytes::new(bytes) } } let my_uuid = Uuid::from_str("bb16106a-10bc-4a07-baa3-126ffe208c43").unwrap(); let my_blob: Vec = vec![0, 1, 2, 4, 8, 16, 32, 64, 128, 255]; let timestamp = PrimitiveDateTime::new(date!(2019 - 01 - 01), time!(3:01)); let my_inner_tuple = MyInnerTuple { my_text: "my_text".to_string(), my_int: 1_000, my_timestamp: timestamp, }; let my_outer_tuple = MyOuterTuple { my_uuid, my_blob, my_inner_tuple, }; let values = query_values!(0i32, my_outer_tuple.clone()); let cql = "INSERT INTO cdrs_test.test_nested_tuples \ (my_key, my_outer_tuple) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_nested_tuples"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_outer_tuple_row: Tuple = row.get_r_by_name("my_outer_tuple").expect("my_outer_tuple"); let my_outer_tuple_row = MyOuterTuple::try_from(my_outer_tuple_row).expect("from tuple"); assert_eq!(my_outer_tuple_row, my_outer_tuple); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn simple_tuple_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.simple_tuple \ (my_tuple tuple PRIMARY KEY)"; let session = setup(cql, Version::V5).await.expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyTuple { pub my_text: String, pub my_int: i32, } impl MyTuple { pub fn try_from(tuple: Tuple) -> Result { let my_text: String = tuple.get_r_by_index(0)?; let my_int: i32 = tuple.get_r_by_index(1)?; Ok(MyTuple { my_text, my_int }) } } impl From for Bytes { fn from(value: MyTuple) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); let val_bytes: Bytes = value.my_int.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); Bytes::new(bytes) } } let my_tuple = MyTuple { my_text: "my_text".to_string(), my_int: 0, }; let values = query_values!(my_tuple.clone()); let cql = "INSERT INTO cdrs_test.simple_tuple \ (my_tuple) VALUES (?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.simple_tuple"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_tuple_row: Tuple = row.get_r_by_name("my_tuple").expect("my_tuple"); let my_tuple_row = MyTuple::try_from(my_tuple_row).expect("my_tuple as rust"); assert_eq!(my_tuple_row, my_tuple); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn nested_tuples_v5() { let cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_nested_tuples \ (my_key int PRIMARY KEY, \ my_outer_tuple tuple>)"; let session = setup(cql, Version::V5).await.expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyInnerTuple { pub my_text: String, pub my_int: i32, pub my_timestamp: PrimitiveDateTime, } impl MyInnerTuple { pub fn try_from(tuple: Tuple) -> Result { let my_text: String = tuple.get_r_by_index(0)?; let my_int: i32 = tuple.get_r_by_index(1)?; let my_timestamp: PrimitiveDateTime = tuple.get_r_by_index(2)?; Ok(MyInnerTuple { my_text, my_int, my_timestamp, }) } } impl From for Bytes { fn from(value: MyInnerTuple) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); let val_bytes: Bytes = value.my_int.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); let val_bytes: Bytes = value.my_timestamp.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); Bytes::new(bytes) } } #[derive(Debug, Clone, PartialEq)] struct MyOuterTuple { pub my_uuid: Uuid, pub my_blob: Vec, pub my_inner_tuple: MyInnerTuple, } impl MyOuterTuple { pub fn try_from(tuple: Tuple) -> Result { let my_uuid: Uuid = tuple.get_r_by_index(0)?; let my_blob: Blob = tuple.get_r_by_index(1)?; let my_inner_tuple: Tuple = tuple.get_r_by_index(2)?; let my_inner_tuple = MyInnerTuple::try_from(my_inner_tuple).expect("from tuple"); Ok(MyOuterTuple { my_uuid, my_blob: my_blob.into_vec(), my_inner_tuple, }) } } impl From for Bytes { fn from(value: MyOuterTuple) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_uuid.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); let val_bytes: Bytes = Bytes::new(value.my_blob); Value::new(val_bytes).serialize(&mut cursor, Version::V5); let val_bytes: Bytes = value.my_inner_tuple.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); Bytes::new(bytes) } } let my_uuid = Uuid::from_str("bb16106a-10bc-4a07-baa3-126ffe208c43").unwrap(); let my_blob: Vec = vec![0, 1, 2, 4, 8, 16, 32, 64, 128, 255]; let timestamp = PrimitiveDateTime::new(date!(2019 - 01 - 01), time!(3:01)); let my_inner_tuple = MyInnerTuple { my_text: "my_text".to_string(), my_int: 1_000, my_timestamp: timestamp, }; let my_outer_tuple = MyOuterTuple { my_uuid, my_blob, my_inner_tuple, }; let values = query_values!(0i32, my_outer_tuple.clone()); let cql = "INSERT INTO cdrs_test.test_nested_tuples \ (my_key, my_outer_tuple) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_nested_tuples"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_outer_tuple_row: Tuple = row.get_r_by_name("my_outer_tuple").expect("my_outer_tuple"); let my_outer_tuple_row = MyOuterTuple::try_from(my_outer_tuple_row).expect("from tuple"); assert_eq!(my_outer_tuple_row, my_outer_tuple); } } ================================================ FILE: cdrs-tokio/tests/user_defined_types.rs ================================================ mod common; #[cfg(feature = "e2e-tests")] use cassandra_protocol::frame::Version; #[cfg(feature = "e2e-tests")] use cdrs_tokio::error::Result; #[cfg(feature = "e2e-tests")] use cdrs_tokio::frame::Serialize; #[cfg(feature = "e2e-tests")] use cdrs_tokio::query_values; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::map::Map; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::udt::Udt; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::value::{Bytes, Value}; #[cfg(feature = "e2e-tests")] use cdrs_tokio::types::{AsRust, IntoRustByName}; #[cfg(feature = "e2e-tests")] use cdrs_tokio_helpers_derive::IntoCdrsValue; #[cfg(feature = "e2e-tests")] use common::*; #[cfg(feature = "e2e-tests")] use maplit::hashmap; #[cfg(feature = "e2e-tests")] use std::collections::HashMap; #[cfg(feature = "e2e-tests")] use std::io::Cursor; #[cfg(feature = "e2e-tests")] use time::PrimitiveDateTime; #[cfg(feature = "e2e-tests")] #[derive(IntoCdrsValue)] #[allow(dead_code)] // this caused stack overflow in rustc pub struct TestStaticStrReference { pub event_name: &'static str, } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn simple_udt_v4() { let create_type_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.simple_udt (my_text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_simple_udt \ (my_key int PRIMARY KEY, my_udt simple_udt)"; let session = setup_multiple(&[create_type_cql, create_table_cql], Version::V4) .await .expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyUdt { pub my_text: String, } impl MyUdt { pub fn try_from(udt: Udt) -> Result { let my_text: String = udt.get_r_by_name("my_text")?; Ok(MyUdt { my_text }) } } impl From for Bytes { fn from(value: MyUdt) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); Bytes::new(bytes) } } let my_udt = MyUdt { my_text: "my_text".to_string(), }; let values = query_values!(0i32, my_udt.clone()); let cql = "INSERT INTO cdrs_test.test_simple_udt \ (my_key, my_udt) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_simple_udt"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_udt_row: Udt = row.get_r_by_name("my_udt").expect("my_udt"); let my_udt_row = MyUdt::try_from(my_udt_row).expect("from udt"); assert_eq!(my_udt_row, my_udt); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn nested_udt_v4() { let create_type1_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.nested_inner_udt (my_text text)"; let create_type2_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.nested_outer_udt \ (my_inner_udt frozen)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_nested_udt \ (my_key int PRIMARY KEY, my_outer_udt nested_outer_udt)"; let session = setup_multiple( &[create_type1_cql, create_type2_cql, create_table_cql], Version::V4, ) .await .expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyInnerUdt { pub my_text: String, } impl MyInnerUdt { pub fn try_from(udt: Udt) -> Result { let my_text: String = udt.get_r_by_name("my_text")?; Ok(MyInnerUdt { my_text }) } } impl From for Bytes { fn from(value: MyInnerUdt) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); Bytes::new(bytes) } } #[derive(Debug, Clone, PartialEq)] struct MyOuterUdt { pub my_inner_udt: MyInnerUdt, } impl MyOuterUdt { pub fn try_from(udt: Udt) -> Result { let my_inner_udt: Udt = udt.get_r_by_name("my_inner_udt")?; let my_inner_udt = MyInnerUdt::try_from(my_inner_udt).expect("from udt"); Ok(MyOuterUdt { my_inner_udt }) } } impl From for Bytes { fn from(value: MyOuterUdt) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_inner_udt.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); Bytes::new(bytes) } } let my_inner_udt = MyInnerUdt { my_text: "my_text".to_string(), }; let my_outer_udt = MyOuterUdt { my_inner_udt }; let values = query_values!(0i32, my_outer_udt.clone()); let cql = "INSERT INTO cdrs_test.test_nested_udt \ (my_key, my_outer_udt) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_nested_udt"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_outer_udt_row: Udt = row.get_r_by_name("my_outer_udt").expect("my_outer_udt"); let my_outer_udt_row = MyOuterUdt::try_from(my_outer_udt_row).expect("from udt"); assert_eq!(my_outer_udt_row, my_outer_udt); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn alter_udt_add_v4() { let drop_table_cql = "DROP TABLE IF EXISTS cdrs_test.test_alter_udt_add"; let drop_type_cql = "DROP TYPE IF EXISTS cdrs_test.alter_udt_add_udt"; let create_type_cql = "CREATE TYPE cdrs_test.alter_udt_add_udt (my_text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_alter_udt_add \ (my_key int PRIMARY KEY, my_map frozen>)"; let session = setup_multiple( &[ drop_table_cql, drop_type_cql, create_type_cql, create_table_cql, ], Version::V4, ) .await .expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyUdtA { pub my_text: String, } impl From for Bytes { fn from(value: MyUdtA) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V4); Bytes::new(bytes) } } #[derive(Debug, Clone, PartialEq)] struct MyUdtB { pub my_text: String, pub my_timestamp: Option, } impl MyUdtB { pub fn try_from(udt: Udt) -> Result { let my_text: String = udt.get_r_by_name("my_text")?; let my_timestamp: Option = udt.get_by_name("my_timestamp")?; Ok(MyUdtB { my_text, my_timestamp, }) } } let my_udt_a = MyUdtA { my_text: "my_text".to_string(), }; let my_map_a = hashmap! { "1" => my_udt_a.clone() }; let values = query_values!(0i32, my_map_a.clone()); let cql = "INSERT INTO cdrs_test.test_alter_udt_add \ (my_key, my_map) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "ALTER TYPE cdrs_test.alter_udt_add_udt ADD my_timestamp timestamp"; session.query(cql).await.expect("alter type"); let my_udt_b = MyUdtB { my_text: my_udt_a.my_text, my_timestamp: None, }; let cql = "SELECT * FROM cdrs_test.test_alter_udt_add"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_map_row: Map = row.get_r_by_name("my_map").expect("my_map"); let my_map_row: HashMap = my_map_row.as_r_rust().expect("my_map as rust"); for (key, my_udt_row) in my_map_row { let my_udt_row = MyUdtB::try_from(my_udt_row).expect("from udt"); assert_eq!(key, "1"); assert_eq!(my_udt_row, my_udt_b); } } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn simple_udt_v5() { let create_type_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.simple_udt (my_text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_simple_udt \ (my_key int PRIMARY KEY, my_udt simple_udt)"; let session = setup_multiple(&[create_type_cql, create_table_cql], Version::V5) .await .expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyUdt { pub my_text: String, } impl MyUdt { pub fn try_from(udt: Udt) -> Result { let my_text: String = udt.get_r_by_name("my_text")?; Ok(MyUdt { my_text }) } } impl From for Bytes { fn from(value: MyUdt) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); Bytes::new(bytes) } } let my_udt = MyUdt { my_text: "my_text".to_string(), }; let values = query_values!(0i32, my_udt.clone()); let cql = "INSERT INTO cdrs_test.test_simple_udt \ (my_key, my_udt) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_simple_udt"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_udt_row: Udt = row.get_r_by_name("my_udt").expect("my_udt"); let my_udt_row = MyUdt::try_from(my_udt_row).expect("from udt"); assert_eq!(my_udt_row, my_udt); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn nested_udt_v5() { let create_type1_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.nested_inner_udt (my_text text)"; let create_type2_cql = "CREATE TYPE IF NOT EXISTS cdrs_test.nested_outer_udt \ (my_inner_udt frozen)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_nested_udt \ (my_key int PRIMARY KEY, my_outer_udt nested_outer_udt)"; let session = setup_multiple( &[create_type1_cql, create_type2_cql, create_table_cql], Version::V5, ) .await .expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyInnerUdt { pub my_text: String, } impl MyInnerUdt { pub fn try_from(udt: Udt) -> Result { let my_text: String = udt.get_r_by_name("my_text")?; Ok(MyInnerUdt { my_text }) } } impl From for Bytes { fn from(value: MyInnerUdt) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); Bytes::new(bytes) } } #[derive(Debug, Clone, PartialEq)] struct MyOuterUdt { pub my_inner_udt: MyInnerUdt, } impl MyOuterUdt { pub fn try_from(udt: Udt) -> Result { let my_inner_udt: Udt = udt.get_r_by_name("my_inner_udt")?; let my_inner_udt = MyInnerUdt::try_from(my_inner_udt).expect("from udt"); Ok(MyOuterUdt { my_inner_udt }) } } impl From for Bytes { fn from(value: MyOuterUdt) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_inner_udt.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); Bytes::new(bytes) } } let my_inner_udt = MyInnerUdt { my_text: "my_text".to_string(), }; let my_outer_udt = MyOuterUdt { my_inner_udt }; let values = query_values!(0i32, my_outer_udt.clone()); let cql = "INSERT INTO cdrs_test.test_nested_udt \ (my_key, my_outer_udt) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "SELECT * FROM cdrs_test.test_nested_udt"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_outer_udt_row: Udt = row.get_r_by_name("my_outer_udt").expect("my_outer_udt"); let my_outer_udt_row = MyOuterUdt::try_from(my_outer_udt_row).expect("from udt"); assert_eq!(my_outer_udt_row, my_outer_udt); } } #[tokio::test] #[cfg(feature = "e2e-tests")] async fn alter_udt_add_v5() { let drop_table_cql = "DROP TABLE IF EXISTS cdrs_test.test_alter_udt_add"; let drop_type_cql = "DROP TYPE IF EXISTS cdrs_test.alter_udt_add_udt"; let create_type_cql = "CREATE TYPE cdrs_test.alter_udt_add_udt (my_text text)"; let create_table_cql = "CREATE TABLE IF NOT EXISTS cdrs_test.test_alter_udt_add \ (my_key int PRIMARY KEY, my_map frozen>)"; let session = setup_multiple( &[ drop_table_cql, drop_type_cql, create_type_cql, create_table_cql, ], Version::V5, ) .await .expect("setup"); #[derive(Debug, Clone, PartialEq)] struct MyUdtA { pub my_text: String, } impl From for Bytes { fn from(value: MyUdtA) -> Bytes { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let val_bytes: Bytes = value.my_text.into(); Value::new(val_bytes).serialize(&mut cursor, Version::V5); Bytes::new(bytes) } } #[derive(Debug, Clone, PartialEq)] struct MyUdtB { pub my_text: String, pub my_timestamp: Option, } impl MyUdtB { pub fn try_from(udt: Udt) -> Result { let my_text: String = udt.get_r_by_name("my_text")?; let my_timestamp: Option = udt.get_by_name("my_timestamp")?; Ok(MyUdtB { my_text, my_timestamp, }) } } let my_udt_a = MyUdtA { my_text: "my_text".to_string(), }; let my_map_a = hashmap! { "1" => my_udt_a.clone() }; let values = query_values!(0i32, my_map_a.clone()); let cql = "INSERT INTO cdrs_test.test_alter_udt_add \ (my_key, my_map) VALUES (?, ?)"; session .query_with_values(cql, values) .await .expect("insert"); let cql = "ALTER TYPE cdrs_test.alter_udt_add_udt ADD my_timestamp timestamp"; session.query(cql).await.expect("alter type"); let my_udt_b = MyUdtB { my_text: my_udt_a.my_text, my_timestamp: None, }; let cql = "SELECT * FROM cdrs_test.test_alter_udt_add"; let rows = session .query(cql) .await .expect("query") .response_body() .expect("get body") .into_rows() .expect("into rows"); assert_eq!(rows.len(), 1); for row in rows { let my_map_row: Map = row.get_r_by_name("my_map").expect("my_map"); let my_map_row: HashMap = my_map_row.as_r_rust().expect("my_map as rust"); for (key, my_udt_row) in my_map_row { let my_udt_row = MyUdtB::try_from(my_udt_row).expect("from udt"); assert_eq!(key, "1"); assert_eq!(my_udt_row, my_udt_b); } } } ================================================ FILE: cdrs-tokio-helpers-derive/Cargo.toml ================================================ [package] name = "cdrs-tokio-helpers-derive" version = "5.0.3" authors = ["Alex Pikalov ", "Kamil Rojewski "] description = "Derive CDRS helper traits" license = "MIT/Apache-2.0" repository = "https://github.com/krojew/cdrs-tokio" edition = "2018" [lib] proc-macro = true [dependencies] itertools = "0.14.0" proc-macro2 = "1.0.103" syn = "2.0.111" quote = "1.0.42" ================================================ FILE: cdrs-tokio-helpers-derive/README.md ================================================ # cdrs-tokio-helpers-derive Procedural macros that derive helper traits for CDRS Cassandra to Rust types conversion back and forth Features: * convert Cassandra primitive types (not lists, sets, maps, UDTs) into Rust * recursively convert Cassandra "collection" types (lists, sets, maps) into Rust * recursively convert Cassandra UDTs into Rust * recursively convert optional fields into Rust * convert Rust primitive types into Cassandra query values * convert Rust "collection" types into Cassandra query values * convert Rust structures into Cassandra query values * convert `Option` into Cassandra query value * generates an insert method for a Rust struct type ================================================ FILE: cdrs-tokio-helpers-derive/src/common.rs ================================================ use itertools::Itertools; use proc_macro2::{Literal, TokenStream}; use quote::*; use syn::spanned::Spanned; use syn::{ parse_str, Data, DataStruct, DeriveInput, Error, Field, Fields, FieldsNamed, GenericArgument, Ident, Path, PathArguments, PathSegment, Result, Type, TypePath, TypeReference, }; pub fn get_struct_fields(ast: &DeriveInput) -> Result> { struct_fields(ast)? .named .iter() .map(|field| { let name = field .ident .clone() .ok_or_else(|| Error::new(field.span(), "Expected a named field!"))?; let value = convert_field_into_rust(field.clone())?; Ok(quote! { #name: #value }) }) .try_collect() } pub fn struct_fields(ast: &DeriveInput) -> Result<&FieldsNamed> { if let Data::Struct(DataStruct { fields: Fields::Named(fields), .. }) = &ast.data { Ok(fields) } else { Err(Error::new(ast.span(), "The derive macro is defined for structs with named fields, not for enums or unit structs")) } } fn extract_type(arg: &GenericArgument) -> Result { match arg { GenericArgument::Type(ty) => Ok(ty.clone()), _ => Err(Error::new(arg.span(), "Expected type argument!")), } } pub fn get_map_params_string(ty: &Type, name: &str) -> Result<(Type, Type)> { match ty { Type::Path(TypePath { path: Path { segments, .. }, .. }) => { match segments.last() { Some(&PathSegment { arguments: PathArguments::AngleBracketed(ref angle_bracketed_data), .. }) => { Ok(( extract_type(angle_bracketed_data.args.first().ok_or_else(|| { Error::new(ty.span(), "Cannot extract map key type") })?)?, extract_type(angle_bracketed_data.args.last().ok_or_else(|| { Error::new(ty.span(), "Cannot extract map value type") })?)?, )) } _ => Err(Error::new(ty.span(), "Cannot infer field type")), } } _ => Err(Error::new( ty.span(), format!("Cannot infer field type {}", get_ident_string(ty, name)?), )), } } fn remove_r(s: String) -> String { if let Some(s) = s.strip_prefix("r#") { s.to_string() } else { s } } fn convert_field_into_rust(field: Field) -> Result { let mut string_name = quote! {}; let span = field.span(); let s = remove_r( field .ident .ok_or_else(|| Error::new(span, "Expected named field!"))? .to_string(), ); string_name.append(Literal::string(s.trim())); let arguments = get_arguments(string_name); into_rust_with_args(&field.ty, arguments, &s) } fn get_arguments(name: TokenStream) -> TokenStream { quote! { &cdrs, #name } } fn into_rust_with_args( field_type: &Type, arguments: TokenStream, name: &str, ) -> Result { let field_type_ident = get_cdrs_type(field_type, name)?; Ok(match get_ident_string(&field_type_ident, name)?.as_str() { "Blob" | "String" | "bool" | "i64" | "i32" | "i16" | "i8" | "f64" | "f32" | "Decimal" | "IpAddr" | "Uuid" | "Timespec" | "PrimitiveDateTime" | "NaiveDateTime" | "DateTime" => { quote! { #field_type_ident::from_cdrs_r(#arguments)? } } "List" => { let list_as_rust = as_rust(field_type, quote! {list}, name)?; quote! { match cdrs_tokio::types::list::List::from_cdrs_r(#arguments) { Ok(ref list) => { #list_as_rust }, _ => return Err("List should not be empty".into()) } } } "Map" => { let map_as_rust = as_rust(field_type, quote! {map}, name)?; quote! { match cdrs_tokio::types::map::Map::from_cdrs_r(#arguments) { Ok(map) => { #map_as_rust }, _ => return Err("Map should not be empty".into()) } } } "Option" => { let opt_type = get_ident_params_string(field_type, name)?; let opt_type_rustified = get_cdrs_type(&opt_type, name)?; let opt_value_as_rust = as_rust(&opt_type, quote! {opt_value}, name)?; if is_non_zero_primitive(&opt_type_rustified, name)? { quote! { #opt_type_rustified::from_cdrs_by_name(#arguments)? } } else { quote! { { match #opt_type_rustified::from_cdrs_by_name(#arguments)? { Some(opt_value) => { let decoded = #opt_value_as_rust; Some(decoded) }, _ => None } } } } } _ => quote! { #field_type::try_from_udt(cdrs_tokio::types::udt::Udt::from_cdrs_r(#arguments)?)? }, }) } fn is_non_zero_primitive(ty: &Type, name: &str) -> Result { get_ident_string(ty, name).map(|ident| { matches!( ident.as_str(), "NonZeroI8" | "NonZeroI16" | "NonZeroI32" | "NonZeroI64" ) }) } fn get_cdrs_type(ty: &Type, name: &str) -> Result { let type_string = get_ident_string(ty, name)?; Ok(match type_string.as_str() { "Blob" => parse_str("Blob").unwrap(), "String" => parse_str("String").unwrap(), "bool" => parse_str("bool").unwrap(), "i64" => parse_str("i64").unwrap(), "i32" => parse_str("i32").unwrap(), "i16" => parse_str("i16").unwrap(), "i8" => parse_str("i8").unwrap(), "f64" => parse_str("f64").unwrap(), "f32" => parse_str("f32").unwrap(), "Decimal" => parse_str("Decimal").unwrap(), "IpAddr" => parse_str("IpAddr").unwrap(), "Uuid" => parse_str("Uuid").unwrap(), "Timespec" => parse_str("Timespec").unwrap(), "PrimitiveDateTime" => parse_str("PrimitiveDateTime").unwrap(), "Vec" => parse_str("cdrs_tokio::types::list::List").unwrap(), "HashMap" => parse_str("cdrs_tokio::types::map::Map").unwrap(), "Option" => parse_str("Option").unwrap(), "NonZeroI8" => parse_str("NonZeroI8").unwrap(), "NonZeroI16" => parse_str("NonZeroI16").unwrap(), "NonZeroI32" => parse_str("NonZeroI32").unwrap(), "NonZeroI64" => parse_str("NonZeroI64").unwrap(), "NaiveDateTime" => parse_str("NaiveDateTime").unwrap(), "DateTime" => parse_str("DateTime").unwrap(), _ => parse_str("cdrs_tokio::types::udt::Udt").unwrap(), }) } fn get_ident<'a>(ty: &'a Type, name: &str) -> Result<&'a Ident> { match ty { Type::Reference(TypeReference { elem, .. }) => get_ident(elem, name), Type::Path(TypePath { path: Path { segments, .. }, .. }) => match segments.last() { Some(PathSegment { ident, .. }) => Ok(ident), _ => Err(Error::new( ty.span(), format!("Cannot infer field type: {}", name), )), }, _ => Err(Error::new( ty.span(), format!("Cannot infer field type: {}", name), )), } } // returns single value decoded and optionally iterative mapping that uses decoded value fn as_rust(ty: &Type, val: TokenStream, name: &str) -> Result { let cdrs_type = get_cdrs_type(ty, name)?; Ok(match get_ident_string(&cdrs_type, name)?.as_str() { "Blob" | "String" | "bool" | "i64" | "i32" | "i16" | "i8" | "f64" | "f32" | "IpAddr" | "Uuid" | "Timespec" | "Decimal" | "PrimitiveDateTime" => val, "List" => { let vec_type = get_ident_params_string(ty, name)?; let inter_rust_type = get_cdrs_type(&vec_type, name)?; let decoded_item = as_rust(&vec_type, quote! {item}, name)?; quote! { { let inner: Vec<#inter_rust_type> = #val.as_r_type()?; let mut decoded: Vec<#vec_type> = Vec::with_capacity(inner.len()); for item in inner { decoded.push(#decoded_item); } decoded } } } "Map" => { let (map_key_type, map_value_type) = get_map_params_string(ty, name)?; let inter_rust_type = get_cdrs_type(&map_value_type, name)?; let decoded_item = as_rust(&map_value_type, quote! {val}, name)?; quote! { { let inner: std::collections::HashMap<#map_key_type, #inter_rust_type> = #val.as_r_type()?; let mut decoded: std::collections::HashMap<#map_key_type, #map_value_type> = std::collections::HashMap::with_capacity(inner.len()); for (key, val) in inner { decoded.insert(key, #decoded_item); } decoded } } } "Option" => { let opt_type = get_ident_params_string(ty, name)?; as_rust(&opt_type, val, name)? } _ => { quote! { #ty::try_from_udt(#val)? } } }) } pub fn get_ident_string(ty: &Type, name: &str) -> Result { get_ident(ty, name).map(|ident| ident.to_string()) } pub fn get_ident_params_string(ty: &Type, name: &str) -> Result { match ty { Type::Path(TypePath { path: Path { segments, .. }, .. }) => match segments.last() { Some(&PathSegment { arguments: PathArguments::AngleBracketed(ref angle_bracketed_data), .. }) => match angle_bracketed_data.args.last() { Some(GenericArgument::Type(ty)) => Ok(ty.clone()), _ => Err(Error::new(ty.span(), "Cannot infer field type")), }, _ => Err(Error::new(ty.span(), "Cannot infer field type")), }, _ => Err(Error::new( ty.span(), format!("Cannot infer field type {}", get_ident_string(ty, name)?), )), } } ================================================ FILE: cdrs-tokio-helpers-derive/src/db_mirror.rs ================================================ use itertools::Itertools; use proc_macro2::TokenStream; use quote::*; use syn::spanned::Spanned; use syn::{DeriveInput, Error, Result}; use crate::common::struct_fields; pub fn impl_db_mirror(ast: &DeriveInput) -> Result { let name = &ast.ident; let idents: Vec<_> = struct_fields(ast)? .named .iter() .map(|f| { f.ident .clone() .ok_or_else(|| Error::new(f.span(), "Expected a named field!")) }) .try_collect()?; let idents_copy = idents.clone(); let fields = idents .iter() .map(|i| i.to_string()) .collect::>(); let names = fields.join(", "); let question_marks = fields .iter() .map(|_| "?".to_string()) .collect::>() .join(", "); Ok(quote! { impl #name { pub fn insert_query() -> &'static str { concat!("insert into ", stringify!(#name), "(", #names, ") values (", #question_marks, ")") } pub fn into_query_values(self) -> cdrs_tokio::query::QueryValues { use std::collections::HashMap; let mut values: HashMap = HashMap::new(); #( values.insert(stringify!(#idents).to_string(), self.#idents_copy.into()); )* cdrs_tokio::query::QueryValues::NamedValues(values) } } }) } ================================================ FILE: cdrs-tokio-helpers-derive/src/into_cdrs_value.rs ================================================ use itertools::Itertools; use proc_macro2::TokenStream; use quote::*; use syn::spanned::Spanned; use syn::{Data, DataStruct, DeriveInput, Error, Result}; use crate::common::get_ident_string; pub fn impl_into_cdrs_value(ast: &DeriveInput) -> Result { let name = &ast.ident; if let Data::Struct(DataStruct { ref fields, .. }) = ast.data { let convert_into_bytes: Vec<_> = fields.iter().map(|field| { let field_ident = field.ident.clone().ok_or_else(|| Error::new(field.span(), "IntoCdrsValue requires all fields be named!"))?; get_ident_string(&field.ty, &field_ident.to_string()).map(|ident| { if ident == "Option" { // We are assuming here primitive value serialization will not change across protocol // versions, which gives us simpler user API. quote! { match value.#field_ident { Some(ref val) => { let field_bytes: Self = val.clone().into(); cdrs_tokio::types::value::Value::new(field_bytes).serialize(&mut cursor, cdrs_tokio::frame::Version::V4); }, None => { cdrs_tokio::types::value::Value::NotSet.serialize(&mut cursor, cdrs_tokio::frame::Version::V4); } } } } else { quote! { let field_bytes: Self = value.#field_ident.into(); cdrs_tokio::types::value::Value::new(field_bytes).serialize(&mut cursor, cdrs_tokio::frame::Version::V4); } } }) }).try_collect()?; // As Value has following implementation impl> From for Value // for a struct it's enough to implement Into in order to be convertible into Value // which is used for making queries Ok(quote! { #[automatically_derived] impl From<#name> for cdrs_tokio::types::value::Bytes { fn from(value: #name) -> Self { use cdrs_tokio::frame::Serialize; let mut bytes: Vec = Vec::new(); let mut cursor = std::io::Cursor::new(&mut bytes); #(#convert_into_bytes)* Self::new(bytes) } } }) } else { Err(Error::new( ast.span(), "#[derive(IntoCdrsValue)] can only be defined for structs!", )) } } ================================================ FILE: cdrs-tokio-helpers-derive/src/lib.rs ================================================ //! This trait provides functionality for derivation `IntoCDRSBytes` trait implementation //! for underlying use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput, Error}; mod common; mod db_mirror; mod into_cdrs_value; mod try_from_row; mod try_from_udt; use crate::db_mirror::impl_db_mirror; use crate::into_cdrs_value::impl_into_cdrs_value; use crate::try_from_row::impl_try_from_row; use crate::try_from_udt::impl_try_from_udt; #[proc_macro_derive(DbMirror)] pub fn db_mirror(input: TokenStream) -> TokenStream { // Parse the string representation let ast = parse_macro_input!(input as DeriveInput); // Build the impl impl_db_mirror(&ast) .unwrap_or_else(Error::into_compile_error) .into() } #[proc_macro_derive(IntoCdrsValue)] pub fn into_cdrs_value(input: TokenStream) -> TokenStream { // Parse the string representation let ast = parse_macro_input!(input as DeriveInput); // Build the impl impl_into_cdrs_value(&ast) .unwrap_or_else(Error::into_compile_error) .into() } #[proc_macro_derive(TryFromRow)] pub fn try_from_row(input: TokenStream) -> TokenStream { // Parse the string representation let ast = parse_macro_input!(input as DeriveInput); // Build the impl impl_try_from_row(&ast) .unwrap_or_else(Error::into_compile_error) .into() } #[proc_macro_derive(TryFromUdt)] pub fn try_from_udt(input: TokenStream) -> TokenStream { // Parse the string representation let ast = parse_macro_input!(input as DeriveInput); // Build the impl impl_try_from_udt(&ast) .unwrap_or_else(Error::into_compile_error) .into() } ================================================ FILE: cdrs-tokio-helpers-derive/src/try_from_row.rs ================================================ use proc_macro2::TokenStream; use quote::*; use syn::{DeriveInput, Result}; use crate::common::get_struct_fields; pub fn impl_try_from_row(ast: &DeriveInput) -> Result { let name = &ast.ident; let fields = get_struct_fields(ast)?; Ok(quote! { #[automatically_derived] impl cdrs_tokio::frame::TryFromRow for #name { fn try_from_row(cdrs: cdrs_tokio::types::rows::Row) -> cdrs_tokio::Result { use cdrs_tokio::frame::TryFromUdt; use cdrs_tokio::types::from_cdrs::FromCdrsByName; use cdrs_tokio::types::IntoRustByName; use cdrs_tokio::types::AsRustType; Ok(#name { #(#fields),* }) } } }) } ================================================ FILE: cdrs-tokio-helpers-derive/src/try_from_udt.rs ================================================ use proc_macro2::TokenStream; use quote::*; use syn::{DeriveInput, Result}; use crate::common::get_struct_fields; pub fn impl_try_from_udt(ast: &DeriveInput) -> Result { let name = &ast.ident; let fields = get_struct_fields(ast)?; Ok(quote! { #[automatically_derived] impl cdrs_tokio::frame::TryFromUdt for #name { fn try_from_udt(cdrs: cdrs_tokio::types::udt::Udt) -> cdrs_tokio::Result { use cdrs_tokio::frame::TryFromUdt; use cdrs_tokio::types::from_cdrs::FromCdrsByName; use cdrs_tokio::types::IntoRustByName; use cdrs_tokio::types::AsRustType; Ok(#name { #(#fields),* }) } } }) } ================================================ FILE: changelog.md ================================================ ## 9.0.1 ### Fixed * Fixe querying peers in `yugabytedb` (by Andries Hiemstra). ## 9.0.0 ### Fixed * Fixed not re-preparing statements in batch queries. ### New * New `Session::prepare_raw_tw_with_query_plan()` function. ### Changed * Removed deprecated functions. ## 8.1.9 ### Fixed * Fixed pool leak (by jojoxhsieh). ## 8.1.8 ### Fixed * Fixed deadlock on transport error (by jojoxhsieh). ## 8.1.7 ### Fixed * Not recreating connections on down event if there are still apparently open ones. ## 8.1.6 ### Fixed * Refreshing node information can preserve invalid `Down` state (by Denis Kosenkov). ## 8.1.5 ### Fixed * Race condition when reconnecting to a cluster with all nodes down (by Denis Kosenkov). ## 8.1.4 ### Fixed * CPU spike after some time running. ## 8.1.3 * Dependency updates. ## 8.1.2 ### Changed * Dependency updates. ## 8.1.1 ### Fixed * Non-fatal errors closing connections. ## 8.1.0 ### Fixed * Sending envelopes now properly jumps to next node in query plan, if current one is unreachable. ### New * `InvalidProtocol` special error for a case when a node doesn't accept requested protocol during handshake. * `ConnectionPoolConfigBuilder` for building configuration easily. * Configurable heartbeat messages to keep connection alive in the pool. ### Changed * Due to an edge case with reconnecting to a seemingly downed node, internal reconnection handling mechanism has been improved. * Hidden internal structures, which were public but not usable in any way. ## 8.0.0 (unavailable) ### Changed * Removed `Ord, ParialOrd` from `QueryFlags`. * Using `rustls` types exported from `tokio-rustls`, rather than depending on `rustls` directly. ## 8.0.0-beta.1 ### Fixed * Fixed stack overflow when cannot determine field type during struct serialization. * Properly supporting references during struct serialization. ### New * Many types are now `Debug`. * HTTP proxy support via the `http-proxy` feature. ### Changed * Made protocol enums non-exhaustive for future compatibility. * Session builders are now async and wait for control connection to be ready before returning a session. * `CBytes::new_empty()` -> `CBytes::new_null()`, `CBytes::is_empty()` -> `CBytes::is_null_or_empty()`. ## 7.0.4 ### Fixed * Invalid Murmur3 hash for keys longer than 15 bytes. ## 7.0.3 ### Fixed * Fixed serialization of routing key with known indexes. ### Changed * Deprecated `query_with_param()` in `Pager`, in favor of `query_with_params()`. ## 7.0.2 ### Fixed * Serializing single PK routing keys by index. * Encoding envelopes with tracing/warning flags. ## 7.0.1 ### Fixed * Overflow when compressed envelope payload exceeds max payload size. * Integer overflow when not received at least 8 header bytes. ## 7.0.0 ### New * `Clone` implemented for `BodyResReady` and `BodyReqExecute`. ### Changed * Control connection errors are now logged as warnings, since they're recoverable. * Exposed fields of `BodyReqAuthResponse` and `BodyReqExecute`. * Replaced `CInet` type with `SocketAddr`, since it was nothing more than a wrapper. ## 7.0.0-beta.2 ### Fixed * Constant control connection re-establishing with legacy clusters. ### New * `ResponseBody::into_error` function. ## 7.0.0-beta.1 ### Fixed * `ExponentialReconnectionSchedule` duration overflow. * Forgetting real error type in certain transport error situations. * Not sending re-preparation statements to correct nodes. * Infinite set keyspace notification loop. ### New * Protocol V5 support. Please look at official changelog for more information: https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v5.spec#L1419. * Support for beta protocols - possibility to connect to beta clusters. * `From` for `BigInt`. * `check_envelope_size` for `Evelope`. * `Error` is now `Clone`. * `FrameEncoder`, `FrameDecoder` and `FrameEncodingFactory` responsible for encoding/decoding frames on the wire. * `with_frame_encoder_factory` Session build option. * `Error` impl for `CheckEnvelopeSizeError` and `ParseEnvelopeError`. * New `Error` variants for more granular error handling. * Node address in `Error::Server` variant. ### Changed * Due to naming changes in V5, frame have been renamed to message, `Frame` to `Envelope` and a frame now corresponds to wrapped envelopes, as defined by the protocol. * `Serialize` and `FromCursor` traits now pass protocol version to implementations. * `Row::from_frame_body` renamed to `from_body`. * `ClusterMetadataManager::find_node` renamed to `find_node_by_rpc_address` for consistency. * `QueryFlags` got extended for V5 and now supports `Serialize` and `FromCursor`. * Session builders now validate given configuration and return a `Result`. * Transport startup now fails gracefully on unexpected server response. * `CdrsTransport` now requires explicit information if messages are a part of initial handshake. * `ResResultBody::as_rows_metadata` and `ResponseBody::as_rows_metadata` now return a reference to the data. * `Hash`, `PartialEq` and `PartialOrd` for `PreparedQuery` only take `id` and `result_metadata_id` into account, since those define equivalence. * Updated `chrono` dependency to work around found CVE. ## 6.2.0 ### New * `derive` feature built into the main crate - no need to explicitly `use cdrs_tokio_helpers_derive::*` anymore. ## 6.1.0 ### New * `#[must_use]` on some functions. ### Fixed * Fixed parsing `NetworkTopologyStrategy`. ## 6.0.0 This version is a departure from legacy API design, stemming from the sync version migration. Due to large performance issues and lack of dynamic topology handling in earlier versions, a decision has been made to cut the ties and focus on delivering the best functionality without legacy burden. The API surface changes are quite large, but everyone is encouraged to update - the performance improvements and new features cannot be understated. ### New * Topology-aware load balancing: `TopologyAwareNodeDistanceEvaluator` and `TopologyAwareLoadBalancingStrategy`. * New `ReconnectionPolicy` used when trying to re-establish connections to downed nodes. * `Error` now implements standard `Error`. * `SessionBuilder` introduced as the preferred way to create a session. * Added missing traits for `BatchType` and `QueryFlags`. * `ToString` implementation for `SimpleServerEvent`. * Standard trait implementations for event frames. * `contains_column`, `is_empty_by_name` and `is_empty` functions for `Row`. * `Display` implementation for public enums. * Missing traits for `PreparedMetadata`, `Value`, `Consistency` and `ColType`. * New `PreparedMetadataFlags`. * New `ClusterMetadata` representing information about a cluster. * Extracted protocol functionality to separate `cassandra-protocol` crate. * Passing final auth data from the server to `SaslAuthenticator`. * `SpeculativeExecutionPolicy` for speculative execution control. ### Changed * All `with_name` fields or args in the query API are now `bool` instead of `Option` * `flags` field removed from `QueryParams` (flags are now derived from the other fields at serialization time) * Rewritten transport layer for massive performance improvements (including removing `bb8`). This involves changing a large portion of public API related to transport and server events. * Rewritten event mechanism - now you can subscribe to server events via `create_event_receiver()` in `Session`. * Replaced `RowsMetadataFlag`, `QueryFlags` and `frame::Flags` vectors with bitflags. * Changed `Target` and `ChangeType` enums to `SchemaChangeTarget` and `SchemaChangeType`. * The `varint` type now uses `num::BigInt` representation (this implies `Decimal` also uses "big" types). * Removed `unstable-dynamic-cluster` feature, since it wasn't working as expected and introduced performance penalty. Dynamic topology handling is now built-in. * Removed `AsBytes` in favor of new `Serialize` trait due to performance penalty. * Removed `FromSingleByte` and `AsByte` in favor of `From`/`TryFrom`. * Removed traits along with `async-trait` dependency: `BatchExecutor`, `ExecExecutor`, `PrepareExecutor`, `QueryExecutor`, `GetConnection` and `CdrsSession`. Everything is now embedded directly in `Session`. * Load balancing strategy now returns query plans, rather than individual nodes, and operates on cluster metadata. * Removed `SingleNode` load balancing strategy. * Removed empty `SimpleError`. * Renamed `connect_generic_static` to `connect_generic`. * Removed `GetRetryPolicy`. * Renamed `ChangeSchemeOptions` to `SchemaChangeOptions`. * Protocol version can now be selected at run time. * `Value` now directly contains the value in the `Some` variant instead of a separate body field. * Consistent naming convention in all builders. * Split protocol-level parameters from high-level statement parameters (`QueryParams` vs `StatementParams`) and simplified API. * `add_query_prepared` for batch queries now takes `PreparedQuery` by reference. ## 5.0.0 ### New * Support for stateful SASL authenticators. ### Changed * Using up-to-date lz4 crate (no more unmaintained dependency alerts). ## 4.0.0 ### Fixed * Build problems with Rustls. * TLS connections sometimes not flushing all data. * Not setting current namespace when not using an authenticator. ### New * New `connect_generic_*` functions allowing custom connection configurations ( see `generic_connection.rs` for example usage). * Possibility to use custom error types which implement `FromCdrsError` throughout the crate. * `Consistency` now implements `FromStr`. * Pagers can be converted into `PagerState`. * Support for v4 marshaled types. * `Copy`, `Clone`, `Ord`, `PartialOrd`, `Eq`, `Hash` for `Opcode`. * Customizable query retry policies with built-in `FallthroughRetrySession` and `DefaultRetryPolicy`. ### Changed * TCP configuration now owns contained data - no need to keep it alive while the config is alive. * `ExecPager` is now public. * `Bytes` now implements `From` for supported types, instead of `Into`. * Moved some generic types to associated types, thus removing a lot of type passing. * `SessionPager` no longer needs mutable session. * A lot of names have been migrated to idiomatic Rust (mainly upper camel case abbreviations). ## 3.0.0 ### Fixed * Remembering `USE`d keyspaces across connections. * Race condition on query id overflow. ### Changed * Removed deprecated `PasswordAuthenticator`. * Removed unused `Compressor` trait. * Large API cleanup. * Renamed `IntoBytes` to `AsBytes`. * `Authenticator` can now be created at runtime - removed static type parameter. * Removed unneeded memory allocations when parsing data. ## 2.1.0 ### Fixed * Recreation of forgotten prepared statements. ### New * `rustls` sessions constructors. ### Changed * Updated `tokio` to 1.1. ## 2.0.0 ### New * Support for `NonZero*` types. * Support for `chrono` `NaiveDateTime` and `DateTime`. * Update `tokio` to 1.0. * `Pager` supporting `QueryValues` and consistency. ## 1.0.0 * Initial release. ================================================ FILE: clippy.toml ================================================ # the main Error enum is a bit above the default limit, but contains very useful data for pattern matching large-error-threshold = 256 ================================================ FILE: documentation/README.md ================================================ # Definitive guide - [Cluster configuration](./cluster-configuration.md). - [CDRS session](./cdrs-session.md): - [Query values](./query-values.md) - [Cassandra-to-Rust deserialization](./deserialization.md). - [Preparing and executing queries](./preparing-and-executing-queries.md). - [Batching multiple queries](./batching-multiple-queries.md). ================================================ FILE: documentation/batching-multiple-queries.md ================================================ ### Batch queries CDRS `Session` supports batching few queries in a single request to Apache Cassandra: ```rust // batch two queries use cdrs_tokio::query::{BatchQueryBuilder, QueryBatch}; let mut queries = BatchQueryBuilder::new(); queries = queries.add_query_prepared(&prepared_query); queries = queries.add_query("INSERT INTO my.store (my_int) VALUES (?)", query_values!(1 as i32)); session.batch_with_params(queries.finalyze()).await; ``` ================================================ FILE: documentation/cdrs-session.md ================================================ # CDRS Session `Session` is a structure that holds as set pools of connections authorised by a Cluster. As well, it provides data decompressing and load balancing mechanisms used for Cassandra frame exchange, or querying in other words. In order to create new session a [cluster config](./cluster-configuration.md) and a load balancing strategy must be provided. Load balancing strategy is used when some query should be performed by driver. At that moment load balancer returns a connection for a node that was picked up in accordance to a strategy. Such logic guarantees that nodes' loads are balanced and there is no need to establish new connection if there is a one that is released after previous query. ## Load balancing Any structure that implements `LoadBalancingStrategy` trait can be used in `Session` as a load balancer. CDRS provides few strategies out of the box so no additional development may not be needed: - `RandomLoadBalancingStrategy` randomly picks up a node from a cluster. - `RoundRobinLoadBalancingStrategy` thread safe round-robin balancing strategy. - `TopologyAwareLoadBalancingStrategy` policy taking dynamic cluster topology into account. Along with that any custom load balancing strategy may be implemented and used with CDRS. The only requirement is the structure must implement `LoadBalancingStrategy` trait. ## Data compression CQL binary protocol allows using LZ4 and Snappy (for protocol version < 5) data compression in order to reduce traffic between Node and Client. CDRS provides methods for creating `Session` with different compression contexts: LZ4 and Snappy. ### Reference 1. LZ4 compression algorithm https://en.wikipedia.org/wiki/LZ4_(compression_algorithm). 2. Snappy compression algorithm https://en.wikipedia.org/wiki/Snappy_(compression). ================================================ FILE: documentation/cluster-configuration.md ================================================ ### Cluster configuration Apache Cassandra is designed to be a scalable and higly available database. So most often developers work with multi node Cassandra clusters. For instance Apple's setup includes 75k nodes, Netflix 2.5k nodes, Ebay >100 nodes. That's why CDRS driver was designed with multi-node support in mind. In order to connect to Cassandra cluster via CDRS connection configuration should be provided: ```rust use cdrs_tokio::authenticators::NoneAuthenticatorProvider; use cdrs_tokio::cluster::NodeTcpConfigBuilder; fn main() { let cluster_config = NodeTcpConfigBuilder::new("127.0.0.1:9042".parse().unwrap(), Arc::new(NoneAuthenticatorProvider)).build(); // ... } ``` For each node configuration, `SaslAuthenticatorProvider` should be provided. `SaslAuthenticatorProvider` is a trait that the structure should implement so it can be used by CDRS session for authentication. Out of the box CDRS provides two types of authenticators: - `cdrs_tokio::authenticators::NoneAuthenticatorProvider` that should be used if authentication is disabled by a node ([Cassandra authenticator](http://cassandra.apache.org/doc/latest/configuration/cassandra_config_file.html#authenticator) is set to `AllowAllAuthenticator`) on server. - `cdrs_tokio::authenticators::StaticPasswordAuthenticatorProvider` that should be used if authentication is enabled on the server and [authenticator](http://cassandra.apache.org/doc/latest/configuration/cassandra_config_file.html#authenticator) is `PasswordAuthenticator`. ```rust use cdrs_tokio::authenticators::StaticPasswordAuthenticatorProvider; let authenticator = StaticPasswordAuthenticatorProvider::new("user", "pass"); ``` If a node has a custom authentication strategy, corresponded `SaslAuthenticatorProvider` should be implemented by a developer and further used in `NodeTcpConfigBuilder`. ### Reference 1. Cassandra cluster configuration https://docs.datastax.com/en/cassandra/3.0/cassandra/initialize/initTOC.html. 2. ScyllaDB cluster configuration https://docs.scylladb.com/operating-scylla/ (see Cluster Management section). ================================================ FILE: documentation/deserialization.md ================================================ ### Mapping results into Rust structures In order to query information from Cassandra DB and transform results to Rust types and structures, each row in a query result should be transformed leveraging one of following traits provided by CDRS `cdrs_tokio::types::{AsRustType, AsRust, IntoRustByName, ByName, IntoRustByIndex, ByIndex}`. - `AsRustType` may be used in order to transform such complex structures as Cassandra lists, sets, tuples. The Cassandra value in this case could non-set and null values. - `AsRust` trait may be used for similar purposes as `AsRustType` but it assumes that Cassandra value is neither non-set nor null value. Otherwise, it panics. - `IntoRustByName` trait may be used to access a value as a Rust structure/type by name. Such as in case of rows where each column has its own name, and maps. These values may be as well non-set and null. - `ByName` trait is the same as `IntoRustByName` but value should be neither non-set nor null. Otherwise, it panics. - `IntoRustByIndex` is the same as `IntoRustByName` but values could be accessed via column index basing on their order provided in query. These values may be as well non-set and null. - `ByIndex` is the same as `IntoRustByIndex` but value can be neither non-set nor null. Otherwise, it panics. Relations between Cassandra and Rust types are described in [type-mapping](type-mapping.md). For details see examples. ================================================ FILE: documentation/preparing-and-executing-queries.md ================================================ ### Preparing queries During preparing a query a server parses the query, saves parsing result into cache and returns to a client an ID that could be further used for executing prepared statement with different parameters (such as values, consistency etc.). When a server executes prepared query it doesn't need to parse it so parsing step will be skipped. ```rust let prepared_query = session.prepare("INSERT INTO my.store (my_int, my_bigint) VALUES (?, ?)").await.unwrap(); ``` ### Executing prepared queries When query is prepared on the server client gets prepared query id of type `cdrs_tokio::query::PreparedQuery`. Having such id it's possible to execute prepared query using session methods: ```rust // execute prepared query without specifying any extra parameters or values session.exec(&preparedQuery).await.unwrap(); // to execute prepared query with bound values, use exec_with_values() // to execute prepared query with advanced parameters, use exec_with_params() ``` ================================================ FILE: documentation/query-values.md ================================================ # Query `Value` Query `Value`-s can be used along with query string templates. Query string templates is a special sort of query string that contains `?` sign. `?` will be substituted by CDRS driver with query `Values`. For instance: ```rust const INSERT_NUMBERS_QUERY: &'static str = "INSERT INTO my.numbers (my_int, my_bigint) VALUES (?, ?)"; let values = query_values!(1 as i32, 1 as i64); session.query_with_values(INSERT_NUMBERS_QUERY, values).await.unwrap(); ``` `INSERT_NUMBERS_QUERY` is a typical query template. `session::query_with_values` method provides an API for using such query strings along with query values. There is full list of `Session` methods that allow using values and query templates: - `exec_with_values` - executes previously prepared query with provided values (see [example](../examples/prepare_batch_execute.rs) and/or [Preparing and Executing](./preparing-and-executing-queries.md) section); - `query_with_params_tw` - immediately executes a query using provided values (see [example](../examples/crud_operations.rs)) ## Simple `Value` and `Value` with names There are two type of query values supported by CDRS: - simple `Value`-s may be imagined as a tuple of actual values. This values will be inserted instead of a `?` that has the same index number as a `Value` within a tuple. To easily create `Value`-s CDRS provides `query_values!` macro: ```rust let values = query_values!(1 as i32, 1 as i64); ``` - `Value`-s with names may be imagined as a `Map` that links a table column name with a value that should be inserted in a column. It means that `Value`-s with maps should not necessarily have the same order as a corresponded `?` in a query template: ```rust const INSERT_NUMBERS_QUERY: &'static str = "INSERT INTO my.numbers (my_int, my_bigint) VALUES (?, ?)"; let values = query_values!(my_bigint => 1 as i64, my_int => 1 as i64); session.query_with_values(INSERT_NUMBERS_QUERY, values).await.unwrap(); ``` What kind of values can be used as `query_values!` arguments? All types that have implementations of `Into`. For Rust structs represented by [Cassandra User Defined types](http://cassandra.apache.org/doc/4.0/cql/types.html#grammar-token-user_defined_type) `#[derive(IntoCdrsValue)]` can be used for recursive implementation. See [CRUD example](../examples/crud_operations.rs). ### Reference 1. Cassandra official docs - User Defined Types http://cassandra.apache.org/doc/4.0/cql/types.html#grammar-token-user_defined_type. 2. Datastax - User Defined Types https://docs.datastax.com/en/cql/3.3/cql/cql_using/useCreateUDT.html. 3. ScyllaDB - User Defined Types https://docs.scylladb.com/getting-started/types/ 4. [CDRS CRUD Example](../examples/crud_operations.rs) ================================================ FILE: documentation/type-mapping.md ================================================ ### Type relations between Rust (in CDRS approach) and Apache Cassandra #### primitive types (`T`) | Cassandra | Rust | Feature |-----------|-------|-------| | tinyint | i8 | v4, v5 | | smallint | i16 | v4, v5 | | int | i32 | all | | bigint | i64 | all | | ascii | String | all | | text | String | all | | varchar | String | all | | boolean | bool | all | | time | i64 | all | | timestamp | i64 | all | | float | f32 | all | | double | f64 | all | | uuid | [Uuid](https://doc.rust-lang.org/uuid/uuid/struct.Uuid.html) | all | | counter | i64 | all | #### complex types | Cassandra | Rust + CDRS | |-----------|-------------| | blob | `Blob -> Vec<>` | | list | `List -> Vec` | | set | `List -> Vec` | | map | `Map -> HashMap` | | udt | Rust struct | ================================================ FILE: rustfmt.toml ================================================ use_field_init_shorthand = true edition = "2018"