Repository: zonyitoo/mqtt-rs Branch: master Commit: 02fa0d44b07f Files: 43 Total size: 127.0 KB Directory structure: gitextract_2bhvi23i/ ├── .github/ │ └── workflows/ │ └── build-and-test.yml ├── .gitignore ├── .travis.yml ├── Cargo.toml ├── LICENSE ├── README.md ├── examples/ │ ├── pub-client.rs │ ├── simple.rs │ ├── sub-client-async.rs │ └── sub-client.rs ├── rustfmt.toml └── src/ ├── control/ │ ├── fixed_header.rs │ ├── mod.rs │ ├── packet_type.rs │ └── variable_header/ │ ├── connect_ack_flags.rs │ ├── connect_flags.rs │ ├── connect_ret_code.rs │ ├── keep_alive.rs │ ├── mod.rs │ ├── packet_identifier.rs │ ├── protocol_level.rs │ ├── protocol_name.rs │ └── topic_name.rs ├── encodable.rs ├── lib.rs ├── packet/ │ ├── connack.rs │ ├── connect.rs │ ├── disconnect.rs │ ├── mod.rs │ ├── pingreq.rs │ ├── pingresp.rs │ ├── puback.rs │ ├── pubcomp.rs │ ├── publish.rs │ ├── pubrec.rs │ ├── pubrel.rs │ ├── suback.rs │ ├── subscribe.rs │ ├── unsuback.rs │ └── unsubscribe.rs ├── qos.rs ├── topic_filter.rs └── topic_name.rs ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/build-and-test.yml ================================================ name: Build & Test on: push: branches: [master] pull_request: branches: [master] env: CARGO_TERM_COLOR: always jobs: build-and-test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Build run: cargo build --verbose - name: Run tests run: cargo test --verbose ================================================ FILE: .gitignore ================================================ target Cargo.lock .vscode ================================================ FILE: .travis.yml ================================================ language: rust rust: - stable - nightly script: - cargo test -v - cargo test --features "tokio-codec" ================================================ FILE: Cargo.toml ================================================ [package] authors = ["Y. T. Chung "] name = "mqtt-protocol" version = "0.12.0" license = "MIT/Apache-2.0" description = "MQTT Protocol Library" keywords = ["mqtt", "protocol"] repository = "https://github.com/zonyitoo/mqtt-rs" documentation = "https://docs.rs/mqtt-protocol" edition = "2018" [dependencies] byteorder = "1.3" log = "0.4" tokio = { version = "1", optional = true } tokio-util = { version = "0.6", features = ["codec"], optional = true } bytes = { version = "1.0", optional = true } thiserror = "1.0" [dev-dependencies] clap = "2" env_logger = "0.8" tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net", "time", "io-util"] } futures = { version = "0.3" } uuid = { version = "0.8", features = ["v4"] } [features] tokio-codec = ["tokio", "tokio-util", "bytes"] default = [] [lib] name = "mqtt" [[example]] name = "sub-client-async" required-features = ["tokio"] ================================================ FILE: LICENSE ================================================ The MIT License (MIT) Copyright (c) 2015 Y. T. CHUNG Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # MQTT-rs [![Build Status](https://img.shields.io/travis/zonyitoo/mqtt-rs.svg)](https://travis-ci.org/zonyitoo/mqtt-rs) ![Build & Test](https://github.com/zonyitoo/mqtt-rs/workflows/Build%20&%20Test/badge.svg) [![License](https://img.shields.io/github/license/zonyitoo/mqtt-rs.svg)](https://github.com/zonyitoo/mqtt-rs) [![crates.io](https://img.shields.io/crates/v/mqtt-protocol.svg)](https://crates.io/crates/mqtt-protocol) [![dependency status](https://deps.rs/repo/github/zonyitoo/mqtt-rs/status.svg)](https://deps.rs/repo/github/zonyitoo/mqtt-rs) MQTT protocol library for Rust ```toml [dependencies] mqtt-protocol = "0.12" ``` ## Usage ```rust extern crate mqtt; use std::io::Cursor; use mqtt::{Encodable, Decodable}; use mqtt::packet::{VariablePacket, PublishPacket, QoSWithPacketIdentifier}; use mqtt::TopicName; fn main() { // Create a new Publish packet let packet = PublishPacket::new(TopicName::new("mqtt/learning").unwrap(), QoSWithPacketIdentifier::Level2(10), "Hello MQTT!"); // Encode let mut buf = Vec::new(); packet.encode(&mut buf).unwrap(); println!("Encoded: {:?}", buf); // Decode it with known type let mut dec_buf = Cursor::new(&buf[..]); let decoded = PublishPacket::decode(&mut dec_buf).unwrap(); println!("Decoded: {:?}", decoded); assert_eq!(packet, decoded); // Auto decode by the fixed header let mut dec_buf = Cursor::new(&buf[..]); let auto_decode = VariablePacket::decode(&mut dec_buf).unwrap(); println!("Variable packet decode: {:?}", auto_decode); assert_eq!(VariablePacket::PublishPacket(packet), auto_decode); } ``` ## Note * Based on [MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) ================================================ FILE: examples/pub-client.rs ================================================ #[macro_use] extern crate log; use std::env; use std::io::{self, Write}; use std::net::TcpStream; use std::thread; use clap::{App, Arg}; use uuid::Uuid; use mqtt::control::variable_header::ConnectReturnCode; use mqtt::packet::*; use mqtt::{Decodable, Encodable, QualityOfService}; use mqtt::{TopicFilter, TopicName}; fn generate_client_id() -> String { format!("/MQTT/rust/{}", Uuid::new_v4()) } fn main() { // configure logging env::set_var("RUST_LOG", env::var_os("RUST_LOG").unwrap_or_else(|| "info".into())); env_logger::init(); let matches = App::new("sub-client") .author("Y. T. Chung ") .arg( Arg::with_name("SERVER") .short("S") .long("server") .takes_value(true) .required(true) .help("MQTT server address (host:port)"), ) .arg( Arg::with_name("SUBSCRIBE") .short("s") .long("subscribe") .takes_value(true) .multiple(true) .required(true) .help("Channel filter to subscribe"), ) .arg( Arg::with_name("USER_NAME") .short("u") .long("username") .takes_value(true) .help("Login user name"), ) .arg( Arg::with_name("PASSWORD") .short("p") .long("password") .takes_value(true) .help("Password"), ) .arg( Arg::with_name("CLIENT_ID") .short("i") .long("client-identifier") .takes_value(true) .help("Client identifier"), ) .get_matches(); let server_addr = matches.value_of("SERVER").unwrap(); let client_id = matches .value_of("CLIENT_ID") .map(|x| x.to_owned()) .unwrap_or_else(generate_client_id); let channel_filters: Vec<(TopicFilter, QualityOfService)> = matches .values_of("SUBSCRIBE") .unwrap() .map(|c| (TopicFilter::new(c.to_string()).unwrap(), QualityOfService::Level0)) .collect(); info!("Connecting to {:?} ... ", server_addr); let mut stream = TcpStream::connect(server_addr).unwrap(); info!("Connected!"); info!("Client identifier {:?}", client_id); let mut conn = ConnectPacket::new(client_id); conn.set_clean_session(true); let mut buf = Vec::new(); conn.encode(&mut buf).unwrap(); stream.write_all(&buf[..]).unwrap(); let connack = ConnackPacket::decode(&mut stream).unwrap(); trace!("CONNACK {:?}", connack); if connack.connect_return_code() != ConnectReturnCode::ConnectionAccepted { panic!( "Failed to connect to server, return code {:?}", connack.connect_return_code() ); } info!("Applying channel filters {:?} ...", channel_filters); let sub = SubscribePacket::new(10, channel_filters); let mut buf = Vec::new(); sub.encode(&mut buf).unwrap(); stream.write_all(&buf[..]).unwrap(); let channels: Vec = matches .values_of("SUBSCRIBE") .unwrap() .map(|c| TopicName::new(c.to_string()).unwrap()) .collect(); let user_name = matches.value_of("USER_NAME").unwrap_or(""); let mut cloned_stream = stream.try_clone().unwrap(); thread::spawn(move || { loop { let packet = match VariablePacket::decode(&mut cloned_stream) { Ok(pk) => pk, Err(err) => { error!("Error in receiving packet {:?}", err); continue; } }; trace!("PACKET {:?}", packet); match packet { VariablePacket::PingreqPacket(..) => { let pingresp = PingrespPacket::new(); info!("Sending Ping response {:?}", pingresp); pingresp.encode(&mut cloned_stream).unwrap(); } VariablePacket::DisconnectPacket(..) => { break; } _ => { // Ignore other packets in pub client } } } }); let stdin = io::stdin(); loop { print!("{}: ", user_name); io::stdout().flush().unwrap(); let mut line = String::new(); stdin.read_line(&mut line).unwrap(); if line.trim_end() == "" { continue; } let message = format!("{}: {}", user_name, line.trim_end()); for chan in &channels { // let publish_packet = PublishPacket::new(chan.clone(), QoSWithPacketIdentifier::Level0, message.clone()); let publish_packet = PublishPacketRef::new(chan, QoSWithPacketIdentifier::Level0, message.as_bytes()); let mut buf = Vec::new(); publish_packet.encode(&mut buf).unwrap(); stream.write_all(&buf[..]).unwrap(); } } } ================================================ FILE: examples/simple.rs ================================================ use std::io::Cursor; use mqtt::packet::{PublishPacket, QoSWithPacketIdentifier, VariablePacket}; use mqtt::TopicName; use mqtt::{Decodable, Encodable}; fn main() { // Create a new Publish packet let packet = PublishPacket::new( TopicName::new("mqtt/learning").unwrap(), QoSWithPacketIdentifier::Level2(10), "Hello MQTT!", ); // Encode let mut buf = Vec::new(); packet.encode(&mut buf).unwrap(); println!("Encoded: {:?}", buf); // Decode it with known type let mut dec_buf = Cursor::new(&buf[..]); let decoded = PublishPacket::decode(&mut dec_buf).unwrap(); println!("Decoded: {:?}", decoded); assert_eq!(packet, decoded); // Auto decode by the fixed header let mut dec_buf = Cursor::new(&buf[..]); let auto_decode = VariablePacket::decode(&mut dec_buf).unwrap(); println!("Variable packet decode: {:?}", auto_decode); assert_eq!(VariablePacket::PublishPacket(packet), auto_decode); } ================================================ FILE: examples/sub-client-async.rs ================================================ use std::env; use std::io::Write; use std::net; use std::str; use std::time::Duration; use clap::{App, Arg}; use log::{error, info, trace}; use uuid::Uuid; use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; use mqtt::control::variable_header::ConnectReturnCode; use mqtt::packet::*; use mqtt::TopicFilter; use mqtt::{Decodable, Encodable, QualityOfService}; fn generate_client_id() -> String { format!("/MQTT/rust/{}", Uuid::new_v4()) } #[tokio::main] async fn main() { // configure logging env::set_var("RUST_LOG", env::var_os("RUST_LOG").unwrap_or_else(|| "info".into())); env_logger::init(); let matches = App::new("sub-client") .author("Y. T. Chung ") .arg( Arg::with_name("SERVER") .short("S") .long("server") .takes_value(true) .required(true) .help("MQTT server address (host:port)"), ) .arg( Arg::with_name("SUBSCRIBE") .short("s") .long("subscribe") .takes_value(true) .multiple(true) .required(true) .help("Channel filter to subscribe"), ) .arg( Arg::with_name("USER_NAME") .short("u") .long("username") .takes_value(true) .help("Login user name"), ) .arg( Arg::with_name("PASSWORD") .short("p") .long("password") .takes_value(true) .help("Password"), ) .arg( Arg::with_name("CLIENT_ID") .short("i") .long("client-identifier") .takes_value(true) .help("Client identifier"), ) .get_matches(); let server_addr = matches.value_of("SERVER").unwrap(); let client_id = matches .value_of("CLIENT_ID") .map(|x| x.to_owned()) .unwrap_or_else(generate_client_id); let channel_filters: Vec<(TopicFilter, QualityOfService)> = matches .values_of("SUBSCRIBE") .unwrap() .map(|c| (TopicFilter::new(c.to_string()).unwrap(), QualityOfService::Level0)) .collect(); let keep_alive = 10; info!("Connecting to {:?} ... ", server_addr); let mut stream = net::TcpStream::connect(server_addr).unwrap(); info!("Connected!"); info!("Client identifier {:?}", client_id); let mut conn = ConnectPacket::new(client_id); conn.set_clean_session(true); conn.set_keep_alive(keep_alive); let mut buf = Vec::new(); conn.encode(&mut buf).unwrap(); stream.write_all(&buf[..]).unwrap(); let connack = ConnackPacket::decode(&mut stream).unwrap(); trace!("CONNACK {:?}", connack); if connack.connect_return_code() != ConnectReturnCode::ConnectionAccepted { panic!( "Failed to connect to server, return code {:?}", connack.connect_return_code() ); } // const CHANNEL_FILTER: &'static str = "typing-speed-test.aoeu.eu"; info!("Applying channel filters {:?} ...", channel_filters); let sub = SubscribePacket::new(10, channel_filters); let mut buf = Vec::new(); sub.encode(&mut buf).unwrap(); stream.write_all(&buf[..]).unwrap(); loop { let packet = match VariablePacket::decode(&mut stream) { Ok(pk) => pk, Err(err) => { error!("Error in receiving packet {:?}", err); continue; } }; trace!("PACKET {:?}", packet); if let VariablePacket::SubackPacket(ref ack) = packet { if ack.packet_identifier() != 10 { panic!("SUBACK packet identifier not match"); } info!("Subscribed!"); break; } } // connection made, start the async work stream.set_nonblocking(true).unwrap(); let mut stream = TcpStream::from_std(stream).unwrap(); let (mut mqtt_read, mut mqtt_write) = stream.split(); let ping_sender = async move { loop { info!("Sending PINGREQ to broker"); let pingreq_packet = PingreqPacket::new(); let mut buf = Vec::new(); pingreq_packet.encode(&mut buf).unwrap(); mqtt_write.write_all(&buf).await.unwrap(); tokio::time::sleep(Duration::from_secs(keep_alive as u64 / 2)).await; } }; let receiver = async move { while let Ok(packet) = VariablePacket::parse(&mut mqtt_read).await { trace!("PACKET {:?}", packet); match packet { VariablePacket::PingrespPacket(..) => { info!("Received PINGRESP from broker .."); } VariablePacket::PublishPacket(ref publ) => { let msg = match str::from_utf8(publ.payload()) { Ok(msg) => msg, Err(err) => { error!("Failed to decode publish message {:?}", err); continue; } }; info!("PUBLISH ({}): {}", publ.topic_name(), msg); } _ => {} } } }; tokio::pin!(ping_sender); tokio::pin!(receiver); tokio::join!(ping_sender, receiver); } ================================================ FILE: examples/sub-client.rs ================================================ extern crate mqtt; #[macro_use] extern crate log; extern crate clap; extern crate env_logger; extern crate uuid; use std::env; use std::io::Write; use std::net::TcpStream; use std::str; use std::thread; use std::time::{Duration, Instant}; use clap::{App, Arg}; use uuid::Uuid; use mqtt::control::variable_header::ConnectReturnCode; use mqtt::packet::*; use mqtt::TopicFilter; use mqtt::{Decodable, Encodable, QualityOfService}; fn generate_client_id() -> String { format!("/MQTT/rust/{}", Uuid::new_v4()) } fn main() { // configure logging env::set_var("RUST_LOG", env::var_os("RUST_LOG").unwrap_or_else(|| "info".into())); env_logger::init(); let matches = App::new("sub-client") .author("Y. T. Chung ") .arg( Arg::with_name("SERVER") .short("S") .long("server") .takes_value(true) .required(true) .help("MQTT server address (host:port)"), ) .arg( Arg::with_name("SUBSCRIBE") .short("s") .long("subscribe") .takes_value(true) .multiple(true) .required(true) .help("Channel filter to subscribe"), ) .arg( Arg::with_name("USER_NAME") .short("u") .long("username") .takes_value(true) .help("Login user name"), ) .arg( Arg::with_name("PASSWORD") .short("p") .long("password") .takes_value(true) .help("Password"), ) .arg( Arg::with_name("CLIENT_ID") .short("i") .long("client-identifier") .takes_value(true) .help("Client identifier"), ) .get_matches(); let server_addr = matches.value_of("SERVER").unwrap(); let client_id = matches .value_of("CLIENT_ID") .map(|x| x.to_owned()) .unwrap_or_else(generate_client_id); let channel_filters: Vec<(TopicFilter, QualityOfService)> = matches .values_of("SUBSCRIBE") .unwrap() .map(|c| (TopicFilter::new(c.to_string()).unwrap(), QualityOfService::Level0)) .collect(); let keep_alive = 10; info!("Connecting to {:?} ... ", server_addr); let mut stream = TcpStream::connect(server_addr).unwrap(); info!("Connected!"); info!("Client identifier {:?}", client_id); let mut conn = ConnectPacket::new(client_id); conn.set_clean_session(true); conn.set_keep_alive(keep_alive); let mut buf = Vec::new(); conn.encode(&mut buf).unwrap(); stream.write_all(&buf[..]).unwrap(); let connack = ConnackPacket::decode(&mut stream).unwrap(); trace!("CONNACK {:?}", connack); if connack.connect_return_code() != ConnectReturnCode::ConnectionAccepted { panic!( "Failed to connect to server, return code {:?}", connack.connect_return_code() ); } // const CHANNEL_FILTER: &'static str = "typing-speed-test.aoeu.eu"; info!("Applying channel filters {:?} ...", channel_filters); let sub = SubscribePacket::new(10, channel_filters); let mut buf = Vec::new(); sub.encode(&mut buf).unwrap(); stream.write_all(&buf[..]).unwrap(); loop { let packet = match VariablePacket::decode(&mut stream) { Ok(pk) => pk, Err(err) => { error!("Error in receiving packet {:?}", err); continue; } }; trace!("PACKET {:?}", packet); if let VariablePacket::SubackPacket(ref ack) = packet { if ack.packet_identifier() != 10 { panic!("SUBACK packet identifier not match"); } info!("Subscribed!"); break; } } let mut stream_clone = stream.try_clone().unwrap(); thread::spawn(move || { let mut last_ping_time = Instant::now(); let mut next_ping_time = last_ping_time + Duration::from_secs((keep_alive as f32 * 0.9) as u64); loop { let current_timestamp = Instant::now(); if keep_alive > 0 && current_timestamp >= next_ping_time { info!("Sending PINGREQ to broker"); let pingreq_packet = PingreqPacket::new(); let mut buf = Vec::new(); pingreq_packet.encode(&mut buf).unwrap(); stream_clone.write_all(&buf[..]).unwrap(); last_ping_time = current_timestamp; next_ping_time = last_ping_time + Duration::from_secs((keep_alive as f32 * 0.9) as u64); thread::sleep(Duration::new((keep_alive / 2) as u64, 0)); } } }); loop { let packet = match VariablePacket::decode(&mut stream) { Ok(pk) => pk, Err(err) => { error!("Error in receiving packet {}", err); continue; } }; trace!("PACKET {:?}", packet); match packet { VariablePacket::PingrespPacket(..) => { info!("Receiving PINGRESP from broker .."); } VariablePacket::PublishPacket(ref publ) => { let msg = match str::from_utf8(publ.payload()) { Ok(msg) => msg, Err(err) => { error!("Failed to decode publish message {:?}", err); continue; } }; info!("PUBLISH ({}): {}", publ.topic_name(), msg); } _ => {} } } } ================================================ FILE: rustfmt.toml ================================================ edition = "2018" max_width = 120 reorder_imports = true use_try_shorthand = true ================================================ FILE: src/control/fixed_header.rs ================================================ //! Fixed header in MQTT use std::io::{self, Read, Write}; use byteorder::{ReadBytesExt, WriteBytesExt}; #[cfg(feature = "tokio")] use tokio::io::{AsyncRead, AsyncReadExt}; use crate::control::packet_type::{PacketType, PacketTypeError}; use crate::{Decodable, Encodable}; /// Fixed header for each MQTT control packet /// /// Format: /// /// ```plain /// 7 3 0 /// +--------------------------+--------------------------+ /// | MQTT Control Packet Type | Flags for each type | /// +--------------------------+--------------------------+ /// | Remaining Length ... | /// +-----------------------------------------------------+ /// ``` #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct FixedHeader { /// Packet Type pub packet_type: PacketType, /// The Remaining Length is the number of bytes remaining within the current packet, /// including data in the variable header and the payload. The Remaining Length does /// not include the bytes used to encode the Remaining Length. pub remaining_length: u32, } impl FixedHeader { pub fn new(packet_type: PacketType, remaining_length: u32) -> FixedHeader { debug_assert!(remaining_length <= 0x0FFF_FFFF); FixedHeader { packet_type, remaining_length, } } #[cfg(feature = "tokio")] /// Asynchronously parse a single fixed header from an AsyncRead type, such as a network /// socket. /// /// This requires mqtt-rs to be built with `feature = "tokio"` pub async fn parse(rdr: &mut A) -> Result { let type_val = rdr.read_u8().await?; let mut remaining_len = 0; let mut i = 0; loop { let byte = rdr.read_u8().await?; remaining_len |= (u32::from(byte) & 0x7F) << (7 * i); if i >= 4 { return Err(FixedHeaderError::MalformedRemainingLength); } if byte & 0x80 == 0 { break; } else { i += 1; } } match PacketType::from_u8(type_val) { Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)), Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)), Err(err) => Err(From::from(err)), } } } impl Encodable for FixedHeader { fn encode(&self, wr: &mut W) -> Result<(), io::Error> { wr.write_u8(self.packet_type.to_u8())?; let mut cur_len = self.remaining_length; loop { let mut byte = (cur_len & 0x7F) as u8; cur_len >>= 7; if cur_len > 0 { byte |= 0x80; } wr.write_u8(byte)?; if cur_len == 0 { break; } } Ok(()) } fn encoded_length(&self) -> u32 { let rem_size = if self.remaining_length >= 2_097_152 { 4 } else if self.remaining_length >= 16_384 { 3 } else if self.remaining_length >= 128 { 2 } else { 1 }; 1 + rem_size } } impl Decodable for FixedHeader { type Error = FixedHeaderError; type Cond = (); fn decode_with(rdr: &mut R, _rest: ()) -> Result { let type_val = rdr.read_u8()?; let remaining_len = { let mut cur = 0u32; for i in 0.. { let byte = rdr.read_u8()?; cur |= ((byte as u32) & 0x7F) << (7 * i); if i >= 4 { return Err(FixedHeaderError::MalformedRemainingLength); } if byte & 0x80 == 0 { break; } } cur }; match PacketType::from_u8(type_val) { Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)), Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)), Err(err) => Err(From::from(err)), } } } #[derive(Debug, thiserror::Error)] pub enum FixedHeaderError { #[error("malformed remaining length")] MalformedRemainingLength, #[error("reserved header ({0}, {1})")] ReservedType(u8, u32), #[error(transparent)] PacketTypeError(#[from] PacketTypeError), #[error(transparent)] IoError(#[from] io::Error), } #[cfg(test)] mod test { use super::*; use crate::control::packet_type::{ControlType, PacketType}; use crate::{Decodable, Encodable}; use std::io::Cursor; #[test] fn test_encode_fixed_header() { let header = FixedHeader::new(PacketType::with_default(ControlType::Connect), 321); let mut buf = Vec::new(); header.encode(&mut buf).unwrap(); let expected = b"\x10\xc1\x02"; assert_eq!(&expected[..], &buf[..]); } #[test] fn test_decode_fixed_header() { let stream = b"\x10\xc1\x02"; let mut cursor = Cursor::new(&stream[..]); let header = FixedHeader::decode(&mut cursor).unwrap(); assert_eq!(header.packet_type, PacketType::with_default(ControlType::Connect)); assert_eq!(header.remaining_length, 321); } #[test] #[should_panic] fn test_decode_too_long_fixed_header() { let stream = b"\x10\x80\x80\x80\x80\x02"; let mut cursor = Cursor::new(&stream[..]); FixedHeader::decode(&mut cursor).unwrap(); } } ================================================ FILE: src/control/mod.rs ================================================ //! Control packets pub use self::fixed_header::FixedHeader; pub use self::packet_type::{ControlType, PacketType}; pub use self::variable_header::*; pub mod fixed_header; pub mod packet_type; pub mod variable_header; ================================================ FILE: src/control/packet_type.rs ================================================ //! Packet types use crate::qos::QualityOfService; /// Packet type // INVARIANT: the high 4 bits of the byte must be a valid control type #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub struct PacketType(u8); /// Defined control types #[rustfmt::skip] #[repr(u8)] #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum ControlType { /// Client request to connect to Server Connect = value::CONNECT, /// Connect acknowledgment ConnectAcknowledgement = value::CONNACK, /// Publish message Publish = value::PUBLISH, /// Publish acknowledgment PublishAcknowledgement = value::PUBACK, /// Publish received (assured delivery part 1) PublishReceived = value::PUBREC, /// Publish release (assured delivery part 2) PublishRelease = value::PUBREL, /// Publish complete (assured delivery part 3) PublishComplete = value::PUBCOMP, /// Client subscribe request Subscribe = value::SUBSCRIBE, /// Subscribe acknowledgment SubscribeAcknowledgement = value::SUBACK, /// Unsubscribe request Unsubscribe = value::UNSUBSCRIBE, /// Unsubscribe acknowledgment UnsubscribeAcknowledgement = value::UNSUBACK, /// PING request PingRequest = value::PINGREQ, /// PING response PingResponse = value::PINGRESP, /// Client is disconnecting Disconnect = value::DISCONNECT, } impl ControlType { #[inline] fn default_flags(self) -> u8 { match self { ControlType::Connect => 0, ControlType::ConnectAcknowledgement => 0, ControlType::Publish => 0, ControlType::PublishAcknowledgement => 0, ControlType::PublishReceived => 0, ControlType::PublishRelease => 0b0010, ControlType::PublishComplete => 0, ControlType::Subscribe => 0b0010, ControlType::SubscribeAcknowledgement => 0, ControlType::Unsubscribe => 0b0010, ControlType::UnsubscribeAcknowledgement => 0, ControlType::PingRequest => 0, ControlType::PingResponse => 0, ControlType::Disconnect => 0, } } } impl PacketType { /// Creates a packet type. Returns None if `flags` is an invalid value for the given /// ControlType as defined by the [MQTT spec]. /// /// [MQTT spec]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Table_2.2_- pub fn new(t: ControlType, flags: u8) -> Result { let flags_ok = match t { ControlType::Publish => { let qos = (flags & 0b0110) >> 1; matches!(qos, 0 | 1 | 2) } _ => t.default_flags() == flags, }; if flags_ok { Ok(PacketType::new_unchecked(t, flags)) } else { Err(InvalidFlag(t, flags)) } } #[inline] fn new_unchecked(t: ControlType, flags: u8) -> PacketType { let byte = (t as u8) << 4 | (flags & 0x0F); #[allow(unused_unsafe)] unsafe { // SAFETY: just constructed from a valid ControlType PacketType(byte) } } /// Creates a packet type with default flags /// /// #[inline] pub fn with_default(t: ControlType) -> PacketType { let flags = t.default_flags(); PacketType::new_unchecked(t, flags) } pub(crate) fn publish(qos: QualityOfService) -> PacketType { PacketType::new_unchecked(ControlType::Publish, (qos as u8) << 1) } #[inline] pub(crate) fn update_flags(&mut self, upd: impl FnOnce(u8) -> u8) { let flags = upd(self.flags()); self.0 = (self.0 & !0x0F) | (flags & 0x0F) } /// To code #[inline] pub fn to_u8(self) -> u8 { self.0 } /// From code pub fn from_u8(val: u8) -> Result { let type_val = val >> 4; let flags = val & 0x0F; let control_type = get_control_type(type_val).ok_or(PacketTypeError::ReservedType(type_val, flags))?; Ok(PacketType::new(control_type, flags)?) } #[inline] pub fn control_type(self) -> ControlType { get_control_type(self.0 >> 4).unwrap_or_else(|| { // SAFETY: this is maintained by the invariant for PacketType unsafe { std::hint::unreachable_unchecked() } }) } #[inline] pub fn flags(self) -> u8 { self.0 & 0x0F } } #[inline] fn get_control_type(val: u8) -> Option { let typ = match val { value::CONNECT => ControlType::Connect, value::CONNACK => ControlType::ConnectAcknowledgement, value::PUBLISH => ControlType::Publish, value::PUBACK => ControlType::PublishAcknowledgement, value::PUBREC => ControlType::PublishReceived, value::PUBREL => ControlType::PublishRelease, value::PUBCOMP => ControlType::PublishComplete, value::SUBSCRIBE => ControlType::Subscribe, value::SUBACK => ControlType::SubscribeAcknowledgement, value::UNSUBSCRIBE => ControlType::Unsubscribe, value::UNSUBACK => ControlType::UnsubscribeAcknowledgement, value::PINGREQ => ControlType::PingRequest, value::PINGRESP => ControlType::PingResponse, value::DISCONNECT => ControlType::Disconnect, _ => return None, }; Some(typ) } /// Parsing packet type errors #[derive(Debug, thiserror::Error)] pub enum PacketTypeError { #[error("reserved type {0:?} (flags {1:#X})")] ReservedType(u8, u8), #[error(transparent)] InvalidFlag(#[from] InvalidFlag), } #[derive(Debug, thiserror::Error)] #[error("invalid flag for {0:?} ({1:#X})")] pub struct InvalidFlag(pub ControlType, pub u8); #[rustfmt::skip] mod value { pub const CONNECT: u8 = 1; pub const CONNACK: u8 = 2; pub const PUBLISH: u8 = 3; pub const PUBACK: u8 = 4; pub const PUBREC: u8 = 5; pub const PUBREL: u8 = 6; pub const PUBCOMP: u8 = 7; pub const SUBSCRIBE: u8 = 8; pub const SUBACK: u8 = 9; pub const UNSUBSCRIBE: u8 = 10; pub const UNSUBACK: u8 = 11; pub const PINGREQ: u8 = 12; pub const PINGRESP: u8 = 13; pub const DISCONNECT: u8 = 14; } ================================================ FILE: src/control/variable_header/connect_ack_flags.rs ================================================ use std::io::{self, Read, Write}; use byteorder::{ReadBytesExt, WriteBytesExt}; use crate::control::variable_header::VariableHeaderError; use crate::{Decodable, Encodable}; /// Flags in `CONNACK` packet #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub struct ConnackFlags { pub session_present: bool, } impl ConnackFlags { pub fn empty() -> ConnackFlags { ConnackFlags { session_present: false } } } impl Encodable for ConnackFlags { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { let code = self.session_present as u8; writer.write_u8(code) } fn encoded_length(&self) -> u32 { 1 } } impl Decodable for ConnackFlags { type Error = VariableHeaderError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { let code = reader.read_u8()?; if code & !1 != 0 { return Err(VariableHeaderError::InvalidReservedFlag); } Ok(ConnackFlags { session_present: code == 1, }) } } ================================================ FILE: src/control/variable_header/connect_flags.rs ================================================ use std::io::{self, Read, Write}; use byteorder::{ReadBytesExt, WriteBytesExt}; use crate::control::variable_header::VariableHeaderError; use crate::{Decodable, Encodable}; /// Flags for `CONNECT` packet #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub struct ConnectFlags { pub user_name: bool, pub password: bool, pub will_retain: bool, pub will_qos: u8, pub will_flag: bool, pub clean_session: bool, // We never use this, but must decode because brokers must verify it's zero per [MQTT-3.1.2-3] pub reserved: bool, } impl ConnectFlags { pub fn empty() -> ConnectFlags { ConnectFlags { user_name: false, password: false, will_retain: false, will_qos: 0, will_flag: false, clean_session: false, reserved: false, } } } impl Encodable for ConnectFlags { #[rustfmt::skip] fn encode(&self, writer: &mut W) -> Result<(), io::Error> { let code = ((self.user_name as u8) << 7) | ((self.password as u8) << 6) | ((self.will_retain as u8) << 5) | ((self.will_qos) << 3) | ((self.will_flag as u8) << 2) | ((self.clean_session as u8) << 1); writer.write_u8(code) } fn encoded_length(&self) -> u32 { 1 } } impl Decodable for ConnectFlags { type Error = VariableHeaderError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { let code = reader.read_u8()?; if code & 1 != 0 { return Err(VariableHeaderError::InvalidReservedFlag); } Ok(ConnectFlags { user_name: (code & 0b1000_0000) != 0, password: (code & 0b0100_0000) != 0, will_retain: (code & 0b0010_0000) != 0, will_qos: (code & 0b0001_1000) >> 3, will_flag: (code & 0b0000_0100) != 0, clean_session: (code & 0b0000_0010) != 0, reserved: (code & 0b0000_0001) != 0, }) } } ================================================ FILE: src/control/variable_header/connect_ret_code.rs ================================================ use std::io::{self, Read, Write}; use byteorder::{ReadBytesExt, WriteBytesExt}; use crate::control::variable_header::VariableHeaderError; use crate::{Decodable, Encodable}; pub const CONNECTION_ACCEPTED: u8 = 0x00; pub const UNACCEPTABLE_PROTOCOL_VERSION: u8 = 0x01; pub const IDENTIFIER_REJECTED: u8 = 0x02; pub const SERVICE_UNAVAILABLE: u8 = 0x03; pub const BAD_USER_NAME_OR_PASSWORD: u8 = 0x04; pub const NOT_AUTHORIZED: u8 = 0x05; /// Return code for `CONNACK` packet #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum ConnectReturnCode { ConnectionAccepted, UnacceptableProtocolVersion, IdentifierRejected, ServiceUnavailable, BadUserNameOrPassword, NotAuthorized, Reserved(u8), } impl ConnectReturnCode { /// Get the code pub fn to_u8(self) -> u8 { match self { ConnectReturnCode::ConnectionAccepted => CONNECTION_ACCEPTED, ConnectReturnCode::UnacceptableProtocolVersion => UNACCEPTABLE_PROTOCOL_VERSION, ConnectReturnCode::IdentifierRejected => IDENTIFIER_REJECTED, ConnectReturnCode::ServiceUnavailable => SERVICE_UNAVAILABLE, ConnectReturnCode::BadUserNameOrPassword => BAD_USER_NAME_OR_PASSWORD, ConnectReturnCode::NotAuthorized => NOT_AUTHORIZED, ConnectReturnCode::Reserved(r) => r, } } /// Create `ConnectReturnCode` from code pub fn from_u8(code: u8) -> ConnectReturnCode { match code { CONNECTION_ACCEPTED => ConnectReturnCode::ConnectionAccepted, UNACCEPTABLE_PROTOCOL_VERSION => ConnectReturnCode::UnacceptableProtocolVersion, IDENTIFIER_REJECTED => ConnectReturnCode::IdentifierRejected, SERVICE_UNAVAILABLE => ConnectReturnCode::ServiceUnavailable, BAD_USER_NAME_OR_PASSWORD => ConnectReturnCode::BadUserNameOrPassword, NOT_AUTHORIZED => ConnectReturnCode::NotAuthorized, _ => ConnectReturnCode::Reserved(code), } } } impl Encodable for ConnectReturnCode { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { writer.write_u8(self.to_u8()) } fn encoded_length(&self) -> u32 { 1 } } impl Decodable for ConnectReturnCode { type Error = VariableHeaderError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { reader.read_u8().map(ConnectReturnCode::from_u8).map_err(From::from) } } ================================================ FILE: src/control/variable_header/keep_alive.rs ================================================ use std::io::{self, Read, Write}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use crate::control::variable_header::VariableHeaderError; use crate::{Decodable, Encodable}; /// Keep alive time interval #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub struct KeepAlive(pub u16); impl Encodable for KeepAlive { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { writer.write_u16::(self.0) } fn encoded_length(&self) -> u32 { 2 } } impl Decodable for KeepAlive { type Error = VariableHeaderError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { reader.read_u16::().map(KeepAlive).map_err(From::from) } } ================================================ FILE: src/control/variable_header/mod.rs ================================================ //! Variable header in MQTT use std::io; use std::string::FromUtf8Error; use crate::topic_name::{TopicNameDecodeError, TopicNameError}; pub use self::connect_ack_flags::ConnackFlags; pub use self::connect_flags::ConnectFlags; pub use self::connect_ret_code::ConnectReturnCode; pub use self::keep_alive::KeepAlive; pub use self::packet_identifier::PacketIdentifier; pub use self::protocol_level::ProtocolLevel; pub use self::protocol_name::ProtocolName; pub use self::topic_name::TopicNameHeader; mod connect_ack_flags; mod connect_flags; mod connect_ret_code; mod keep_alive; mod packet_identifier; pub mod protocol_level; mod protocol_name; mod topic_name; /// Errors while decoding variable header #[derive(Debug, thiserror::Error)] pub enum VariableHeaderError { #[error(transparent)] IoError(#[from] io::Error), #[error("invalid reserved flags")] InvalidReservedFlag, #[error(transparent)] FromUtf8Error(#[from] FromUtf8Error), #[error(transparent)] TopicNameError(#[from] TopicNameError), #[error("invalid protocol version")] InvalidProtocolVersion, } impl From for VariableHeaderError { fn from(err: TopicNameDecodeError) -> VariableHeaderError { match err { TopicNameDecodeError::IoError(e) => Self::IoError(e), TopicNameDecodeError::InvalidTopicName(e) => Self::TopicNameError(e), } } } ================================================ FILE: src/control/variable_header/packet_identifier.rs ================================================ use std::io::{self, Read, Write}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use crate::control::variable_header::VariableHeaderError; use crate::{Decodable, Encodable}; /// Packet identifier #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub struct PacketIdentifier(pub u16); impl Encodable for PacketIdentifier { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { writer.write_u16::(self.0) } fn encoded_length(&self) -> u32 { 2 } } impl Decodable for PacketIdentifier { type Error = VariableHeaderError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { reader.read_u16::().map(PacketIdentifier).map_err(From::from) } } ================================================ FILE: src/control/variable_header/protocol_level.rs ================================================ //! Protocol level header use std::io::{self, Read, Write}; use byteorder::{ReadBytesExt, WriteBytesExt}; use crate::control::variable_header::VariableHeaderError; use crate::{Decodable, Encodable}; pub const SPEC_3_1_0: u8 = 0x03; pub const SPEC_3_1_1: u8 = 0x04; pub const SPEC_5_0: u8 = 0x05; /// Protocol level in MQTT (`0x04` in v3.1.1) #[derive(Debug, Eq, PartialEq, Copy, Clone)] #[repr(u8)] pub enum ProtocolLevel { Version310 = SPEC_3_1_0, Version311 = SPEC_3_1_1, Version50 = SPEC_5_0, } impl Encodable for ProtocolLevel { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { writer.write_u8(*self as u8) } fn encoded_length(&self) -> u32 { 1 } } impl Decodable for ProtocolLevel { type Error = VariableHeaderError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { reader .read_u8() .map_err(From::from) .map(ProtocolLevel::from_u8) .and_then(|x| x.ok_or(VariableHeaderError::InvalidProtocolVersion)) } } impl ProtocolLevel { pub fn from_u8(n: u8) -> Option { match n { SPEC_3_1_0 => Some(ProtocolLevel::Version310), SPEC_3_1_1 => Some(ProtocolLevel::Version311), SPEC_5_0 => Some(ProtocolLevel::Version50), _ => None, } } } ================================================ FILE: src/control/variable_header/protocol_name.rs ================================================ use std::io::{self, Read, Write}; use crate::control::variable_header::VariableHeaderError; use crate::{Decodable, Encodable}; /// Protocol name in variable header /// /// # Example /// /// ```plain /// 7 3 0 /// +--------------------------+--------------------------+ /// | Length MSB (0) | /// | Length LSB (4) | /// | 0100 | 1101 | 'M' /// | 0101 | 0001 | 'Q' /// | 0101 | 0100 | 'T' /// | 0101 | 0100 | 'T' /// +--------------------------+--------------------------+ /// ``` #[derive(Debug, Eq, PartialEq, Clone)] pub struct ProtocolName(pub String); impl Encodable for ProtocolName { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { (&self.0[..]).encode(writer) } fn encoded_length(&self) -> u32 { (&self.0[..]).encoded_length() } } impl Decodable for ProtocolName { type Error = VariableHeaderError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { Ok(ProtocolName(Decodable::decode(reader)?)) } } ================================================ FILE: src/control/variable_header/topic_name.rs ================================================ use std::io::{self, Read, Write}; use crate::control::variable_header::VariableHeaderError; use crate::topic_name::TopicName; use crate::{Decodable, Encodable}; /// Topic name wrapper #[derive(Debug, Eq, PartialEq, Clone)] pub struct TopicNameHeader(TopicName); impl TopicNameHeader { pub fn new(topic_name: String) -> Result { match TopicName::new(topic_name) { Ok(h) => Ok(TopicNameHeader(h)), Err(err) => Err(VariableHeaderError::TopicNameError(err)), } } } impl From for TopicName { fn from(hdr: TopicNameHeader) -> Self { hdr.0 } } impl Encodable for TopicNameHeader { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { (&self.0[..]).encode(writer) } fn encoded_length(&self) -> u32 { (&self.0[..]).encoded_length() } } impl Decodable for TopicNameHeader { type Error = VariableHeaderError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { TopicNameHeader::new(Decodable::decode(reader)?) } } ================================================ FILE: src/encodable.rs ================================================ //! Encodable traits use std::convert::Infallible; use std::error::Error; use std::io::{self, Read, Write}; use std::marker::Sized; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; /// Methods for encoding an Object to bytes according to MQTT specification pub trait Encodable { /// Encodes to writer fn encode(&self, writer: &mut W) -> io::Result<()>; /// Length of bytes after encoded fn encoded_length(&self) -> u32; } // impl Encodable for &T { // fn encode(&self, writer: &mut W) -> io::Result<()> { // (**self).encode(writer) // } // fn encoded_length(&self) -> u32 { // (**self).encoded_length() // } // } impl Encodable for Option { fn encode(&self, writer: &mut W) -> io::Result<()> { if let Some(this) = self { this.encode(writer)? } Ok(()) } fn encoded_length(&self) -> u32 { self.as_ref().map_or(0, |x| x.encoded_length()) } } /// Methods for decoding bytes to an Object according to MQTT specification pub trait Decodable: Sized { type Error: Error; type Cond; /// Decodes object from reader fn decode(reader: &mut R) -> Result where Self::Cond: Default, { Self::decode_with(reader, Default::default()) } /// Decodes object with additional data (or hints) fn decode_with(reader: &mut R, cond: Self::Cond) -> Result; } impl<'a> Encodable for &'a str { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { assert!(self.as_bytes().len() <= u16::max_value() as usize); writer .write_u16::(self.as_bytes().len() as u16) .and_then(|_| writer.write_all(self.as_bytes())) } fn encoded_length(&self) -> u32 { 2 + self.as_bytes().len() as u32 } } impl<'a> Encodable for &'a [u8] { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { writer.write_all(self) } fn encoded_length(&self) -> u32 { self.len() as u32 } } impl Encodable for String { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { (&self[..]).encode(writer) } fn encoded_length(&self) -> u32 { (&self[..]).encoded_length() } } impl Decodable for String { type Error = io::Error; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { let VarBytes(buf) = VarBytes::decode(reader)?; String::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) } } impl Encodable for Vec { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { (&self[..]).encode(writer) } fn encoded_length(&self) -> u32 { (&self[..]).encoded_length() } } impl Decodable for Vec { type Error = io::Error; type Cond = Option; fn decode_with(reader: &mut R, length: Option) -> Result, io::Error> { match length { Some(length) => { let mut buf = Vec::with_capacity(length as usize); reader.take(length.into()).read_to_end(&mut buf)?; Ok(buf) } None => { let mut buf = Vec::new(); reader.read_to_end(&mut buf)?; Ok(buf) } } } } impl Encodable for () { fn encode(&self, _: &mut W) -> Result<(), io::Error> { Ok(()) } fn encoded_length(&self) -> u32 { 0 } } impl Decodable for () { type Error = Infallible; type Cond = (); fn decode_with(_: &mut R, _: ()) -> Result<(), Self::Error> { Ok(()) } } /// Bytes that encoded with length #[derive(Debug, Eq, PartialEq, Clone)] pub struct VarBytes(pub Vec); impl Encodable for VarBytes { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { assert!(self.0.len() <= u16::max_value() as usize); let len = self.0.len() as u16; writer.write_u16::(len)?; writer.write_all(&self.0)?; Ok(()) } fn encoded_length(&self) -> u32 { 2 + self.0.len() as u32 } } impl Decodable for VarBytes { type Error = io::Error; type Cond = (); fn decode_with(reader: &mut R, _: ()) -> Result { let length = reader.read_u16::()?; let mut buf = Vec::with_capacity(length as usize); reader.take(length.into()).read_to_end(&mut buf)?; Ok(VarBytes(buf)) } } #[cfg(test)] mod test { use super::*; use std::io::Cursor; #[test] fn varbyte_encode() { let test_var = vec![0, 1, 2, 3, 4, 5]; let bytes = VarBytes(test_var); assert_eq!(bytes.encoded_length() as usize, 2 + 6); let mut buf = Vec::new(); bytes.encode(&mut buf).unwrap(); assert_eq!(&buf, &[0, 6, 0, 1, 2, 3, 4, 5]); let mut reader = Cursor::new(buf); let decoded = VarBytes::decode(&mut reader).unwrap(); assert_eq!(decoded, bytes); } } ================================================ FILE: src/lib.rs ================================================ //! MQTT protocol utilities library //! //! Strictly implements protocol of [MQTT v3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) //! //! ## Usage //! //! ```rust //! use std::io::Cursor; //! //! use mqtt::{Encodable, Decodable}; //! use mqtt::packet::{VariablePacket, PublishPacket, QoSWithPacketIdentifier}; //! use mqtt::TopicName; //! //! // Create a new Publish packet //! let packet = PublishPacket::new(TopicName::new("mqtt/learning").unwrap(), //! QoSWithPacketIdentifier::Level2(10), //! b"Hello MQTT!".to_vec()); //! //! // Encode //! let mut buf = Vec::new(); //! packet.encode(&mut buf).unwrap(); //! println!("Encoded: {:?}", buf); //! //! // Decode it with known type //! let mut dec_buf = Cursor::new(&buf[..]); //! let decoded = PublishPacket::decode(&mut dec_buf).unwrap(); //! println!("Decoded: {:?}", decoded); //! assert_eq!(packet, decoded); //! //! // Auto decode by the fixed header //! let mut dec_buf = Cursor::new(&buf[..]); //! let auto_decode = VariablePacket::decode(&mut dec_buf).unwrap(); //! println!("Variable packet decode: {:?}", auto_decode); //! assert_eq!(VariablePacket::PublishPacket(packet), auto_decode); //! ``` pub use self::encodable::{Decodable, Encodable}; pub use self::qos::QualityOfService; pub use self::topic_filter::{TopicFilter, TopicFilterRef}; pub use self::topic_name::{TopicName, TopicNameRef}; pub mod control; pub mod encodable; pub mod packet; pub mod qos; pub mod topic_filter; pub mod topic_name; ================================================ FILE: src/packet/connack.rs ================================================ //! CONNACK use std::io::Read; use crate::control::variable_header::{ConnackFlags, ConnectReturnCode}; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::Decodable; /// `CONNACK` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct ConnackPacket { fixed_header: FixedHeader, flags: ConnackFlags, ret_code: ConnectReturnCode, } encodable_packet!(ConnackPacket(flags, ret_code)); impl ConnackPacket { pub fn new(session_present: bool, ret_code: ConnectReturnCode) -> ConnackPacket { ConnackPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::ConnectAcknowledgement), 2), flags: ConnackFlags { session_present }, ret_code, } } pub fn connack_flags(&self) -> ConnackFlags { self.flags } pub fn connect_return_code(&self) -> ConnectReturnCode { self.ret_code } } impl DecodablePacket for ConnackPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let flags: ConnackFlags = Decodable::decode(reader)?; let code: ConnectReturnCode = Decodable::decode(reader)?; Ok(ConnackPacket { fixed_header, flags, ret_code: code, }) } } #[cfg(test)] mod test { use super::*; use std::io::Cursor; use crate::control::variable_header::ConnectReturnCode; use crate::{Decodable, Encodable}; #[test] pub fn test_connack_packet_basic() { let packet = ConnackPacket::new(false, ConnectReturnCode::IdentifierRejected); let mut buf = Vec::new(); packet.encode(&mut buf).unwrap(); let mut decode_buf = Cursor::new(buf); let decoded = ConnackPacket::decode(&mut decode_buf).unwrap(); assert_eq!(packet, decoded); } } ================================================ FILE: src/packet/connect.rs ================================================ //! CONNECT use std::io::{self, Read, Write}; use crate::control::variable_header::protocol_level::SPEC_3_1_1; use crate::control::variable_header::{ConnectFlags, KeepAlive, ProtocolLevel, ProtocolName, VariableHeaderError}; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::encodable::VarBytes; use crate::packet::{DecodablePacket, PacketError}; use crate::topic_name::{TopicName, TopicNameDecodeError, TopicNameError}; use crate::{Decodable, Encodable}; /// `CONNECT` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct ConnectPacket { fixed_header: FixedHeader, protocol_name: ProtocolName, protocol_level: ProtocolLevel, flags: ConnectFlags, keep_alive: KeepAlive, payload: ConnectPacketPayload, } encodable_packet!(ConnectPacket(protocol_name, protocol_level, flags, keep_alive, payload)); impl ConnectPacket { pub fn new(client_identifier: C) -> ConnectPacket where C: Into, { ConnectPacket::with_level("MQTT", client_identifier, SPEC_3_1_1).expect("SPEC_3_1_1 should always be valid") } pub fn with_level(protoname: P, client_identifier: C, level: u8) -> Result where P: Into, C: Into, { let protocol_level = ProtocolLevel::from_u8(level).ok_or(VariableHeaderError::InvalidProtocolVersion)?; let mut pk = ConnectPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Connect), 0), protocol_name: ProtocolName(protoname.into()), protocol_level, flags: ConnectFlags::empty(), keep_alive: KeepAlive(0), payload: ConnectPacketPayload::new(client_identifier.into()), }; pk.fix_header_remaining_len(); Ok(pk) } pub fn set_keep_alive(&mut self, keep_alive: u16) { self.keep_alive = KeepAlive(keep_alive); } pub fn set_user_name(&mut self, name: Option) { self.flags.user_name = name.is_some(); self.payload.user_name = name; self.fix_header_remaining_len(); } pub fn set_will(&mut self, topic_message: Option<(TopicName, Vec)>) { self.flags.will_flag = topic_message.is_some(); self.payload.will = topic_message.map(|(t, m)| (t, VarBytes(m))); self.fix_header_remaining_len(); } pub fn set_password(&mut self, password: Option) { self.flags.password = password.is_some(); self.payload.password = password; self.fix_header_remaining_len(); } pub fn set_client_identifier>(&mut self, id: I) { self.payload.client_identifier = id.into(); self.fix_header_remaining_len(); } pub fn set_will_retain(&mut self, will_retain: bool) { self.flags.will_retain = will_retain; } pub fn set_will_qos(&mut self, will_qos: u8) { assert!(will_qos <= 2); self.flags.will_qos = will_qos; } pub fn set_clean_session(&mut self, clean_session: bool) { self.flags.clean_session = clean_session; } pub fn user_name(&self) -> Option<&str> { self.payload.user_name.as_ref().map(|x| &x[..]) } pub fn password(&self) -> Option<&str> { self.payload.password.as_ref().map(|x| &x[..]) } pub fn will(&self) -> Option<(&str, &[u8])> { self.payload.will.as_ref().map(|(topic, msg)| (&topic[..], &*msg.0)) } pub fn will_retain(&self) -> bool { self.flags.will_retain } pub fn will_qos(&self) -> u8 { self.flags.will_qos } pub fn client_identifier(&self) -> &str { &self.payload.client_identifier[..] } pub fn protocol_name(&self) -> &str { &self.protocol_name.0 } pub fn protocol_level(&self) -> ProtocolLevel { self.protocol_level } pub fn clean_session(&self) -> bool { self.flags.clean_session } pub fn keep_alive(&self) -> u16 { self.keep_alive.0 } /// Read back the "reserved" Connect flag bit 0. For compliant implementations this should /// always be false. pub fn reserved_flag(&self) -> bool { self.flags.reserved } } impl DecodablePacket for ConnectPacket { type DecodePacketError = ConnectPacketError; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let protoname: ProtocolName = Decodable::decode(reader)?; let protocol_level: ProtocolLevel = Decodable::decode(reader)?; let flags: ConnectFlags = Decodable::decode(reader)?; let keep_alive: KeepAlive = Decodable::decode(reader)?; let payload: ConnectPacketPayload = Decodable::decode_with(reader, Some(flags)).map_err(PacketError::PayloadError)?; Ok(ConnectPacket { fixed_header, protocol_name: protoname, protocol_level, flags, keep_alive, payload, }) } } /// Payloads for connect packet #[derive(Debug, Eq, PartialEq, Clone)] struct ConnectPacketPayload { client_identifier: String, will: Option<(TopicName, VarBytes)>, user_name: Option, password: Option, } impl ConnectPacketPayload { pub fn new(client_identifier: String) -> ConnectPacketPayload { ConnectPacketPayload { client_identifier, will: None, user_name: None, password: None, } } } impl Encodable for ConnectPacketPayload { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { self.client_identifier.encode(writer)?; if let Some((will_topic, will_message)) = &self.will { will_topic.encode(writer)?; will_message.encode(writer)?; } if let Some(ref user_name) = self.user_name { user_name.encode(writer)?; } if let Some(ref password) = self.password { password.encode(writer)?; } Ok(()) } fn encoded_length(&self) -> u32 { self.client_identifier.encoded_length() + self .will .as_ref() .map(|(a, b)| a.encoded_length() + b.encoded_length()) .unwrap_or(0) + self.user_name.as_ref().map(|t| t.encoded_length()).unwrap_or(0) + self.password.as_ref().map(|t| t.encoded_length()).unwrap_or(0) } } impl Decodable for ConnectPacketPayload { type Error = ConnectPacketError; type Cond = Option; fn decode_with( reader: &mut R, rest: Option, ) -> Result { let mut need_will = false; let mut need_user_name = false; let mut need_password = false; if let Some(r) = rest { need_will = r.will_flag; need_user_name = r.user_name; need_password = r.password; } let ident = String::decode(reader)?; let will = if need_will { let topic = TopicName::decode(reader).map_err(|e| match e { TopicNameDecodeError::IoError(e) => ConnectPacketError::from(e), TopicNameDecodeError::InvalidTopicName(e) => e.into(), })?; let msg = VarBytes::decode(reader)?; Some((topic, msg)) } else { None }; let uname = if need_user_name { Some(String::decode(reader)?) } else { None }; let pwd = if need_password { Some(String::decode(reader)?) } else { None }; Ok(ConnectPacketPayload { client_identifier: ident, will, user_name: uname, password: pwd, }) } } #[derive(Debug, thiserror::Error)] #[error(transparent)] pub enum ConnectPacketError { IoError(#[from] io::Error), TopicNameError(#[from] TopicNameError), } #[cfg(test)] mod test { use super::*; use std::io::Cursor; use crate::{Decodable, Encodable}; #[test] fn test_connect_packet_encode_basic() { let packet = ConnectPacket::new("12345".to_owned()); let expected = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345"; let mut buf = Vec::new(); packet.encode(&mut buf).unwrap(); assert_eq!(&expected[..], &buf[..]); } #[test] fn test_connect_packet_decode_basic() { let encoded_data = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345"; let mut buf = Cursor::new(&encoded_data[..]); let packet = ConnectPacket::decode(&mut buf).unwrap(); let expected = ConnectPacket::new("12345".to_owned()); assert_eq!(expected, packet); } #[test] fn test_connect_packet_user_name() { let mut packet = ConnectPacket::new("12345".to_owned()); packet.set_user_name(Some("mqtt_player".to_owned())); let mut buf = Vec::new(); packet.encode(&mut buf).unwrap(); let mut decode_buf = Cursor::new(buf); let decoded_packet = ConnectPacket::decode(&mut decode_buf).unwrap(); assert_eq!(packet, decoded_packet); } } ================================================ FILE: src/packet/disconnect.rs ================================================ //! DISCONNECT use std::io::Read; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; /// `DISCONNECT` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct DisconnectPacket { fixed_header: FixedHeader, } encodable_packet!(DisconnectPacket()); impl DisconnectPacket { pub fn new() -> DisconnectPacket { DisconnectPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Disconnect), 0), } } } impl Default for DisconnectPacket { fn default() -> DisconnectPacket { DisconnectPacket::new() } } impl DecodablePacket for DisconnectPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(_reader: &mut R, fixed_header: FixedHeader) -> Result> { Ok(DisconnectPacket { fixed_header }) } } ================================================ FILE: src/packet/mod.rs ================================================ //! Specific packets use std::error::Error; use std::fmt::{self, Debug}; use std::io::{self, Read, Write}; #[cfg(feature = "tokio")] use tokio::io::{AsyncRead, AsyncReadExt}; use crate::control::fixed_header::FixedHeaderError; use crate::control::variable_header::VariableHeaderError; use crate::control::ControlType; use crate::control::FixedHeader; use crate::topic_name::{TopicNameDecodeError, TopicNameError}; use crate::{Decodable, Encodable}; macro_rules! encodable_packet { ($typ:ident($($field:ident),* $(,)?)) => { impl $crate::packet::EncodablePacket for $typ { fn fixed_header(&self) -> &$crate::control::fixed_header::FixedHeader { &self.fixed_header } #[allow(unused)] fn encode_packet(&self, writer: &mut W) -> ::std::io::Result<()> { $($crate::encodable::Encodable::encode(&self.$field, writer)?;)* Ok(()) } fn encoded_packet_length(&self) -> u32 { $($crate::encodable::Encodable::encoded_length(&self.$field) +)* 0 } } impl $typ { #[allow(unused)] #[inline(always)] fn fix_header_remaining_len(&mut self) { self.fixed_header.remaining_length = $crate::packet::EncodablePacket::encoded_packet_length(self); } } }; } pub use self::connack::ConnackPacket; pub use self::connect::ConnectPacket; pub use self::disconnect::DisconnectPacket; pub use self::pingreq::PingreqPacket; pub use self::pingresp::PingrespPacket; pub use self::puback::PubackPacket; pub use self::pubcomp::PubcompPacket; pub use self::publish::{PublishPacket, PublishPacketRef}; pub use self::pubrec::PubrecPacket; pub use self::pubrel::PubrelPacket; pub use self::suback::SubackPacket; pub use self::subscribe::SubscribePacket; pub use self::unsuback::UnsubackPacket; pub use self::unsubscribe::UnsubscribePacket; pub use self::publish::QoSWithPacketIdentifier; pub mod connack; pub mod connect; pub mod disconnect; pub mod pingreq; pub mod pingresp; pub mod puback; pub mod pubcomp; pub mod publish; pub mod pubrec; pub mod pubrel; pub mod suback; pub mod subscribe; pub mod unsuback; pub mod unsubscribe; /// A trait representing a packet that can be encoded, when passed as `FooPacket` or as /// `&FooPacket`. Different from [`Encodable`] in that it prevents you from accidentally passing /// a type intended to be encoded only as a part of a packet and doesn't have a header, e.g. /// `Vec`. pub trait EncodablePacket { /// Get a reference to `FixedHeader`. All MQTT packet must have a fixed header. fn fixed_header(&self) -> &FixedHeader; /// Encodes packet data after fixed header, including variable headers and payload fn encode_packet(&self, _writer: &mut W) -> io::Result<()> { Ok(()) } /// Length in bytes for data after fixed header, including variable headers and payload fn encoded_packet_length(&self) -> u32 { 0 } } impl Encodable for T { fn encode(&self, writer: &mut W) -> io::Result<()> { self.fixed_header().encode(writer)?; self.encode_packet(writer) } fn encoded_length(&self) -> u32 { self.fixed_header().encoded_length() + self.encoded_packet_length() } } pub trait DecodablePacket: EncodablePacket + Sized { type DecodePacketError: Error + 'static; /// Decode packet given a `FixedHeader` fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result>; } impl Decodable for T { type Error = PacketError; type Cond = Option; fn decode_with(reader: &mut R, fixed_header: Self::Cond) -> Result { let fixed_header: FixedHeader = if let Some(hdr) = fixed_header { hdr } else { Decodable::decode(reader)? }; ::decode_packet(reader, fixed_header) } } /// Parsing errors for packet #[derive(thiserror::Error)] #[error(transparent)] pub enum PacketError

where P: DecodablePacket, { FixedHeaderError(#[from] FixedHeaderError), VariableHeaderError(#[from] VariableHeaderError), PayloadError(

::DecodePacketError), IoError(#[from] io::Error), TopicNameError(#[from] TopicNameError), } impl

Debug for PacketError

where P: DecodablePacket, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { PacketError::FixedHeaderError(ref e) => f.debug_tuple("FixedHeaderError").field(e).finish(), PacketError::VariableHeaderError(ref e) => f.debug_tuple("VariableHeaderError").field(e).finish(), PacketError::PayloadError(ref e) => f.debug_tuple("PayloadError").field(e).finish(), PacketError::IoError(ref e) => f.debug_tuple("IoError").field(e).finish(), PacketError::TopicNameError(ref e) => f.debug_tuple("TopicNameError").field(e).finish(), } } } impl From for PacketError

{ fn from(e: TopicNameDecodeError) -> Self { match e { TopicNameDecodeError::IoError(e) => e.into(), TopicNameDecodeError::InvalidTopicName(e) => e.into(), } } } macro_rules! impl_variable_packet { ($($name:ident & $errname:ident => $hdr:ident,)+) => { /// Variable packet #[derive(Debug, Eq, PartialEq, Clone)] pub enum VariablePacket { $( $name($name), )+ } #[cfg(feature = "tokio")] impl VariablePacket { /// Asynchronously parse a packet from a `tokio::io::AsyncRead` /// /// This requires mqtt-rs to be built with `feature = "tokio"` pub async fn parse(rdr: &mut A) -> Result { use std::io::Cursor; let fixed_header = FixedHeader::parse(rdr).await?; let mut buffer = vec![0u8; fixed_header.remaining_length as usize]; rdr.read_exact(&mut buffer).await?; decode_with_header(&mut Cursor::new(buffer), fixed_header) } } #[inline] fn decode_with_header(rdr: &mut R, fixed_header: FixedHeader) -> Result { match fixed_header.packet_type.control_type() { $( ControlType::$hdr => { let pk = <$name as DecodablePacket>::decode_packet(rdr, fixed_header)?; Ok(VariablePacket::$name(pk)) } )+ } } $( impl From<$name> for VariablePacket { fn from(pk: $name) -> VariablePacket { VariablePacket::$name(pk) } } )+ // impl Encodable for VariablePacket { // fn encode(&self, writer: &mut W) -> Result<(), io::Error> { // match *self { // $( // VariablePacket::$name(ref pk) => pk.encode(writer), // )+ // } // } // fn encoded_length(&self) -> u32 { // match *self { // $( // VariablePacket::$name(ref pk) => pk.encoded_length(), // )+ // } // } // } impl EncodablePacket for VariablePacket { fn fixed_header(&self) -> &FixedHeader { match *self { $( VariablePacket::$name(ref pk) => pk.fixed_header(), )+ } } fn encode_packet(&self, writer: &mut W) -> io::Result<()> { match *self { $( VariablePacket::$name(ref pk) => pk.encode_packet(writer), )+ } } fn encoded_packet_length(&self) -> u32 { match *self { $( VariablePacket::$name(ref pk) => pk.encoded_packet_length(), )+ } } } impl Decodable for VariablePacket { type Error = VariablePacketError; type Cond = Option; fn decode_with(reader: &mut R, fixed_header: Self::Cond) -> Result { let fixed_header = match fixed_header { Some(fh) => fh, None => { match FixedHeader::decode(reader) { Ok(header) => header, Err(FixedHeaderError::ReservedType(code, length)) => { let reader = &mut reader.take(length as u64); let mut buf = Vec::with_capacity(length as usize); reader.read_to_end(&mut buf)?; return Err(VariablePacketError::ReservedPacket(code, buf)); }, Err(err) => return Err(From::from(err)) } } }; let reader = &mut reader.take(fixed_header.remaining_length as u64); decode_with_header(reader, fixed_header) } } /// Parsing errors for variable packet #[derive(Debug, thiserror::Error)] pub enum VariablePacketError { #[error(transparent)] FixedHeaderError(#[from] FixedHeaderError), #[error("reserved packet type ({0}), [u8, ..{}]", .1.len())] ReservedPacket(u8, Vec), #[error(transparent)] IoError(#[from] io::Error), $( #[error(transparent)] $errname(#[from] PacketError<$name>), )+ } } } impl_variable_packet! { ConnectPacket & ConnectPacketError => Connect, ConnackPacket & ConnackPacketError => ConnectAcknowledgement, PublishPacket & PublishPacketError => Publish, PubackPacket & PubackPacketError => PublishAcknowledgement, PubrecPacket & PubrecPacketError => PublishReceived, PubrelPacket & PubrelPacketError => PublishRelease, PubcompPacket & PubcompPacketError => PublishComplete, PingreqPacket & PingreqPacketError => PingRequest, PingrespPacket & PingrespPacketError => PingResponse, SubscribePacket & SubscribePacketError => Subscribe, SubackPacket & SubackPacketError => SubscribeAcknowledgement, UnsubscribePacket & UnsubscribePacketError => Unsubscribe, UnsubackPacket & UnsubackPacketError => UnsubscribeAcknowledgement, DisconnectPacket & DisconnectPacketError => Disconnect, } impl VariablePacket { pub fn new(t: T) -> VariablePacket where VariablePacket: From, { From::from(t) } } #[cfg(feature = "tokio-codec")] mod tokio_codec { use super::*; use crate::control::packet_type::{PacketType, PacketTypeError}; use bytes::{Buf, BufMut, BytesMut}; use tokio_util::codec; pub struct MqttDecoder { state: DecodeState, } enum DecodeState { Start, Packet { length: u32, typ: DecodePacketType }, } #[derive(Copy, Clone)] enum DecodePacketType { Standard(PacketType), Reserved(u8), } impl MqttDecoder { pub const fn new() -> Self { MqttDecoder { state: DecodeState::Start, } } } /// Like FixedHeader::decode(), but on a buffer instead of a stream. Returns None if it reaches /// the end of the buffer before it finishes decoding the header. #[inline] fn decode_header(mut data: &[u8]) -> Option> { let mut header_size = 0; macro_rules! read_u8 { () => {{ let (&x, rest) = data.split_first()?; data = rest; header_size += 1; x }}; } let type_val = read_u8!(); let remaining_len = { let mut cur = 0u32; for i in 0.. { let byte = read_u8!(); cur |= ((byte as u32) & 0x7F) << (7 * i); if i >= 4 { return Some(Err(FixedHeaderError::MalformedRemainingLength)); } if byte & 0x80 == 0 { break; } } cur }; let packet_type = match PacketType::from_u8(type_val) { Ok(ty) => DecodePacketType::Standard(ty), Err(PacketTypeError::ReservedType(ty, _)) => DecodePacketType::Reserved(ty), Err(err) => return Some(Err(err.into())), }; Some(Ok((packet_type, remaining_len, header_size))) } impl codec::Decoder for MqttDecoder { type Item = VariablePacket; type Error = VariablePacketError; fn decode(&mut self, src: &mut BytesMut) -> Result, VariablePacketError> { loop { match &mut self.state { DecodeState::Start => match decode_header(&src[..]) { Some(Ok((typ, length, header_size))) => { src.advance(header_size); self.state = DecodeState::Packet { length, typ }; continue; } Some(Err(e)) => return Err(e.into()), None => return Ok(None), }, DecodeState::Packet { length, typ } => { let length = *length; if src.remaining() < length as usize { return Ok(None); } let typ = *typ; self.state = DecodeState::Start; match typ { DecodePacketType::Standard(typ) => { let header = FixedHeader { packet_type: typ, remaining_length: length, }; return decode_with_header(&mut src.reader(), header).map(Some); } DecodePacketType::Reserved(code) => { let data = src[..length as usize].to_vec(); src.advance(length as usize); return Err(VariablePacketError::ReservedPacket(code, data)); } } } } } } } pub struct MqttEncoder { _priv: (), } impl MqttEncoder { pub const fn new() -> Self { MqttEncoder { _priv: () } } } impl codec::Encoder for MqttEncoder { type Error = io::Error; fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> { dst.reserve(packet.encoded_length() as usize); packet.encode(&mut dst.writer()) } } pub struct MqttCodec { decode: MqttDecoder, encode: MqttEncoder, } impl MqttCodec { pub const fn new() -> Self { MqttCodec { decode: MqttDecoder::new(), encode: MqttEncoder::new(), } } } impl codec::Decoder for MqttCodec { type Item = VariablePacket; type Error = VariablePacketError; #[inline] fn decode(&mut self, src: &mut BytesMut) -> Result, VariablePacketError> { self.decode.decode(src) } } impl codec::Encoder for MqttCodec { type Error = io::Error; #[inline] fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> { self.encode.encode(packet, dst) } } } #[cfg(feature = "tokio-codec")] pub use tokio_codec::{MqttCodec, MqttDecoder, MqttEncoder}; #[cfg(test)] mod test { use super::*; use std::io::Cursor; use crate::{Decodable, Encodable}; #[test] fn test_variable_packet_basic() { let packet = ConnectPacket::new("1234".to_owned()); // Wrap it let var_packet = VariablePacket::new(packet); // Encode let mut buf = Vec::new(); var_packet.encode(&mut buf).unwrap(); // Decode let mut decode_buf = Cursor::new(buf); let decoded_packet = VariablePacket::decode(&mut decode_buf).unwrap(); assert_eq!(var_packet, decoded_packet); } #[cfg(feature = "tokio")] #[tokio::test] async fn test_variable_packet_async_parse() { let packet = ConnectPacket::new("1234".to_owned()); // Wrap it let var_packet = VariablePacket::new(packet); // Encode let mut buf = Vec::new(); var_packet.encode(&mut buf).unwrap(); // Parse let mut async_buf = buf.as_slice(); let decoded_packet = VariablePacket::parse(&mut async_buf).await.unwrap(); assert_eq!(var_packet, decoded_packet); } #[cfg(feature = "tokio-codec")] #[tokio::test] async fn test_variable_packet_framed() { use crate::{QualityOfService, TopicFilter}; use futures::{SinkExt, StreamExt}; use tokio_util::codec::{FramedRead, FramedWrite}; let conn_packet = ConnectPacket::new("1234".to_owned()); let sub_packet = SubscribePacket::new(1, vec![(TopicFilter::new("foo/#").unwrap(), QualityOfService::Level0)]); // small, to make sure buffering and stuff works let (reader, writer) = tokio::io::duplex(8); let task = tokio::spawn({ let (conn_packet, sub_packet) = (conn_packet.clone(), sub_packet.clone()); async move { let mut sink = FramedWrite::new(writer, MqttEncoder::new()); sink.send(conn_packet).await.unwrap(); sink.send(sub_packet).await.unwrap(); SinkExt::::flush(&mut sink).await.unwrap(); } }); let mut stream = FramedRead::new(reader, MqttDecoder::new()); let decoded_conn = stream.next().await.unwrap().unwrap(); let decoded_sub = stream.next().await.unwrap().unwrap(); task.await.unwrap(); assert!(stream.next().await.is_none()); assert_eq!(decoded_conn, conn_packet.into()); assert_eq!(decoded_sub, sub_packet.into()); } } ================================================ FILE: src/packet/pingreq.rs ================================================ //! PINGREQ use std::io::Read; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; /// `PINGREQ` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct PingreqPacket { fixed_header: FixedHeader, } encodable_packet!(PingreqPacket()); impl PingreqPacket { pub fn new() -> PingreqPacket { PingreqPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PingRequest), 0), } } } impl Default for PingreqPacket { fn default() -> PingreqPacket { PingreqPacket::new() } } impl DecodablePacket for PingreqPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(_reader: &mut R, fixed_header: FixedHeader) -> Result> { Ok(PingreqPacket { fixed_header }) } } ================================================ FILE: src/packet/pingresp.rs ================================================ //! PINGRESP use std::io::Read; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; /// `PINGRESP` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct PingrespPacket { fixed_header: FixedHeader, } encodable_packet!(PingrespPacket()); impl PingrespPacket { pub fn new() -> PingrespPacket { PingrespPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PingResponse), 0), } } } impl Default for PingrespPacket { fn default() -> PingrespPacket { PingrespPacket::new() } } impl DecodablePacket for PingrespPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(_reader: &mut R, fixed_header: FixedHeader) -> Result> { Ok(PingrespPacket { fixed_header }) } } ================================================ FILE: src/packet/puback.rs ================================================ //! PUBACK use std::io::Read; use crate::control::variable_header::PacketIdentifier; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::Decodable; /// `PUBACK` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct PubackPacket { fixed_header: FixedHeader, packet_identifier: PacketIdentifier, } encodable_packet!(PubackPacket(packet_identifier)); impl PubackPacket { pub fn new(pkid: u16) -> PubackPacket { PubackPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PublishAcknowledgement), 2), packet_identifier: PacketIdentifier(pkid), } } pub fn packet_identifier(&self) -> u16 { self.packet_identifier.0 } pub fn set_packet_identifier(&mut self, pkid: u16) { self.packet_identifier.0 = pkid; } } impl DecodablePacket for PubackPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; Ok(PubackPacket { fixed_header, packet_identifier, }) } } ================================================ FILE: src/packet/pubcomp.rs ================================================ //! PUBCOMP use std::io::Read; use crate::control::variable_header::PacketIdentifier; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::Decodable; /// `PUBCOMP` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct PubcompPacket { fixed_header: FixedHeader, packet_identifier: PacketIdentifier, } encodable_packet!(PubcompPacket(packet_identifier)); impl PubcompPacket { pub fn new(pkid: u16) -> PubcompPacket { PubcompPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PublishComplete), 2), packet_identifier: PacketIdentifier(pkid), } } pub fn packet_identifier(&self) -> u16 { self.packet_identifier.0 } pub fn set_packet_identifier(&mut self, pkid: u16) { self.packet_identifier.0 = pkid; } } impl DecodablePacket for PubcompPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; Ok(PubcompPacket { fixed_header, packet_identifier, }) } } ================================================ FILE: src/packet/publish.rs ================================================ //! PUBLISH use std::io::{self, Read, Write}; use crate::control::{FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::qos::QualityOfService; use crate::topic_name::TopicName; use crate::{control::variable_header::PacketIdentifier, TopicNameRef}; use crate::{Decodable, Encodable}; use super::EncodablePacket; /// QoS with identifier pairs #[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)] pub enum QoSWithPacketIdentifier { Level0, Level1(u16), Level2(u16), } impl QoSWithPacketIdentifier { pub fn new(qos: QualityOfService, id: u16) -> QoSWithPacketIdentifier { match (qos, id) { (QualityOfService::Level0, _) => QoSWithPacketIdentifier::Level0, (QualityOfService::Level1, id) => QoSWithPacketIdentifier::Level1(id), (QualityOfService::Level2, id) => QoSWithPacketIdentifier::Level2(id), } } pub fn split(self) -> (QualityOfService, Option) { match self { QoSWithPacketIdentifier::Level0 => (QualityOfService::Level0, None), QoSWithPacketIdentifier::Level1(pkid) => (QualityOfService::Level1, Some(pkid)), QoSWithPacketIdentifier::Level2(pkid) => (QualityOfService::Level2, Some(pkid)), } } } /// `PUBLISH` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct PublishPacket { fixed_header: FixedHeader, topic_name: TopicName, packet_identifier: Option, payload: Vec, } encodable_packet!(PublishPacket(topic_name, packet_identifier, payload)); impl PublishPacket { pub fn new>>(topic_name: TopicName, qos: QoSWithPacketIdentifier, payload: P) -> PublishPacket { let (qos, pkid) = qos.split(); let mut pk = PublishPacket { fixed_header: FixedHeader::new(PacketType::publish(qos), 0), topic_name, packet_identifier: pkid.map(PacketIdentifier), payload: payload.into(), }; pk.fix_header_remaining_len(); pk } pub fn set_dup(&mut self, dup: bool) { self.fixed_header .packet_type .update_flags(|flags| (flags & !(1 << 3)) | (dup as u8) << 3) } pub fn dup(&self) -> bool { self.fixed_header.packet_type.flags() & 0x80 != 0 } pub fn set_qos(&mut self, qos: QoSWithPacketIdentifier) { let (qos, pkid) = qos.split(); self.fixed_header .packet_type .update_flags(|flags| (flags & !0b0110) | (qos as u8) << 1); self.packet_identifier = pkid.map(PacketIdentifier); self.fix_header_remaining_len(); } pub fn qos(&self) -> QoSWithPacketIdentifier { match self.packet_identifier { None => QoSWithPacketIdentifier::Level0, Some(pkid) => { let qos_val = (self.fixed_header.packet_type.flags() & 0b0110) >> 1; match qos_val { 1 => QoSWithPacketIdentifier::Level1(pkid.0), 2 => QoSWithPacketIdentifier::Level2(pkid.0), _ => unreachable!(), } } } } pub fn set_retain(&mut self, ret: bool) { self.fixed_header .packet_type .update_flags(|flags| (flags & !0b0001) | (ret as u8)) } pub fn retain(&self) -> bool { self.fixed_header.packet_type.flags() & 0b0001 != 0 } pub fn set_topic_name(&mut self, topic_name: TopicName) { self.topic_name = topic_name; self.fix_header_remaining_len(); } pub fn topic_name(&self) -> &str { &self.topic_name[..] } pub fn payload(&self) -> &[u8] { &self.payload } pub fn set_payload>>(&mut self, payload: P) { self.payload = payload.into(); self.fix_header_remaining_len(); } } impl DecodablePacket for PublishPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let topic_name = TopicName::decode(reader)?; let qos = (fixed_header.packet_type.flags() & 0b0110) >> 1; let packet_identifier = if qos > 0 { Some(PacketIdentifier::decode(reader)?) } else { None }; let vhead_len = topic_name.encoded_length() + packet_identifier.as_ref().map(|x| x.encoded_length()).unwrap_or(0); let payload_len = fixed_header.remaining_length - vhead_len; let payload = Vec::::decode_with(reader, Some(payload_len))?; Ok(PublishPacket { fixed_header, topic_name, packet_identifier, payload, }) } } /// `PUBLISH` packet by reference, for encoding only pub struct PublishPacketRef<'a> { fixed_header: FixedHeader, topic_name: &'a TopicNameRef, packet_identifier: Option, payload: &'a [u8], } impl<'a> PublishPacketRef<'a> { pub fn new(topic_name: &'a TopicNameRef, qos: QoSWithPacketIdentifier, payload: &'a [u8]) -> PublishPacketRef<'a> { let (qos, pkid) = qos.split(); let mut pk = PublishPacketRef { fixed_header: FixedHeader::new(PacketType::publish(qos), 0), topic_name, packet_identifier: pkid.map(PacketIdentifier), payload, }; pk.fix_header_remaining_len(); pk } fn fix_header_remaining_len(&mut self) { self.fixed_header.remaining_length = self.topic_name.encoded_length() + self.packet_identifier.encoded_length() + self.payload.encoded_length(); } } impl EncodablePacket for PublishPacketRef<'_> { fn fixed_header(&self) -> &FixedHeader { &self.fixed_header } fn encode_packet(&self, writer: &mut W) -> io::Result<()> { self.topic_name.encode(writer)?; self.packet_identifier.encode(writer)?; self.payload.encode(writer) } fn encoded_packet_length(&self) -> u32 { self.topic_name.encoded_length() + self.packet_identifier.encoded_length() + self.payload.encoded_length() } } #[cfg(test)] mod test { use super::*; use std::io::Cursor; use crate::topic_name::TopicName; use crate::{Decodable, Encodable}; #[test] fn test_publish_packet_basic() { let packet = PublishPacket::new( TopicName::new("a/b".to_owned()).unwrap(), QoSWithPacketIdentifier::Level2(10), b"Hello world!".to_vec(), ); let mut buf = Vec::new(); packet.encode(&mut buf).unwrap(); let mut decode_buf = Cursor::new(buf); let decoded = PublishPacket::decode(&mut decode_buf).unwrap(); assert_eq!(packet, decoded); } #[test] fn issue56() { let mut packet = PublishPacket::new( TopicName::new("topic").unwrap(), QoSWithPacketIdentifier::Level0, Vec::new(), ); assert_eq!(packet.fixed_header().remaining_length, 7); packet.set_qos(QoSWithPacketIdentifier::Level1(1)); assert_eq!(packet.fixed_header().remaining_length, 9); } } ================================================ FILE: src/packet/pubrec.rs ================================================ //! PUBREC use std::io::Read; use crate::control::variable_header::PacketIdentifier; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::Decodable; /// `PUBREC` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct PubrecPacket { fixed_header: FixedHeader, packet_identifier: PacketIdentifier, } encodable_packet!(PubrecPacket(packet_identifier)); impl PubrecPacket { pub fn new(pkid: u16) -> PubrecPacket { PubrecPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PublishReceived), 2), packet_identifier: PacketIdentifier(pkid), } } pub fn packet_identifier(&self) -> u16 { self.packet_identifier.0 } pub fn set_packet_identifier(&mut self, pkid: u16) { self.packet_identifier.0 = pkid; } } impl DecodablePacket for PubrecPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; Ok(PubrecPacket { fixed_header, packet_identifier, }) } } ================================================ FILE: src/packet/pubrel.rs ================================================ //! PUBREL use std::io::Read; use crate::control::variable_header::PacketIdentifier; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::Decodable; /// `PUBREL` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct PubrelPacket { fixed_header: FixedHeader, packet_identifier: PacketIdentifier, } encodable_packet!(PubrelPacket(packet_identifier)); impl PubrelPacket { pub fn new(pkid: u16) -> PubrelPacket { PubrelPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PublishRelease), 2), packet_identifier: PacketIdentifier(pkid), } } pub fn packet_identifier(&self) -> u16 { self.packet_identifier.0 } pub fn set_packet_identifier(&mut self, pkid: u16) { self.packet_identifier.0 = pkid; } } impl DecodablePacket for PubrelPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; Ok(PubrelPacket { fixed_header, packet_identifier, }) } } ================================================ FILE: src/packet/suback.rs ================================================ //! SUBACK use std::cmp::Ordering; use std::io::{self, Read, Write}; use byteorder::{ReadBytesExt, WriteBytesExt}; use crate::control::variable_header::PacketIdentifier; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::qos::QualityOfService; use crate::{Decodable, Encodable}; /// Subscribe code #[repr(u8)] #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum SubscribeReturnCode { MaximumQoSLevel0 = 0x00, MaximumQoSLevel1 = 0x01, MaximumQoSLevel2 = 0x02, Failure = 0x80, } impl PartialOrd for SubscribeReturnCode { fn partial_cmp(&self, other: &Self) -> Option { use self::SubscribeReturnCode::*; match (self, other) { (&Failure, _) => None, (_, &Failure) => None, (&MaximumQoSLevel0, &MaximumQoSLevel0) => Some(Ordering::Equal), (&MaximumQoSLevel1, &MaximumQoSLevel1) => Some(Ordering::Equal), (&MaximumQoSLevel2, &MaximumQoSLevel2) => Some(Ordering::Equal), (&MaximumQoSLevel0, _) => Some(Ordering::Less), (&MaximumQoSLevel1, &MaximumQoSLevel0) => Some(Ordering::Greater), (&MaximumQoSLevel1, &MaximumQoSLevel2) => Some(Ordering::Less), (&MaximumQoSLevel2, _) => Some(Ordering::Greater), } } } impl From for SubscribeReturnCode { fn from(qos: QualityOfService) -> Self { match qos { QualityOfService::Level0 => SubscribeReturnCode::MaximumQoSLevel0, QualityOfService::Level1 => SubscribeReturnCode::MaximumQoSLevel1, QualityOfService::Level2 => SubscribeReturnCode::MaximumQoSLevel2, } } } /// `SUBACK` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct SubackPacket { fixed_header: FixedHeader, packet_identifier: PacketIdentifier, payload: SubackPacketPayload, } encodable_packet!(SubackPacket(packet_identifier, payload)); impl SubackPacket { pub fn new(pkid: u16, subscribes: Vec) -> SubackPacket { let mut pk = SubackPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::SubscribeAcknowledgement), 0), packet_identifier: PacketIdentifier(pkid), payload: SubackPacketPayload::new(subscribes), }; pk.fix_header_remaining_len(); pk } pub fn packet_identifier(&self) -> u16 { self.packet_identifier.0 } pub fn set_packet_identifier(&mut self, pkid: u16) { self.packet_identifier.0 = pkid; } pub fn subscribes(&self) -> &[SubscribeReturnCode] { &self.payload.subscribes[..] } } impl DecodablePacket for SubackPacket { type DecodePacketError = SubackPacketError; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let packet_identifier = PacketIdentifier::decode(reader)?; let payload: SubackPacketPayload = SubackPacketPayload::decode_with( reader, fixed_header.remaining_length - packet_identifier.encoded_length(), ) .map_err(PacketError::PayloadError)?; Ok(SubackPacket { fixed_header, packet_identifier, payload, }) } } #[derive(Debug, Eq, PartialEq, Clone)] struct SubackPacketPayload { subscribes: Vec, } impl SubackPacketPayload { pub fn new(subs: Vec) -> SubackPacketPayload { SubackPacketPayload { subscribes: subs } } } impl Encodable for SubackPacketPayload { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { for code in self.subscribes.iter() { writer.write_u8(*code as u8)?; } Ok(()) } fn encoded_length(&self) -> u32 { self.subscribes.len() as u32 } } impl Decodable for SubackPacketPayload { type Error = SubackPacketError; type Cond = u32; fn decode_with(reader: &mut R, payload_len: u32) -> Result { let mut subs = Vec::new(); for _ in 0..payload_len { let retcode = match reader.read_u8()? { 0x00 => SubscribeReturnCode::MaximumQoSLevel0, 0x01 => SubscribeReturnCode::MaximumQoSLevel1, 0x02 => SubscribeReturnCode::MaximumQoSLevel2, 0x80 => SubscribeReturnCode::Failure, code => return Err(SubackPacketError::InvalidSubscribeReturnCode(code)), }; subs.push(retcode); } Ok(SubackPacketPayload::new(subs)) } } #[derive(Debug, thiserror::Error)] pub enum SubackPacketError { #[error(transparent)] IoError(#[from] io::Error), #[error("invalid subscribe return code {0}")] InvalidSubscribeReturnCode(u8), } ================================================ FILE: src/packet/subscribe.rs ================================================ //! SUBSCRIBE use std::io::{self, Read, Write}; use std::string::FromUtf8Error; use byteorder::{ReadBytesExt, WriteBytesExt}; use crate::control::variable_header::PacketIdentifier; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::topic_filter::{TopicFilter, TopicFilterDecodeError, TopicFilterError}; use crate::{Decodable, Encodable, QualityOfService}; /// `SUBSCRIBE` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct SubscribePacket { fixed_header: FixedHeader, packet_identifier: PacketIdentifier, payload: SubscribePacketPayload, } encodable_packet!(SubscribePacket(packet_identifier, payload)); impl SubscribePacket { pub fn new(pkid: u16, subscribes: Vec<(TopicFilter, QualityOfService)>) -> SubscribePacket { let mut pk = SubscribePacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Subscribe), 0), packet_identifier: PacketIdentifier(pkid), payload: SubscribePacketPayload::new(subscribes), }; pk.fix_header_remaining_len(); pk } pub fn packet_identifier(&self) -> u16 { self.packet_identifier.0 } pub fn set_packet_identifier(&mut self, pkid: u16) { self.packet_identifier.0 = pkid; } pub fn subscribes(&self) -> &[(TopicFilter, QualityOfService)] { &self.payload.subscribes[..] } } impl DecodablePacket for SubscribePacket { type DecodePacketError = SubscribePacketError; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; let payload: SubscribePacketPayload = SubscribePacketPayload::decode_with( reader, fixed_header.remaining_length - packet_identifier.encoded_length(), ) .map_err(PacketError::PayloadError)?; Ok(SubscribePacket { fixed_header, packet_identifier, payload, }) } } /// Payload of subscribe packet #[derive(Debug, Eq, PartialEq, Clone)] struct SubscribePacketPayload { subscribes: Vec<(TopicFilter, QualityOfService)>, } impl SubscribePacketPayload { pub fn new(subs: Vec<(TopicFilter, QualityOfService)>) -> SubscribePacketPayload { SubscribePacketPayload { subscribes: subs } } } impl Encodable for SubscribePacketPayload { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { for &(ref filter, ref qos) in self.subscribes.iter() { filter.encode(writer)?; writer.write_u8(*qos as u8)?; } Ok(()) } fn encoded_length(&self) -> u32 { self.subscribes.iter().fold(0, |b, a| b + a.0.encoded_length() + 1) } } impl Decodable for SubscribePacketPayload { type Error = SubscribePacketError; type Cond = u32; fn decode_with( reader: &mut R, mut payload_len: u32, ) -> Result { let mut subs = Vec::new(); while payload_len > 0 { let filter = TopicFilter::decode(reader)?; let qos = match reader.read_u8()? { 0 => QualityOfService::Level0, 1 => QualityOfService::Level1, 2 => QualityOfService::Level2, _ => return Err(SubscribePacketError::InvalidQualityOfService), }; payload_len -= filter.encoded_length() + 1; subs.push((filter, qos)); } Ok(SubscribePacketPayload::new(subs)) } } #[derive(Debug, thiserror::Error)] pub enum SubscribePacketError { #[error(transparent)] IoError(#[from] io::Error), #[error(transparent)] FromUtf8Error(#[from] FromUtf8Error), #[error("invalid quality of service")] InvalidQualityOfService, #[error(transparent)] TopicFilterError(#[from] TopicFilterError), } impl From for SubscribePacketError { fn from(e: TopicFilterDecodeError) -> Self { match e { TopicFilterDecodeError::IoError(e) => e.into(), TopicFilterDecodeError::InvalidTopicFilter(e) => e.into(), } } } ================================================ FILE: src/packet/unsuback.rs ================================================ //! UNSUBACK use std::io::Read; use crate::control::variable_header::PacketIdentifier; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::Decodable; /// `UNSUBACK` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct UnsubackPacket { fixed_header: FixedHeader, packet_identifier: PacketIdentifier, } encodable_packet!(UnsubackPacket(packet_identifier)); impl UnsubackPacket { pub fn new(pkid: u16) -> UnsubackPacket { UnsubackPacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::UnsubscribeAcknowledgement), 2), packet_identifier: PacketIdentifier(pkid), } } pub fn packet_identifier(&self) -> u16 { self.packet_identifier.0 } pub fn set_packet_identifier(&mut self, pkid: u16) { self.packet_identifier.0 = pkid; } } impl DecodablePacket for UnsubackPacket { type DecodePacketError = std::convert::Infallible; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; Ok(UnsubackPacket { fixed_header, packet_identifier, }) } } ================================================ FILE: src/packet/unsubscribe.rs ================================================ //! UNSUBSCRIBE use std::io::{self, Read, Write}; use std::string::FromUtf8Error; use crate::control::variable_header::PacketIdentifier; use crate::control::{ControlType, FixedHeader, PacketType}; use crate::packet::{DecodablePacket, PacketError}; use crate::topic_filter::{TopicFilter, TopicFilterDecodeError, TopicFilterError}; use crate::{Decodable, Encodable}; /// `UNSUBSCRIBE` packet #[derive(Debug, Eq, PartialEq, Clone)] pub struct UnsubscribePacket { fixed_header: FixedHeader, packet_identifier: PacketIdentifier, payload: UnsubscribePacketPayload, } encodable_packet!(UnsubscribePacket(packet_identifier, payload)); impl UnsubscribePacket { pub fn new(pkid: u16, subscribes: Vec) -> UnsubscribePacket { let mut pk = UnsubscribePacket { fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Unsubscribe), 0), packet_identifier: PacketIdentifier(pkid), payload: UnsubscribePacketPayload::new(subscribes), }; pk.fix_header_remaining_len(); pk } pub fn packet_identifier(&self) -> u16 { self.packet_identifier.0 } pub fn set_packet_identifier(&mut self, pkid: u16) { self.packet_identifier.0 = pkid; } pub fn subscribes(&self) -> &[TopicFilter] { &self.payload.subscribes[..] } } impl DecodablePacket for UnsubscribePacket { type DecodePacketError = UnsubscribePacketError; fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; let payload: UnsubscribePacketPayload = UnsubscribePacketPayload::decode_with( reader, fixed_header.remaining_length - packet_identifier.encoded_length(), ) .map_err(PacketError::PayloadError)?; Ok(UnsubscribePacket { fixed_header, packet_identifier, payload, }) } } #[derive(Debug, Eq, PartialEq, Clone)] struct UnsubscribePacketPayload { subscribes: Vec, } impl UnsubscribePacketPayload { pub fn new(subs: Vec) -> UnsubscribePacketPayload { UnsubscribePacketPayload { subscribes: subs } } } impl Encodable for UnsubscribePacketPayload { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { for filter in self.subscribes.iter() { filter.encode(writer)?; } Ok(()) } fn encoded_length(&self) -> u32 { self.subscribes.iter().fold(0, |b, a| b + a.encoded_length()) } } impl Decodable for UnsubscribePacketPayload { type Error = UnsubscribePacketError; type Cond = u32; fn decode_with( reader: &mut R, mut payload_len: u32, ) -> Result { let mut subs = Vec::new(); while payload_len > 0 { let filter = TopicFilter::decode(reader)?; payload_len -= filter.encoded_length(); subs.push(filter); } Ok(UnsubscribePacketPayload::new(subs)) } } #[derive(Debug, thiserror::Error)] #[error(transparent)] pub enum UnsubscribePacketError { IoError(#[from] io::Error), FromUtf8Error(#[from] FromUtf8Error), TopicFilterError(#[from] TopicFilterError), } impl From for UnsubscribePacketError { fn from(e: TopicFilterDecodeError) -> Self { match e { TopicFilterDecodeError::IoError(e) => e.into(), TopicFilterDecodeError::InvalidTopicFilter(e) => e.into(), } } } ================================================ FILE: src/qos.rs ================================================ //! QoS (Quality of Services) use crate::packet::publish::QoSWithPacketIdentifier; #[repr(u8)] #[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)] pub enum QualityOfService { Level0 = 0, Level1 = 1, Level2 = 2, } impl From for QualityOfService { fn from(qos: QoSWithPacketIdentifier) -> Self { match qos { QoSWithPacketIdentifier::Level0 => QualityOfService::Level0, QoSWithPacketIdentifier::Level1(_) => QualityOfService::Level1, QoSWithPacketIdentifier::Level2(_) => QualityOfService::Level2, } } } #[cfg(test)] mod test { use super::*; use std::cmp::min; #[test] fn min_qos() { let q1 = QoSWithPacketIdentifier::Level1(0).into(); let q2 = QualityOfService::Level2; assert_eq!(min(q1, q2), q1); let q1 = QoSWithPacketIdentifier::Level0.into(); let q2 = QualityOfService::Level2; assert_eq!(min(q1, q2), q1); let q1 = QoSWithPacketIdentifier::Level2(0).into(); let q2 = QualityOfService::Level1; assert_eq!(min(q1, q2), q2); } } ================================================ FILE: src/topic_filter.rs ================================================ //! Topic filter use std::io::{self, Read, Write}; use std::ops::Deref; use crate::topic_name::TopicNameRef; use crate::{Decodable, Encodable}; #[inline] fn is_invalid_topic_filter(topic: &str) -> bool { if topic.is_empty() || topic.as_bytes().len() > 65535 { return true; } let mut found_hash = false; for member in topic.split('/') { if found_hash { return true; } match member { "#" => found_hash = true, "+" => {} _ => { if member.contains(['#', '+']) { return true; } } } } false } /// Topic filter /// /// /// /// ```rust /// use mqtt::{TopicFilter, TopicNameRef}; /// /// let topic_filter = TopicFilter::new("sport/+/player1").unwrap(); /// let matcher = topic_filter.get_matcher(); /// assert!(matcher.is_match(TopicNameRef::new("sport/abc/player1").unwrap())); /// ``` #[derive(Debug, Eq, PartialEq, Clone, Hash, Ord, PartialOrd)] pub struct TopicFilter(String); impl TopicFilter { /// Creates a new topic filter from string /// Return error if it is not a valid topic filter pub fn new>(topic: S) -> Result { let topic = topic.into(); if is_invalid_topic_filter(&topic) { Err(TopicFilterError(topic)) } else { Ok(TopicFilter(topic)) } } /// Creates a new topic filter from string without validation /// /// # Safety /// /// Topic filters' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). /// Creating a filter from raw string may cause errors pub unsafe fn new_unchecked>(topic: S) -> TopicFilter { TopicFilter(topic.into()) } } impl From for String { fn from(topic: TopicFilter) -> String { topic.0 } } impl Encodable for TopicFilter { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { (&self.0[..]).encode(writer) } fn encoded_length(&self) -> u32 { (&self.0[..]).encoded_length() } } impl Decodable for TopicFilter { type Error = TopicFilterDecodeError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { let topic_filter = String::decode(reader)?; Ok(TopicFilter::new(topic_filter)?) } } impl Deref for TopicFilter { type Target = TopicFilterRef; fn deref(&self) -> &TopicFilterRef { unsafe { TopicFilterRef::new_unchecked(&self.0) } } } /// Reference to a `TopicFilter` #[derive(Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] #[repr(transparent)] pub struct TopicFilterRef(str); impl TopicFilterRef { /// Creates a new topic filter from string /// Return error if it is not a valid topic filter pub fn new + ?Sized>(topic: &S) -> Result<&TopicFilterRef, TopicFilterError> { let topic = topic.as_ref(); if is_invalid_topic_filter(topic) { Err(TopicFilterError(topic.to_owned())) } else { Ok(unsafe { &*(topic as *const str as *const TopicFilterRef) }) } } /// Creates a new topic filter from string without validation /// /// # Safety /// /// Topic filters' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). /// Creating a filter from raw string may cause errors pub unsafe fn new_unchecked + ?Sized>(topic: &S) -> &TopicFilterRef { let topic = topic.as_ref(); &*(topic as *const str as *const TopicFilterRef) } /// Get a matcher pub fn get_matcher(&self) -> TopicFilterMatcher<'_> { TopicFilterMatcher::new(&self.0) } } impl Deref for TopicFilterRef { type Target = str; fn deref(&self) -> &str { &self.0 } } #[derive(Debug, thiserror::Error)] #[error("invalid topic filter ({0})")] pub struct TopicFilterError(pub String); /// Errors while parsing topic filters #[derive(Debug, thiserror::Error)] #[error(transparent)] pub enum TopicFilterDecodeError { IoError(#[from] io::Error), InvalidTopicFilter(#[from] TopicFilterError), } /// Matcher for matching topic names with this filter #[derive(Debug, Copy, Clone)] pub struct TopicFilterMatcher<'a> { topic_filter: &'a str, } impl<'a> TopicFilterMatcher<'a> { fn new(filter: &'a str) -> TopicFilterMatcher<'a> { TopicFilterMatcher { topic_filter: filter } } /// Check if this filter can match the `topic_name` pub fn is_match(&self, topic_name: &TopicNameRef) -> bool { let mut tn_itr = topic_name.split('/'); let mut ft_itr = self.topic_filter.split('/'); // The Server MUST NOT match Topic Filters starting with a wildcard character (# or +) // with Topic Names beginning with a $ character [MQTT-4.7.2-1]. let first_ft = ft_itr.next().unwrap(); let first_tn = tn_itr.next().unwrap(); if first_tn.starts_with('$') { if first_tn != first_ft { return false; } } else { match first_ft { // Matches the whole topic "#" => return true, "+" => {} _ => { if first_tn != first_ft { return false; } } } } loop { match (ft_itr.next(), tn_itr.next()) { (Some(ft), Some(tn)) => match ft { "#" => break, "+" => {} _ => { if ft != tn { return false; } } }, (Some(ft), None) => { if ft != "#" { return false; } else { break; } } (None, Some(..)) => return false, (None, None) => break, } } true } } #[cfg(test)] mod test { use super::*; #[test] fn topic_filter_validate() { let topic = "#".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "sport/tennis/player1".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "sport/tennis/player1/ranking".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "sport/tennis/player1/#".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "#".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "sport/tennis/#".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "sport/tennis#".to_owned(); assert!(TopicFilter::new(topic).is_err()); let topic = "sport/tennis/#/ranking".to_owned(); assert!(TopicFilter::new(topic).is_err()); let topic = "+".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "+/tennis/#".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "sport+".to_owned(); assert!(TopicFilter::new(topic).is_err()); let topic = "sport/+/player1".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "+/+".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "$SYS/#".to_owned(); TopicFilter::new(topic).unwrap(); let topic = "$SYS".to_owned(); TopicFilter::new(topic).unwrap(); } #[test] fn topic_filter_matcher() { let filter = TopicFilter::new("sport/#").unwrap(); let matcher = filter.get_matcher(); assert!(matcher.is_match(TopicNameRef::new("sport").unwrap())); let filter = TopicFilter::new("#").unwrap(); let matcher = filter.get_matcher(); assert!(matcher.is_match(TopicNameRef::new("sport").unwrap())); assert!(matcher.is_match(TopicNameRef::new("/").unwrap())); assert!(matcher.is_match(TopicNameRef::new("abc/def").unwrap())); assert!(!matcher.is_match(TopicNameRef::new("$SYS").unwrap())); assert!(!matcher.is_match(TopicNameRef::new("$SYS/abc").unwrap())); let filter = TopicFilter::new("+/monitor/Clients").unwrap(); let matcher = filter.get_matcher(); assert!(!matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap())); let filter = TopicFilter::new("$SYS/#").unwrap(); let matcher = filter.get_matcher(); assert!(matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap())); assert!(matcher.is_match(TopicNameRef::new("$SYS").unwrap())); let filter = TopicFilter::new("$SYS/monitor/+").unwrap(); let matcher = filter.get_matcher(); assert!(matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap())); } } ================================================ FILE: src/topic_name.rs ================================================ //! Topic name use std::{ borrow::{Borrow, BorrowMut}, io::{self, Read, Write}, ops::{Deref, DerefMut}, }; use crate::{Decodable, Encodable}; #[inline] fn is_invalid_topic_name(topic_name: &str) -> bool { topic_name.is_empty() || topic_name.as_bytes().len() > 65535 || topic_name.chars().any(|ch| ch == '#' || ch == '+') } /// Topic name /// /// #[derive(Debug, Eq, PartialEq, Clone, Hash, Ord, PartialOrd)] pub struct TopicName(String); impl TopicName { /// Creates a new topic name from string /// Return error if the string is not a valid topic name pub fn new>(topic_name: S) -> Result { let topic_name = topic_name.into(); if is_invalid_topic_name(&topic_name) { Err(TopicNameError(topic_name)) } else { Ok(TopicName(topic_name)) } } /// Creates a new topic name from string without validation /// /// # Safety /// /// Topic names' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). /// Creating a name from raw string may cause errors pub unsafe fn new_unchecked(topic_name: String) -> TopicName { TopicName(topic_name) } } impl From for String { fn from(topic_name: TopicName) -> String { topic_name.0 } } impl Deref for TopicName { type Target = TopicNameRef; fn deref(&self) -> &TopicNameRef { unsafe { TopicNameRef::new_unchecked(&self.0) } } } impl DerefMut for TopicName { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { TopicNameRef::new_mut_unchecked(&mut self.0) } } } impl Borrow for TopicName { fn borrow(&self) -> &TopicNameRef { Deref::deref(self) } } impl BorrowMut for TopicName { fn borrow_mut(&mut self) -> &mut TopicNameRef { DerefMut::deref_mut(self) } } impl Encodable for TopicName { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { (&self.0[..]).encode(writer) } fn encoded_length(&self) -> u32 { (&self.0[..]).encoded_length() } } impl Decodable for TopicName { type Error = TopicNameDecodeError; type Cond = (); fn decode_with(reader: &mut R, _rest: ()) -> Result { let topic_name = String::decode(reader)?; Ok(TopicName::new(topic_name)?) } } #[derive(Debug, thiserror::Error)] #[error("invalid topic filter ({0})")] pub struct TopicNameError(pub String); /// Errors while parsing topic names #[derive(Debug, thiserror::Error)] #[error(transparent)] pub enum TopicNameDecodeError { IoError(#[from] io::Error), InvalidTopicName(#[from] TopicNameError), } /// Reference to a topic name #[derive(Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] #[repr(transparent)] pub struct TopicNameRef(str); impl TopicNameRef { /// Creates a new topic name from string /// Return error if the string is not a valid topic name pub fn new + ?Sized>(topic_name: &S) -> Result<&TopicNameRef, TopicNameError> { let topic_name = topic_name.as_ref(); if is_invalid_topic_name(topic_name) { Err(TopicNameError(topic_name.to_owned())) } else { Ok(unsafe { &*(topic_name as *const str as *const TopicNameRef) }) } } /// Creates a new topic name from string /// Return error if the string is not a valid topic name pub fn new_mut + ?Sized>(topic_name: &mut S) -> Result<&mut TopicNameRef, TopicNameError> { let topic_name = topic_name.as_mut(); if is_invalid_topic_name(topic_name) { Err(TopicNameError(topic_name.to_owned())) } else { Ok(unsafe { &mut *(topic_name as *mut str as *mut TopicNameRef) }) } } /// Creates a new topic name from string without validation /// /// # Safety /// /// Topic names' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). /// Creating a name from raw string may cause errors pub unsafe fn new_unchecked + ?Sized>(topic_name: &S) -> &TopicNameRef { let topic_name = topic_name.as_ref(); &*(topic_name as *const str as *const TopicNameRef) } /// Creates a new topic name from string without validation /// /// # Safety /// /// Topic names' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). /// Creating a name from raw string may cause errors pub unsafe fn new_mut_unchecked + ?Sized>(topic_name: &mut S) -> &mut TopicNameRef { let topic_name = topic_name.as_mut(); &mut *(topic_name as *mut str as *mut TopicNameRef) } /// Check if this topic name is only for server. /// /// Topic names that beginning with a '$' character are reserved for servers pub fn is_server_specific(&self) -> bool { self.0.starts_with('$') } } impl Deref for TopicNameRef { type Target = str; fn deref(&self) -> &str { &self.0 } } impl ToOwned for TopicNameRef { type Owned = TopicName; fn to_owned(&self) -> Self::Owned { TopicName(self.0.to_owned()) } } impl Encodable for TopicNameRef { fn encode(&self, writer: &mut W) -> Result<(), io::Error> { (&self.0[..]).encode(writer) } fn encoded_length(&self) -> u32 { (&self.0[..]).encoded_length() } } #[cfg(test)] mod test { use super::*; #[test] fn topic_name_sys() { let topic_name = "$SYS".to_owned(); TopicName::new(topic_name).unwrap(); let topic_name = "$SYS/broker/connection/test.cosm-energy/state".to_owned(); TopicName::new(topic_name).unwrap(); } #[test] fn topic_name_slash() { TopicName::new("/").unwrap(); } #[test] fn topic_name_basic() { TopicName::new("/finance").unwrap(); TopicName::new("/finance//def").unwrap(); } }